!40691 Make log of NLLLossGrad matched api.

Merge pull request !40691 from TronZhang/change_nlllossgrad_log
This commit is contained in:
i-robot 2022-08-26 02:07:35 +00:00 committed by Gitee
commit 5be9a63386
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 12 additions and 12 deletions

View File

@ -49,15 +49,15 @@ void CheckNLLLossGradShapeValid(const std::string &prim_name, const ShapeVector
return;
}
CheckValueIn("x rank", x_shape.size(), {1, 2}, prim_name);
(void)CheckAndConvertUtils::CheckInteger("target rank", SizeToLong(t_shape.size()), kEqual, 1, prim_name);
CheckValueIn("logits rank", x_shape.size(), {1, 2}, prim_name);
(void)CheckAndConvertUtils::CheckInteger("labels rank", SizeToLong(t_shape.size()), kEqual, 1, prim_name);
(void)CheckAndConvertUtils::CheckInteger("weight rank", SizeToLong(w_shape.size()), kEqual, 1, prim_name);
if (x_shape.size() == 1) {
CheckAndConvertUtils::Check("target_shape", 1, kEqual, t_shape[0], prim_name);
CheckAndConvertUtils::Check("weight_shape", x_shape[0], kEqual, w_shape[0], prim_name);
CheckAndConvertUtils::Check("labels shape", t_shape[0], kEqual, 1, prim_name);
CheckAndConvertUtils::Check("weight shape", w_shape[0], kEqual, x_shape[0], prim_name);
} else {
CheckAndConvertUtils::Check("target_shape", x_shape[0], kEqual, t_shape[0], prim_name);
CheckAndConvertUtils::Check("weight_shape", x_shape[1], kEqual, w_shape[0], prim_name);
CheckAndConvertUtils::Check("labels shape", t_shape[0], kEqual, x_shape[0], prim_name);
CheckAndConvertUtils::Check("weight shape", w_shape[0], kEqual, x_shape[1], prim_name);
}
}
} // namespace
@ -119,12 +119,12 @@ TypePtr NLLLossGradInferType(const PrimitivePtr &primitive, const std::vector<Ab
auto t_dtype = input_args[kInputIndex2]->BuildType();
auto w_dtype = input_args[kInputIndex3]->BuildType();
auto tw_dtype = input_args[kInputIndex4]->BuildType();
(void)CheckAndConvertUtils::CheckTensorTypeValid("x_dtype", x_dtype, valid_types, prim_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("y_grad_dtype", y_grad_dtype, valid_types, prim_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("t_dtype", t_dtype, {kInt32}, prim_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("w_dtype", w_dtype, valid_types, prim_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("tw_dtype", tw_dtype, valid_types, prim_name);
CheckAndConvertUtils::Check("w_dtype", std::vector<TypeId>{tw_dtype->type_id()}, kEqual,
(void)CheckAndConvertUtils::CheckTensorTypeValid("logits dtype", x_dtype, valid_types, prim_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("loss's grad dtype", y_grad_dtype, valid_types, prim_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("labels dtype", t_dtype, {kInt32}, prim_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("weight dtype", w_dtype, valid_types, prim_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("total_weight dtype", tw_dtype, valid_types, prim_name);
CheckAndConvertUtils::Check("weight dtype", std::vector<TypeId>{tw_dtype->type_id()}, kEqual,
std::vector<TypeId>{w_dtype->type_id()}, prim_name);
return x_dtype;
}