forked from mindspore-Ecosystem/mindspore
!40691 Make log of NLLLossGrad matched api.
Merge pull request !40691 from TronZhang/change_nlllossgrad_log
This commit is contained in:
commit
5be9a63386
|
@ -49,15 +49,15 @@ void CheckNLLLossGradShapeValid(const std::string &prim_name, const ShapeVector
|
|||
return;
|
||||
}
|
||||
|
||||
CheckValueIn("x rank", x_shape.size(), {1, 2}, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("target rank", SizeToLong(t_shape.size()), kEqual, 1, prim_name);
|
||||
CheckValueIn("logits rank", x_shape.size(), {1, 2}, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("labels rank", SizeToLong(t_shape.size()), kEqual, 1, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("weight rank", SizeToLong(w_shape.size()), kEqual, 1, prim_name);
|
||||
if (x_shape.size() == 1) {
|
||||
CheckAndConvertUtils::Check("target_shape", 1, kEqual, t_shape[0], prim_name);
|
||||
CheckAndConvertUtils::Check("weight_shape", x_shape[0], kEqual, w_shape[0], prim_name);
|
||||
CheckAndConvertUtils::Check("labels shape", t_shape[0], kEqual, 1, prim_name);
|
||||
CheckAndConvertUtils::Check("weight shape", w_shape[0], kEqual, x_shape[0], prim_name);
|
||||
} else {
|
||||
CheckAndConvertUtils::Check("target_shape", x_shape[0], kEqual, t_shape[0], prim_name);
|
||||
CheckAndConvertUtils::Check("weight_shape", x_shape[1], kEqual, w_shape[0], prim_name);
|
||||
CheckAndConvertUtils::Check("labels shape", t_shape[0], kEqual, x_shape[0], prim_name);
|
||||
CheckAndConvertUtils::Check("weight shape", w_shape[0], kEqual, x_shape[1], prim_name);
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
@ -119,12 +119,12 @@ TypePtr NLLLossGradInferType(const PrimitivePtr &primitive, const std::vector<Ab
|
|||
auto t_dtype = input_args[kInputIndex2]->BuildType();
|
||||
auto w_dtype = input_args[kInputIndex3]->BuildType();
|
||||
auto tw_dtype = input_args[kInputIndex4]->BuildType();
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("x_dtype", x_dtype, valid_types, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("y_grad_dtype", y_grad_dtype, valid_types, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("t_dtype", t_dtype, {kInt32}, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("w_dtype", w_dtype, valid_types, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("tw_dtype", tw_dtype, valid_types, prim_name);
|
||||
CheckAndConvertUtils::Check("w_dtype", std::vector<TypeId>{tw_dtype->type_id()}, kEqual,
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("logits dtype", x_dtype, valid_types, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("loss's grad dtype", y_grad_dtype, valid_types, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("labels dtype", t_dtype, {kInt32}, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("weight dtype", w_dtype, valid_types, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("total_weight dtype", tw_dtype, valid_types, prim_name);
|
||||
CheckAndConvertUtils::Check("weight dtype", std::vector<TypeId>{tw_dtype->type_id()}, kEqual,
|
||||
std::vector<TypeId>{w_dtype->type_id()}, prim_name);
|
||||
return x_dtype;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue