nllloss bugfix for null input

This commit is contained in:
zhangyanhui 2023-02-12 10:20:18 +08:00
parent c29a919b72
commit c70beba908
1 changed files with 11 additions and 0 deletions

View File

@ -32,11 +32,18 @@ class NLLLossInfer : public abstract::OpInferBase {
public:
BaseShapePtr InferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) const override {
MS_EXCEPTION_IF_NULL(primitive);
const auto prim_name = primitive->name();
const int64_t input_num = 3;
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num,
prim_name);
MS_EXCEPTION_IF_NULL(input_args[kInputIndex0]);
auto logits_shape_ptr = input_args[kInputIndex0]->BuildShape();
auto logits_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(logits_shape_ptr)[kShape];
MS_EXCEPTION_IF_NULL(input_args[kInputIndex1]);
auto target_shape_ptr = input_args[kInputIndex1]->BuildShape();
auto target_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(target_shape_ptr)[kShape];
MS_EXCEPTION_IF_NULL(input_args[kInputIndex2]);
auto weight_shape_ptr = input_args[kInputIndex2]->BuildShape();
auto weight_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(weight_shape_ptr)[kShape];
@ -83,6 +90,10 @@ class NLLLossInfer : public abstract::OpInferBase {
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) const override {
const std::set valid_types = {kFloat16, kFloat32};
MS_EXCEPTION_IF_NULL(prim);
auto prim_name = prim->name();
const int64_t input_num = 3;
CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
auto logits_data_type = input_args[kIndex0]->BuildType();
auto target_type = input_args[kIndex1]->BuildType();
auto weight_data_type = input_args[kIndex2]->BuildType();