forked from mindspore-Ecosystem/mindspore
dynamic op re primitive when infer
This commit is contained in:
parent
1d1f6841b9
commit
39cc9e70cd
|
@ -43,8 +43,6 @@ const AnfNodePtr ConvertConstInputToAttr::Process(const FuncGraphPtr &, const An
|
|||
todos.push_back(node);
|
||||
}
|
||||
|
||||
std::set<string> DynamicShapeConstInputToAttr = {
|
||||
kCastOpName, kExpandDimsOpName, kReshapeOpName, kEmbeddingLookupOpName, kTransposeOpName, kReduceSumOpName};
|
||||
for (auto &t : todos) {
|
||||
CNodePtr cnode = t->cast<CNodePtr>();
|
||||
ConstInputToAttrInfoRegister reg;
|
||||
|
|
|
@ -1636,6 +1636,13 @@ void AnfRuntimeAlgorithm::InferShape(const CNodePtr &node) {
|
|||
args_spec_list.emplace_back(real_input->abstract());
|
||||
}
|
||||
}
|
||||
auto prim_name = primitive->name();
|
||||
if (DynamicShapeConstInputToAttr.find(prim_name) != DynamicShapeConstInputToAttr.end()) {
|
||||
auto attrs = primitive->attrs();
|
||||
auto new_prim_name = "Dynamic" + prim_name;
|
||||
primitive = std::make_shared<Primitive>(new_prim_name);
|
||||
primitive->SetAttrs(attrs);
|
||||
}
|
||||
auto &prim_eval_implement_map = abstract::GetPrimitiveToEvalImplMap();
|
||||
auto ret = prim_eval_implement_map.find(primitive);
|
||||
if (ret == prim_eval_implement_map.end()) {
|
||||
|
|
|
@ -774,6 +774,12 @@ void PynativeExecutor::GetOpOutputAbstract(const OpExecInfoPtr &op_exec_info,
|
|||
MS_EXCEPTION_IF_NULL(py_shape);
|
||||
auto py_shape_info = py_shape->ToString();
|
||||
if (py_shape_info.find("-1") != string::npos) {
|
||||
if (DynamicShapeConstInputToAttr.find(op_name) != DynamicShapeConstInputToAttr.end()) {
|
||||
auto new_prim_name = "Dynamic" + op_name;
|
||||
auto attrs = prim->attrs();
|
||||
prim = std::make_shared<PrimitivePy>(new_prim_name, py::object());
|
||||
prim->SetAttrs(attrs);
|
||||
}
|
||||
auto c_abstract = abstract::CppInferShape(prim, args_spec_list);
|
||||
MS_EXCEPTION_IF_NULL(c_abstract);
|
||||
auto c_shape = c_abstract->BuildShape();
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include "common/trans.h"
|
||||
#include "pipeline/jit/static_analysis/static_analysis.h"
|
||||
#include "abstract/dshape.h"
|
||||
#include "utils/utils.h"
|
||||
#include "abstract/param_validator.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -123,6 +124,13 @@ void DynamicKernel::InferShape() {
|
|||
args_spec_list.emplace_back(real_input->abstract());
|
||||
}
|
||||
}
|
||||
auto prim_name = primitive->name();
|
||||
if (DynamicShapeConstInputToAttr.find(prim_name) != DynamicShapeConstInputToAttr.end()) {
|
||||
auto new_prim_name = "Dynamic" + prim_name;
|
||||
auto attrs = primitive->attrs();
|
||||
primitive = std::make_shared<Primitive>(new_prim_name);
|
||||
primitive->SetAttrs(attrs);
|
||||
}
|
||||
|
||||
auto eval_result = abstract::CppInferShape(primitive, args_spec_list);
|
||||
cnode_ptr_->set_abstract(eval_result);
|
||||
|
|
|
@ -490,6 +490,9 @@ const std::set<std::string> kComputeDepend = {kUniqueOpName, kComputeAccidentalH
|
|||
|
||||
const std::set<std::string> k3DFormatSet = {kOpFormat_NCDHW, kOpFormat_NDC1HWC0, kOpFormat_FRACTAL_Z_3D};
|
||||
|
||||
const std::set<std::string> DynamicShapeConstInputToAttr = {
|
||||
kCastOpName, kExpandDimsOpName, kReshapeOpName, kEmbeddingLookupOpName, kTransposeOpName, kReduceSumOpName};
|
||||
|
||||
static inline void ChangeFileMode(const std::string &file_name, mode_t mode) {
|
||||
try {
|
||||
if (chmod(file_name.c_str(), mode) != 0) {
|
||||
|
|
|
@ -249,30 +249,30 @@ AbstractBasePtr InferImplReduceScatter(const AnalysisEnginePtr &, const Primitiv
|
|||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplSGD(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplTranspose(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplReshape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplDynamicTranspose(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplDynamicReshape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplMemCpyAsync(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplEmbeddingLookup(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplDynamicEmbeddingLookup(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplSub(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplReduceSum(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplCast(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplDynamicReduceSum(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplDynamicCast(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplMinimum(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplDivNoNan(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplLinSpace(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplExpandDims(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplDynamicExpandDims(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplGpuConvertToDynamicShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplPad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
|
|
|
@ -656,9 +656,9 @@ AbstractBasePtr InferImplDynamicAssign(const AnalysisEnginePtr &, const Primitiv
|
|||
}
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplEmbeddingLookup(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
const std::string op_name = primitive->name();
|
||||
AbstractBasePtr InferImplDynamicEmbeddingLookup(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
const std::string op_name = primitive->name().substr(kDynamic);
|
||||
CheckArgsSize(op_name, args_spec_list, 2);
|
||||
auto params = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
||||
auto params_shp = params->shape();
|
||||
|
@ -752,9 +752,9 @@ AbstractBasePtr InferImplZerosLike(const AnalysisEnginePtr &, const PrimitivePtr
|
|||
return std::make_shared<AbstractTensor>(input_x->element(), output_shape);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplTranspose(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
const std::string &op_name = primitive->name();
|
||||
AbstractBasePtr InferImplDynamicTranspose(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
const std::string &op_name = primitive->name().substr(kDynamic);
|
||||
AbstractTensorPtr input = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
||||
auto input_shp = input->shape()->shape();
|
||||
ValuePtr perm = primitive->GetAttr("perm");
|
||||
|
@ -779,9 +779,9 @@ AbstractBasePtr InferImplTranspose(const AnalysisEnginePtr &, const PrimitivePtr
|
|||
return std::make_shared<AbstractTensor>(input->element(), std::make_shared<Shape>(result_shp, min_shp, max_shp));
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplReshape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
const std::string op_name = primitive->name();
|
||||
AbstractBasePtr InferImplDynamicReshape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
const std::string op_name = primitive->name().substr(kDynamic);
|
||||
auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
||||
MS_EXCEPTION_IF_NULL(x);
|
||||
MS_EXCEPTION_IF_NULL(x->shape());
|
||||
|
|
|
@ -121,9 +121,9 @@ AbstractBasePtr InferImplEqual(const AnalysisEnginePtr &, const PrimitivePtr &pr
|
|||
return ret;
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplReduceSum(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
const std::string op_name = primitive->name();
|
||||
AbstractBasePtr InferImplDynamicReduceSum(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
const std::string op_name = primitive->name().substr(kDynamic);
|
||||
CheckArgsSize(op_name, args_spec_list, 1);
|
||||
auto input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
||||
MS_EXCEPTION_IF_NULL(input_x);
|
||||
|
|
|
@ -479,9 +479,9 @@ AbstractBasePtr InferImplMemCpyAsync(const AnalysisEnginePtr &, const PrimitiveP
|
|||
return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(x->shape()->shape()));
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplCast(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
const std::string op_name = primitive->name();
|
||||
AbstractBasePtr InferImplDynamicCast(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
const std::string op_name = primitive->name().substr(kDynamic);
|
||||
// GPU has 2 inputs while tbe has 1 only. Skip CheckArgsSize.
|
||||
auto input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
||||
MS_EXCEPTION_IF_NULL(input_x);
|
||||
|
@ -491,9 +491,9 @@ AbstractBasePtr InferImplCast(const AnalysisEnginePtr &, const PrimitivePtr &pri
|
|||
return ret;
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplExpandDims(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
const std::string op_name = primitive->name();
|
||||
AbstractBasePtr InferImplDynamicExpandDims(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
const std::string op_name = primitive->name().substr(kDynamic);
|
||||
CheckArgsSize(op_name, args_spec_list, 1);
|
||||
auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
||||
MS_EXCEPTION_IF_NULL(x);
|
||||
|
|
|
@ -44,7 +44,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
|
|||
{prim::kPrimSqrtGrad, {InferImplSqrtGrad, true}},
|
||||
{prim::kPrimSub, {InferImplSub, true}},
|
||||
{prim::kPrimEqual, {InferImplEqual, true}},
|
||||
{prim::kPrimReduceSum, {InferImplReduceSum, true}},
|
||||
{prim::kPrimDynamicReduceSum, {InferImplDynamicReduceSum, true}},
|
||||
{prim::kPrimMinimum, {InferImplMinimum, true}},
|
||||
{prim::kPrimDivNoNan, {InferImplDivNoNan, true}},
|
||||
{prim::kPrimLinSpace, {InferImplLinSpace, true}},
|
||||
|
@ -59,7 +59,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
|
|||
{prim::kPrimUniqueGrad, {InferImplUniqueGrad, true}},
|
||||
{prim::kPrimGatherV2, {InferImplGatherV2, true}},
|
||||
{prim::kPrimSparseGatherV2, {InferImplGatherV2, true}},
|
||||
{prim::kPrimEmbeddingLookup, {InferImplEmbeddingLookup, true}},
|
||||
{prim::kPrimDynamicEmbeddingLookup, {InferImplDynamicEmbeddingLookup, true}},
|
||||
{prim::kPrimUnsortedSegmentSum, {InferImplUnsortedSegmentSum, true}},
|
||||
{prim::kPrimUnsortedSegmentMax, {InferImplUnsortedSegmentMax, true}},
|
||||
{prim::kPrimUnsortedSegmentMin, {InferImplUnsortedSegmentMin, true}},
|
||||
|
@ -76,8 +76,8 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
|
|||
{prim::kPrimRealDiv, {InferImplRealDiv, true}},
|
||||
{prim::kPrimShape, {InferImplShape, false}},
|
||||
{prim::kPrimDynamicShape, {InferImplDynamicShape, true}},
|
||||
{prim::kPrimTranspose, {InferImplTranspose, true}},
|
||||
{prim::kPrimReshape, {InferImplReshape, true}},
|
||||
{prim::kPrimDynamicTranspose, {InferImplDynamicTranspose, true}},
|
||||
{prim::kPrimDynamicReshape, {InferImplDynamicReshape, true}},
|
||||
{prim::kPrimMapUniform, {InferImplMapUniform, true}},
|
||||
{prim::kPrimSplit, {InferImplSplit, true}},
|
||||
{prim::kPrimSequenceMask, {InferImplSequenceMask, true}},
|
||||
|
@ -155,8 +155,8 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
|
|||
{prim::kPrimAllSwap, {InferImplAllSwap, true}},
|
||||
{prim::kPrimReduceScatter, {InferImplReduceScatter, true}},
|
||||
{prim::kPrimMemCpyAsync, {InferImplMemCpyAsync, true}},
|
||||
{prim::kPrimCast, {InferImplCast, true}},
|
||||
{prim::kPrimExpandDims, {InferImplExpandDims, true}},
|
||||
{prim::kPrimDynamicCast, {InferImplDynamicCast, true}},
|
||||
{prim::kPrimDynamicExpandDims, {InferImplDynamicExpandDims, true}},
|
||||
};
|
||||
return prim_eval_implement_map;
|
||||
}
|
||||
|
|
|
@ -29,6 +29,8 @@
|
|||
#include "utils/shape_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
// length of string "dynamic"
|
||||
const int kDynamic = 7;
|
||||
namespace abstract {
|
||||
ValuePtr ValueJoin(const ValuePtr &value1, const ValuePtr &value2);
|
||||
TypePtr TypeJoin(const TypePtr &type1, const TypePtr &type2);
|
||||
|
|
|
@ -80,6 +80,12 @@ inline const PrimitivePtr kPrimBroadcastShape = std::make_shared<Primitive>("bro
|
|||
inline const PrimitivePtr kPrimArrayMap = std::make_shared<Primitive>("array_map");
|
||||
inline const PrimitivePtr kPrimArrayReduce = std::make_shared<Primitive>("array_reduce");
|
||||
inline const PrimitivePtr kPrimCast = std::make_shared<Primitive>("Cast");
|
||||
inline const PrimitivePtr kPrimDynamicCast = std::make_shared<Primitive>("DynamicCast");
|
||||
inline const PrimitivePtr kPrimDynamicReshape = std::make_shared<Primitive>("DynamicReshape");
|
||||
inline const PrimitivePtr kPrimDynamicReduceSum = std::make_shared<Primitive>("DynamicReduceSum");
|
||||
inline const PrimitivePtr kPrimDynamicTranspose = std::make_shared<Primitive>("DynamicTranspose");
|
||||
inline const PrimitivePtr kPrimDynamicExpandDims = std::make_shared<Primitive>("DynamicExpandDims");
|
||||
inline const PrimitivePtr kPrimDynamicEmbeddingLookup = std::make_shared<Primitive>("DynamicEmbeddingLookup");
|
||||
inline const PrimitivePtr kPrimConcat = std::make_shared<Primitive>("Concat");
|
||||
inline const PrimitivePtr kPrimSqueeze = std::make_shared<Primitive>("Squeeze");
|
||||
inline const PrimitivePtr kPrimTranspose = std::make_shared<Primitive>("Transpose");
|
||||
|
|
Loading…
Reference in New Issue