diff --git a/mindspore/core/ops/one_hot.cc b/mindspore/core/ops/one_hot.cc index 1e17437d66f..327ba44f160 100644 --- a/mindspore/core/ops/one_hot.cc +++ b/mindspore/core/ops/one_hot.cc @@ -73,7 +73,9 @@ TypePtr OneHotInferType(const PrimitivePtr &prim, const std::vector &input_args) { - return abstract::MakeAbstract(OneHotInferShape(primitive, input_args), OneHotInferType(primitive, input_args)); + auto infer_type = OneHotInferType(primitive, input_args); + auto infer_shape = OneHotInferShape(primitive, input_args); + return abstract::MakeAbstract(infer_shape, infer_type); } REGISTER_PRIMITIVE_EVAL_IMPL(OneHot, prim::kPrimOneHot, OneHotInfer, nullptr, true); } // namespace ops