From 39cc9e70cd1d73047fe839862c7717afa17aeb6c Mon Sep 17 00:00:00 2001 From: liubuyu Date: Thu, 7 Jan 2021 10:49:07 +0800 Subject: [PATCH] dynamic op re primitive when infer --- .../pass/convert_const_input_to_attr.cc | 2 -- .../backend/session/anf_runtime_algorithm.cc | 7 ++++++ .../pipeline/pynative/pynative_execute.cc | 6 +++++ .../runtime/device/executor/dynamic_kernel.cc | 8 +++++++ mindspore/ccsrc/utils/utils.h | 3 +++ mindspore/core/abstract/infer_functions.h | 24 +++++++++---------- mindspore/core/abstract/prim_arrays.cc | 18 +++++++------- mindspore/core/abstract/prim_maths.cc | 6 ++--- mindspore/core/abstract/prim_others.cc | 12 +++++----- .../core/abstract/primitive_infer_map.cc | 12 +++++----- mindspore/core/abstract/utils.h | 2 ++ mindspore/core/base/core_ops.h | 6 +++++ 12 files changed, 68 insertions(+), 38 deletions(-) diff --git a/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_attr.cc b/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_attr.cc index 04d2ea723ba..dfcd8523c37 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_attr.cc +++ b/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_attr.cc @@ -43,8 +43,6 @@ const AnfNodePtr ConvertConstInputToAttr::Process(const FuncGraphPtr &, const An todos.push_back(node); } - std::set DynamicShapeConstInputToAttr = { - kCastOpName, kExpandDimsOpName, kReshapeOpName, kEmbeddingLookupOpName, kTransposeOpName, kReduceSumOpName}; for (auto &t : todos) { CNodePtr cnode = t->cast(); ConstInputToAttrInfoRegister reg; diff --git a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc index 412c10ff40d..f6928b3d1e8 100644 --- a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc +++ b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc @@ -1636,6 +1636,13 @@ void AnfRuntimeAlgorithm::InferShape(const CNodePtr &node) { args_spec_list.emplace_back(real_input->abstract()); } } + auto prim_name = primitive->name(); + if (DynamicShapeConstInputToAttr.find(prim_name) != DynamicShapeConstInputToAttr.end()) { + auto attrs = primitive->attrs(); + auto new_prim_name = "Dynamic" + prim_name; + primitive = std::make_shared(new_prim_name); + primitive->SetAttrs(attrs); + } auto &prim_eval_implement_map = abstract::GetPrimitiveToEvalImplMap(); auto ret = prim_eval_implement_map.find(primitive); if (ret == prim_eval_implement_map.end()) { diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index 912300337e7..364827622dc 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -774,6 +774,12 @@ void PynativeExecutor::GetOpOutputAbstract(const OpExecInfoPtr &op_exec_info, MS_EXCEPTION_IF_NULL(py_shape); auto py_shape_info = py_shape->ToString(); if (py_shape_info.find("-1") != string::npos) { + if (DynamicShapeConstInputToAttr.find(op_name) != DynamicShapeConstInputToAttr.end()) { + auto new_prim_name = "Dynamic" + op_name; + auto attrs = prim->attrs(); + prim = std::make_shared(new_prim_name, py::object()); + prim->SetAttrs(attrs); + } auto c_abstract = abstract::CppInferShape(prim, args_spec_list); MS_EXCEPTION_IF_NULL(c_abstract); auto c_shape = c_abstract->BuildShape(); diff --git a/mindspore/ccsrc/runtime/device/executor/dynamic_kernel.cc b/mindspore/ccsrc/runtime/device/executor/dynamic_kernel.cc index eaa442f2f58..5cfa7673501 100644 --- a/mindspore/ccsrc/runtime/device/executor/dynamic_kernel.cc +++ b/mindspore/ccsrc/runtime/device/executor/dynamic_kernel.cc @@ -23,6 +23,7 @@ #include "common/trans.h" #include "pipeline/jit/static_analysis/static_analysis.h" #include "abstract/dshape.h" +#include "utils/utils.h" #include "abstract/param_validator.h" namespace mindspore { @@ -123,6 +124,13 @@ void DynamicKernel::InferShape() { args_spec_list.emplace_back(real_input->abstract()); } } + auto prim_name = primitive->name(); + if (DynamicShapeConstInputToAttr.find(prim_name) != DynamicShapeConstInputToAttr.end()) { + auto new_prim_name = "Dynamic" + prim_name; + auto attrs = primitive->attrs(); + primitive = std::make_shared(new_prim_name); + primitive->SetAttrs(attrs); + } auto eval_result = abstract::CppInferShape(primitive, args_spec_list); cnode_ptr_->set_abstract(eval_result); diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 5669722fc14..34e3b4133ac 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -490,6 +490,9 @@ const std::set kComputeDepend = {kUniqueOpName, kComputeAccidentalH const std::set k3DFormatSet = {kOpFormat_NCDHW, kOpFormat_NDC1HWC0, kOpFormat_FRACTAL_Z_3D}; +const std::set DynamicShapeConstInputToAttr = { + kCastOpName, kExpandDimsOpName, kReshapeOpName, kEmbeddingLookupOpName, kTransposeOpName, kReduceSumOpName}; + static inline void ChangeFileMode(const std::string &file_name, mode_t mode) { try { if (chmod(file_name.c_str(), mode) != 0) { diff --git a/mindspore/core/abstract/infer_functions.h b/mindspore/core/abstract/infer_functions.h index 92dc45a53b8..b833bc28da7 100644 --- a/mindspore/core/abstract/infer_functions.h +++ b/mindspore/core/abstract/infer_functions.h @@ -249,30 +249,30 @@ AbstractBasePtr InferImplReduceScatter(const AnalysisEnginePtr &, const Primitiv const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplSGD(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplTranspose(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplReshape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplDynamicTranspose(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplDynamicReshape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplMemCpyAsync(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplEmbeddingLookup(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplDynamicEmbeddingLookup(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplSub(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplReduceSum(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplCast(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplDynamicReduceSum(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplDynamicCast(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplMinimum(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplDivNoNan(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplLinSpace(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplExpandDims(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplDynamicExpandDims(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplGpuConvertToDynamicShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplPad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, diff --git a/mindspore/core/abstract/prim_arrays.cc b/mindspore/core/abstract/prim_arrays.cc index a243850099b..4d39bc7875a 100644 --- a/mindspore/core/abstract/prim_arrays.cc +++ b/mindspore/core/abstract/prim_arrays.cc @@ -656,9 +656,9 @@ AbstractBasePtr InferImplDynamicAssign(const AnalysisEnginePtr &, const Primitiv } } -AbstractBasePtr InferImplEmbeddingLookup(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - const std::string op_name = primitive->name(); +AbstractBasePtr InferImplDynamicEmbeddingLookup(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + const std::string op_name = primitive->name().substr(kDynamic); CheckArgsSize(op_name, args_spec_list, 2); auto params = CheckArg(op_name, args_spec_list, 0); auto params_shp = params->shape(); @@ -752,9 +752,9 @@ AbstractBasePtr InferImplZerosLike(const AnalysisEnginePtr &, const PrimitivePtr return std::make_shared(input_x->element(), output_shape); } -AbstractBasePtr InferImplTranspose(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - const std::string &op_name = primitive->name(); +AbstractBasePtr InferImplDynamicTranspose(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + const std::string &op_name = primitive->name().substr(kDynamic); AbstractTensorPtr input = CheckArg(op_name, args_spec_list, 0); auto input_shp = input->shape()->shape(); ValuePtr perm = primitive->GetAttr("perm"); @@ -779,9 +779,9 @@ AbstractBasePtr InferImplTranspose(const AnalysisEnginePtr &, const PrimitivePtr return std::make_shared(input->element(), std::make_shared(result_shp, min_shp, max_shp)); } -AbstractBasePtr InferImplReshape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - const std::string op_name = primitive->name(); +AbstractBasePtr InferImplDynamicReshape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + const std::string op_name = primitive->name().substr(kDynamic); auto x = CheckArg(op_name, args_spec_list, 0); MS_EXCEPTION_IF_NULL(x); MS_EXCEPTION_IF_NULL(x->shape()); diff --git a/mindspore/core/abstract/prim_maths.cc b/mindspore/core/abstract/prim_maths.cc index 1ee95a2ea3c..fe10bc302b5 100644 --- a/mindspore/core/abstract/prim_maths.cc +++ b/mindspore/core/abstract/prim_maths.cc @@ -121,9 +121,9 @@ AbstractBasePtr InferImplEqual(const AnalysisEnginePtr &, const PrimitivePtr &pr return ret; } -AbstractBasePtr InferImplReduceSum(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - const std::string op_name = primitive->name(); +AbstractBasePtr InferImplDynamicReduceSum(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + const std::string op_name = primitive->name().substr(kDynamic); CheckArgsSize(op_name, args_spec_list, 1); auto input_x = CheckArg(op_name, args_spec_list, 0); MS_EXCEPTION_IF_NULL(input_x); diff --git a/mindspore/core/abstract/prim_others.cc b/mindspore/core/abstract/prim_others.cc index a2010058891..e1286e558ad 100644 --- a/mindspore/core/abstract/prim_others.cc +++ b/mindspore/core/abstract/prim_others.cc @@ -479,9 +479,9 @@ AbstractBasePtr InferImplMemCpyAsync(const AnalysisEnginePtr &, const PrimitiveP return std::make_shared(x->element(), std::make_shared(x->shape()->shape())); } -AbstractBasePtr InferImplCast(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - const std::string op_name = primitive->name(); +AbstractBasePtr InferImplDynamicCast(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + const std::string op_name = primitive->name().substr(kDynamic); // GPU has 2 inputs while tbe has 1 only. Skip CheckArgsSize. auto input_x = CheckArg(op_name, args_spec_list, 0); MS_EXCEPTION_IF_NULL(input_x); @@ -491,9 +491,9 @@ AbstractBasePtr InferImplCast(const AnalysisEnginePtr &, const PrimitivePtr &pri return ret; } -AbstractBasePtr InferImplExpandDims(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - const std::string op_name = primitive->name(); +AbstractBasePtr InferImplDynamicExpandDims(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + const std::string op_name = primitive->name().substr(kDynamic); CheckArgsSize(op_name, args_spec_list, 1); auto x = CheckArg(op_name, args_spec_list, 0); MS_EXCEPTION_IF_NULL(x); diff --git a/mindspore/core/abstract/primitive_infer_map.cc b/mindspore/core/abstract/primitive_infer_map.cc index 7fcf58fe762..c5219429676 100644 --- a/mindspore/core/abstract/primitive_infer_map.cc +++ b/mindspore/core/abstract/primitive_infer_map.cc @@ -44,7 +44,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { {prim::kPrimSqrtGrad, {InferImplSqrtGrad, true}}, {prim::kPrimSub, {InferImplSub, true}}, {prim::kPrimEqual, {InferImplEqual, true}}, - {prim::kPrimReduceSum, {InferImplReduceSum, true}}, + {prim::kPrimDynamicReduceSum, {InferImplDynamicReduceSum, true}}, {prim::kPrimMinimum, {InferImplMinimum, true}}, {prim::kPrimDivNoNan, {InferImplDivNoNan, true}}, {prim::kPrimLinSpace, {InferImplLinSpace, true}}, @@ -59,7 +59,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { {prim::kPrimUniqueGrad, {InferImplUniqueGrad, true}}, {prim::kPrimGatherV2, {InferImplGatherV2, true}}, {prim::kPrimSparseGatherV2, {InferImplGatherV2, true}}, - {prim::kPrimEmbeddingLookup, {InferImplEmbeddingLookup, true}}, + {prim::kPrimDynamicEmbeddingLookup, {InferImplDynamicEmbeddingLookup, true}}, {prim::kPrimUnsortedSegmentSum, {InferImplUnsortedSegmentSum, true}}, {prim::kPrimUnsortedSegmentMax, {InferImplUnsortedSegmentMax, true}}, {prim::kPrimUnsortedSegmentMin, {InferImplUnsortedSegmentMin, true}}, @@ -76,8 +76,8 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { {prim::kPrimRealDiv, {InferImplRealDiv, true}}, {prim::kPrimShape, {InferImplShape, false}}, {prim::kPrimDynamicShape, {InferImplDynamicShape, true}}, - {prim::kPrimTranspose, {InferImplTranspose, true}}, - {prim::kPrimReshape, {InferImplReshape, true}}, + {prim::kPrimDynamicTranspose, {InferImplDynamicTranspose, true}}, + {prim::kPrimDynamicReshape, {InferImplDynamicReshape, true}}, {prim::kPrimMapUniform, {InferImplMapUniform, true}}, {prim::kPrimSplit, {InferImplSplit, true}}, {prim::kPrimSequenceMask, {InferImplSequenceMask, true}}, @@ -155,8 +155,8 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { {prim::kPrimAllSwap, {InferImplAllSwap, true}}, {prim::kPrimReduceScatter, {InferImplReduceScatter, true}}, {prim::kPrimMemCpyAsync, {InferImplMemCpyAsync, true}}, - {prim::kPrimCast, {InferImplCast, true}}, - {prim::kPrimExpandDims, {InferImplExpandDims, true}}, + {prim::kPrimDynamicCast, {InferImplDynamicCast, true}}, + {prim::kPrimDynamicExpandDims, {InferImplDynamicExpandDims, true}}, }; return prim_eval_implement_map; } diff --git a/mindspore/core/abstract/utils.h b/mindspore/core/abstract/utils.h index 8220294f7ef..da07a1f4833 100644 --- a/mindspore/core/abstract/utils.h +++ b/mindspore/core/abstract/utils.h @@ -29,6 +29,8 @@ #include "utils/shape_utils.h" namespace mindspore { +// length of string "dynamic" +const int kDynamic = 7; namespace abstract { ValuePtr ValueJoin(const ValuePtr &value1, const ValuePtr &value2); TypePtr TypeJoin(const TypePtr &type1, const TypePtr &type2); diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index 8ab89abffc5..bd1662f05be 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -80,6 +80,12 @@ inline const PrimitivePtr kPrimBroadcastShape = std::make_shared("bro inline const PrimitivePtr kPrimArrayMap = std::make_shared("array_map"); inline const PrimitivePtr kPrimArrayReduce = std::make_shared("array_reduce"); inline const PrimitivePtr kPrimCast = std::make_shared("Cast"); +inline const PrimitivePtr kPrimDynamicCast = std::make_shared("DynamicCast"); +inline const PrimitivePtr kPrimDynamicReshape = std::make_shared("DynamicReshape"); +inline const PrimitivePtr kPrimDynamicReduceSum = std::make_shared("DynamicReduceSum"); +inline const PrimitivePtr kPrimDynamicTranspose = std::make_shared("DynamicTranspose"); +inline const PrimitivePtr kPrimDynamicExpandDims = std::make_shared("DynamicExpandDims"); +inline const PrimitivePtr kPrimDynamicEmbeddingLookup = std::make_shared("DynamicEmbeddingLookup"); inline const PrimitivePtr kPrimConcat = std::make_shared("Concat"); inline const PrimitivePtr kPrimSqueeze = std::make_shared("Squeeze"); inline const PrimitivePtr kPrimTranspose = std::make_shared("Transpose");