change macro of reg infer

This commit is contained in:
LianLiguang 2021-04-01 16:53:45 +08:00
parent 609b8edb9f
commit 8cabd090de
12 changed files with 238 additions and 246 deletions

View File

@ -30,45 +30,6 @@ namespace prim {
ValuePtr GetPythonOps(const std::string &op_name,
const std::string &module_name = "mindspore._extends.parse.standard_method",
bool use_signature = false);
// Primitives only used by frontend;
// Type introspection
inline const PrimitivePtr kPrimTypeOf = std::make_shared<Primitive>("typeof");
inline const PrimitivePtr kPrimHasType = std::make_shared<Primitive>("hastype");
inline const PrimitivePtr kPrimResolve = std::make_shared<Primitive>("resolve");
inline const PrimitivePtr kPrimEmbed = std::make_shared<Primitive>("embed");
inline const PrimitivePtr kPrimRefToEmbed = std::make_shared<Primitive>("RefToEmbed");
inline const PrimitivePtr kPrimCreateInstance = std::make_shared<Primitive>("create_instance");
// Other miscellaneous
inline const PrimitivePtr kPrimGetRefOrigin = std::make_shared<Primitive>("get_ref_origin");
inline const PrimitivePtr kPrimInsertGradientOf = std::make_shared<Primitive>("InsertGradientOf");
inline const PrimitivePtr kPrimCheckBprop = std::make_shared<Primitive>("CheckBprop");
inline const PrimitivePtr kPrimMixedPrecisionCast = std::make_shared<Primitive>("mixed_precision_cast");
inline const PrimitivePtr kPrimMakeRecord = std::make_shared<Primitive>("make_record");
// Structures
inline const PrimitivePtr kPrimListMap = std::make_shared<Primitive>("list_map");
inline const PrimitivePtr kPrimListReduce = std::make_shared<Primitive>("list_reduce");
inline const PrimitivePtr kPrimTupleReversed = std::make_shared<Primitive>("tuple_reversed");
inline const PrimitivePtr kPrimReducedShape = std::make_shared<Primitive>("reduced_shape");
inline const PrimitivePtr kPrimTupleDiv = std::make_shared<Primitive>("tuple_div");
inline const PrimitivePtr kPrimTupleToArray = std::make_shared<Primitive>("tuple_to_array");
inline const PrimitivePtr kPrimShapeMul = std::make_shared<Primitive>("shape_mul");
inline const PrimitivePtr kPrimTupleEqual = std::make_shared<Primitive>("tuple_equal");
inline const PrimitivePtr kPrimListEqual = std::make_shared<Primitive>("list_equal");
inline const PrimitivePtr kPrimMakeRange = std::make_shared<Primitive>("make_range");
inline const PrimitivePtr kPrimStopGradient = std::make_shared<Primitive>("stop_gradient");
inline const PrimitivePtr kPrimStringEqual = std::make_shared<Primitive>("string_equal");
inline const PrimitivePtr kPrimStringConcat = std::make_shared<Primitive>("string_concat");
inline const PrimitivePtr kPrimDictLen = std::make_shared<Primitive>("dict_len");
inline const PrimitivePtr kPrimFakeBprop = std::make_shared<Primitive>("fake_bprop");
inline const PrimitivePtr kPrimBroadcastGradientArgs = std::make_shared<Primitive>("BroadcastGradientArgs");
class UnpackGraphPrimitive : public Primitive {
public:
explicit UnpackGraphPrimitive(const std::string &name, const bool &with_sens, const bool &need_unpack_args)

View File

@ -639,55 +639,26 @@ AbstractBasePtr InferImplMakeRecord(const AnalysisEnginePtr &, const PrimitivePt
return std::make_shared<AbstractClass>(cls->tag(), abs_attributes, cls->methods());
}
AbstractBasePtr InferImplAssign(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: Ref, value, [universal]
CheckRequiredArgsSize(primitive->name(), args_spec_list, 2);
MS_LOG(DEBUG) << "InferImplAssign " << args_spec_list[0];
auto type = args_spec_list[0]->BuildType();
if (type->type_id() == kObjectTypeRefKey) {
return args_spec_list[1]->Broaden();
} else {
return args_spec_list[0];
}
}
AbstractBasePtr InferImplLoad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: Ref/Tensor, universal
CheckArgsSize(primitive->name(), args_spec_list, 2);
auto ref_abs = dyn_cast<abstract::AbstractRef>(args_spec_list[0]);
if (ref_abs != nullptr) {
// Return tensor value if input is Ref.
return ref_abs->CloneAsTensor();
}
return args_spec_list[0]->Broaden();
}
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(TypeOf, prim::kPrimTypeOf, InferImplTypeof);
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(HasType, prim::kPrimHasType, InferImplHasType);
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(MakeRecord, prim::kPrimMakeRecord, InferImplMakeRecord);
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(ListMap, prim::kPrimListMap, InferImplListMap);
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(ListReduce, prim::kPrimListReduce, InferImplListReduce);
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(TupleReversed, prim::kPrimTupleReversed, InferImplTupleReversed);
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(ReducedShape, prim::kPrimReducedShape, InferImplReduceShape);
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(TupleDiv, prim::kPrimTupleDiv, InferImplTupleDiv);
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(TupleToArray, prim::kPrimTupleToArray, InferImplTuple2Array);
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(ShapeMul, prim::kPrimShapeMul, InferImplShapeMul);
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(TupleEqual, prim::kPrimTupleEqual, InferImplTupleEqual);
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(ListEqual, prim::kPrimListEqual, InferImplListEqual);
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(MakeRange, prim::kPrimMakeRange, InferImplMakeRange);
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(StopGradient, prim::kPrimStopGradient, InferImplStopGradient);
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(StringEqual, prim::kPrimStringEqual, InferImplStringEqual);
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(StringConcat, prim::kPrimStringConcat, InferImplStringConcat);
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(DictLen, prim::kPrimDictLen, InferImplDictLen);
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(FakeBprop, prim::kPrimFakeBprop, InferImplFakeBprop);
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(J, prim::kPrimJ, InferImplJ);
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(BroadcastGradientArgs, prim::kPrimBroadcastGradientArgs,
InferImplBroadcastGradientArgs);
REGISTER_PRIMITIVE_EVAL_IMPL(Assign, prim::kPrimAssign, InferImplAssign);
REGISTER_PRIMITIVE_EVAL_IMPL(Load, prim::kPrimLoad, InferImplLoad);
REGISTER_PRIMITIVE_EVAL_IMPL(TypeOf, prim::kPrimTypeOf, InferImplTypeof, nullptr, false);
REGISTER_PRIMITIVE_EVAL_IMPL(HasType, prim::kPrimHasType, InferImplHasType, nullptr, false);
REGISTER_PRIMITIVE_EVAL_IMPL(MakeRecord, prim::kPrimMakeRecord, InferImplMakeRecord, nullptr, false);
REGISTER_PRIMITIVE_EVAL_IMPL(ListMap, prim::kPrimListMap, InferImplListMap, nullptr, false);
REGISTER_PRIMITIVE_EVAL_IMPL(ListReduce, prim::kPrimListReduce, InferImplListReduce, nullptr, false);
REGISTER_PRIMITIVE_EVAL_IMPL(TupleReversed, prim::kPrimTupleReversed, InferImplTupleReversed, nullptr, false);
REGISTER_PRIMITIVE_EVAL_IMPL(ReducedShape, prim::kPrimReducedShape, InferImplReduceShape, nullptr, false);
REGISTER_PRIMITIVE_EVAL_IMPL(TupleDiv, prim::kPrimTupleDiv, InferImplTupleDiv, nullptr, false);
REGISTER_PRIMITIVE_EVAL_IMPL(TupleToArray, prim::kPrimTupleToArray, InferImplTuple2Array, nullptr, false);
REGISTER_PRIMITIVE_EVAL_IMPL(ShapeMul, prim::kPrimShapeMul, InferImplShapeMul, nullptr, false);
REGISTER_PRIMITIVE_EVAL_IMPL(TupleEqual, prim::kPrimTupleEqual, InferImplTupleEqual, nullptr, false);
REGISTER_PRIMITIVE_EVAL_IMPL(ListEqual, prim::kPrimListEqual, InferImplListEqual, nullptr, false);
REGISTER_PRIMITIVE_EVAL_IMPL(MakeRange, prim::kPrimMakeRange, InferImplMakeRange, nullptr, false);
REGISTER_PRIMITIVE_EVAL_IMPL(StopGradient, prim::kPrimStopGradient, InferImplStopGradient, nullptr, false);
REGISTER_PRIMITIVE_EVAL_IMPL(StringEqual, prim::kPrimStringEqual, InferImplStringEqual, nullptr, false);
REGISTER_PRIMITIVE_EVAL_IMPL(StringConcat, prim::kPrimStringConcat, InferImplStringConcat, nullptr, false);
REGISTER_PRIMITIVE_EVAL_IMPL(DictLen, prim::kPrimDictLen, InferImplDictLen, nullptr, false);
REGISTER_PRIMITIVE_EVAL_IMPL(FakeBprop, prim::kPrimFakeBprop, InferImplFakeBprop, nullptr, false);
REGISTER_PRIMITIVE_EVAL_IMPL(J, prim::kPrimJ, InferImplJ, nullptr, false);
REGISTER_PRIMITIVE_EVAL_IMPL(BroadcastGradientArgs, prim::kPrimBroadcastGradientArgs, InferImplBroadcastGradientArgs,
nullptr, false);
} // namespace abstract
} // namespace mindspore

View File

@ -59,18 +59,6 @@ AbstractBasePtr InferImplFakeBprop(const AnalysisEnginePtr &, const PrimitivePtr
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplMakeRecord(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
class RegisterFrontendPrimitiveEvalHelper {
public:
RegisterFrontendPrimitiveEvalHelper(const PrimitivePtr &primitive, const StandardPrimitiveEvalImpl &impl) {
const StandardPrimitiveImplReg impl_reg{impl, false};
RegisterStandardPrimitiveImpl(primitive, impl_reg);
}
~RegisterFrontendPrimitiveEvalHelper() = default;
};
#define REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(name, primitive, impl) \
static auto helper_##name = RegisterFrontendPrimitiveEvalHelper(primitive, impl)
} // namespace abstract
} // namespace mindspore

View File

@ -308,6 +308,10 @@ AbstractBasePtr InferImplSparseSoftmaxCrossEntropyWithLogits(const AnalysisEngin
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplDType(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplLoad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplAssign(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
template <typename T>
AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) {
// Inputs: a tuple or list or dict.

View File

@ -577,5 +577,31 @@ AbstractBasePtr InferImplDType(const AnalysisEnginePtr &, const PrimitivePtr &pr
abstract->set_value(value);
return abstract;
}
AbstractBasePtr InferImplLoad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: Ref/Tensor, universal
CheckArgsSize(primitive->name(), args_spec_list, 2);
auto ref_abs = dyn_cast<abstract::AbstractRef>(args_spec_list[0]);
if (ref_abs != nullptr) {
// Return tensor value if input is Ref.
return ref_abs->CloneAsTensor();
}
return args_spec_list[0]->Broaden();
}
AbstractBasePtr InferImplAssign(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: Ref, value, [universal]
CheckRequiredArgsSize(primitive->name(), args_spec_list, 2);
MS_LOG(DEBUG) << "InferImplAssign " << args_spec_list[0];
auto type = args_spec_list[0]->BuildType();
if (type->type_id() == kObjectTypeRefKey) {
return args_spec_list[1]->Broaden();
} else {
return args_spec_list[0];
}
}
} // namespace abstract
} // namespace mindspore

View File

@ -55,158 +55,161 @@ std::vector<int64_t> GetDependsFormMap(const CNodePtr &cnode) {
PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
static PrimitiveEvalImplMap prim_eval_implement_map = {
// Statements
{prim::kPrimReturn, {InferImplReturn, true}},
{prim::kPrimSwitch, {InferImplSwitch, true}},
{prim::kPrimSwitchLayer, {InferImplSwitchLayer, true}},
{prim::kPrimIs_, {InferImplIs_, true}},
{prim::kPrimIsNot, {InferImplIsNot, true}},
{prim::kPrimInDict, {InferImplInDict, true}},
{prim::kPrimNotInDict, {InferImplNotInDict, true}},
{prim::kPrimIsConsant, {InferImplIsConstant, true}},
{prim::kPrimReturn, {InferImplReturn, nullptr, true}},
{prim::kPrimSwitch, {InferImplSwitch, nullptr, true}},
{prim::kPrimSwitchLayer, {InferImplSwitchLayer, nullptr, true}},
{prim::kPrimIs_, {InferImplIs_, nullptr, true}},
{prim::kPrimIsNot, {InferImplIsNot, nullptr, true}},
{prim::kPrimInDict, {InferImplInDict, nullptr, true}},
{prim::kPrimNotInDict, {InferImplNotInDict, nullptr, true}},
{prim::kPrimIsConsant, {InferImplIsConstant, nullptr, true}},
// Maths
{prim::kPrimSquare, {InferImplSquare, true}},
{prim::kPrimMatMul, {InferImplMatMul, true}},
{prim::kPrimBatchMatMul, {InferImplBatchMatMul, true}},
{prim::kPrimMaximumGrad, {InferImplMinOrMaxGrad, true}},
{prim::kPrimMinimumGrad, {InferImplMinOrMaxGrad, true}},
{prim::kPrimSqrt, {InferImplSqrt, true}},
{prim::kPrimSquare, {InferImplSquare, nullptr, true}},
{prim::kPrimMatMul, {InferImplMatMul, nullptr, true}},
{prim::kPrimBatchMatMul, {InferImplBatchMatMul, nullptr, true}},
{prim::kPrimMaximumGrad, {InferImplMinOrMaxGrad, nullptr, true}},
{prim::kPrimMinimumGrad, {InferImplMinOrMaxGrad, nullptr, true}},
{prim::kPrimSqrt, {InferImplSqrt, nullptr, true}},
// Array
{prim::kPrimRange, {InferImplRange, true}},
{prim::kPrimScalarToArray, {InferImplScalarToArray, true}},
{prim::kPrimArrayToScalar, {InferImplArrayToScalar, true}},
{prim::kPrimBroadcastShape, {InferImplBroadCastShape, true}},
{prim::kPrimUnique, {InferImplUnique, true}},
{prim::kPrimUniqueGrad, {InferImplUniqueGrad, true}},
{prim::kPrimGather, {InferImplGatherV2, true}},
{prim::kPrimEmbeddingLookup, {InferImplEmbeddingLookup, true}},
{prim::kPrimSparseGatherV2, {InferImplGatherV2, true}},
{prim::kPrimUnsortedSegmentMax, {InferImplUnsortedSegmentMax, true}},
{prim::kPrimUnsortedSegmentMin, {InferImplUnsortedSegmentMin, true}},
{prim::kPrimScatterAdd, {InferImplScatterAdd, true}},
{prim::kPrimSubAndFilter, {InferImplSubAndFilter, true}},
{prim::kPrimScatterUpdate, {InferImplScatterUpdate, true}},
{prim::kPrimMapCacheIdx, {InferImplMapCacheIdx, true}},
{prim::kPrimDynamicAssign, {InferImplDynamicAssign, true}},
{prim::kPrimCacheSwapTable, {InferImplCacheSwapTable, true}},
{prim::kPrimUpdateCache, {InferImplUpdateCache, true}},
{prim::kPrimComputeAccidentalHits, {InferImplComputeAccidentalHits, true}},
{prim::kPrimPadAndShift, {InferImplPadAndShift, true}},
{prim::kPrimDynamicShape, {InferImplDynamicShape, true}},
{prim::kPrimMapUniform, {InferImplMapUniform, true}},
{prim::kPrimSplit, {InferImplSplit, true}},
{prim::kPrimSequenceMask, {InferImplSequenceMask, true}},
{prim::kPrimRange, {InferImplRange, nullptr, true}},
{prim::kPrimScalarToArray, {InferImplScalarToArray, nullptr, true}},
{prim::kPrimArrayToScalar, {InferImplArrayToScalar, nullptr, true}},
{prim::kPrimBroadcastShape, {InferImplBroadCastShape, nullptr, true}},
{prim::kPrimUnique, {InferImplUnique, nullptr, true}},
{prim::kPrimUniqueGrad, {InferImplUniqueGrad, nullptr, true}},
{prim::kPrimGather, {InferImplGatherV2, nullptr, true}},
{prim::kPrimEmbeddingLookup, {InferImplEmbeddingLookup, nullptr, true}},
{prim::kPrimSparseGatherV2, {InferImplGatherV2, nullptr, true}},
{prim::kPrimUnsortedSegmentMax, {InferImplUnsortedSegmentMax, nullptr, true}},
{prim::kPrimUnsortedSegmentMin, {InferImplUnsortedSegmentMin, nullptr, true}},
{prim::kPrimScatterAdd, {InferImplScatterAdd, nullptr, true}},
{prim::kPrimSubAndFilter, {InferImplSubAndFilter, nullptr, true}},
{prim::kPrimScatterUpdate, {InferImplScatterUpdate, nullptr, true}},
{prim::kPrimMapCacheIdx, {InferImplMapCacheIdx, nullptr, true}},
{prim::kPrimDynamicAssign, {InferImplDynamicAssign, nullptr, true}},
{prim::kPrimCacheSwapTable, {InferImplCacheSwapTable, nullptr, true}},
{prim::kPrimUpdateCache, {InferImplUpdateCache, nullptr, true}},
{prim::kPrimComputeAccidentalHits, {InferImplComputeAccidentalHits, nullptr, true}},
{prim::kPrimPadAndShift, {InferImplPadAndShift, nullptr, true}},
{prim::kPrimDynamicShape, {InferImplDynamicShape, nullptr, true}},
{prim::kPrimMapUniform, {InferImplMapUniform, nullptr, true}},
{prim::kPrimSplit, {InferImplSplit, nullptr, true}},
{prim::kPrimSequenceMask, {InferImplSequenceMask, nullptr, true}},
// Structure
{prim::kPrimMakeTuple, {InferImplMakeTuple, true}},
{prim::kPrimMakeList, {InferImplMakeList, true}},
{prim::kPrimMakeDict, {InferImplMakeDict, true}},
{prim::kPrimMakeSlice, {InferImplMakeSlice, true}},
{prim::kPrimMakeKeywordArg, {InferImplMakeKwarg, true}},
{prim::kPrimExtractKeywordArg, {InferImplExtractKwarg, true}},
{prim::kPrimTupleGetItem, {InferImplTupleGetItem, true}},
{prim::kPrimListGetItem, {InferImplListGetItem, true}},
{prim::kPrimTupleSetItem, {InferImplTupleSetItem, true}},
{prim::kPrimListSetItem, {InferImplListSetItem, true}},
{prim::kPrimDictGetItem, {InferImplDictGetItem, true}},
{prim::kPrimDictSetItem, {InferImplDictSetItem, true}},
{prim::kPrimDictGetKeys, {InferImplDictGetKeys, true}},
{prim::kPrimDictGetValues, {InferImplDictGetValues, true}},
{prim::kPrimListAppend, {InferImplListAppend, true}},
{prim::kPrimTupleLen, {InferImplTupleLen, true}},
{prim::kPrimListLen, {InferImplListLen, true}},
{prim::kPrimArrayLen, {InferImplArrayLen, true}},
{prim::kPrimMakeTuple, {InferImplMakeTuple, nullptr, true}},
{prim::kPrimMakeList, {InferImplMakeList, nullptr, true}},
{prim::kPrimMakeDict, {InferImplMakeDict, nullptr, true}},
{prim::kPrimMakeSlice, {InferImplMakeSlice, nullptr, true}},
{prim::kPrimMakeKeywordArg, {InferImplMakeKwarg, nullptr, true}},
{prim::kPrimExtractKeywordArg, {InferImplExtractKwarg, nullptr, true}},
{prim::kPrimTupleGetItem, {InferImplTupleGetItem, nullptr, true}},
{prim::kPrimListGetItem, {InferImplListGetItem, nullptr, true}},
{prim::kPrimTupleSetItem, {InferImplTupleSetItem, nullptr, true}},
{prim::kPrimListSetItem, {InferImplListSetItem, nullptr, true}},
{prim::kPrimDictGetItem, {InferImplDictGetItem, nullptr, true}},
{prim::kPrimDictSetItem, {InferImplDictSetItem, nullptr, true}},
{prim::kPrimDictGetKeys, {InferImplDictGetKeys, nullptr, true}},
{prim::kPrimDictGetValues, {InferImplDictGetValues, nullptr, true}},
{prim::kPrimListAppend, {InferImplListAppend, nullptr, true}},
{prim::kPrimTupleLen, {InferImplTupleLen, nullptr, true}},
{prim::kPrimListLen, {InferImplListLen, nullptr, true}},
{prim::kPrimArrayLen, {InferImplArrayLen, nullptr, true}},
// NN
{prim::kPrimPooling, {InferImplPooling, true}},
{prim::kPrimPoolingGrad, {InferImplPoolingGrad, true}},
{prim::kPrimBatchNorm, {InferImplBatchNorm, true}},
{prim::kPrimReluGrad, {InferImplReluGrad, true}},
{prim::kPrimConv2D, {InferImplConv2D, true}},
{prim::kPrimBiasAdd, {InferImplBiasAdd, true}},
{prim::kPrimRelu, {InferImplRelu, true}},
{prim::kPrimRelu6, {InferImplRelu, true}},
{prim::kPrimZerosLike, {InferImplZerosLike, true}},
{prim::kPrimBpropCut, {InferImplBpropCut, true}},
{prim::kPrimLayerNorm, {InferImplLayerNorm, true}},
{prim::kPrimLayerNormGrad, {InferImplLayerNormGrad, true}},
{prim::kPrimDropout, {InferImplDropout, true}},
{prim::kPrimDropoutGenMask, {InferImplDropoutGenMask, true}},
{prim::kPrimSparseApplyFtrl, {InferImplSparseApplyFtrl, true}},
{prim::kPrimSparseApplyProximalAdagrad, {InferImplSparseApplyProximalAdagrad, true}},
{prim::kPrimSGD, {InferImplSGD, true}},
{prim::kPrimCTCGreedyDecoder, {InferImplCTCGreedyDecoder, true}},
{prim::kPrimPooling, {InferImplPooling, nullptr, true}},
{prim::kPrimPoolingGrad, {InferImplPoolingGrad, nullptr, true}},
{prim::kPrimBatchNorm, {InferImplBatchNorm, nullptr, true}},
{prim::kPrimReluGrad, {InferImplReluGrad, nullptr, true}},
{prim::kPrimConv2D, {InferImplConv2D, nullptr, true}},
{prim::kPrimBiasAdd, {InferImplBiasAdd, nullptr, true}},
{prim::kPrimRelu, {InferImplRelu, nullptr, true}},
{prim::kPrimRelu6, {InferImplRelu, nullptr, true}},
{prim::kPrimZerosLike, {InferImplZerosLike, nullptr, true}},
{prim::kPrimBpropCut, {InferImplBpropCut, nullptr, true}},
{prim::kPrimLayerNorm, {InferImplLayerNorm, nullptr, true}},
{prim::kPrimLayerNormGrad, {InferImplLayerNormGrad, nullptr, true}},
{prim::kPrimDropout, {InferImplDropout, nullptr, true}},
{prim::kPrimDropoutGenMask, {InferImplDropoutGenMask, nullptr, true}},
{prim::kPrimSparseApplyFtrl, {InferImplSparseApplyFtrl, nullptr, true}},
{prim::kPrimSparseApplyProximalAdagrad, {InferImplSparseApplyProximalAdagrad, nullptr, true}},
{prim::kPrimSGD, {InferImplSGD, nullptr, true}},
{prim::kPrimCTCGreedyDecoder, {InferImplCTCGreedyDecoder, nullptr, true}},
// Others
{prim::kPrimIdentity, {InferImplIdentity, true}},
{prim::kPrimIdentity, {InferImplIdentity, nullptr, true}},
{prim::kPrimLoad, {InferImplLoad, nullptr, true}},
{prim::kPrimAssign, {InferImplAssign, nullptr, true}},
// Set impl to null as it will use PartialEvaluator;
{prim::kPrimPartial, {nullptr, true}},
{prim::kPrimEnvGetItem, {InferImplEnvGetItem, true}},
{prim::kPrimEnvSetItem, {InferImplEnvSetItem, true}},
{prim::kPrimEnvAdd, {InferImplEnvAdd, true}},
{prim::kPrimMakeRefKey, {InferImplMakeRefKey, true}},
{prim::kPrimMakeRef, {InferImplMakeRef, true}},
{prim::kPrimGetRefKey, {InferImplGetRefKey, true}},
{prim::kPrimGetRefValue, {InferImplGetRefValue, true}},
{prim::kPrimStateSetItem, {InferImplStateSetItem, true}},
{prim::kPrimDepend, {InferImplDepend, true}},
{prim::kPrimUpdateState, {InferImplUpdateState, true}},
{prim::kPrimControlDepend, {InferImplControlDepend, true}},
{prim::kPrimPartial, {nullptr, nullptr, true}},
{prim::kPrimEnvGetItem, {InferImplEnvGetItem, nullptr, true}},
{prim::kPrimEnvSetItem, {InferImplEnvSetItem, nullptr, true}},
{prim::kPrimEnvAdd, {InferImplEnvAdd, nullptr, true}},
{prim::kPrimMakeRefKey, {InferImplMakeRefKey, nullptr, true}},
{prim::kPrimMakeRef, {InferImplMakeRef, nullptr, true}},
{prim::kPrimGetRefKey, {InferImplGetRefKey, nullptr, true}},
{prim::kPrimGetRefValue, {InferImplGetRefValue, nullptr, true}},
{prim::kPrimStateSetItem, {InferImplStateSetItem, nullptr, true}},
{prim::kPrimDepend, {InferImplDepend, nullptr, true}},
{prim::kPrimUpdateState, {InferImplUpdateState, nullptr, true}},
{prim::kPrimControlDepend, {InferImplControlDepend, nullptr, true}},
// Debug
{prim::kPrimDebug, {InferImplDebug, true}},
{prim::kPrimDebug, {InferImplDebug, nullptr, true}},
// Dynamic shape testing
{prim::kPrimGpuConvertToDynamicShape, {InferImplGpuConvertToDynamicShape, true}},
{prim::kPrimGpuConvertToDynamicShape, {InferImplGpuConvertToDynamicShape, nullptr, true}},
// SparseTensor
{prim::kPrimMakeSparseTensor, {InferImplMakeSparseTensor, true}},
{prim::kPrimSparseTensorGetValues, {InferImplSparseTensorGetValues, true}},
{prim::kPrimSparseTensorGetIndices, {InferImplSparseTensorGetIndices, true}},
{prim::kPrimSparseTensorGetDenseShape, {InferImplSparseTensorGetDenseShape, true}},
{prim::kPrimMakeSparseTensor, {InferImplMakeSparseTensor, nullptr, true}},
{prim::kPrimSparseTensorGetValues, {InferImplSparseTensorGetValues, nullptr, true}},
{prim::kPrimSparseTensorGetIndices, {InferImplSparseTensorGetIndices, nullptr, true}},
{prim::kPrimSparseTensorGetDenseShape, {InferImplSparseTensorGetDenseShape, nullptr, true}},
// RowTensor
{prim::kPrimMakeRowTensor, {InferImplMakeRowTensor, true}},
{prim::kPrimRowTensorGetValues, {InferImplRowTensorGetValues, true}},
{prim::kPrimRowTensorGetIndices, {InferImplRowTensorGetIndices, true}},
{prim::kPrimRowTensorGetDenseShape, {InferImplRowTensorGetDenseShape, true}},
{prim::kPrimRowTensorAdd, {InferImplRowTensorAdd, false}},
{prim::kPrimMakeRowTensor, {InferImplMakeRowTensor, nullptr, true}},
{prim::kPrimRowTensorGetValues, {InferImplRowTensorGetValues, nullptr, true}},
{prim::kPrimRowTensorGetIndices, {InferImplRowTensorGetIndices, nullptr, true}},
{prim::kPrimRowTensorGetDenseShape, {InferImplRowTensorGetDenseShape, nullptr, true}},
{prim::kPrimRowTensorAdd, {InferImplRowTensorAdd, nullptr, false}},
// Comm Ops
{prim::kPrimAllSwap, {InferImplAllSwap, true}},
{prim::kPrimMemCpyAsync, {InferImplMemCpyAsync, true}},
{prim::kPrimAllSwap, {InferImplAllSwap, nullptr, true}},
{prim::kPrimMemCpyAsync, {InferImplMemCpyAsync, nullptr, true}},
};
return prim_eval_implement_map;
}
PrimitiveEvalImplMap &GetPrimitiveToBackendEvalImplMap() {
static PrimitiveEvalImplMap prim_backend_eval_implement_map = {
{prim::kPrimMul, {InferImplMul, true}},
{prim::kPrimAdd, {InferImplAdd, true}},
{prim::kPrimSqrtGrad, {InferImplSqrtGrad, true}},
{prim::kPrimSub, {InferImplSub, true}},
{prim::kPrimEqual, {InferImplEqual, true}},
{prim::kPrimReduceSum, {InferImplReduceFunc, true}},
{prim::kPrimReduceMean, {InferImplReduceFunc, true}},
{prim::kPrimReduceAll, {InferImplReduceFunc, true}},
{prim::kPrimReduceAny, {InferImplReduceFunc, true}},
{prim::kPrimReduceMax, {InferImplReduceFunc, true}},
{prim::kPrimReduceMin, {InferImplReduceFunc, true}},
{prim::kPrimBiasAddGrad, {InferImplBiasAddGrad, true}},
{prim::kPrimReduceScatter, {InferImplReduceScatter, true}},
{prim::kPrimCast, {InferImplCast, true}},
{prim::kPrimExpandDims, {InferImplExpandDims, true}},
{prim::kPrimAllReduce, {InferImplAllReduce, true}},
{prim::kPrimBroadcast, {InferImplBroadcast, true}},
{prim::kPrimAllGather, {InferImplAllGather, true}},
{prim::kPrimMinimum, {InferImplMinimum, true}},
{prim::kPrimDivNoNan, {InferImplDivNoNan, true}},
{prim::kPrimLinSpace, {InferImplLinSpace, true}},
{prim::kPrimAddN, {InferImplAddN, true}},
{prim::kPrimMul, {InferImplMul, nullptr, true}},
{prim::kPrimAdd, {InferImplAdd, nullptr, true}},
{prim::kPrimSqrtGrad, {InferImplSqrtGrad, nullptr, true}},
{prim::kPrimSub, {InferImplSub, nullptr, true}},
{prim::kPrimEqual, {InferImplEqual, nullptr, true}},
{prim::kPrimReduceSum, {InferImplReduceFunc, nullptr, true}},
{prim::kPrimReduceMean, {InferImplReduceFunc, nullptr, true}},
{prim::kPrimReduceAll, {InferImplReduceFunc, nullptr, true}},
{prim::kPrimReduceAny, {InferImplReduceFunc, nullptr, true}},
{prim::kPrimReduceMax, {InferImplReduceFunc, nullptr, true}},
{prim::kPrimReduceMin, {InferImplReduceFunc, nullptr, true}},
{prim::kPrimBiasAddGrad, {InferImplBiasAddGrad, nullptr, true}},
{prim::kPrimReduceScatter, {InferImplReduceScatter, nullptr, true}},
{prim::kPrimCast, {InferImplCast, nullptr, true}},
{prim::kPrimExpandDims, {InferImplExpandDims, nullptr, true}},
{prim::kPrimAllReduce, {InferImplAllReduce, nullptr, true}},
{prim::kPrimBroadcast, {InferImplBroadcast, nullptr, true}},
{prim::kPrimAllGather, {InferImplAllGather, nullptr, true}},
{prim::kPrimMinimum, {InferImplMinimum, nullptr, true}},
{prim::kPrimDivNoNan, {InferImplDivNoNan, nullptr, true}},
{prim::kPrimLinSpace, {InferImplLinSpace, nullptr, true}},
{prim::kPrimAddN, {InferImplAddN, nullptr, true}},
{prim::kPrimLess, {InferImplLess, true}},
{prim::kPrimStack, {InferImplStack, true}},
{prim::kPrimPad, {InferImplPad, true}},
{prim::kPrimUnsortedSegmentSum, {InferImplUnsortedSegmentSum, true}},
{prim::kPrimDiv, {InferImplDiv, true}},
{prim::kPrimRealDiv, {InferImplRealDiv, true}},
{prim::kPrimShape, {InferImplShape, false}},
{prim::kPrimTranspose, {InferImplTranspose, true}},
{prim::kPrimReshape, {InferImplReshape, true}},
{prim::kPrimConcat, {InferImplConcat, true}},
{prim::kPrimArgMaxWithValue, {InferImplArgMaxWithValue, true}},
{prim::kPrimFusedSparseAdam, {InferImplFusedSparseAdam, true}},
{prim::kPrimLess, {InferImplLess, nullptr, true}},
{prim::kPrimStack, {InferImplStack, nullptr, true}},
{prim::kPrimPad, {InferImplPad, nullptr, true}},
{prim::kPrimUnsortedSegmentSum, {InferImplUnsortedSegmentSum, nullptr, true}},
{prim::kPrimDiv, {InferImplDiv, nullptr, true}},
{prim::kPrimRealDiv, {InferImplRealDiv, nullptr, true}},
{prim::kPrimShape, {InferImplShape, nullptr, false}},
{prim::kPrimTranspose, {InferImplTranspose, nullptr, true}},
{prim::kPrimReshape, {InferImplReshape, nullptr, true}},
{prim::kPrimConcat, {InferImplConcat, nullptr, true}},
{prim::kPrimArgMaxWithValue, {InferImplArgMaxWithValue, nullptr, true}},
{prim::kPrimFusedSparseAdam, {InferImplFusedSparseAdam, nullptr, true}},
};
return prim_backend_eval_implement_map;
}

View File

@ -28,9 +28,13 @@ namespace mindspore {
namespace abstract {
using StandardPrimitiveEvalImpl = AbstractBasePtr (*)(const abstract::AnalysisEnginePtr &, const PrimitivePtr &,
const AbstractBasePtrList &);
using InferValueEvalImpl = ValuePtr (*)(const PrimitivePtr &, const AbstractBasePtrList &, const AbstractBasePtr &);
struct StandardPrimitiveImplReg {
StandardPrimitiveEvalImpl impl_; // Implement function of Primitive.
bool in_white_list_; // true if this Primitive in white list, else false.
StandardPrimitiveEvalImpl impl_; // Implement function of Primitive
InferValueEvalImpl infer_value_func_; // infer value of primitive
// true means this primitive can be executed by vm backend else will be constant folded by frontend
bool in_white_list_;
};
using PrimitiveEvalImplMap =
@ -48,15 +52,17 @@ void RegisterStandardPrimitiveImpl(const PrimitivePtr &primitive, const Standard
class RegisterStandardPrimitiveEvalHelper {
public:
RegisterStandardPrimitiveEvalHelper(const PrimitivePtr &primitive, const StandardPrimitiveEvalImpl &impl) {
const StandardPrimitiveImplReg impl_reg{impl, true};
RegisterStandardPrimitiveEvalHelper(const PrimitivePtr &primitive, const StandardPrimitiveEvalImpl &impl,
const InferValueEvalImpl &infer_value_impl, const bool is_wight_list = true) {
const StandardPrimitiveImplReg impl_reg{impl, infer_value_impl, is_wight_list};
RegisterStandardPrimitiveImpl(primitive, impl_reg);
}
~RegisterStandardPrimitiveEvalHelper() = default;
};
#define REGISTER_PRIMITIVE_EVAL_IMPL(name, primitive, impl) \
static auto helper_##name = abstract::RegisterStandardPrimitiveEvalHelper(primitive, impl)
#define REGISTER_PRIMITIVE_EVAL_IMPL(name, primitive, impl, infer_value_impl, is_wight_list) \
static auto helper_##name = \
abstract::RegisterStandardPrimitiveEvalHelper(primitive, impl, infer_value_impl, is_wight_list)
} // namespace abstract
} // namespace mindspore
#endif // MINDSPORE_CORE_ABSTRACT_PRIMITIVE_INFER_MAP_H_

View File

@ -539,6 +539,40 @@ inline const PrimitivePtr kPrimReduceFusion = std::make_shared<Primitive>("Reduc
inline const PrimitivePtr kPrimLayerNormFusion = std::make_shared<Primitive>("LayerNormFusion");
inline const PrimitivePtr kPrimDType = std::make_shared<Primitive>("DType");
// Type introspection
inline const PrimitivePtr kPrimTypeOf = std::make_shared<Primitive>("typeof");
inline const PrimitivePtr kPrimHasType = std::make_shared<Primitive>("hastype");
inline const PrimitivePtr kPrimResolve = std::make_shared<Primitive>("resolve");
inline const PrimitivePtr kPrimEmbed = std::make_shared<Primitive>("embed");
inline const PrimitivePtr kPrimRefToEmbed = std::make_shared<Primitive>("RefToEmbed");
inline const PrimitivePtr kPrimCreateInstance = std::make_shared<Primitive>("create_instance");
// Other miscellaneous
inline const PrimitivePtr kPrimGetRefOrigin = std::make_shared<Primitive>("get_ref_origin");
inline const PrimitivePtr kPrimInsertGradientOf = std::make_shared<Primitive>("InsertGradientOf");
inline const PrimitivePtr kPrimCheckBprop = std::make_shared<Primitive>("CheckBprop");
inline const PrimitivePtr kPrimMixedPrecisionCast = std::make_shared<Primitive>("mixed_precision_cast");
inline const PrimitivePtr kPrimMakeRecord = std::make_shared<Primitive>("make_record");
// Structures
inline const PrimitivePtr kPrimListMap = std::make_shared<Primitive>("list_map");
inline const PrimitivePtr kPrimListReduce = std::make_shared<Primitive>("list_reduce");
inline const PrimitivePtr kPrimTupleReversed = std::make_shared<Primitive>("tuple_reversed");
inline const PrimitivePtr kPrimReducedShape = std::make_shared<Primitive>("reduced_shape");
inline const PrimitivePtr kPrimTupleDiv = std::make_shared<Primitive>("tuple_div");
inline const PrimitivePtr kPrimTupleToArray = std::make_shared<Primitive>("tuple_to_array");
inline const PrimitivePtr kPrimShapeMul = std::make_shared<Primitive>("shape_mul");
inline const PrimitivePtr kPrimTupleEqual = std::make_shared<Primitive>("tuple_equal");
inline const PrimitivePtr kPrimListEqual = std::make_shared<Primitive>("list_equal");
inline const PrimitivePtr kPrimMakeRange = std::make_shared<Primitive>("make_range");
inline const PrimitivePtr kPrimStopGradient = std::make_shared<Primitive>("stop_gradient");
inline const PrimitivePtr kPrimStringEqual = std::make_shared<Primitive>("string_equal");
inline const PrimitivePtr kPrimStringConcat = std::make_shared<Primitive>("string_concat");
inline const PrimitivePtr kPrimDictLen = std::make_shared<Primitive>("dict_len");
inline const PrimitivePtr kPrimFakeBprop = std::make_shared<Primitive>("fake_bprop");
inline const PrimitivePtr kPrimBroadcastGradientArgs = std::make_shared<Primitive>("BroadcastGradientArgs");
class DoSignaturePrimitive : public Primitive {
public:
explicit DoSignaturePrimitive(const std::string &name, const ValuePtr &function)

View File

@ -49,6 +49,5 @@ AbstractBasePtr AddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
InferShape(primitive, input_args)->shape());
}
REGISTER_PRIMITIVE_EVAL_IMPL(Add, prim::kPrimAdd, AddInfer);
REGISTER_PRIMITIVE_C(kNameAdd, Add);
} // namespace mindspore

View File

@ -42,7 +42,7 @@ AbstractBasePtr ScalarSummaryInfer(const abstract::AnalysisEnginePtr &, const Pr
CheckAndConvertUtils::CheckInteger("v rank", v_shape.size(), kLessEqual, 1, prim_name);
return std::make_shared<abstract::AbstractTensor>(kInt32, std::make_shared<abstract::Shape>(ShapeVector(1)));
}
REGISTER_PRIMITIVE_EVAL_IMPL(ScalarSummary, prim::kPrimScalarSummary, ScalarSummaryInfer);
REGISTER_PRIMITIVE_EVAL_IMPL(ScalarSummary, prim::kPrimScalarSummary, ScalarSummaryInfer, nullptr, true);
REGISTER_PRIMITIVE_C(kNameScalarSummary, ScalarSummary);
} // namespace ops
} // namespace mindspore

View File

@ -42,7 +42,7 @@ AbstractBasePtr TensorSummaryInfer(const abstract::AnalysisEnginePtr &, const Pr
CheckAndConvertUtils::CheckInteger("v rank", v_shape.size(), kGreaterEqual, 1, prim_name);
return std::make_shared<abstract::AbstractTensor>(kInt32, std::make_shared<abstract::Shape>(ShapeVector(1)));
}
REGISTER_PRIMITIVE_EVAL_IMPL(TensorSummary, prim::kPrimTensorSummary, TensorSummaryInfer);
REGISTER_PRIMITIVE_EVAL_IMPL(TensorSummary, prim::kPrimTensorSummary, TensorSummaryInfer, nullptr, true);
REGISTER_PRIMITIVE_C(kNameTensorSummary, TensorSummary);
} // namespace ops
} // namespace mindspore

View File

@ -36,7 +36,7 @@ AbstractBasePtr InferImplAttrTest(const abstract::AnalysisEnginePtr &, const Pri
EXPECT_EQ(args_spec_list[1]->isa<abstract::AbstractTuple>(), true);
return args_spec_list[0];
}
REGISTER_PRIMITIVE_EVAL_IMPL(TestAttr,kPrimAttrConvertTest,InferImplAttrTest);
REGISTER_PRIMITIVE_EVAL_IMPL(TestAttr, kPrimAttrConvertTest, InferImplAttrTest, nullptr, true);
AbstractBasePtr InferImplDynamicInputTest(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
EXPECT_EQ(args_spec_list.size(), 3);
@ -45,7 +45,7 @@ AbstractBasePtr InferImplDynamicInputTest(const abstract::AnalysisEnginePtr &, c
auto item = args_spec_list[1]->cast<abstract::AbstractTuplePtr>();
return args_spec_list[0];
}
REGISTER_PRIMITIVE_EVAL_IMPL(TestDynamicInput,kPrimDynamicInputTest,InferImplDynamicInputTest);
REGISTER_PRIMITIVE_EVAL_IMPL(TestDynamicInput, kPrimDynamicInputTest, InferImplDynamicInputTest, nullptr, true);
class TestAttrAndDynamicBackendInfer : public UT::Common {
public:
TestAttrAndDynamicBackendInfer() {}