dynamic op re primitive when infer

This commit is contained in:
liubuyu 2021-01-07 10:49:07 +08:00
parent 1d1f6841b9
commit 39cc9e70cd
12 changed files with 68 additions and 38 deletions

View File

@ -43,8 +43,6 @@ const AnfNodePtr ConvertConstInputToAttr::Process(const FuncGraphPtr &, const An
todos.push_back(node);
}
std::set<string> DynamicShapeConstInputToAttr = {
kCastOpName, kExpandDimsOpName, kReshapeOpName, kEmbeddingLookupOpName, kTransposeOpName, kReduceSumOpName};
for (auto &t : todos) {
CNodePtr cnode = t->cast<CNodePtr>();
ConstInputToAttrInfoRegister reg;

View File

@ -1636,6 +1636,13 @@ void AnfRuntimeAlgorithm::InferShape(const CNodePtr &node) {
args_spec_list.emplace_back(real_input->abstract());
}
}
auto prim_name = primitive->name();
if (DynamicShapeConstInputToAttr.find(prim_name) != DynamicShapeConstInputToAttr.end()) {
auto attrs = primitive->attrs();
auto new_prim_name = "Dynamic" + prim_name;
primitive = std::make_shared<Primitive>(new_prim_name);
primitive->SetAttrs(attrs);
}
auto &prim_eval_implement_map = abstract::GetPrimitiveToEvalImplMap();
auto ret = prim_eval_implement_map.find(primitive);
if (ret == prim_eval_implement_map.end()) {

View File

@ -774,6 +774,12 @@ void PynativeExecutor::GetOpOutputAbstract(const OpExecInfoPtr &op_exec_info,
MS_EXCEPTION_IF_NULL(py_shape);
auto py_shape_info = py_shape->ToString();
if (py_shape_info.find("-1") != string::npos) {
if (DynamicShapeConstInputToAttr.find(op_name) != DynamicShapeConstInputToAttr.end()) {
auto new_prim_name = "Dynamic" + op_name;
auto attrs = prim->attrs();
prim = std::make_shared<PrimitivePy>(new_prim_name, py::object());
prim->SetAttrs(attrs);
}
auto c_abstract = abstract::CppInferShape(prim, args_spec_list);
MS_EXCEPTION_IF_NULL(c_abstract);
auto c_shape = c_abstract->BuildShape();

View File

@ -23,6 +23,7 @@
#include "common/trans.h"
#include "pipeline/jit/static_analysis/static_analysis.h"
#include "abstract/dshape.h"
#include "utils/utils.h"
#include "abstract/param_validator.h"
namespace mindspore {
@ -123,6 +124,13 @@ void DynamicKernel::InferShape() {
args_spec_list.emplace_back(real_input->abstract());
}
}
auto prim_name = primitive->name();
if (DynamicShapeConstInputToAttr.find(prim_name) != DynamicShapeConstInputToAttr.end()) {
auto new_prim_name = "Dynamic" + prim_name;
auto attrs = primitive->attrs();
primitive = std::make_shared<Primitive>(new_prim_name);
primitive->SetAttrs(attrs);
}
auto eval_result = abstract::CppInferShape(primitive, args_spec_list);
cnode_ptr_->set_abstract(eval_result);

View File

@ -490,6 +490,9 @@ const std::set<std::string> kComputeDepend = {kUniqueOpName, kComputeAccidentalH
const std::set<std::string> k3DFormatSet = {kOpFormat_NCDHW, kOpFormat_NDC1HWC0, kOpFormat_FRACTAL_Z_3D};
const std::set<std::string> DynamicShapeConstInputToAttr = {
kCastOpName, kExpandDimsOpName, kReshapeOpName, kEmbeddingLookupOpName, kTransposeOpName, kReduceSumOpName};
static inline void ChangeFileMode(const std::string &file_name, mode_t mode) {
try {
if (chmod(file_name.c_str(), mode) != 0) {

View File

@ -249,21 +249,21 @@ AbstractBasePtr InferImplReduceScatter(const AnalysisEnginePtr &, const Primitiv
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplSGD(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplTranspose(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
AbstractBasePtr InferImplDynamicTranspose(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplReshape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
AbstractBasePtr InferImplDynamicReshape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplMemCpyAsync(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplEmbeddingLookup(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
AbstractBasePtr InferImplDynamicEmbeddingLookup(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplSub(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplReduceSum(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
AbstractBasePtr InferImplDynamicReduceSum(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplCast(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
AbstractBasePtr InferImplDynamicCast(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplMinimum(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
@ -271,7 +271,7 @@ AbstractBasePtr InferImplDivNoNan(const AnalysisEnginePtr &, const PrimitivePtr
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplLinSpace(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplExpandDims(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
AbstractBasePtr InferImplDynamicExpandDims(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplGpuConvertToDynamicShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);

View File

@ -656,9 +656,9 @@ AbstractBasePtr InferImplDynamicAssign(const AnalysisEnginePtr &, const Primitiv
}
}
AbstractBasePtr InferImplEmbeddingLookup(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
AbstractBasePtr InferImplDynamicEmbeddingLookup(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
const std::string op_name = primitive->name();
const std::string op_name = primitive->name().substr(kDynamic);
CheckArgsSize(op_name, args_spec_list, 2);
auto params = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
auto params_shp = params->shape();
@ -752,9 +752,9 @@ AbstractBasePtr InferImplZerosLike(const AnalysisEnginePtr &, const PrimitivePtr
return std::make_shared<AbstractTensor>(input_x->element(), output_shape);
}
AbstractBasePtr InferImplTranspose(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
AbstractBasePtr InferImplDynamicTranspose(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
const std::string &op_name = primitive->name();
const std::string &op_name = primitive->name().substr(kDynamic);
AbstractTensorPtr input = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
auto input_shp = input->shape()->shape();
ValuePtr perm = primitive->GetAttr("perm");
@ -779,9 +779,9 @@ AbstractBasePtr InferImplTranspose(const AnalysisEnginePtr &, const PrimitivePtr
return std::make_shared<AbstractTensor>(input->element(), std::make_shared<Shape>(result_shp, min_shp, max_shp));
}
AbstractBasePtr InferImplReshape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
AbstractBasePtr InferImplDynamicReshape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
const std::string op_name = primitive->name();
const std::string op_name = primitive->name().substr(kDynamic);
auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
MS_EXCEPTION_IF_NULL(x);
MS_EXCEPTION_IF_NULL(x->shape());

View File

@ -121,9 +121,9 @@ AbstractBasePtr InferImplEqual(const AnalysisEnginePtr &, const PrimitivePtr &pr
return ret;
}
AbstractBasePtr InferImplReduceSum(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
AbstractBasePtr InferImplDynamicReduceSum(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
const std::string op_name = primitive->name();
const std::string op_name = primitive->name().substr(kDynamic);
CheckArgsSize(op_name, args_spec_list, 1);
auto input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
MS_EXCEPTION_IF_NULL(input_x);

View File

@ -479,9 +479,9 @@ AbstractBasePtr InferImplMemCpyAsync(const AnalysisEnginePtr &, const PrimitiveP
return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(x->shape()->shape()));
}
AbstractBasePtr InferImplCast(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
AbstractBasePtr InferImplDynamicCast(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
const std::string op_name = primitive->name();
const std::string op_name = primitive->name().substr(kDynamic);
// GPU has 2 inputs while tbe has 1 only. Skip CheckArgsSize.
auto input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
MS_EXCEPTION_IF_NULL(input_x);
@ -491,9 +491,9 @@ AbstractBasePtr InferImplCast(const AnalysisEnginePtr &, const PrimitivePtr &pri
return ret;
}
AbstractBasePtr InferImplExpandDims(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
AbstractBasePtr InferImplDynamicExpandDims(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
const std::string op_name = primitive->name();
const std::string op_name = primitive->name().substr(kDynamic);
CheckArgsSize(op_name, args_spec_list, 1);
auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
MS_EXCEPTION_IF_NULL(x);

View File

@ -44,7 +44,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimSqrtGrad, {InferImplSqrtGrad, true}},
{prim::kPrimSub, {InferImplSub, true}},
{prim::kPrimEqual, {InferImplEqual, true}},
{prim::kPrimReduceSum, {InferImplReduceSum, true}},
{prim::kPrimDynamicReduceSum, {InferImplDynamicReduceSum, true}},
{prim::kPrimMinimum, {InferImplMinimum, true}},
{prim::kPrimDivNoNan, {InferImplDivNoNan, true}},
{prim::kPrimLinSpace, {InferImplLinSpace, true}},
@ -59,7 +59,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimUniqueGrad, {InferImplUniqueGrad, true}},
{prim::kPrimGatherV2, {InferImplGatherV2, true}},
{prim::kPrimSparseGatherV2, {InferImplGatherV2, true}},
{prim::kPrimEmbeddingLookup, {InferImplEmbeddingLookup, true}},
{prim::kPrimDynamicEmbeddingLookup, {InferImplDynamicEmbeddingLookup, true}},
{prim::kPrimUnsortedSegmentSum, {InferImplUnsortedSegmentSum, true}},
{prim::kPrimUnsortedSegmentMax, {InferImplUnsortedSegmentMax, true}},
{prim::kPrimUnsortedSegmentMin, {InferImplUnsortedSegmentMin, true}},
@ -76,8 +76,8 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimRealDiv, {InferImplRealDiv, true}},
{prim::kPrimShape, {InferImplShape, false}},
{prim::kPrimDynamicShape, {InferImplDynamicShape, true}},
{prim::kPrimTranspose, {InferImplTranspose, true}},
{prim::kPrimReshape, {InferImplReshape, true}},
{prim::kPrimDynamicTranspose, {InferImplDynamicTranspose, true}},
{prim::kPrimDynamicReshape, {InferImplDynamicReshape, true}},
{prim::kPrimMapUniform, {InferImplMapUniform, true}},
{prim::kPrimSplit, {InferImplSplit, true}},
{prim::kPrimSequenceMask, {InferImplSequenceMask, true}},
@ -155,8 +155,8 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimAllSwap, {InferImplAllSwap, true}},
{prim::kPrimReduceScatter, {InferImplReduceScatter, true}},
{prim::kPrimMemCpyAsync, {InferImplMemCpyAsync, true}},
{prim::kPrimCast, {InferImplCast, true}},
{prim::kPrimExpandDims, {InferImplExpandDims, true}},
{prim::kPrimDynamicCast, {InferImplDynamicCast, true}},
{prim::kPrimDynamicExpandDims, {InferImplDynamicExpandDims, true}},
};
return prim_eval_implement_map;
}

View File

@ -29,6 +29,8 @@
#include "utils/shape_utils.h"
namespace mindspore {
// length of string "dynamic"
const int kDynamic = 7;
namespace abstract {
ValuePtr ValueJoin(const ValuePtr &value1, const ValuePtr &value2);
TypePtr TypeJoin(const TypePtr &type1, const TypePtr &type2);

View File

@ -80,6 +80,12 @@ inline const PrimitivePtr kPrimBroadcastShape = std::make_shared<Primitive>("bro
inline const PrimitivePtr kPrimArrayMap = std::make_shared<Primitive>("array_map");
inline const PrimitivePtr kPrimArrayReduce = std::make_shared<Primitive>("array_reduce");
inline const PrimitivePtr kPrimCast = std::make_shared<Primitive>("Cast");
inline const PrimitivePtr kPrimDynamicCast = std::make_shared<Primitive>("DynamicCast");
inline const PrimitivePtr kPrimDynamicReshape = std::make_shared<Primitive>("DynamicReshape");
inline const PrimitivePtr kPrimDynamicReduceSum = std::make_shared<Primitive>("DynamicReduceSum");
inline const PrimitivePtr kPrimDynamicTranspose = std::make_shared<Primitive>("DynamicTranspose");
inline const PrimitivePtr kPrimDynamicExpandDims = std::make_shared<Primitive>("DynamicExpandDims");
inline const PrimitivePtr kPrimDynamicEmbeddingLookup = std::make_shared<Primitive>("DynamicEmbeddingLookup");
inline const PrimitivePtr kPrimConcat = std::make_shared<Primitive>("Concat");
inline const PrimitivePtr kPrimSqueeze = std::make_shared<Primitive>("Squeeze");
inline const PrimitivePtr kPrimTranspose = std::make_shared<Primitive>("Transpose");