From 68f3b5784f7fa4b7945730b92493b5b977fe3e85 Mon Sep 17 00:00:00 2001 From: lianliguang Date: Mon, 31 May 2021 14:32:41 +0800 Subject: [PATCH] add a process of make type and shape to be abstract --- .../pipeline/jit/static_analysis/prim.cc | 36 +++++--- mindspore/ccsrc/utils/convert_utils_py.cc | 36 +++----- mindspore/ccsrc/utils/convert_utils_py.h | 4 +- mindspore/core/abstract/utils.cc | 92 +++++++++++++++++++ mindspore/core/abstract/utils.h | 4 +- mindspore/core/ops/add.cc | 3 +- mindspore/core/ops/addn.cc | 9 +- mindspore/core/ops/avg_pool_3d.cc | 3 +- mindspore/core/ops/batch_matmul.cc | 5 +- mindspore/core/ops/broadcast_to.cc | 4 +- mindspore/core/ops/conv2d.cc | 3 +- mindspore/core/ops/ctcloss.cc | 20 ++-- mindspore/core/ops/dropout_do_mask.cc | 3 +- mindspore/core/ops/dropout_gen_mask.cc | 3 +- mindspore/core/ops/dtype.cc | 3 +- mindspore/core/ops/gather.cc | 6 +- mindspore/core/ops/gather_d.cc | 4 +- mindspore/core/ops/gelu.cc | 3 +- mindspore/core/ops/grad/bias_add_grad.cc | 3 +- mindspore/core/ops/grad/layer_norm_grad.cc | 9 +- mindspore/core/ops/grad/relu_grad.cc | 3 +- mindspore/core/ops/grad/relu_grad_v2.cc | 3 +- mindspore/core/ops/layer_norm.cc | 17 ++-- mindspore/core/ops/log1p.cc | 8 +- mindspore/core/ops/log_softmax.cc | 4 +- mindspore/core/ops/mat_mul.cc | 3 +- mindspore/core/ops/mul.cc | 3 +- mindspore/core/ops/one_hot.cc | 3 +- mindspore/core/ops/ones_like.cc | 3 +- mindspore/core/ops/real_div.cc | 3 +- mindspore/core/ops/relu.cc | 5 +- mindspore/core/ops/relu6.cc | 3 +- mindspore/core/ops/reluv2.cc | 15 ++- mindspore/core/ops/scalar_summary.cc | 2 +- mindspore/core/ops/softmax.cc | 3 +- mindspore/core/ops/softplus.cc | 4 +- mindspore/core/ops/sub.cc | 3 +- mindspore/core/ops/tensor_summary.cc | 2 +- mindspore/core/ops/tile.cc | 3 +- mindspore/core/ops/zeros.cc | 4 +- mindspore/core/ops/zeros_like.cc | 6 +- 41 files changed, 205 insertions(+), 148 deletions(-) diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc index bc1593e16e5..9321316e74c 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc @@ -505,12 +505,33 @@ py::tuple PreparePyInputs(const PrimitivePyPtr &prim_py, const AbstractBasePtrLi return py_args; } +void CheckCustomPrimOutputInferResult(const PrimitivePtr &prim, const AbstractBasePtr &res_spec) { + const string kOutputNum = "output_num"; + if (prim->IsCustomPrim()) { + // Raise error if output_num is not match the infer result. + auto output_num_value = prim->GetAttr(kOutputNum); + if (output_num_value == nullptr) { + MS_LOG(DEBUG) << "The output num may no need to check"; + return; + } + int64_t output_num = GetValue(output_num_value); + if (res_spec->isa() && output_num != 1) { + MS_LOG(EXCEPTION) << "Custom primitive " << prim->ToString() << " output_num " << output_num + << " not matches the infer result."; + } else if (res_spec->isa() && + (res_spec->cast()->size() != LongToSize(output_num))) { + MS_LOG(EXCEPTION) << "Custom primitive " << prim->ToString() << " output_num " << output_num + << " not matches the infer result."; + } + } +} + AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dict &output) { // Convert to AbstractValue based on type and shape auto out_dtype = output[ATTR_DTYPE]; if (output[ATTR_VALUE].is_none()) { auto out_shape = output[ATTR_SHAPE]; - return PyListDtype2AbstractTensor(out_shape, out_dtype, output); + return MakePyInferRes2Abstract(out_shape, out_dtype, output); } // Convert pyobject to Value, then to AbstractValue ValuePtr converted_ret = nullptr; @@ -527,18 +548,7 @@ AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dic res_tensor->set_value(converted_ret); SetValueRange(res_tensor, output); } - if (prim_py->IsCustomPrim()) { - // Raise error if output_num is not match the infer result. - int64_t output_num = GetValue(prim_py->GetAttr("output_num")); - if (res_spec->isa() && output_num != 1) { - MS_LOG(EXCEPTION) << "Custom primitive " << prim_py->ToString() << " output_num " << output_num - << " not matches the infer result."; - } else if (res_spec->isa() && - (res_spec->cast()->size() != LongToSize(output_num))) { - MS_LOG(EXCEPTION) << "Custom primitive " << prim_py->ToString() << " output_num " << output_num - << " not matches the infer result."; - } - } + CheckCustomPrimOutputInferResult(prim_py, res_spec); return res_spec; } } // end anonymous namespace diff --git a/mindspore/ccsrc/utils/convert_utils_py.cc b/mindspore/ccsrc/utils/convert_utils_py.cc index d59ea447e96..3ac445048fd 100644 --- a/mindspore/ccsrc/utils/convert_utils_py.cc +++ b/mindspore/ccsrc/utils/convert_utils_py.cc @@ -25,6 +25,7 @@ #include #include "abstract/abstract_value.h" +#include "abstract/utils.h" #include "pipeline/jit/parse/parse.h" #include "pipeline/jit/parse/parse_base.h" #include "ir/value.h" @@ -357,9 +358,8 @@ void SetValueRange(const AbstractBasePtr &tensor, const py::object &output) { } } -AbstractBasePtr PyList2DynamicShapeTensor(const py::object &shape_obj, const py::object &type_obj, - const py::object &output) { - AbstractBasePtr tensor = nullptr; +AbstractBasePtr MakePyInferRes2AbstractTensor(const py::object &shape_obj, const py::object &type_obj, + const py::object &output) { auto ret_vec = shape_obj.cast(); auto ret_dtype = type_obj.cast(); ShapeVector min_shape_vec; @@ -379,15 +379,7 @@ AbstractBasePtr PyList2DynamicShapeTensor(const py::object &shape_obj, const py: } auto ret_shape = std::make_shared(ret_vec, min_shape_vec, max_shape_vec); - if (ret_dtype->isa()) { - auto tensor_type = type_obj.cast(); - MS_EXCEPTION_IF_NULL(tensor_type); - auto element = std::make_shared(kAnyValue, tensor_type->element()); - tensor = std::make_shared(element, ret_shape); - } else { - auto element = std::make_shared(kAnyValue, ret_dtype); - tensor = std::make_shared(element, ret_shape); - } + AbstractBasePtr tensor = MakeAbstractTensor(ret_shape, ret_dtype); SetValueRange(tensor, output); return tensor; @@ -404,18 +396,16 @@ static bool IsMonadType(const py::object &type_obj) { static AbstractBasePtr ToMonadAbstract(const py::object &type_obj) { if (py::isinstance(type_obj)) { auto type = type_obj.cast(); - if (type->isa()) { - return kUMonad->ToAbstract(); - } - if (type->isa()) { - return kIOMonad->ToAbstract(); + if (!type->isa()) { + MS_LOG(EXCEPTION) << "Not a monad type object: " << py::str(type_obj); } + return abstract::MakeMonadAbstract(type->cast()); } - MS_LOG(EXCEPTION) << "Not a monad type object: " << py::str(type_obj); + MS_LOG(EXCEPTION) << "Not a type object: " << py::str(type_obj); } -AbstractBasePtr PyListDtype2AbstractTensor(const py::object &shape_obj, const py::object &type_obj, - const py::object &output) { +AbstractBasePtr MakePyInferRes2Abstract(const py::object &shape_obj, const py::object &type_obj, + const py::object &output) { if ((py::isinstance(shape_obj) || py::isinstance(shape_obj)) && py::isinstance(type_obj)) { auto ret_vec = shape_obj.cast(); auto ret_dtype = type_obj.cast(); @@ -425,13 +415,13 @@ AbstractBasePtr PyListDtype2AbstractTensor(const py::object &shape_obj, const py abstract::AbstractScalarPtr abs_scalar = std::make_shared(kAnyValue, ret_dtype); return abs_scalar; } - return PyList2DynamicShapeTensor(shape_obj, type_obj, output); + return MakePyInferRes2AbstractTensor(shape_obj, type_obj, output); } else if (py::isinstance(shape_obj) && py::isinstance(type_obj)) { auto shape_tuple = shape_obj.cast(); auto typeid_tuple = type_obj.cast(); AbstractBasePtrList ptr_list; for (size_t it = 0; it < shape_tuple.size(); ++it) { - auto tensor_it = PyListDtype2AbstractTensor(shape_tuple[it], typeid_tuple[it]); + auto tensor_it = MakePyInferRes2Abstract(shape_tuple[it], typeid_tuple[it]); ptr_list.push_back(tensor_it); } auto tuple = std::make_shared(ptr_list); @@ -441,7 +431,7 @@ AbstractBasePtr PyListDtype2AbstractTensor(const py::object &shape_obj, const py auto typeid_list = type_obj.cast(); AbstractBasePtrList ptr_list; for (size_t it = 0; it < shape_list.size(); ++it) { - auto tensor_it = PyListDtype2AbstractTensor(shape_list[it], typeid_list[it]); + auto tensor_it = MakePyInferRes2Abstract(shape_list[it], typeid_list[it]); ptr_list.push_back(tensor_it); } auto list = std::make_shared(ptr_list); diff --git a/mindspore/ccsrc/utils/convert_utils_py.h b/mindspore/ccsrc/utils/convert_utils_py.h index 703d5cb6053..d4748e0b244 100644 --- a/mindspore/ccsrc/utils/convert_utils_py.h +++ b/mindspore/ccsrc/utils/convert_utils_py.h @@ -36,8 +36,8 @@ py::object ValuePtrToPyData(const ValuePtr &value); bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &output, const py::tuple &args, const std::shared_ptr &ret_val); -AbstractBasePtr PyListDtype2AbstractTensor(const py::object &shape_obj, const py::object &type_obj, - const py::object &output = py::none()); +AbstractBasePtr MakePyInferRes2Abstract(const py::object &shape_obj, const py::object &type_obj, + const py::object &output = py::none()); void SetValueRange(const AbstractBasePtr &tensor, const py::object &output); } // namespace mindspore diff --git a/mindspore/core/abstract/utils.cc b/mindspore/core/abstract/utils.cc index 404b06ae03d..ef7af7f9dd4 100644 --- a/mindspore/core/abstract/utils.cc +++ b/mindspore/core/abstract/utils.cc @@ -21,6 +21,7 @@ #include #include #include +#include "utils/ms_context.h" #include "utils/symbolic.h" #include "abstract/param_validator.h" @@ -348,5 +349,96 @@ int64_t GetUnsortedSegmentOpScalarArg(const AbstractBasePtrList &args_spec_list, } return num_segments_value; } + +AbstractBasePtr MakeAbstractTensor(const ShapePtr &shape, const TypePtr &type) { + MS_EXCEPTION_IF_NULL(shape); + MS_EXCEPTION_IF_NULL(type); + AbstractBasePtr tensor = nullptr; + auto ret_vec = shape->shape(); + ShapeVector min_shape_vec; + ShapeVector max_shape_vec; + + if (shape->IsDynamic()) { + if (!shape->min_shape().empty()) { + min_shape_vec = shape->min_shape(); + } + if (!shape->max_shape().empty()) { + max_shape_vec = shape->max_shape(); + } + } + + auto ret_shape = std::make_shared(ret_vec, min_shape_vec, max_shape_vec); + if (type->isa()) { + auto tensor_type = type->cast(); + MS_EXCEPTION_IF_NULL(tensor_type); + auto element = std::make_shared(kAnyValue, tensor_type->element()); + tensor = std::make_shared(element, ret_shape); + } else { + auto element = std::make_shared(kAnyValue, type); + tensor = std::make_shared(element, ret_shape); + } + return tensor; +} + +AbstractBasePtr MakeMonadAbstract(const MonadTypePtr &type) { + if (type->isa()) { + return kUMonad->ToAbstract(); + } else if (type->isa()) { + return kIOMonad->ToAbstract(); + } + MS_EXCEPTION(UnknownError) << "Unsupported to convert type " << type->ToString() << " to monad abstract"; +} + +AbstractBasePtr MakeAbstract(const BaseShapePtr &base_shape, const TypePtr &type) { + MS_EXCEPTION_IF_NULL(base_shape); + MS_EXCEPTION_IF_NULL(type); + if ((base_shape->isa())) { + auto shape = base_shape->cast(); + MS_EXCEPTION_IF_NULL(shape); + auto shape_vec = shape->shape(); + // if the size of shape list is empty, return an scalar abstract + if (shape_vec.empty() && (!type->isa())) { + abstract::AbstractScalarPtr abs_scalar = std::make_shared(kAnyValue, type); + return abs_scalar; + } + return MakeAbstractTensor(shape, type); + } else if (base_shape->isa() && type->isa()) { + auto shape_tuple = base_shape->cast(); + auto type_tuple = type->cast(); + AbstractBasePtrList ptr_list; + for (size_t it = 0; it < shape_tuple->size(); ++it) { + auto tensor_it = MakeAbstract((*shape_tuple)[it], (*type_tuple)[it]); + ptr_list.push_back(tensor_it); + } + auto tuple = std::make_shared(ptr_list); + return tuple; + } else if (base_shape->isa() && type->isa()) { + auto shape_list = base_shape->cast(); + auto type_list = type->cast(); + AbstractBasePtrList ptr_list; + for (size_t it = 0; it < shape_list->size(); ++it) { + auto tensor_it = MakeAbstract((*shape_list)[it], (*type_list)[it]); + ptr_list.push_back(tensor_it); + } + auto list = std::make_shared(ptr_list); + return list; + } else if (base_shape->isa() && type->isa()) { + // AbstractNone indicates there is no output for this CNode node. + auto abstract_none = std::make_shared(); + return abstract_none; + } else if (type->isa()) { + // Return monad abstract if it is monad type. + return MakeMonadAbstract(type->cast()); + } else { + // When sparse enabled, the undetermined might be raised and eliminated in opt passes + auto context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context); + bool enable_sparse = context->get_param(MS_CTX_ENABLE_SPARSE); + if (enable_sparse) { + return std::make_shared(); + } + MS_LOG(EXCEPTION) << "evaluator return invalid shape " << base_shape->ToString() << "or type. " << type->ToString(); + } +} } // namespace abstract } // namespace mindspore diff --git a/mindspore/core/abstract/utils.h b/mindspore/core/abstract/utils.h index 8220294f7ef..340278a6a89 100644 --- a/mindspore/core/abstract/utils.h +++ b/mindspore/core/abstract/utils.h @@ -62,7 +62,9 @@ void CheckMinMaxShape(const ShapeVector &shape, ShapeVector *min_shape, ShapeVec // Get 3rd argument for UnsortedSegmentOps' inferImpl function int64_t GetUnsortedSegmentOpScalarArg(const AbstractBasePtrList &args_spec_list, const std::string &op_name); - +AbstractBasePtr MakeAbstract(const BaseShapePtr &base_shape, const TypePtr &type); +AbstractBasePtr MakeMonadAbstract(const MonadTypePtr &type); +AbstractBasePtr MakeAbstractTensor(const ShapePtr &shape, const TypePtr &type); } // namespace abstract } // namespace mindspore #endif // MINDSPORE_CORE_ABSTRACT_UTILS_H_ diff --git a/mindspore/core/ops/add.cc b/mindspore/core/ops/add.cc index 6ebf970fbe0..c33f1d66f43 100644 --- a/mindspore/core/ops/add.cc +++ b/mindspore/core/ops/add.cc @@ -50,8 +50,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector & AbstractBasePtr AddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { - return std::make_shared(InferType(primitive, input_args), - InferShape(primitive, input_args)->shape()); + return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args)); } REGISTER_PRIMITIVE_C(kNameAdd, Add); } // namespace ops diff --git a/mindspore/core/ops/addn.cc b/mindspore/core/ops/addn.cc index f9d9c3683d8..c5176371c78 100644 --- a/mindspore/core/ops/addn.cc +++ b/mindspore/core/ops/addn.cc @@ -72,7 +72,8 @@ TypePtr AddNInferType(const PrimitivePtr &prim, const std::vector valid_types = common_valid_types; valid_types.insert(kBool); - return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim_name); + CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim_name); + return elements[0]->BuildType(); } } // namespace AbstractBasePtr AddNInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, @@ -83,11 +84,7 @@ AbstractBasePtr AddNInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt for (const auto &item : input_args) { MS_EXCEPTION_IF_NULL(item); } - auto abs_type = AddNInferType(primitive, input_args); - if (abs_type->type_id() == kObjectTypeUndeterminedType) { - return std::make_shared(); - } - return std::make_shared(abs_type, AddNInferShape(primitive, input_args)); + return abstract::MakeAbstract(AddNInferShape(primitive, input_args), AddNInferType(primitive, input_args)); } REGISTER_PRIMITIVE_EVAL_IMPL(AddN, prim::kPrimAddN, AddNInfer, nullptr, true); } // namespace ops diff --git a/mindspore/core/ops/avg_pool_3d.cc b/mindspore/core/ops/avg_pool_3d.cc index 5a4edccf568..96eec734eab 100644 --- a/mindspore/core/ops/avg_pool_3d.cc +++ b/mindspore/core/ops/avg_pool_3d.cc @@ -172,8 +172,7 @@ TypePtr InferType(const PrimitivePtr &primitive, const std::vector &input_args) { - return std::make_shared(InferType(primitive, input_args), - InferShape(primitive, input_args)->shape()); + return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args)); } REGISTER_PRIMITIVE_EVAL_IMPL(AvgPool3D, prim::kPrimAvgPool3D, AvgPool3DInfer, nullptr, true); diff --git a/mindspore/core/ops/batch_matmul.cc b/mindspore/core/ops/batch_matmul.cc index 8a4c520c2d0..d9798f13048 100644 --- a/mindspore/core/ops/batch_matmul.cc +++ b/mindspore/core/ops/batch_matmul.cc @@ -145,9 +145,8 @@ AbstractBasePtr BatchMatmulInfer(const abstract::AnalysisEnginePtr &, const Prim const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); CheckAndConvertUtils::CheckInteger("BatchMatmul infer", input_args.size(), kGreaterEqual, 2, primitive->name()); - auto abs = std::make_shared(BatchMatmulInferType(primitive, input_args), - BatchMatmulInferShape(primitive, input_args)); - return abs; + return abstract::MakeAbstract(BatchMatmulInferShape(primitive, input_args), + BatchMatmulInferType(primitive, input_args)); } REGISTER_PRIMITIVE_EVAL_IMPL(BatchMatmul, prim::kPrimBatchMatMul, BatchMatmulInfer, nullptr, true); } // namespace ops diff --git a/mindspore/core/ops/broadcast_to.cc b/mindspore/core/ops/broadcast_to.cc index 52b202b2ff9..353bd78711c 100644 --- a/mindspore/core/ops/broadcast_to.cc +++ b/mindspore/core/ops/broadcast_to.cc @@ -84,8 +84,8 @@ std::vector BroadcastTo::get_shape() const { } AbstractBasePtr BroadcastToInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { - return std::make_shared(BroadcastToInferType(primitive, input_args), - BroadcastToInferShape(primitive, input_args)); + return abstract::MakeAbstract(BroadcastToInferShape(primitive, input_args), + BroadcastToInferType(primitive, input_args)); } REGISTER_PRIMITIVE_EVAL_IMPL(BroadcastTo, prim::kPrimBroadcastTo, BroadcastToInfer, nullptr, true); } // namespace ops diff --git a/mindspore/core/ops/conv2d.cc b/mindspore/core/ops/conv2d.cc index 737cf2c9af9..93f32942ce3 100644 --- a/mindspore/core/ops/conv2d.cc +++ b/mindspore/core/ops/conv2d.cc @@ -313,8 +313,7 @@ Format Conv2D::get_format() const { AbstractBasePtr Conv2dInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { CheckAndConvertUtils::CheckInteger("Conv2d infer", input_args.size(), kGreaterEqual, 2, primitive->name()); - return std::make_shared(Conv2dInferType(primitive, input_args), - Conv2dInferShape(primitive, input_args)); + return abstract::MakeAbstract(Conv2dInferShape(primitive, input_args), Conv2dInferType(primitive, input_args)); } REGISTER_PRIMITIVE_EVAL_IMPL(Conv2D, prim::kPrimConv2D, Conv2dInfer, nullptr, true); } // namespace ops diff --git a/mindspore/core/ops/ctcloss.cc b/mindspore/core/ops/ctcloss.cc index 919b0a13625..fe0b3621f30 100644 --- a/mindspore/core/ops/ctcloss.cc +++ b/mindspore/core/ops/ctcloss.cc @@ -56,8 +56,7 @@ void CheckCTCLossInputs(const std::vector &input_args, const st } } -std::vector InferShape(const PrimitivePtr &primitive, - const std::vector &input_args) { +abstract::TupleShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { auto op_name = primitive->name(); CheckCTCLossInputs(input_args, op_name); auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape()); @@ -71,36 +70,33 @@ std::vector InferShape(const PrimitivePtr &primitive, if (min_shape.empty() || max_shape.empty()) { loss_shape = std::make_shared(batch); gradient_shape = std::make_shared(shape); - return std::vector{loss_shape, gradient_shape}; + return std::make_shared(std::vector{loss_shape, gradient_shape}); } ShapeVector batch_min = {min_shape[1]}; ShapeVector batch_max = {max_shape[1]}; loss_shape = std::make_shared(batch, batch_min, batch_max); gradient_shape = std::make_shared(shape, min_shape, max_shape); - return std::vector{loss_shape, gradient_shape}; + return std::make_shared(std::vector{loss_shape, gradient_shape}); } -TypePtr InferType(const PrimitivePtr &primitive, const std::vector &input_args) { +TuplePtr InferType(const PrimitivePtr &primitive, const std::vector &input_args) { auto op_name = primitive->name(); CheckAndConvertUtils::CheckTensorTypeValid("labels_indices", input_args[1]->BuildType(), {kInt64}, op_name); CheckAndConvertUtils::CheckTensorTypeValid("labels_values", input_args[2]->BuildType(), {kInt32}, op_name); CheckAndConvertUtils::CheckTensorTypeValid("sequence_length", input_args[3]->BuildType(), {kInt32}, op_name); const std::set valid_types = {kFloat16, kFloat32, kFloat64}; - return CheckAndConvertUtils::CheckTensorTypeValid("inputs", input_args[0]->BuildType(), valid_types, op_name); + auto type = CheckAndConvertUtils::CheckTensorTypeValid("inputs", input_args[0]->BuildType(), valid_types, op_name); + return std::make_shared(std::vector{type, type}); } - } // namespace AbstractBasePtr CTCLossInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto type = InferType(primitive, input_args); + auto types = InferType(primitive, input_args); auto shapes = InferShape(primitive, input_args); - auto loss = std::make_shared(type, shapes[0]); - auto gradient = std::make_shared(type, shapes[1]); - AbstractBasePtrList outputs = {loss, gradient}; - return std::make_shared(outputs); + return abstract::MakeAbstract(shapes, types); } REGISTER_PRIMITIVE_EVAL_IMPL(CTCLoss, prim::kPrimCTCLoss, CTCLossInfer, nullptr, true); } // namespace ops diff --git a/mindspore/core/ops/dropout_do_mask.cc b/mindspore/core/ops/dropout_do_mask.cc index 1678afcb81b..2548c92e948 100644 --- a/mindspore/core/ops/dropout_do_mask.cc +++ b/mindspore/core/ops/dropout_do_mask.cc @@ -116,8 +116,7 @@ AbstractBasePtr DropoutDoMaskInfer(const abstract::AnalysisEnginePtr &, const Pr const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); CheckAndConvertUtils::CheckInteger("infer shape", input_args.size(), kGreaterEqual, 3, primitive->name()); - return std::make_shared(InferType(primitive, input_args), - InferShape(primitive, input_args)); + return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args)); } REGISTER_PRIMITIVE_EVAL_IMPL(DropoutDoMask, prim::kPrimDropoutDoMask, DropoutDoMaskInfer, nullptr, true); } // namespace ops diff --git a/mindspore/core/ops/dropout_gen_mask.cc b/mindspore/core/ops/dropout_gen_mask.cc index 6fdf27c191c..c51d304e1e9 100644 --- a/mindspore/core/ops/dropout_gen_mask.cc +++ b/mindspore/core/ops/dropout_gen_mask.cc @@ -147,8 +147,7 @@ TypePtr InferType(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - return std::make_shared(InferType(primitive, input_args), - InferShape(primitive, input_args)); + return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args)); } REGISTER_PRIMITIVE_EVAL_IMPL(DropoutGenMask, prim::kPrimDropoutGenMask, DropoutGenMaskInfer, nullptr, true); } // namespace ops diff --git a/mindspore/core/ops/dtype.cc b/mindspore/core/ops/dtype.cc index f24c607fd4a..f8f598b182e 100644 --- a/mindspore/core/ops/dtype.cc +++ b/mindspore/core/ops/dtype.cc @@ -44,8 +44,7 @@ AbstractBasePtr DTypeInfer(const abstract::AnalysisEnginePtr &, const PrimitiveP MS_EXCEPTION_IF_NULL(value); auto type = value->cast(); MS_EXCEPTION_IF_NULL(type); - auto abstract = std::make_shared(type); - return abstract; + return abstract::MakeAbstract(std::make_shared(), type); } REGISTER_PRIMITIVE_EVAL_IMPL(DType, prim::kPrimDType, DTypeInfer, DTypeInferValue, false); diff --git a/mindspore/core/ops/gather.cc b/mindspore/core/ops/gather.cc index c6d9e5993db..a2f42e335e6 100644 --- a/mindspore/core/ops/gather.cc +++ b/mindspore/core/ops/gather.cc @@ -84,10 +84,10 @@ AbstractBasePtr GatherInfer(const abstract::AnalysisEnginePtr &, const Primitive if (ind_dyn || param_dyn) { ShapeVector min_shape = calc_shape(indices_shp_min, param_shp_min); ShapeVector max_shape = calc_shape(indices_shp_max, param_shp_max); - return std::make_shared( - params->element(), std::make_shared(out_shape, min_shape, max_shape)); + return abstract::MakeAbstract(std::make_shared(out_shape, min_shape, max_shape), + params->BuildType()); } - return std::make_shared(params->element(), std::make_shared(out_shape)); + return abstract::MakeAbstract(std::make_shared(out_shape), params->BuildType()); } REGISTER_PRIMITIVE_EVAL_IMPL(Gather, prim::kPrimGather, GatherInfer, nullptr, true); } // namespace ops diff --git a/mindspore/core/ops/gather_d.cc b/mindspore/core/ops/gather_d.cc index b035f7660e3..7151a74e605 100644 --- a/mindspore/core/ops/gather_d.cc +++ b/mindspore/core/ops/gather_d.cc @@ -68,9 +68,7 @@ AbstractBasePtr GatherDInfer(const abstract::AnalysisEnginePtr &, const Primitiv std::set valid_types = {kInt32, kInt64}; CheckAndConvertUtils::CheckTensorTypeValid("index_type", input_args[2]->BuildType(), valid_types, prim_name); CheckAndConvertUtils::CheckSubClass("dim_type", input_args[1]->BuildType(), valid_types, prim_name); - auto abs = std::make_shared(GatherDInferType(primitive, input_args), - GatherDInferShape(primitive, input_args)); - return abs; + return abstract::MakeAbstract(GatherDInferShape(primitive, input_args), GatherDInferType(primitive, input_args)); } REGISTER_PRIMITIVE_EVAL_IMPL(GatherD, prim::kPrimGatherD, GatherDInfer, nullptr, true); } // namespace ops diff --git a/mindspore/core/ops/gelu.cc b/mindspore/core/ops/gelu.cc index d8cba7a140b..2d6426a959c 100644 --- a/mindspore/core/ops/gelu.cc +++ b/mindspore/core/ops/gelu.cc @@ -54,8 +54,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector & } // namespace AbstractBasePtr GeLUInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { - return std::make_shared(InferType(primitive, input_args), - InferShape(primitive, input_args)->shape()); + return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args)); } REGISTER_PRIMITIVE_EVAL_IMPL(GeLU, prim::kPrimGeLU, GeLUInfer, nullptr, true); diff --git a/mindspore/core/ops/grad/bias_add_grad.cc b/mindspore/core/ops/grad/bias_add_grad.cc index 5d8f269a14e..165c7e30b9c 100644 --- a/mindspore/core/ops/grad/bias_add_grad.cc +++ b/mindspore/core/ops/grad/bias_add_grad.cc @@ -72,8 +72,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector & } // namespace AbstractBasePtr BiasAddGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { - return std::make_shared(InferType(primitive, input_args), - InferShape(primitive, input_args)->shape()); + return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args)); } REGISTER_PRIMITIVE_EVAL_IMPL(BiasAddGrad, prim::kPrimBiasAddGrad, BiasAddGradInfer, nullptr, true); diff --git a/mindspore/core/ops/grad/layer_norm_grad.cc b/mindspore/core/ops/grad/layer_norm_grad.cc index cc74304f21a..a5699e95530 100644 --- a/mindspore/core/ops/grad/layer_norm_grad.cc +++ b/mindspore/core/ops/grad/layer_norm_grad.cc @@ -30,11 +30,12 @@ AbstractBasePtr LayerNormGradInfer(const abstract::AnalysisEnginePtr &, const Pr auto x_backprob = input_args[0]->Broaden(); auto gamma_backprob = input_args[4]->Broaden(); auto beta_backprob = input_args[4]->Broaden(); - - AbstractBasePtrList args_list({x_backprob, gamma_backprob, beta_backprob}); - return std::make_shared(args_list); + auto shapes = std::make_shared(std::vector{ + x_backprob->BuildShape(), gamma_backprob->BuildShape(), beta_backprob->BuildShape()}); + auto types = std::make_shared( + std::vector{x_backprob->BuildType(), gamma_backprob->BuildType(), beta_backprob->BuildType()}); + return abstract::MakeAbstract(shapes, types); } - void LayerNormGrad::Init(const int64_t begin_norm_axis, const int64_t begin_params_axis) { this->set_begin_norm_axis(begin_norm_axis); this->set_begin_params_axis(begin_params_axis); diff --git a/mindspore/core/ops/grad/relu_grad.cc b/mindspore/core/ops/grad/relu_grad.cc index e848a288d90..25c78a4d73f 100644 --- a/mindspore/core/ops/grad/relu_grad.cc +++ b/mindspore/core/ops/grad/relu_grad.cc @@ -62,8 +62,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector & } // namespace AbstractBasePtr ReLUGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { - return std::make_shared(InferType(primitive, input_args), - InferShape(primitive, input_args)); + return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args)); } REGISTER_PRIMITIVE_EVAL_IMPL(ReLUGrad, prim::kPrimReluGrad, ReLUGradInfer, nullptr, true); diff --git a/mindspore/core/ops/grad/relu_grad_v2.cc b/mindspore/core/ops/grad/relu_grad_v2.cc index 243be948e9e..8c23cf868e1 100644 --- a/mindspore/core/ops/grad/relu_grad_v2.cc +++ b/mindspore/core/ops/grad/relu_grad_v2.cc @@ -55,8 +55,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector & } // namespace AbstractBasePtr ReLUGradV2Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { - return std::make_shared(InferType(primitive, input_args), - InferShape(primitive, input_args)); + return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args)); } REGISTER_PRIMITIVE_EVAL_IMPL(ReLUGradV2, prim::kPrimReluGradV2, ReLUGradV2Infer, nullptr, true); diff --git a/mindspore/core/ops/layer_norm.cc b/mindspore/core/ops/layer_norm.cc index dd6c00630ac..98ee4f96baa 100644 --- a/mindspore/core/ops/layer_norm.cc +++ b/mindspore/core/ops/layer_norm.cc @@ -93,23 +93,22 @@ AbstractBasePtr LayerNormInfer(const abstract::AnalysisEnginePtr &, const Primit } } - auto mean = input_x->Broaden(); - auto var = input_x->Broaden(); - + std::vector shapes_list = {input_x->BuildShape()}; + std::vector types_list = {input_x->BuildType(), input_x->BuildType(), input_x->BuildType()}; auto mean_var_shape = CalLayerNormMeanAndVarShape(begin_norm_axis, input_shape->shape()); auto input_min_shape = input_shape->min_shape(); auto input_max_shape = input_shape->max_shape(); if (input_min_shape.empty() || input_max_shape.empty()) { - mean->set_shape(std::make_shared(mean_var_shape)); - var->set_shape(std::make_shared(mean_var_shape)); + shapes_list.emplace_back(std::make_shared(mean_var_shape)); + shapes_list.emplace_back(std::make_shared(mean_var_shape)); } else { auto mean_var_shape_min = CalLayerNormMeanAndVarShape(begin_norm_axis, input_min_shape); auto mean_var_shape_max = CalLayerNormMeanAndVarShape(begin_norm_axis, input_min_shape); - mean->set_shape(std::make_shared(mean_var_shape, mean_var_shape_min, mean_var_shape_max)); - var->set_shape(std::make_shared(mean_var_shape, mean_var_shape_min, mean_var_shape_max)); + shapes_list.emplace_back(std::make_shared(mean_var_shape, mean_var_shape_min, mean_var_shape_max)); + shapes_list.emplace_back(std::make_shared(mean_var_shape, mean_var_shape_min, mean_var_shape_max)); } - AbstractBasePtrList args_list({input_x->Broaden(), mean, var}); - return std::make_shared(args_list); + return abstract::MakeAbstract(std::make_shared(shapes_list), + std::make_shared(types_list)); } void LayerNorm::Init(const int64_t begin_norm_axis, const int64_t begin_params_axis, const float epsilon) { diff --git a/mindspore/core/ops/log1p.cc b/mindspore/core/ops/log1p.cc index babb09ad0d4..709214f9ab3 100644 --- a/mindspore/core/ops/log1p.cc +++ b/mindspore/core/ops/log1p.cc @@ -40,8 +40,8 @@ TypePtr Log1pInferType(const PrimitivePtr &prim, const std::vectorname(); // check std::set valid_index_types = {kFloat16, kFloat32}; - auto x_type = - CheckAndConvertUtils::CheckTensorTypeValid("x", input_args[0]->BuildType(), valid_index_types, prim_name); + auto x_type = input_args[0]->BuildType(); + (void)CheckAndConvertUtils::CheckTensorTypeValid("x", input_args[0]->BuildType(), valid_index_types, prim_name); return x_type; } } // namespace @@ -49,9 +49,7 @@ TypePtr Log1pInferType(const PrimitivePtr &prim, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto abs = std::make_shared(Log1pInferType(primitive, input_args), - Log1pInferShape(primitive, input_args)); - return abs; + return abstract::MakeAbstract(Log1pInferShape(primitive, input_args), Log1pInferType(primitive, input_args)); } REGISTER_PRIMITIVE_EVAL_IMPL(Log1p, prim::kPrimLog1p, Log1pInfer, nullptr, true); } // namespace ops diff --git a/mindspore/core/ops/log_softmax.cc b/mindspore/core/ops/log_softmax.cc index cb32039e2e4..41784ea3a88 100644 --- a/mindspore/core/ops/log_softmax.cc +++ b/mindspore/core/ops/log_softmax.cc @@ -65,8 +65,8 @@ TypePtr LogSoftmaxInferType(const PrimitivePtr &prim, const std::vector &input_args) { - return std::make_shared(LogSoftmaxInferType(primitive, input_args), - LogSoftmaxInferShape(primitive, input_args)->shape()); + return abstract::MakeAbstract(LogSoftmaxInferShape(primitive, input_args), + LogSoftmaxInferType(primitive, input_args)); } REGISTER_PRIMITIVE_EVAL_IMPL(LogSoftmax, prim::kPrimLogSoftmax, LogSoftmaxInfer, nullptr, true); } // namespace ops diff --git a/mindspore/core/ops/mat_mul.cc b/mindspore/core/ops/mat_mul.cc index fcf4bae9845..4ae0776f461 100644 --- a/mindspore/core/ops/mat_mul.cc +++ b/mindspore/core/ops/mat_mul.cc @@ -126,8 +126,7 @@ bool MatMul::get_transpose_b() const { AbstractBasePtr MatMulInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { CheckAndConvertUtils::CheckInteger("MatMul infer", input_args.size(), kGreaterEqual, 2, primitive->name()); - return std::make_shared(MatMulInferType(primitive, input_args), - MatMulInferShape(primitive, input_args)->shape()); + return abstract::MakeAbstract(MatMulInferShape(primitive, input_args), MatMulInferType(primitive, input_args)); } // Add REGISTER_PRIMITIVE_C(kNameMatMul, MatMul); diff --git a/mindspore/core/ops/mul.cc b/mindspore/core/ops/mul.cc index f30be39be1a..738a0392570 100644 --- a/mindspore/core/ops/mul.cc +++ b/mindspore/core/ops/mul.cc @@ -52,8 +52,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector & AbstractBasePtr MulInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { - return std::make_shared(InferType(primitive, input_args), - InferShape(primitive, input_args)->shape()); + return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args)); } REGISTER_PRIMITIVE_C(kNameMul, Mul); } // namespace ops diff --git a/mindspore/core/ops/one_hot.cc b/mindspore/core/ops/one_hot.cc index f4aa32036a4..1e17437d66f 100644 --- a/mindspore/core/ops/one_hot.cc +++ b/mindspore/core/ops/one_hot.cc @@ -73,8 +73,7 @@ TypePtr OneHotInferType(const PrimitivePtr &prim, const std::vector &input_args) { - return std::make_shared(OneHotInferType(primitive, input_args), - OneHotInferShape(primitive, input_args)); + return abstract::MakeAbstract(OneHotInferShape(primitive, input_args), OneHotInferType(primitive, input_args)); } REGISTER_PRIMITIVE_EVAL_IMPL(OneHot, prim::kPrimOneHot, OneHotInfer, nullptr, true); } // namespace ops diff --git a/mindspore/core/ops/ones_like.cc b/mindspore/core/ops/ones_like.cc index b536f1deb6b..0d3c3e1b63f 100644 --- a/mindspore/core/ops/ones_like.cc +++ b/mindspore/core/ops/ones_like.cc @@ -45,8 +45,7 @@ TypePtr InferType(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - return std::make_shared(InferType(primitive, input_args), - InferShape(primitive, input_args)); + return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args)); } REGISTER_PRIMITIVE_EVAL_IMPL(OnesLike, prim::kPrimOnesLike, OnesLikeInfer, nullptr, true); } // namespace ops diff --git a/mindspore/core/ops/real_div.cc b/mindspore/core/ops/real_div.cc index 28ca3635073..44a86e96bbd 100644 --- a/mindspore/core/ops/real_div.cc +++ b/mindspore/core/ops/real_div.cc @@ -52,8 +52,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector & AbstractBasePtr RealDivInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { - return std::make_shared(InferType(primitive, input_args), - InferShape(primitive, input_args)->shape()); + return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args)); } REGISTER_PRIMITIVE_C(kNameRealDiv, RealDiv); } // namespace ops diff --git a/mindspore/core/ops/relu.cc b/mindspore/core/ops/relu.cc index e93e65b22ec..a24ae7c7615 100644 --- a/mindspore/core/ops/relu.cc +++ b/mindspore/core/ops/relu.cc @@ -44,14 +44,15 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector & CheckAndConvertUtils::CheckInteger("ReLU infer", input_args.size(), kEqual, 1, prim_name); MS_EXCEPTION_IF_NULL(input_args[0]); auto x_type = input_args[0]->BuildType(); - return CheckAndConvertUtils::CheckTensorTypeValid("input_x", x_type, common_valid_types, prim_name); + (void)CheckAndConvertUtils::CheckTensorTypeValid("input_x", x_type, common_valid_types, prim_name); + return x_type; } } // namespace AbstractBasePtr ReLUInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { auto type = InferType(primitive, input_args); auto shape = InferShape(primitive, input_args); - return std::make_shared(type, shape); + return abstract::MakeAbstract(shape, type); } REGISTER_PRIMITIVE_EVAL_IMPL(ReLU, prim::kPrimRelu, ReLUInfer, nullptr, true); diff --git a/mindspore/core/ops/relu6.cc b/mindspore/core/ops/relu6.cc index a823ebd2e38..e1b4e1b177d 100644 --- a/mindspore/core/ops/relu6.cc +++ b/mindspore/core/ops/relu6.cc @@ -52,8 +52,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector & } // namespace AbstractBasePtr ReLU6Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { - return std::make_shared(InferType(primitive, input_args), - InferShape(primitive, input_args)); + return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args)); } REGISTER_PRIMITIVE_EVAL_IMPL(ReLU6, prim::kPrimRelu6, ReLU6Infer, nullptr, true); } // namespace ops diff --git a/mindspore/core/ops/reluv2.cc b/mindspore/core/ops/reluv2.cc index fb981e2067a..657c0ec2cc9 100644 --- a/mindspore/core/ops/reluv2.cc +++ b/mindspore/core/ops/reluv2.cc @@ -51,8 +51,7 @@ std::vector GetOutputMaskShape(const std::vector &input_shape, } return mask_shape; } -std::vector InferShape(const PrimitivePtr &primitive, - const std::vector &input_args) { +abstract::TupleShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); auto prim_name = primitive->name(); CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kEqual, 1, prim_name); @@ -74,13 +73,13 @@ std::vector InferShape(const PrimitivePtr &primitive, if (min_shape.empty() || max_shape.empty()) { inputs_shape = std::make_shared(input_shape); masks_shape = std::make_shared(mask_shape); - return std::vector{inputs_shape, masks_shape}; + return std::make_shared(std::vector{inputs_shape, masks_shape}); } auto min_mask_shape = GetOutputMaskShape(min_shape, x_dtype); auto max_mask_shape = GetOutputMaskShape(max_shape, x_dtype); inputs_shape = std::make_shared(input_shape, min_shape, max_shape); masks_shape = std::make_shared(mask_shape, min_mask_shape, max_mask_shape); - return std::vector{inputs_shape, masks_shape}; + return std::make_shared(std::vector{inputs_shape, masks_shape}); } TypePtr InferType(const PrimitivePtr &prim, const std::vector &input_args) { @@ -92,7 +91,8 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector & MS_EXCEPTION_IF_NULL(x_type); std::set valid_x_type = {kTensorType}; CheckAndConvertUtils::CheckSubClass("input_x", x_type, valid_x_type, prim_name); - return CheckAndConvertUtils::CheckTensorTypeValid("input_x", x_type, valid_x_type, prim_name); + auto type = CheckAndConvertUtils::CheckTensorTypeValid("input_x", x_type, valid_x_type, prim_name); + return std::make_shared(std::vector{type, type}); } } // namespace AbstractBasePtr ReLUV2Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, @@ -100,10 +100,7 @@ AbstractBasePtr ReLUV2Infer(const abstract::AnalysisEnginePtr &, const Primitive MS_EXCEPTION_IF_NULL(primitive); auto types = InferType(primitive, input_args); auto shapes = InferShape(primitive, input_args); - auto input_dtype = std::make_shared(types, shapes[0]); - auto mask_dtype = std::make_shared(kUInt8, shapes[1]); - AbstractBasePtrList outputs = {input_dtype, mask_dtype}; - return std::make_shared(outputs); + return abstract::MakeAbstract(shapes, types); } REGISTER_PRIMITIVE_EVAL_IMPL(ReLUV2, prim::kPrimReluV2, ReLUV2Infer, nullptr, true); diff --git a/mindspore/core/ops/scalar_summary.cc b/mindspore/core/ops/scalar_summary.cc index 43f14f83a9d..2b061dd0a15 100644 --- a/mindspore/core/ops/scalar_summary.cc +++ b/mindspore/core/ops/scalar_summary.cc @@ -48,7 +48,7 @@ AbstractBasePtr ScalarSummaryInfer(const abstract::AnalysisEnginePtr &, const Pr MS_EXCEPTION_IF_NULL(primitive); // check CheckAndConvertUtils::CheckSummaryParam(input_args[0], input_args[1], primitive->name()); - return std::make_shared(kInt32, ScalarSummaryInferShape(primitive, input_args)); + return abstract::MakeAbstract(ScalarSummaryInferShape(primitive, input_args), kInt32); } REGISTER_PRIMITIVE_EVAL_IMPL(ScalarSummary, prim::kPrimScalarSummary, ScalarSummaryInfer, nullptr, true); } // namespace ops diff --git a/mindspore/core/ops/softmax.cc b/mindspore/core/ops/softmax.cc index b4e71a3c9e2..746fa3007ac 100644 --- a/mindspore/core/ops/softmax.cc +++ b/mindspore/core/ops/softmax.cc @@ -79,8 +79,7 @@ TypePtr SoftMaxInferType(const PrimitivePtr &prim, const std::vector &input_args) { - return std::make_shared(SoftMaxInferType(primitive, input_args), - SoftMaxInferShape(primitive, input_args)); + return abstract::MakeAbstract(SoftMaxInferShape(primitive, input_args), SoftMaxInferType(primitive, input_args)); } REGISTER_PRIMITIVE_EVAL_IMPL(Softmax, prim::kPrimSoftmax, SoftmaxInfer, nullptr, true); } // namespace ops diff --git a/mindspore/core/ops/softplus.cc b/mindspore/core/ops/softplus.cc index f76964b7290..06350adbd8b 100644 --- a/mindspore/core/ops/softplus.cc +++ b/mindspore/core/ops/softplus.cc @@ -49,9 +49,7 @@ TypePtr SoftplusInferType(const PrimitivePtr &prim, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto abs = std::make_shared(SoftplusInferType(primitive, input_args), - SoftplusInferShape(primitive, input_args)); - return abs; + return abstract::MakeAbstract(SoftplusInferShape(primitive, input_args), SoftplusInferType(primitive, input_args)); } REGISTER_PRIMITIVE_EVAL_IMPL(Softplus, prim::kPrimSoftplus, SoftplusInfer, nullptr, true); } // namespace ops diff --git a/mindspore/core/ops/sub.cc b/mindspore/core/ops/sub.cc index 22c0df31b10..ae9e4286247 100644 --- a/mindspore/core/ops/sub.cc +++ b/mindspore/core/ops/sub.cc @@ -52,8 +52,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector & AbstractBasePtr SubInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { - return std::make_shared(InferType(primitive, input_args), - InferShape(primitive, input_args)->shape()); + return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args)); } REGISTER_PRIMITIVE_C(kNameSub, Sub); } // namespace ops diff --git a/mindspore/core/ops/tensor_summary.cc b/mindspore/core/ops/tensor_summary.cc index e7b0469cb49..802b7966515 100644 --- a/mindspore/core/ops/tensor_summary.cc +++ b/mindspore/core/ops/tensor_summary.cc @@ -48,7 +48,7 @@ AbstractBasePtr TensorSummaryInfer(const abstract::AnalysisEnginePtr &, const Pr MS_EXCEPTION_IF_NULL(primitive); // check CheckAndConvertUtils::CheckSummaryParam(input_args[0], input_args[1], primitive->name()); - return std::make_shared(kInt32, TensorSummaryInferShape(primitive, input_args)); + return abstract::MakeAbstract(TensorSummaryInferShape(primitive, input_args), kInt32); } REGISTER_PRIMITIVE_EVAL_IMPL(TensorSummary, prim::kPrimTensorSummary, TensorSummaryInfer, nullptr, true); } // namespace ops diff --git a/mindspore/core/ops/tile.cc b/mindspore/core/ops/tile.cc index 3f84ae6317c..95d9040cdcd 100644 --- a/mindspore/core/ops/tile.cc +++ b/mindspore/core/ops/tile.cc @@ -91,8 +91,7 @@ TypePtr TileInferType(const PrimitivePtr &prim, const std::vector &input_args) { - return std::make_shared(TileInferType(primitive, input_args), - TileInferShape(primitive, input_args)->shape()); + return abstract::MakeAbstract(TileInferShape(primitive, input_args), TileInferType(primitive, input_args)); } REGISTER_PRIMITIVE_C(kNameTile, Tile); diff --git a/mindspore/core/ops/zeros.cc b/mindspore/core/ops/zeros.cc index bd8f11f4cdd..ad46d38b5aa 100644 --- a/mindspore/core/ops/zeros.cc +++ b/mindspore/core/ops/zeros.cc @@ -52,9 +52,7 @@ TypePtr ZerosInferType(const PrimitivePtr &prim, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto abs = std::make_shared(ZerosInferType(primitive, input_args), - ZerosInferShape(primitive, input_args)); - return abs; + return abstract::MakeAbstract(ZerosInferShape(primitive, input_args), ZerosInferType(primitive, input_args)); } ValuePtr ZerosInferValue(const PrimitivePtr &prim, const std::vector &input_args) { diff --git a/mindspore/core/ops/zeros_like.cc b/mindspore/core/ops/zeros_like.cc index ca7d655d342..df237bca126 100644 --- a/mindspore/core/ops/zeros_like.cc +++ b/mindspore/core/ops/zeros_like.cc @@ -38,15 +38,15 @@ TypePtr InferType(const PrimitivePtr &primitive, const std::vectorBuildType(); auto valid_type = common_valid_types; valid_type.insert(kBool); - return CheckAndConvertUtils::CheckTensorTypeValid("infer_type", infer_type, valid_type, op_name); + (void)CheckAndConvertUtils::CheckTensorTypeValid("infer_type", infer_type, valid_type, op_name); + return infer_type; } } // namespace AbstractBasePtr ZerosLikeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - return std::make_shared(InferType(primitive, input_args), - InferShape(primitive, input_args)); + return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args)); } REGISTER_PRIMITIVE_EVAL_IMPL(ZerosLike, prim::kPrimZerosLike, ZerosLikeInfer, nullptr, true); } // namespace ops