From bdc21bcf3eef4e2b77efcace8421705a1558d84b Mon Sep 17 00:00:00 2001 From: LianLiguang Date: Wed, 7 Apr 2021 11:42:30 +0800 Subject: [PATCH] change infer value --- .../ccsrc/backend/optimizer/common/helper.cc | 6 +- .../operator/ops_front_infer_function.cc | 42 ++++++------- .../operator/ops_front_infer_function.h | 4 +- .../pipeline/jit/static_analysis/prim.cc | 59 ++++++++++++------- .../ccsrc/pipeline/jit/static_analysis/prim.h | 7 ++- .../jit/static_analysis/static_analysis.cc | 4 +- .../core/abstract/primitive_infer_map.cc | 6 +- mindspore/core/abstract/primitive_infer_map.h | 30 ++++++---- mindspore/core/ops/gather_d.cc | 1 - mindspore/core/ops/primitive_c.cc | 2 +- mindspore/core/ops/scalar_summary.cc | 1 - mindspore/core/ops/softmax.cc | 1 - mindspore/core/ops/tensor_summary.cc | 1 - mindspore/core/ops/zeros.cc | 2 - mindspore/core/utils/tensor_construct_utils.h | 6 +- tests/ut/cpp/abstract/abstract_test.cc | 6 +- .../static_analysis/static_analysis_test.cc | 11 ++-- .../restore_abs_input_in_backed_infer_test.cc | 8 +++ 18 files changed, 113 insertions(+), 84 deletions(-) diff --git a/mindspore/ccsrc/backend/optimizer/common/helper.cc b/mindspore/ccsrc/backend/optimizer/common/helper.cc index e734230a5ee..79d63fd2cb3 100644 --- a/mindspore/ccsrc/backend/optimizer/common/helper.cc +++ b/mindspore/ccsrc/backend/optimizer/common/helper.cc @@ -923,14 +923,16 @@ AbstractBasePtr CppInferShape(const PrimitivePtr &prim, const AbstractBasePtrLis auto ret = prim_eval_implement_map.find(prim); if (ret != prim_eval_implement_map.end()) { // fing infer function in the front infer map and restore input abastract form dynamic inputs and reg attr + MS_EXCEPTION_IF_NULL(ret->second.infer_shape_dtype_impl_); auto infer_spec_list = RectifyAbstract(prim, args_spec_list); - return ret->second.impl_(nullptr, prim, infer_spec_list); + return ret->second.infer_shape_dtype_impl_(nullptr, prim, infer_spec_list); } else { // if the infer function has been not founded in the front infer map find it in the backend infer map instead auto &prim_backend_eval_impl_map = abstract::GetPrimitiveToBackendEvalImplMap(); auto ret_backend = prim_backend_eval_impl_map.find(prim); if (ret_backend != prim_backend_eval_impl_map.end()) { - return ret_backend->second.impl_(nullptr, prim, args_spec_list); + MS_EXCEPTION_IF_NULL(ret_backend->second.infer_shape_dtype_impl_); + return ret_backend->second.infer_shape_dtype_impl_(nullptr, prim, args_spec_list); } } MS_LOG(EXCEPTION) << "Get infer shape function failed, primitive name:" << prim->name() diff --git a/mindspore/ccsrc/frontend/operator/ops_front_infer_function.cc b/mindspore/ccsrc/frontend/operator/ops_front_infer_function.cc index a463fd01958..ebaa674ecbf 100644 --- a/mindspore/ccsrc/frontend/operator/ops_front_infer_function.cc +++ b/mindspore/ccsrc/frontend/operator/ops_front_infer_function.cc @@ -639,26 +639,26 @@ AbstractBasePtr InferImplMakeRecord(const AnalysisEnginePtr &, const PrimitivePt return std::make_shared(cls->tag(), abs_attributes, cls->methods()); } -REGISTER_PRIMITIVE_EVAL_IMPL(TypeOf, prim::kPrimTypeOf, InferImplTypeof, nullptr, false); -REGISTER_PRIMITIVE_EVAL_IMPL(HasType, prim::kPrimHasType, InferImplHasType, nullptr, false); -REGISTER_PRIMITIVE_EVAL_IMPL(MakeRecord, prim::kPrimMakeRecord, InferImplMakeRecord, nullptr, false); -REGISTER_PRIMITIVE_EVAL_IMPL(ListMap, prim::kPrimListMap, InferImplListMap, nullptr, false); -REGISTER_PRIMITIVE_EVAL_IMPL(ListReduce, prim::kPrimListReduce, InferImplListReduce, nullptr, false); -REGISTER_PRIMITIVE_EVAL_IMPL(TupleReversed, prim::kPrimTupleReversed, InferImplTupleReversed, nullptr, false); -REGISTER_PRIMITIVE_EVAL_IMPL(ReducedShape, prim::kPrimReducedShape, InferImplReduceShape, nullptr, false); -REGISTER_PRIMITIVE_EVAL_IMPL(TupleDiv, prim::kPrimTupleDiv, InferImplTupleDiv, nullptr, false); -REGISTER_PRIMITIVE_EVAL_IMPL(TupleToArray, prim::kPrimTupleToArray, InferImplTuple2Array, nullptr, false); -REGISTER_PRIMITIVE_EVAL_IMPL(ShapeMul, prim::kPrimShapeMul, InferImplShapeMul, nullptr, false); -REGISTER_PRIMITIVE_EVAL_IMPL(TupleEqual, prim::kPrimTupleEqual, InferImplTupleEqual, nullptr, false); -REGISTER_PRIMITIVE_EVAL_IMPL(ListEqual, prim::kPrimListEqual, InferImplListEqual, nullptr, false); -REGISTER_PRIMITIVE_EVAL_IMPL(MakeRange, prim::kPrimMakeRange, InferImplMakeRange, nullptr, false); -REGISTER_PRIMITIVE_EVAL_IMPL(StopGradient, prim::kPrimStopGradient, InferImplStopGradient, nullptr, false); -REGISTER_PRIMITIVE_EVAL_IMPL(StringEqual, prim::kPrimStringEqual, InferImplStringEqual, nullptr, false); -REGISTER_PRIMITIVE_EVAL_IMPL(StringConcat, prim::kPrimStringConcat, InferImplStringConcat, nullptr, false); -REGISTER_PRIMITIVE_EVAL_IMPL(DictLen, prim::kPrimDictLen, InferImplDictLen, nullptr, false); -REGISTER_PRIMITIVE_EVAL_IMPL(FakeBprop, prim::kPrimFakeBprop, InferImplFakeBprop, nullptr, false); -REGISTER_PRIMITIVE_EVAL_IMPL(J, prim::kPrimJ, InferImplJ, nullptr, false); -REGISTER_PRIMITIVE_EVAL_IMPL(BroadcastGradientArgs, prim::kPrimBroadcastGradientArgs, InferImplBroadcastGradientArgs, - nullptr, false); +REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(TypeOf, prim::kPrimTypeOf, InferImplTypeof, nullptr); +REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(HasType, prim::kPrimHasType, InferImplHasType, nullptr); +REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(MakeRecord, prim::kPrimMakeRecord, InferImplMakeRecord, nullptr); +REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(ListMap, prim::kPrimListMap, InferImplListMap, nullptr); +REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(ListReduce, prim::kPrimListReduce, InferImplListReduce, nullptr); +REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(TupleReversed, prim::kPrimTupleReversed, InferImplTupleReversed, nullptr); +REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(ReducedShape, prim::kPrimReducedShape, InferImplReduceShape, nullptr); +REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(TupleDiv, prim::kPrimTupleDiv, InferImplTupleDiv, nullptr); +REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(TupleToArray, prim::kPrimTupleToArray, InferImplTuple2Array, nullptr); +REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(ShapeMul, prim::kPrimShapeMul, InferImplShapeMul, nullptr); +REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(TupleEqual, prim::kPrimTupleEqual, InferImplTupleEqual, nullptr); +REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(ListEqual, prim::kPrimListEqual, InferImplListEqual, nullptr); +REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(MakeRange, prim::kPrimMakeRange, InferImplMakeRange, nullptr); +REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(StopGradient, prim::kPrimStopGradient, InferImplStopGradient, nullptr); +REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(StringEqual, prim::kPrimStringEqual, InferImplStringEqual, nullptr); +REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(StringConcat, prim::kPrimStringConcat, InferImplStringConcat, nullptr); +REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(DictLen, prim::kPrimDictLen, InferImplDictLen, nullptr); +REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(FakeBprop, prim::kPrimFakeBprop, InferImplFakeBprop, nullptr); +REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(J, prim::kPrimJ, InferImplJ, nullptr); +REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(BroadcastGradientArgs, prim::kPrimBroadcastGradientArgs, + InferImplBroadcastGradientArgs, nullptr); } // namespace abstract } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/operator/ops_front_infer_function.h b/mindspore/ccsrc/frontend/operator/ops_front_infer_function.h index 33ef55ef490..0c488b927ac 100644 --- a/mindspore/ccsrc/frontend/operator/ops_front_infer_function.h +++ b/mindspore/ccsrc/frontend/operator/ops_front_infer_function.h @@ -59,7 +59,9 @@ AbstractBasePtr InferImplFakeBprop(const AnalysisEnginePtr &, const PrimitivePtr const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplMakeRecord(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); +#define REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(name, primitive, infer_impl, infer_value_impl) \ + static auto helper_##name = \ + abstract::RegisterStandardPrimitiveEvalHelper(primitive, infer_impl, infer_value_impl, false); } // namespace abstract } // namespace mindspore - #endif // MINDSPORE_CCSRC_FRONTEND_OPERATE_OPS_FRONT_INFER_FUNCTION_H_ diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc index 3debf5309c0..f8b7699204a 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc @@ -530,28 +530,17 @@ AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dic } } // end anonymous namespace -EvalResultPtr StandardPrimEvaluator::EvalPyCheckPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) { +EvalResultPtr StandardPrimEvaluator::RunPyInferValue(const AnalysisEnginePtr &engine, const AbstractBasePtr &abs_base, + const AbstractBasePtrList &args) { auto prim_py = dyn_cast(prim_); if (prim_py == nullptr) { MS_LOG(EXCEPTION) << "The primitive with type 'kPrimTypePyInferCheck' should be a python primitive."; } - - // Call checking method '__check__' for subclass of 'PrimitiveWithCheck' + // Call checking method 'infer_value' for python primitive MS_LOG(DEBUG) << "Begin input args checking for: " << prim_py->ToString(); auto py_args = PreparePyInputs(prim_py, args); - prim_py->RunCheck(py_args); - - prim_->BeginRecordAddAttr(); - AbstractBasePtr abs_base = eval_impl_(engine, prim_, args); - prim_->EndRecordAddAttr(); - auto added_attrs = prim_->evaluate_added_attrs(); - - if (!py::hasattr(prim_py->GetPyObj(), PY_PRIM_METHOD_INFER_VALUE)) { - return std::make_shared(abs_base, std::make_shared(added_attrs)); - } - - // Call method 'infer_value' for primitive with this method for constant propagation py::tuple py_vals(py_args.size()); + auto added_attrs = prim_->evaluate_added_attrs(); for (size_t i = 0; i < py_args.size(); ++i) { py_vals[i] = py_args[i][ATTR_VALUE]; } @@ -559,7 +548,6 @@ EvalResultPtr StandardPrimEvaluator::EvalPyCheckPrim(const AnalysisEnginePtr &en if (py::isinstance(py_ret)) { return std::make_shared(abs_base, std::make_shared(added_attrs)); } - // Convert pyobject to Value, then to AbstractValue ValuePtr converted_ret = nullptr; TypePtr dtype = abs_base->BuildType(); @@ -577,6 +565,28 @@ EvalResultPtr StandardPrimEvaluator::EvalPyCheckPrim(const AnalysisEnginePtr &en return std::make_shared(res_spec, std::make_shared(added_attrs)); } +EvalResultPtr StandardPrimEvaluator::EvalPyCheckPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) { + auto prim_py = dyn_cast(prim_); + if (prim_py == nullptr) { + MS_LOG(EXCEPTION) << "The primitive with type 'kPrimTypePyInferCheck' should be a python primitive."; + } + // Call checking method '__check__' for subclass of 'PrimitiveWithCheck' + MS_LOG(DEBUG) << "Begin input args checking for: " << prim_py->ToString(); + auto py_args = PreparePyInputs(prim_py, args); + prim_py->RunCheck(py_args); + + prim_->BeginRecordAddAttr(); + AbstractBasePtr abs_base = eval_impl_.infer_shape_dtype_impl_(engine, prim_, args); + prim_->EndRecordAddAttr(); + auto added_attrs = prim_->evaluate_added_attrs(); + + if (!py::hasattr(prim_py->GetPyObj(), PY_PRIM_METHOD_INFER_VALUE)) { + return std::make_shared(abs_base, std::make_shared(added_attrs)); + } + // Call method 'infer_value' for primitive with this method for constant propagation + return RunPyInferValue(engine, abs_base, args); +} + EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) { if (prims_to_skip_undetermined_infer.find(prim_->name()) == prims_to_skip_undetermined_infer.end()) { auto ret_abstract = AbstractEval(args); @@ -589,11 +599,19 @@ EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, c if (prim_->prim_type() == PrimType::kPrimTypePyInferCheck) { return EvalPyCheckPrim(engine, args); } - + auto context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context); + bool need_infer_value = eval_impl_.in_white_list_ || (context->get_param(MS_CTX_EXECUTION_MODE) == kGraphMode); prim_->BeginRecordAddAttr(); - AbstractBasePtr abs_base = eval_impl_(engine, prim_, args); + AbstractBasePtr abs_base = eval_impl_.infer_shape_dtype_impl_(engine, prim_, args); prim_->EndRecordAddAttr(); auto added_attrs = prim_->evaluate_added_attrs(); + if (need_infer_value) { + if (eval_impl_.infer_value_func_ != nullptr) { + auto value = eval_impl_.infer_value_func_(prim_, args, abs_base); + abs_base->set_value(value); + } + } auto eval_result = std::make_shared(abs_base, std::make_shared(added_attrs)); return eval_result; } @@ -617,7 +635,6 @@ EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const Abs auto added_attrs = prim_py_->evaluate_added_attrs(); MS_LOG(DEBUG) << "Output type is " << (std::string)py::str(output); auto res_spec = PyInferRes2Abstract(prim_py_, output); - MS_LOG(DEBUG) << "Python InferTensor result spec: " << res_spec->ToString() << "."; auto infer_result = std::make_shared(res_spec, std::make_shared(added_attrs)); (*evaluator_cache_map_)[args] = infer_result; @@ -689,7 +706,7 @@ ValuePtr UniformPrimEvaluator::RunImpl(const ValuePtrList &args) const { // Primitive implementation // static function start namespace { -EvaluatorPtr InitStandardPrimEvaluator(PrimitivePtr primitive, const StandardPrimitiveEvalImpl eval_impl) { +EvaluatorPtr InitStandardPrimEvaluator(PrimitivePtr primitive, const StandardPrimitiveImplReg eval_impl) { EvaluatorPtr prim_evaluator = std::make_shared(primitive, eval_impl); return prim_evaluator; } @@ -1279,7 +1296,7 @@ void InitPrimEvaluatorConstructors() { PrimEvaluatorMap &constructor = PrimEvaluatorConstructors; for (const auto &iter : GetPrimitiveToEvalImplMap()) { - constructor[iter.first] = InitStandardPrimEvaluator(iter.first, iter.second.impl_); + constructor[iter.first] = InitStandardPrimEvaluator(iter.first, iter.second); } for (const auto &iter : GetUniformPrimitiveToImplMap()) { diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.h b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.h index 868b002f754..4dfcfd265ac 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.h +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.h @@ -32,7 +32,7 @@ namespace mindspore { namespace abstract { class StandardPrimEvaluator : public TrivialPrimEvaluator { public: - StandardPrimEvaluator(const PrimitivePtr primitive, StandardPrimitiveEvalImpl eval_impl) + StandardPrimEvaluator(const PrimitivePtr &primitive, const StandardPrimitiveImplReg &eval_impl) : TrivialPrimEvaluator("StandardPrimEvaluator"), prim_(primitive), eval_impl_(eval_impl) {} ~StandardPrimEvaluator() override = default; MS_DECLARE_PARENT(StandardPrimEvaluator, TrivialPrimEvaluator); @@ -43,9 +43,10 @@ class StandardPrimEvaluator : public TrivialPrimEvaluator { private: EvalResultPtr EvalPyCheckPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args); - + EvalResultPtr RunPyInferValue(const AnalysisEnginePtr &engine, const AbstractBasePtr &abs_base, + const AbstractBasePtrList &args); PrimitivePtr prim_; - const StandardPrimitiveEvalImpl eval_impl_; + const StandardPrimitiveImplReg eval_impl_; }; using StandardPrimEvaluatorPtr = std::shared_ptr; diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc index 3c29cf17013..9f05e8e4669 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc @@ -358,8 +358,8 @@ EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr } // find prim infer function in the prim function map return a standard evaluator - StandardPrimitiveEvalImpl eval_impl = GetPrimitiveInferImpl(prim); - if (eval_impl != nullptr) { + auto eval_impl = GetPrimitiveInferImpl(prim); + if (eval_impl.infer_shape_dtype_impl_ != nullptr) { return std::make_shared(prim, eval_impl); } diff --git a/mindspore/core/abstract/primitive_infer_map.cc b/mindspore/core/abstract/primitive_infer_map.cc index 6ed0136aabc..cb9fd5fa5eb 100644 --- a/mindspore/core/abstract/primitive_infer_map.cc +++ b/mindspore/core/abstract/primitive_infer_map.cc @@ -213,13 +213,13 @@ PrimitiveEvalImplMap &GetPrimitiveToBackendEvalImplMap() { return prim_backend_eval_implement_map; } -StandardPrimitiveEvalImpl GetPrimitiveInferImpl(const PrimitivePtr &primitive) { +StandardPrimitiveImplReg GetPrimitiveInferImpl(const PrimitivePtr &primitive) { MS_EXCEPTION_IF_NULL(primitive); auto iter = GetPrimitiveToEvalImplMap().find(primitive); if (iter == GetPrimitiveToEvalImplMap().end()) { - return nullptr; + return {nullptr, nullptr, false}; } - return iter->second.impl_; + return iter->second; } void RegisterStandardPrimitiveImpl(const PrimitivePtr &primitive, const StandardPrimitiveImplReg &impl_reg) { diff --git a/mindspore/core/abstract/primitive_infer_map.h b/mindspore/core/abstract/primitive_infer_map.h index c4b6d984c9c..e4b0710dbbc 100644 --- a/mindspore/core/abstract/primitive_infer_map.h +++ b/mindspore/core/abstract/primitive_infer_map.h @@ -19,21 +19,24 @@ #define MINDSPORE_CORE_ABSTRACT_PRIMITIVE_INFER_MAP_H_ #include #include +#include #include "ir/primitive.h" +#include "ops/primitive_c.h" #include "base/core_ops.h" #include "abstract/abstract_value.h" #include "ir/anf.h" namespace mindspore { namespace abstract { -using StandardPrimitiveEvalImpl = AbstractBasePtr (*)(const abstract::AnalysisEnginePtr &, const PrimitivePtr &, - const AbstractBasePtrList &); +using InferShapeAndTypeImpl = AbstractBasePtr (*)(const abstract::AnalysisEnginePtr &, const PrimitivePtr &, + const AbstractBasePtrList &); using InferValueEvalImpl = ValuePtr (*)(const PrimitivePtr &, const AbstractBasePtrList &, const AbstractBasePtr &); struct StandardPrimitiveImplReg { - StandardPrimitiveEvalImpl impl_; // Implement function of Primitive - InferValueEvalImpl infer_value_func_; // infer value of primitive - // true means this primitive can be executed by vm backend else will be constant folded by frontend + InferShapeAndTypeImpl infer_shape_dtype_impl_; // Implement function of Primitive + InferValueEvalImpl infer_value_func_; // infer value of primitive + // in_white_list_ is true means this primitive can be executed by vm backend + // else will be optimized by frontend bool in_white_list_; }; @@ -44,7 +47,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap(); PrimitiveEvalImplMap &GetPrimitiveToBackendEvalImplMap(); -StandardPrimitiveEvalImpl GetPrimitiveInferImpl(const PrimitivePtr &primitive); +StandardPrimitiveImplReg GetPrimitiveInferImpl(const PrimitivePtr &primitive); std::vector GetDependsFormMap(const CNodePtr &cnode); @@ -52,17 +55,22 @@ void RegisterStandardPrimitiveImpl(const PrimitivePtr &primitive, const Standard class RegisterStandardPrimitiveEvalHelper { public: - RegisterStandardPrimitiveEvalHelper(const PrimitivePtr &primitive, const StandardPrimitiveEvalImpl &impl, + RegisterStandardPrimitiveEvalHelper(const PrimitivePtr &primitive, const InferShapeAndTypeImpl &infer_impl, const InferValueEvalImpl &infer_value_impl, const bool is_wight_list = true) { - const StandardPrimitiveImplReg impl_reg{impl, infer_value_impl, is_wight_list}; + const StandardPrimitiveImplReg impl_reg{infer_impl, infer_value_impl, is_wight_list}; RegisterStandardPrimitiveImpl(primitive, impl_reg); } ~RegisterStandardPrimitiveEvalHelper() = default; }; -#define REGISTER_PRIMITIVE_EVAL_IMPL(name, primitive, impl, infer_value_impl, is_wight_list) \ - static auto helper_##name = \ - abstract::RegisterStandardPrimitiveEvalHelper(primitive, impl, infer_value_impl, is_wight_list) +#define REGISTER_PRIMITIVE_EVAL_IMPL(name, primitive, infer_impl, infer_value_impl, is_wight_list) \ + static auto helper_##name = \ + abstract::RegisterStandardPrimitiveEvalHelper(primitive, infer_impl, infer_value_impl, is_wight_list); \ + std::shared_ptr GetDefaultPrimC##name() { \ + auto out = std::make_shared(); \ + return out; \ + } \ + ops::OpPrimCRegisterHelper primc_gen_##name(#name, GetDefaultPrimC##name); } // namespace abstract } // namespace mindspore #endif // MINDSPORE_CORE_ABSTRACT_PRIMITIVE_INFER_MAP_H_ diff --git a/mindspore/core/ops/gather_d.cc b/mindspore/core/ops/gather_d.cc index b195977a9c0..05d95697def 100644 --- a/mindspore/core/ops/gather_d.cc +++ b/mindspore/core/ops/gather_d.cc @@ -70,6 +70,5 @@ AbstractBasePtr GatherDInfer(const abstract::AnalysisEnginePtr &, const Primitiv return abs; } REGISTER_PRIMITIVE_EVAL_IMPL(GatherD, prim::kPrimGatherD, GatherDInfer, nullptr, false); -REGISTER_PRIMITIVE_C(kNameGatherD, GatherD); } // namespace ops } // namespace mindspore diff --git a/mindspore/core/ops/primitive_c.cc b/mindspore/core/ops/primitive_c.cc index dcb89530eaf..067a0a3bee2 100644 --- a/mindspore/core/ops/primitive_c.cc +++ b/mindspore/core/ops/primitive_c.cc @@ -32,7 +32,7 @@ AbstractBasePtr PrimitiveC::Infer(const AbstractBasePtrList &abstract_list) { if (iter == infer_map.end()) { MS_EXCEPTION(NotExistsError) << "Cannot find the " << this->name() << "infer function in the infer map!"; } - auto infer_function = iter->second.impl_; + auto infer_function = iter->second.infer_shape_dtype_impl_; return infer_function(nullptr, shared_from_base(), abstract_list); } diff --git a/mindspore/core/ops/scalar_summary.cc b/mindspore/core/ops/scalar_summary.cc index 03e0c87adcb..641ae573181 100644 --- a/mindspore/core/ops/scalar_summary.cc +++ b/mindspore/core/ops/scalar_summary.cc @@ -51,6 +51,5 @@ AbstractBasePtr ScalarSummaryInfer(const abstract::AnalysisEnginePtr &, const Pr return std::make_shared(kInt32, ScalarSummaryInferShape(primitive, input_args)); } REGISTER_PRIMITIVE_EVAL_IMPL(ScalarSummary, prim::kPrimScalarSummary, ScalarSummaryInfer, nullptr, true); -REGISTER_PRIMITIVE_C(kNameScalarSummary, ScalarSummary); } // namespace ops } // namespace mindspore diff --git a/mindspore/core/ops/softmax.cc b/mindspore/core/ops/softmax.cc index 0df32d9275c..5026a4712f3 100644 --- a/mindspore/core/ops/softmax.cc +++ b/mindspore/core/ops/softmax.cc @@ -81,6 +81,5 @@ AbstractBasePtr SoftmaxInfer(const abstract::AnalysisEnginePtr &, const Primitiv SoftMaxInferShape(primitive, input_args)->shape()); } REGISTER_PRIMITIVE_EVAL_IMPL(Softmax, prim::kPrimSoftmax, SoftmaxInfer, nullptr, true); -REGISTER_PRIMITIVE_C(kNameSoftmax, Softmax); } // namespace ops } // namespace mindspore diff --git a/mindspore/core/ops/tensor_summary.cc b/mindspore/core/ops/tensor_summary.cc index efe23d4e436..73266a960c5 100644 --- a/mindspore/core/ops/tensor_summary.cc +++ b/mindspore/core/ops/tensor_summary.cc @@ -51,6 +51,5 @@ AbstractBasePtr TensorSummaryInfer(const abstract::AnalysisEnginePtr &, const Pr return std::make_shared(kInt32, TensorSummaryInferShape(primitive, input_args)); } REGISTER_PRIMITIVE_EVAL_IMPL(TensorSummary, prim::kPrimTensorSummary, TensorSummaryInfer, nullptr, true); -REGISTER_PRIMITIVE_C(kNameTensorSummary, TensorSummary); } // namespace ops } // namespace mindspore diff --git a/mindspore/core/ops/zeros.cc b/mindspore/core/ops/zeros.cc index e39a0a58a26..398e40996a0 100644 --- a/mindspore/core/ops/zeros.cc +++ b/mindspore/core/ops/zeros.cc @@ -66,10 +66,8 @@ AbstractBasePtr ZerosInfer(const abstract::AnalysisEnginePtr &, const PrimitiveP MS_EXCEPTION_IF_NULL(primitive); auto abs = std::make_shared(ZerosInferType(primitive, input_args), ZerosInferShape(primitive, input_args)); - abs->set_value(ZerosInferValue(primitive, input_args, abs)); return abs; } REGISTER_PRIMITIVE_EVAL_IMPL(Zeros, prim::kPrimZeros, ZerosInfer, ZerosInferValue, false); -REGISTER_PRIMITIVE_C(kNameZeros, Zeros); } // namespace ops } // namespace mindspore diff --git a/mindspore/core/utils/tensor_construct_utils.h b/mindspore/core/utils/tensor_construct_utils.h index ec08c122e58..fc8b3bebed9 100644 --- a/mindspore/core/utils/tensor_construct_utils.h +++ b/mindspore/core/utils/tensor_construct_utils.h @@ -16,6 +16,7 @@ #ifndef MINDSPORE_CORE_UTILS_TENSOR_CONSTRUCT_UTILS_H_ #define MINDSPORE_CORE_UTILS_TENSOR_CONSTRUCT_UTILS_H_ #include +#include #include "ir/tensor.h" namespace mindspore { template @@ -23,10 +24,7 @@ void SetTensorData(void *data, T num, size_t data_length) { MS_EXCEPTION_IF_NULL(data); auto tensor_data = reinterpret_cast(data); MS_EXCEPTION_IF_NULL(tensor_data); - for (size_t index = 0; index < data_length; ++index) { - *tensor_data = num; - ++tensor_data; - } + std::fill(tensor_data, tensor_data + data_length, num); } class TensorConstructUtils { public: diff --git a/tests/ut/cpp/abstract/abstract_test.cc b/tests/ut/cpp/abstract/abstract_test.cc index 1c244246295..380c94b250a 100644 --- a/tests/ut/cpp/abstract/abstract_test.cc +++ b/tests/ut/cpp/abstract/abstract_test.cc @@ -86,10 +86,10 @@ TEST_F(TestAbstract, TestParseDataClass) { AbstractBasePtrList args_list = {abs_scalar, abstract_x, abstract_y}; - StandardPrimitiveEvalImpl eval_impl = GetPrimitiveInferImpl(prim::kPrimMakeRecord); - ASSERT_TRUE(nullptr != eval_impl); + auto eval_impl = GetPrimitiveInferImpl(prim::kPrimMakeRecord); + ASSERT_TRUE(nullptr != eval_impl.infer_shape_dtype_impl_); - AbstractBasePtr new_cls = eval_impl(nullptr, prim::kPrimMakeRecord, args_list); + AbstractBasePtr new_cls = eval_impl.infer_shape_dtype_impl_(nullptr, prim::kPrimMakeRecord, args_list); ASSERT_TRUE(nullptr != new_cls); } diff --git a/tests/ut/cpp/pipeline/static_analysis/static_analysis_test.cc b/tests/ut/cpp/pipeline/static_analysis/static_analysis_test.cc index 8bfde95be30..50f01376d7e 100644 --- a/tests/ut/cpp/pipeline/static_analysis/static_analysis_test.cc +++ b/tests/ut/cpp/pipeline/static_analysis/static_analysis_test.cc @@ -47,8 +47,8 @@ AbstractBasePtr InferImplScalarAddStub(const AnalysisEnginePtr &engine, const Pr } EvaluatorPtr InitPrimitiveScalarAddEvaluatorStub() { - EvaluatorPtr PrimitiveScalarAddEvaluator = - std::make_shared(prim::kPrimScalarAdd, InferImplScalarAddStub); + EvaluatorPtr PrimitiveScalarAddEvaluator = std::make_shared( + prim::kPrimScalarAdd, StandardPrimitiveImplReg{InferImplScalarAddStub, nullptr, true}); return PrimitiveScalarAddEvaluator; } @@ -63,8 +63,8 @@ AbstractBasePtr InferImplReturnStub(const AnalysisEnginePtr &engine, const Primi } EvaluatorPtr InitPrimitiveReturnEvaluatorStub() { - EvaluatorPtr PrimitiveReturnEvaluator = - std::make_shared(prim::kPrimReturn, InferImplReturnStub); + EvaluatorPtr PrimitiveReturnEvaluator = std::make_shared( + prim::kPrimReturn, StandardPrimitiveImplReg{InferImplReturnStub, nullptr, true}); return PrimitiveReturnEvaluator; } @@ -396,7 +396,6 @@ TEST_F(TestInferUniform, test_inferred_scalar_add) { ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeInt64); } - class TestEvalOnePrim : public UT::Common { public: TestEvalOnePrim() : getPyFun("gtest_input.pipeline.infer.infer_test", true), engine_(nullptr) {} @@ -435,7 +434,7 @@ class TestGraphEval : public UT::Common { UT::PyFuncGraphFetcher getPyFun; }; -void TestGraphEval::SetUp() { engine_ = SetupAnalysisEngine(); } +void TestGraphEval::SetUp() { engine_ = SetupAnalysisEngine(); } void TestGraphEval::TearDown() { // destroy resource diff --git a/tests/ut/cpp/pre_activate/common/restore_abs_input_in_backed_infer_test.cc b/tests/ut/cpp/pre_activate/common/restore_abs_input_in_backed_infer_test.cc index ff9f191c6a8..4f829a368ba 100644 --- a/tests/ut/cpp/pre_activate/common/restore_abs_input_in_backed_infer_test.cc +++ b/tests/ut/cpp/pre_activate/common/restore_abs_input_in_backed_infer_test.cc @@ -25,6 +25,14 @@ #include "common/common_test.h" namespace mindspore { namespace opt { +class TestAttr : public ops::PrimitiveC { + public: + TestAttr() : PrimitiveC("") {} +}; +class TestDynamicInput : public ops::PrimitiveC { + public: + TestDynamicInput() : PrimitiveC("") {} +}; constexpr auto kAttrConvertTestName = "attr_convert_test"; constexpr auto kDynamicInputTestName = "dynamic_input_test"; inline const PrimitivePtr kPrimAttrConvertTest = std::make_shared(kAttrConvertTestName);