diff --git a/mindspore/core/ops/mul.cc b/mindspore/core/ops/mul.cc index 888ea785f44..f781fd4b693 100644 --- a/mindspore/core/ops/mul.cc +++ b/mindspore/core/ops/mul.cc @@ -49,7 +49,8 @@ TypePtr MulInferType(const PrimitivePtr &prim, const std::vector types; (void)types.emplace("x", input_args[0]->BuildType()); (void)types.emplace("y", input_args[1]->BuildType()); - return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name()); + (void)CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name()); + return input_args[0]->BuildType(); } } // namespace diff --git a/mindspore/core/ops/not_equal.cc b/mindspore/core/ops/not_equal.cc index 220a5aa9187..27b7a52ceb3 100644 --- a/mindspore/core/ops/not_equal.cc +++ b/mindspore/core/ops/not_equal.cc @@ -33,7 +33,8 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector & std::map types; (void)types.emplace("x", input_args[0]->BuildType()); (void)types.emplace("y", input_args[1]->BuildType()); - return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, op_name); + (void)CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, op_name); + return std::make_shared(kBool); } } // namespace @@ -47,8 +48,8 @@ AbstractBasePtr NotEqualInfer(const abstract::AnalysisEnginePtr &, const Primiti for (const auto &item : input_args) { MS_EXCEPTION_IF_NULL(item); } - (void)InferType(primitive, input_args); - return abstract::MakeAbstract(BroadCastInferShape(op_name, input_args), kBool); + auto type = InferType(primitive, input_args); + return abstract::MakeAbstract(BroadCastInferShape(op_name, input_args), type); } REGISTER_PRIMITIVE_C(kNameNotEqual, NotEqual); } // namespace ops diff --git a/mindspore/core/ops/real_div.cc b/mindspore/core/ops/real_div.cc index c21118fb3bf..417e8136e83 100644 --- a/mindspore/core/ops/real_div.cc +++ b/mindspore/core/ops/real_div.cc @@ -49,7 +49,8 @@ TypePtr RealDivInferType(const PrimitivePtr &prim, const std::vector types; (void)types.emplace("x", input_args[0]->BuildType()); (void)types.emplace("y", input_args[1]->BuildType()); - return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name()); + (void)CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name()); + return input_args[0]->BuildType(); } } // namespace