forked from mindspore-Ecosystem/mindspore
change macro of reg infer
This commit is contained in:
parent
609b8edb9f
commit
8cabd090de
|
@ -30,45 +30,6 @@ namespace prim {
|
|||
ValuePtr GetPythonOps(const std::string &op_name,
|
||||
const std::string &module_name = "mindspore._extends.parse.standard_method",
|
||||
bool use_signature = false);
|
||||
|
||||
// Primitives only used by frontend;
|
||||
// Type introspection
|
||||
inline const PrimitivePtr kPrimTypeOf = std::make_shared<Primitive>("typeof");
|
||||
inline const PrimitivePtr kPrimHasType = std::make_shared<Primitive>("hastype");
|
||||
|
||||
inline const PrimitivePtr kPrimResolve = std::make_shared<Primitive>("resolve");
|
||||
inline const PrimitivePtr kPrimEmbed = std::make_shared<Primitive>("embed");
|
||||
inline const PrimitivePtr kPrimRefToEmbed = std::make_shared<Primitive>("RefToEmbed");
|
||||
inline const PrimitivePtr kPrimCreateInstance = std::make_shared<Primitive>("create_instance");
|
||||
|
||||
// Other miscellaneous
|
||||
inline const PrimitivePtr kPrimGetRefOrigin = std::make_shared<Primitive>("get_ref_origin");
|
||||
inline const PrimitivePtr kPrimInsertGradientOf = std::make_shared<Primitive>("InsertGradientOf");
|
||||
inline const PrimitivePtr kPrimCheckBprop = std::make_shared<Primitive>("CheckBprop");
|
||||
inline const PrimitivePtr kPrimMixedPrecisionCast = std::make_shared<Primitive>("mixed_precision_cast");
|
||||
inline const PrimitivePtr kPrimMakeRecord = std::make_shared<Primitive>("make_record");
|
||||
|
||||
// Structures
|
||||
|
||||
inline const PrimitivePtr kPrimListMap = std::make_shared<Primitive>("list_map");
|
||||
inline const PrimitivePtr kPrimListReduce = std::make_shared<Primitive>("list_reduce");
|
||||
inline const PrimitivePtr kPrimTupleReversed = std::make_shared<Primitive>("tuple_reversed");
|
||||
inline const PrimitivePtr kPrimReducedShape = std::make_shared<Primitive>("reduced_shape");
|
||||
inline const PrimitivePtr kPrimTupleDiv = std::make_shared<Primitive>("tuple_div");
|
||||
inline const PrimitivePtr kPrimTupleToArray = std::make_shared<Primitive>("tuple_to_array");
|
||||
inline const PrimitivePtr kPrimShapeMul = std::make_shared<Primitive>("shape_mul");
|
||||
inline const PrimitivePtr kPrimTupleEqual = std::make_shared<Primitive>("tuple_equal");
|
||||
inline const PrimitivePtr kPrimListEqual = std::make_shared<Primitive>("list_equal");
|
||||
inline const PrimitivePtr kPrimMakeRange = std::make_shared<Primitive>("make_range");
|
||||
inline const PrimitivePtr kPrimStopGradient = std::make_shared<Primitive>("stop_gradient");
|
||||
inline const PrimitivePtr kPrimStringEqual = std::make_shared<Primitive>("string_equal");
|
||||
inline const PrimitivePtr kPrimStringConcat = std::make_shared<Primitive>("string_concat");
|
||||
inline const PrimitivePtr kPrimDictLen = std::make_shared<Primitive>("dict_len");
|
||||
|
||||
inline const PrimitivePtr kPrimFakeBprop = std::make_shared<Primitive>("fake_bprop");
|
||||
|
||||
inline const PrimitivePtr kPrimBroadcastGradientArgs = std::make_shared<Primitive>("BroadcastGradientArgs");
|
||||
|
||||
class UnpackGraphPrimitive : public Primitive {
|
||||
public:
|
||||
explicit UnpackGraphPrimitive(const std::string &name, const bool &with_sens, const bool &need_unpack_args)
|
||||
|
|
|
@ -639,55 +639,26 @@ AbstractBasePtr InferImplMakeRecord(const AnalysisEnginePtr &, const PrimitivePt
|
|||
|
||||
return std::make_shared<AbstractClass>(cls->tag(), abs_attributes, cls->methods());
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplAssign(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: Ref, value, [universal]
|
||||
CheckRequiredArgsSize(primitive->name(), args_spec_list, 2);
|
||||
|
||||
MS_LOG(DEBUG) << "InferImplAssign " << args_spec_list[0];
|
||||
auto type = args_spec_list[0]->BuildType();
|
||||
if (type->type_id() == kObjectTypeRefKey) {
|
||||
return args_spec_list[1]->Broaden();
|
||||
} else {
|
||||
return args_spec_list[0];
|
||||
}
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplLoad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: Ref/Tensor, universal
|
||||
CheckArgsSize(primitive->name(), args_spec_list, 2);
|
||||
auto ref_abs = dyn_cast<abstract::AbstractRef>(args_spec_list[0]);
|
||||
if (ref_abs != nullptr) {
|
||||
// Return tensor value if input is Ref.
|
||||
return ref_abs->CloneAsTensor();
|
||||
}
|
||||
return args_spec_list[0]->Broaden();
|
||||
}
|
||||
|
||||
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(TypeOf, prim::kPrimTypeOf, InferImplTypeof);
|
||||
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(HasType, prim::kPrimHasType, InferImplHasType);
|
||||
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(MakeRecord, prim::kPrimMakeRecord, InferImplMakeRecord);
|
||||
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(ListMap, prim::kPrimListMap, InferImplListMap);
|
||||
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(ListReduce, prim::kPrimListReduce, InferImplListReduce);
|
||||
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(TupleReversed, prim::kPrimTupleReversed, InferImplTupleReversed);
|
||||
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(ReducedShape, prim::kPrimReducedShape, InferImplReduceShape);
|
||||
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(TupleDiv, prim::kPrimTupleDiv, InferImplTupleDiv);
|
||||
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(TupleToArray, prim::kPrimTupleToArray, InferImplTuple2Array);
|
||||
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(ShapeMul, prim::kPrimShapeMul, InferImplShapeMul);
|
||||
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(TupleEqual, prim::kPrimTupleEqual, InferImplTupleEqual);
|
||||
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(ListEqual, prim::kPrimListEqual, InferImplListEqual);
|
||||
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(MakeRange, prim::kPrimMakeRange, InferImplMakeRange);
|
||||
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(StopGradient, prim::kPrimStopGradient, InferImplStopGradient);
|
||||
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(StringEqual, prim::kPrimStringEqual, InferImplStringEqual);
|
||||
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(StringConcat, prim::kPrimStringConcat, InferImplStringConcat);
|
||||
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(DictLen, prim::kPrimDictLen, InferImplDictLen);
|
||||
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(FakeBprop, prim::kPrimFakeBprop, InferImplFakeBprop);
|
||||
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(J, prim::kPrimJ, InferImplJ);
|
||||
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(BroadcastGradientArgs, prim::kPrimBroadcastGradientArgs,
|
||||
InferImplBroadcastGradientArgs);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Assign, prim::kPrimAssign, InferImplAssign);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Load, prim::kPrimLoad, InferImplLoad);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(TypeOf, prim::kPrimTypeOf, InferImplTypeof, nullptr, false);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(HasType, prim::kPrimHasType, InferImplHasType, nullptr, false);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(MakeRecord, prim::kPrimMakeRecord, InferImplMakeRecord, nullptr, false);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(ListMap, prim::kPrimListMap, InferImplListMap, nullptr, false);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(ListReduce, prim::kPrimListReduce, InferImplListReduce, nullptr, false);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(TupleReversed, prim::kPrimTupleReversed, InferImplTupleReversed, nullptr, false);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(ReducedShape, prim::kPrimReducedShape, InferImplReduceShape, nullptr, false);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(TupleDiv, prim::kPrimTupleDiv, InferImplTupleDiv, nullptr, false);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(TupleToArray, prim::kPrimTupleToArray, InferImplTuple2Array, nullptr, false);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(ShapeMul, prim::kPrimShapeMul, InferImplShapeMul, nullptr, false);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(TupleEqual, prim::kPrimTupleEqual, InferImplTupleEqual, nullptr, false);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(ListEqual, prim::kPrimListEqual, InferImplListEqual, nullptr, false);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(MakeRange, prim::kPrimMakeRange, InferImplMakeRange, nullptr, false);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(StopGradient, prim::kPrimStopGradient, InferImplStopGradient, nullptr, false);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(StringEqual, prim::kPrimStringEqual, InferImplStringEqual, nullptr, false);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(StringConcat, prim::kPrimStringConcat, InferImplStringConcat, nullptr, false);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(DictLen, prim::kPrimDictLen, InferImplDictLen, nullptr, false);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(FakeBprop, prim::kPrimFakeBprop, InferImplFakeBprop, nullptr, false);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(J, prim::kPrimJ, InferImplJ, nullptr, false);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(BroadcastGradientArgs, prim::kPrimBroadcastGradientArgs, InferImplBroadcastGradientArgs,
|
||||
nullptr, false);
|
||||
} // namespace abstract
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -59,18 +59,6 @@ AbstractBasePtr InferImplFakeBprop(const AnalysisEnginePtr &, const PrimitivePtr
|
|||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplMakeRecord(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
|
||||
class RegisterFrontendPrimitiveEvalHelper {
|
||||
public:
|
||||
RegisterFrontendPrimitiveEvalHelper(const PrimitivePtr &primitive, const StandardPrimitiveEvalImpl &impl) {
|
||||
const StandardPrimitiveImplReg impl_reg{impl, false};
|
||||
RegisterStandardPrimitiveImpl(primitive, impl_reg);
|
||||
}
|
||||
~RegisterFrontendPrimitiveEvalHelper() = default;
|
||||
};
|
||||
|
||||
#define REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(name, primitive, impl) \
|
||||
static auto helper_##name = RegisterFrontendPrimitiveEvalHelper(primitive, impl)
|
||||
} // namespace abstract
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -308,6 +308,10 @@ AbstractBasePtr InferImplSparseSoftmaxCrossEntropyWithLogits(const AnalysisEngin
|
|||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplDType(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplLoad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplAssign(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
template <typename T>
|
||||
AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: a tuple or list or dict.
|
||||
|
|
|
@ -577,5 +577,31 @@ AbstractBasePtr InferImplDType(const AnalysisEnginePtr &, const PrimitivePtr &pr
|
|||
abstract->set_value(value);
|
||||
return abstract;
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplLoad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: Ref/Tensor, universal
|
||||
CheckArgsSize(primitive->name(), args_spec_list, 2);
|
||||
auto ref_abs = dyn_cast<abstract::AbstractRef>(args_spec_list[0]);
|
||||
if (ref_abs != nullptr) {
|
||||
// Return tensor value if input is Ref.
|
||||
return ref_abs->CloneAsTensor();
|
||||
}
|
||||
return args_spec_list[0]->Broaden();
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplAssign(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: Ref, value, [universal]
|
||||
CheckRequiredArgsSize(primitive->name(), args_spec_list, 2);
|
||||
|
||||
MS_LOG(DEBUG) << "InferImplAssign " << args_spec_list[0];
|
||||
auto type = args_spec_list[0]->BuildType();
|
||||
if (type->type_id() == kObjectTypeRefKey) {
|
||||
return args_spec_list[1]->Broaden();
|
||||
} else {
|
||||
return args_spec_list[0];
|
||||
}
|
||||
}
|
||||
} // namespace abstract
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -55,158 +55,161 @@ std::vector<int64_t> GetDependsFormMap(const CNodePtr &cnode) {
|
|||
PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
|
||||
static PrimitiveEvalImplMap prim_eval_implement_map = {
|
||||
// Statements
|
||||
{prim::kPrimReturn, {InferImplReturn, true}},
|
||||
{prim::kPrimSwitch, {InferImplSwitch, true}},
|
||||
{prim::kPrimSwitchLayer, {InferImplSwitchLayer, true}},
|
||||
{prim::kPrimIs_, {InferImplIs_, true}},
|
||||
{prim::kPrimIsNot, {InferImplIsNot, true}},
|
||||
{prim::kPrimInDict, {InferImplInDict, true}},
|
||||
{prim::kPrimNotInDict, {InferImplNotInDict, true}},
|
||||
{prim::kPrimIsConsant, {InferImplIsConstant, true}},
|
||||
{prim::kPrimReturn, {InferImplReturn, nullptr, true}},
|
||||
{prim::kPrimSwitch, {InferImplSwitch, nullptr, true}},
|
||||
{prim::kPrimSwitchLayer, {InferImplSwitchLayer, nullptr, true}},
|
||||
{prim::kPrimIs_, {InferImplIs_, nullptr, true}},
|
||||
{prim::kPrimIsNot, {InferImplIsNot, nullptr, true}},
|
||||
{prim::kPrimInDict, {InferImplInDict, nullptr, true}},
|
||||
{prim::kPrimNotInDict, {InferImplNotInDict, nullptr, true}},
|
||||
{prim::kPrimIsConsant, {InferImplIsConstant, nullptr, true}},
|
||||
// Maths
|
||||
{prim::kPrimSquare, {InferImplSquare, true}},
|
||||
{prim::kPrimMatMul, {InferImplMatMul, true}},
|
||||
{prim::kPrimBatchMatMul, {InferImplBatchMatMul, true}},
|
||||
{prim::kPrimMaximumGrad, {InferImplMinOrMaxGrad, true}},
|
||||
{prim::kPrimMinimumGrad, {InferImplMinOrMaxGrad, true}},
|
||||
{prim::kPrimSqrt, {InferImplSqrt, true}},
|
||||
{prim::kPrimSquare, {InferImplSquare, nullptr, true}},
|
||||
{prim::kPrimMatMul, {InferImplMatMul, nullptr, true}},
|
||||
{prim::kPrimBatchMatMul, {InferImplBatchMatMul, nullptr, true}},
|
||||
{prim::kPrimMaximumGrad, {InferImplMinOrMaxGrad, nullptr, true}},
|
||||
{prim::kPrimMinimumGrad, {InferImplMinOrMaxGrad, nullptr, true}},
|
||||
{prim::kPrimSqrt, {InferImplSqrt, nullptr, true}},
|
||||
// Array
|
||||
{prim::kPrimRange, {InferImplRange, true}},
|
||||
{prim::kPrimScalarToArray, {InferImplScalarToArray, true}},
|
||||
{prim::kPrimArrayToScalar, {InferImplArrayToScalar, true}},
|
||||
{prim::kPrimBroadcastShape, {InferImplBroadCastShape, true}},
|
||||
{prim::kPrimUnique, {InferImplUnique, true}},
|
||||
{prim::kPrimUniqueGrad, {InferImplUniqueGrad, true}},
|
||||
{prim::kPrimGather, {InferImplGatherV2, true}},
|
||||
{prim::kPrimEmbeddingLookup, {InferImplEmbeddingLookup, true}},
|
||||
{prim::kPrimSparseGatherV2, {InferImplGatherV2, true}},
|
||||
{prim::kPrimUnsortedSegmentMax, {InferImplUnsortedSegmentMax, true}},
|
||||
{prim::kPrimUnsortedSegmentMin, {InferImplUnsortedSegmentMin, true}},
|
||||
{prim::kPrimScatterAdd, {InferImplScatterAdd, true}},
|
||||
{prim::kPrimSubAndFilter, {InferImplSubAndFilter, true}},
|
||||
{prim::kPrimScatterUpdate, {InferImplScatterUpdate, true}},
|
||||
{prim::kPrimMapCacheIdx, {InferImplMapCacheIdx, true}},
|
||||
{prim::kPrimDynamicAssign, {InferImplDynamicAssign, true}},
|
||||
{prim::kPrimCacheSwapTable, {InferImplCacheSwapTable, true}},
|
||||
{prim::kPrimUpdateCache, {InferImplUpdateCache, true}},
|
||||
{prim::kPrimComputeAccidentalHits, {InferImplComputeAccidentalHits, true}},
|
||||
{prim::kPrimPadAndShift, {InferImplPadAndShift, true}},
|
||||
{prim::kPrimDynamicShape, {InferImplDynamicShape, true}},
|
||||
{prim::kPrimMapUniform, {InferImplMapUniform, true}},
|
||||
{prim::kPrimSplit, {InferImplSplit, true}},
|
||||
{prim::kPrimSequenceMask, {InferImplSequenceMask, true}},
|
||||
{prim::kPrimRange, {InferImplRange, nullptr, true}},
|
||||
{prim::kPrimScalarToArray, {InferImplScalarToArray, nullptr, true}},
|
||||
{prim::kPrimArrayToScalar, {InferImplArrayToScalar, nullptr, true}},
|
||||
{prim::kPrimBroadcastShape, {InferImplBroadCastShape, nullptr, true}},
|
||||
{prim::kPrimUnique, {InferImplUnique, nullptr, true}},
|
||||
{prim::kPrimUniqueGrad, {InferImplUniqueGrad, nullptr, true}},
|
||||
{prim::kPrimGather, {InferImplGatherV2, nullptr, true}},
|
||||
{prim::kPrimEmbeddingLookup, {InferImplEmbeddingLookup, nullptr, true}},
|
||||
{prim::kPrimSparseGatherV2, {InferImplGatherV2, nullptr, true}},
|
||||
{prim::kPrimUnsortedSegmentMax, {InferImplUnsortedSegmentMax, nullptr, true}},
|
||||
{prim::kPrimUnsortedSegmentMin, {InferImplUnsortedSegmentMin, nullptr, true}},
|
||||
{prim::kPrimScatterAdd, {InferImplScatterAdd, nullptr, true}},
|
||||
{prim::kPrimSubAndFilter, {InferImplSubAndFilter, nullptr, true}},
|
||||
{prim::kPrimScatterUpdate, {InferImplScatterUpdate, nullptr, true}},
|
||||
{prim::kPrimMapCacheIdx, {InferImplMapCacheIdx, nullptr, true}},
|
||||
{prim::kPrimDynamicAssign, {InferImplDynamicAssign, nullptr, true}},
|
||||
{prim::kPrimCacheSwapTable, {InferImplCacheSwapTable, nullptr, true}},
|
||||
{prim::kPrimUpdateCache, {InferImplUpdateCache, nullptr, true}},
|
||||
{prim::kPrimComputeAccidentalHits, {InferImplComputeAccidentalHits, nullptr, true}},
|
||||
{prim::kPrimPadAndShift, {InferImplPadAndShift, nullptr, true}},
|
||||
{prim::kPrimDynamicShape, {InferImplDynamicShape, nullptr, true}},
|
||||
{prim::kPrimMapUniform, {InferImplMapUniform, nullptr, true}},
|
||||
{prim::kPrimSplit, {InferImplSplit, nullptr, true}},
|
||||
{prim::kPrimSequenceMask, {InferImplSequenceMask, nullptr, true}},
|
||||
// Structure
|
||||
{prim::kPrimMakeTuple, {InferImplMakeTuple, true}},
|
||||
{prim::kPrimMakeList, {InferImplMakeList, true}},
|
||||
{prim::kPrimMakeDict, {InferImplMakeDict, true}},
|
||||
{prim::kPrimMakeSlice, {InferImplMakeSlice, true}},
|
||||
{prim::kPrimMakeKeywordArg, {InferImplMakeKwarg, true}},
|
||||
{prim::kPrimExtractKeywordArg, {InferImplExtractKwarg, true}},
|
||||
{prim::kPrimTupleGetItem, {InferImplTupleGetItem, true}},
|
||||
{prim::kPrimListGetItem, {InferImplListGetItem, true}},
|
||||
{prim::kPrimTupleSetItem, {InferImplTupleSetItem, true}},
|
||||
{prim::kPrimListSetItem, {InferImplListSetItem, true}},
|
||||
{prim::kPrimDictGetItem, {InferImplDictGetItem, true}},
|
||||
{prim::kPrimDictSetItem, {InferImplDictSetItem, true}},
|
||||
{prim::kPrimDictGetKeys, {InferImplDictGetKeys, true}},
|
||||
{prim::kPrimDictGetValues, {InferImplDictGetValues, true}},
|
||||
{prim::kPrimListAppend, {InferImplListAppend, true}},
|
||||
{prim::kPrimTupleLen, {InferImplTupleLen, true}},
|
||||
{prim::kPrimListLen, {InferImplListLen, true}},
|
||||
{prim::kPrimArrayLen, {InferImplArrayLen, true}},
|
||||
{prim::kPrimMakeTuple, {InferImplMakeTuple, nullptr, true}},
|
||||
{prim::kPrimMakeList, {InferImplMakeList, nullptr, true}},
|
||||
{prim::kPrimMakeDict, {InferImplMakeDict, nullptr, true}},
|
||||
{prim::kPrimMakeSlice, {InferImplMakeSlice, nullptr, true}},
|
||||
{prim::kPrimMakeKeywordArg, {InferImplMakeKwarg, nullptr, true}},
|
||||
{prim::kPrimExtractKeywordArg, {InferImplExtractKwarg, nullptr, true}},
|
||||
{prim::kPrimTupleGetItem, {InferImplTupleGetItem, nullptr, true}},
|
||||
{prim::kPrimListGetItem, {InferImplListGetItem, nullptr, true}},
|
||||
{prim::kPrimTupleSetItem, {InferImplTupleSetItem, nullptr, true}},
|
||||
{prim::kPrimListSetItem, {InferImplListSetItem, nullptr, true}},
|
||||
{prim::kPrimDictGetItem, {InferImplDictGetItem, nullptr, true}},
|
||||
{prim::kPrimDictSetItem, {InferImplDictSetItem, nullptr, true}},
|
||||
{prim::kPrimDictGetKeys, {InferImplDictGetKeys, nullptr, true}},
|
||||
{prim::kPrimDictGetValues, {InferImplDictGetValues, nullptr, true}},
|
||||
{prim::kPrimListAppend, {InferImplListAppend, nullptr, true}},
|
||||
{prim::kPrimTupleLen, {InferImplTupleLen, nullptr, true}},
|
||||
{prim::kPrimListLen, {InferImplListLen, nullptr, true}},
|
||||
{prim::kPrimArrayLen, {InferImplArrayLen, nullptr, true}},
|
||||
// NN
|
||||
{prim::kPrimPooling, {InferImplPooling, true}},
|
||||
{prim::kPrimPoolingGrad, {InferImplPoolingGrad, true}},
|
||||
{prim::kPrimBatchNorm, {InferImplBatchNorm, true}},
|
||||
{prim::kPrimReluGrad, {InferImplReluGrad, true}},
|
||||
{prim::kPrimConv2D, {InferImplConv2D, true}},
|
||||
{prim::kPrimBiasAdd, {InferImplBiasAdd, true}},
|
||||
{prim::kPrimRelu, {InferImplRelu, true}},
|
||||
{prim::kPrimRelu6, {InferImplRelu, true}},
|
||||
{prim::kPrimZerosLike, {InferImplZerosLike, true}},
|
||||
{prim::kPrimBpropCut, {InferImplBpropCut, true}},
|
||||
{prim::kPrimLayerNorm, {InferImplLayerNorm, true}},
|
||||
{prim::kPrimLayerNormGrad, {InferImplLayerNormGrad, true}},
|
||||
{prim::kPrimDropout, {InferImplDropout, true}},
|
||||
{prim::kPrimDropoutGenMask, {InferImplDropoutGenMask, true}},
|
||||
{prim::kPrimSparseApplyFtrl, {InferImplSparseApplyFtrl, true}},
|
||||
{prim::kPrimSparseApplyProximalAdagrad, {InferImplSparseApplyProximalAdagrad, true}},
|
||||
{prim::kPrimSGD, {InferImplSGD, true}},
|
||||
{prim::kPrimCTCGreedyDecoder, {InferImplCTCGreedyDecoder, true}},
|
||||
{prim::kPrimPooling, {InferImplPooling, nullptr, true}},
|
||||
{prim::kPrimPoolingGrad, {InferImplPoolingGrad, nullptr, true}},
|
||||
{prim::kPrimBatchNorm, {InferImplBatchNorm, nullptr, true}},
|
||||
{prim::kPrimReluGrad, {InferImplReluGrad, nullptr, true}},
|
||||
{prim::kPrimConv2D, {InferImplConv2D, nullptr, true}},
|
||||
{prim::kPrimBiasAdd, {InferImplBiasAdd, nullptr, true}},
|
||||
{prim::kPrimRelu, {InferImplRelu, nullptr, true}},
|
||||
{prim::kPrimRelu6, {InferImplRelu, nullptr, true}},
|
||||
{prim::kPrimZerosLike, {InferImplZerosLike, nullptr, true}},
|
||||
{prim::kPrimBpropCut, {InferImplBpropCut, nullptr, true}},
|
||||
{prim::kPrimLayerNorm, {InferImplLayerNorm, nullptr, true}},
|
||||
{prim::kPrimLayerNormGrad, {InferImplLayerNormGrad, nullptr, true}},
|
||||
{prim::kPrimDropout, {InferImplDropout, nullptr, true}},
|
||||
{prim::kPrimDropoutGenMask, {InferImplDropoutGenMask, nullptr, true}},
|
||||
{prim::kPrimSparseApplyFtrl, {InferImplSparseApplyFtrl, nullptr, true}},
|
||||
{prim::kPrimSparseApplyProximalAdagrad, {InferImplSparseApplyProximalAdagrad, nullptr, true}},
|
||||
{prim::kPrimSGD, {InferImplSGD, nullptr, true}},
|
||||
{prim::kPrimCTCGreedyDecoder, {InferImplCTCGreedyDecoder, nullptr, true}},
|
||||
// Others
|
||||
{prim::kPrimIdentity, {InferImplIdentity, true}},
|
||||
{prim::kPrimIdentity, {InferImplIdentity, nullptr, true}},
|
||||
{prim::kPrimLoad, {InferImplLoad, nullptr, true}},
|
||||
{prim::kPrimAssign, {InferImplAssign, nullptr, true}},
|
||||
// Set impl to null as it will use PartialEvaluator;
|
||||
{prim::kPrimPartial, {nullptr, true}},
|
||||
{prim::kPrimEnvGetItem, {InferImplEnvGetItem, true}},
|
||||
{prim::kPrimEnvSetItem, {InferImplEnvSetItem, true}},
|
||||
{prim::kPrimEnvAdd, {InferImplEnvAdd, true}},
|
||||
{prim::kPrimMakeRefKey, {InferImplMakeRefKey, true}},
|
||||
{prim::kPrimMakeRef, {InferImplMakeRef, true}},
|
||||
{prim::kPrimGetRefKey, {InferImplGetRefKey, true}},
|
||||
{prim::kPrimGetRefValue, {InferImplGetRefValue, true}},
|
||||
{prim::kPrimStateSetItem, {InferImplStateSetItem, true}},
|
||||
{prim::kPrimDepend, {InferImplDepend, true}},
|
||||
{prim::kPrimUpdateState, {InferImplUpdateState, true}},
|
||||
{prim::kPrimControlDepend, {InferImplControlDepend, true}},
|
||||
{prim::kPrimPartial, {nullptr, nullptr, true}},
|
||||
{prim::kPrimEnvGetItem, {InferImplEnvGetItem, nullptr, true}},
|
||||
{prim::kPrimEnvSetItem, {InferImplEnvSetItem, nullptr, true}},
|
||||
{prim::kPrimEnvAdd, {InferImplEnvAdd, nullptr, true}},
|
||||
{prim::kPrimMakeRefKey, {InferImplMakeRefKey, nullptr, true}},
|
||||
{prim::kPrimMakeRef, {InferImplMakeRef, nullptr, true}},
|
||||
{prim::kPrimGetRefKey, {InferImplGetRefKey, nullptr, true}},
|
||||
{prim::kPrimGetRefValue, {InferImplGetRefValue, nullptr, true}},
|
||||
{prim::kPrimStateSetItem, {InferImplStateSetItem, nullptr, true}},
|
||||
{prim::kPrimDepend, {InferImplDepend, nullptr, true}},
|
||||
{prim::kPrimUpdateState, {InferImplUpdateState, nullptr, true}},
|
||||
{prim::kPrimControlDepend, {InferImplControlDepend, nullptr, true}},
|
||||
// Debug
|
||||
{prim::kPrimDebug, {InferImplDebug, true}},
|
||||
{prim::kPrimDebug, {InferImplDebug, nullptr, true}},
|
||||
// Dynamic shape testing
|
||||
{prim::kPrimGpuConvertToDynamicShape, {InferImplGpuConvertToDynamicShape, true}},
|
||||
{prim::kPrimGpuConvertToDynamicShape, {InferImplGpuConvertToDynamicShape, nullptr, true}},
|
||||
// SparseTensor
|
||||
{prim::kPrimMakeSparseTensor, {InferImplMakeSparseTensor, true}},
|
||||
{prim::kPrimSparseTensorGetValues, {InferImplSparseTensorGetValues, true}},
|
||||
{prim::kPrimSparseTensorGetIndices, {InferImplSparseTensorGetIndices, true}},
|
||||
{prim::kPrimSparseTensorGetDenseShape, {InferImplSparseTensorGetDenseShape, true}},
|
||||
{prim::kPrimMakeSparseTensor, {InferImplMakeSparseTensor, nullptr, true}},
|
||||
{prim::kPrimSparseTensorGetValues, {InferImplSparseTensorGetValues, nullptr, true}},
|
||||
{prim::kPrimSparseTensorGetIndices, {InferImplSparseTensorGetIndices, nullptr, true}},
|
||||
{prim::kPrimSparseTensorGetDenseShape, {InferImplSparseTensorGetDenseShape, nullptr, true}},
|
||||
// RowTensor
|
||||
{prim::kPrimMakeRowTensor, {InferImplMakeRowTensor, true}},
|
||||
{prim::kPrimRowTensorGetValues, {InferImplRowTensorGetValues, true}},
|
||||
{prim::kPrimRowTensorGetIndices, {InferImplRowTensorGetIndices, true}},
|
||||
{prim::kPrimRowTensorGetDenseShape, {InferImplRowTensorGetDenseShape, true}},
|
||||
{prim::kPrimRowTensorAdd, {InferImplRowTensorAdd, false}},
|
||||
{prim::kPrimMakeRowTensor, {InferImplMakeRowTensor, nullptr, true}},
|
||||
|
||||
{prim::kPrimRowTensorGetValues, {InferImplRowTensorGetValues, nullptr, true}},
|
||||
{prim::kPrimRowTensorGetIndices, {InferImplRowTensorGetIndices, nullptr, true}},
|
||||
{prim::kPrimRowTensorGetDenseShape, {InferImplRowTensorGetDenseShape, nullptr, true}},
|
||||
{prim::kPrimRowTensorAdd, {InferImplRowTensorAdd, nullptr, false}},
|
||||
// Comm Ops
|
||||
{prim::kPrimAllSwap, {InferImplAllSwap, true}},
|
||||
{prim::kPrimMemCpyAsync, {InferImplMemCpyAsync, true}},
|
||||
{prim::kPrimAllSwap, {InferImplAllSwap, nullptr, true}},
|
||||
{prim::kPrimMemCpyAsync, {InferImplMemCpyAsync, nullptr, true}},
|
||||
};
|
||||
return prim_eval_implement_map;
|
||||
}
|
||||
|
||||
PrimitiveEvalImplMap &GetPrimitiveToBackendEvalImplMap() {
|
||||
static PrimitiveEvalImplMap prim_backend_eval_implement_map = {
|
||||
{prim::kPrimMul, {InferImplMul, true}},
|
||||
{prim::kPrimAdd, {InferImplAdd, true}},
|
||||
{prim::kPrimSqrtGrad, {InferImplSqrtGrad, true}},
|
||||
{prim::kPrimSub, {InferImplSub, true}},
|
||||
{prim::kPrimEqual, {InferImplEqual, true}},
|
||||
{prim::kPrimReduceSum, {InferImplReduceFunc, true}},
|
||||
{prim::kPrimReduceMean, {InferImplReduceFunc, true}},
|
||||
{prim::kPrimReduceAll, {InferImplReduceFunc, true}},
|
||||
{prim::kPrimReduceAny, {InferImplReduceFunc, true}},
|
||||
{prim::kPrimReduceMax, {InferImplReduceFunc, true}},
|
||||
{prim::kPrimReduceMin, {InferImplReduceFunc, true}},
|
||||
{prim::kPrimBiasAddGrad, {InferImplBiasAddGrad, true}},
|
||||
{prim::kPrimReduceScatter, {InferImplReduceScatter, true}},
|
||||
{prim::kPrimCast, {InferImplCast, true}},
|
||||
{prim::kPrimExpandDims, {InferImplExpandDims, true}},
|
||||
{prim::kPrimAllReduce, {InferImplAllReduce, true}},
|
||||
{prim::kPrimBroadcast, {InferImplBroadcast, true}},
|
||||
{prim::kPrimAllGather, {InferImplAllGather, true}},
|
||||
{prim::kPrimMinimum, {InferImplMinimum, true}},
|
||||
{prim::kPrimDivNoNan, {InferImplDivNoNan, true}},
|
||||
{prim::kPrimLinSpace, {InferImplLinSpace, true}},
|
||||
{prim::kPrimAddN, {InferImplAddN, true}},
|
||||
{prim::kPrimMul, {InferImplMul, nullptr, true}},
|
||||
{prim::kPrimAdd, {InferImplAdd, nullptr, true}},
|
||||
{prim::kPrimSqrtGrad, {InferImplSqrtGrad, nullptr, true}},
|
||||
{prim::kPrimSub, {InferImplSub, nullptr, true}},
|
||||
{prim::kPrimEqual, {InferImplEqual, nullptr, true}},
|
||||
{prim::kPrimReduceSum, {InferImplReduceFunc, nullptr, true}},
|
||||
{prim::kPrimReduceMean, {InferImplReduceFunc, nullptr, true}},
|
||||
{prim::kPrimReduceAll, {InferImplReduceFunc, nullptr, true}},
|
||||
{prim::kPrimReduceAny, {InferImplReduceFunc, nullptr, true}},
|
||||
{prim::kPrimReduceMax, {InferImplReduceFunc, nullptr, true}},
|
||||
{prim::kPrimReduceMin, {InferImplReduceFunc, nullptr, true}},
|
||||
{prim::kPrimBiasAddGrad, {InferImplBiasAddGrad, nullptr, true}},
|
||||
{prim::kPrimReduceScatter, {InferImplReduceScatter, nullptr, true}},
|
||||
{prim::kPrimCast, {InferImplCast, nullptr, true}},
|
||||
{prim::kPrimExpandDims, {InferImplExpandDims, nullptr, true}},
|
||||
{prim::kPrimAllReduce, {InferImplAllReduce, nullptr, true}},
|
||||
{prim::kPrimBroadcast, {InferImplBroadcast, nullptr, true}},
|
||||
{prim::kPrimAllGather, {InferImplAllGather, nullptr, true}},
|
||||
{prim::kPrimMinimum, {InferImplMinimum, nullptr, true}},
|
||||
{prim::kPrimDivNoNan, {InferImplDivNoNan, nullptr, true}},
|
||||
{prim::kPrimLinSpace, {InferImplLinSpace, nullptr, true}},
|
||||
{prim::kPrimAddN, {InferImplAddN, nullptr, true}},
|
||||
|
||||
{prim::kPrimLess, {InferImplLess, true}},
|
||||
{prim::kPrimStack, {InferImplStack, true}},
|
||||
{prim::kPrimPad, {InferImplPad, true}},
|
||||
{prim::kPrimUnsortedSegmentSum, {InferImplUnsortedSegmentSum, true}},
|
||||
{prim::kPrimDiv, {InferImplDiv, true}},
|
||||
{prim::kPrimRealDiv, {InferImplRealDiv, true}},
|
||||
{prim::kPrimShape, {InferImplShape, false}},
|
||||
{prim::kPrimTranspose, {InferImplTranspose, true}},
|
||||
{prim::kPrimReshape, {InferImplReshape, true}},
|
||||
{prim::kPrimConcat, {InferImplConcat, true}},
|
||||
{prim::kPrimArgMaxWithValue, {InferImplArgMaxWithValue, true}},
|
||||
{prim::kPrimFusedSparseAdam, {InferImplFusedSparseAdam, true}},
|
||||
{prim::kPrimLess, {InferImplLess, nullptr, true}},
|
||||
{prim::kPrimStack, {InferImplStack, nullptr, true}},
|
||||
{prim::kPrimPad, {InferImplPad, nullptr, true}},
|
||||
{prim::kPrimUnsortedSegmentSum, {InferImplUnsortedSegmentSum, nullptr, true}},
|
||||
{prim::kPrimDiv, {InferImplDiv, nullptr, true}},
|
||||
{prim::kPrimRealDiv, {InferImplRealDiv, nullptr, true}},
|
||||
{prim::kPrimShape, {InferImplShape, nullptr, false}},
|
||||
{prim::kPrimTranspose, {InferImplTranspose, nullptr, true}},
|
||||
{prim::kPrimReshape, {InferImplReshape, nullptr, true}},
|
||||
{prim::kPrimConcat, {InferImplConcat, nullptr, true}},
|
||||
{prim::kPrimArgMaxWithValue, {InferImplArgMaxWithValue, nullptr, true}},
|
||||
{prim::kPrimFusedSparseAdam, {InferImplFusedSparseAdam, nullptr, true}},
|
||||
};
|
||||
return prim_backend_eval_implement_map;
|
||||
}
|
||||
|
|
|
@ -28,9 +28,13 @@ namespace mindspore {
|
|||
namespace abstract {
|
||||
using StandardPrimitiveEvalImpl = AbstractBasePtr (*)(const abstract::AnalysisEnginePtr &, const PrimitivePtr &,
|
||||
const AbstractBasePtrList &);
|
||||
using InferValueEvalImpl = ValuePtr (*)(const PrimitivePtr &, const AbstractBasePtrList &, const AbstractBasePtr &);
|
||||
|
||||
struct StandardPrimitiveImplReg {
|
||||
StandardPrimitiveEvalImpl impl_; // Implement function of Primitive.
|
||||
bool in_white_list_; // true if this Primitive in white list, else false.
|
||||
StandardPrimitiveEvalImpl impl_; // Implement function of Primitive
|
||||
InferValueEvalImpl infer_value_func_; // infer value of primitive
|
||||
// true means this primitive can be executed by vm backend else will be constant folded by frontend
|
||||
bool in_white_list_;
|
||||
};
|
||||
|
||||
using PrimitiveEvalImplMap =
|
||||
|
@ -48,15 +52,17 @@ void RegisterStandardPrimitiveImpl(const PrimitivePtr &primitive, const Standard
|
|||
|
||||
class RegisterStandardPrimitiveEvalHelper {
|
||||
public:
|
||||
RegisterStandardPrimitiveEvalHelper(const PrimitivePtr &primitive, const StandardPrimitiveEvalImpl &impl) {
|
||||
const StandardPrimitiveImplReg impl_reg{impl, true};
|
||||
RegisterStandardPrimitiveEvalHelper(const PrimitivePtr &primitive, const StandardPrimitiveEvalImpl &impl,
|
||||
const InferValueEvalImpl &infer_value_impl, const bool is_wight_list = true) {
|
||||
const StandardPrimitiveImplReg impl_reg{impl, infer_value_impl, is_wight_list};
|
||||
RegisterStandardPrimitiveImpl(primitive, impl_reg);
|
||||
}
|
||||
~RegisterStandardPrimitiveEvalHelper() = default;
|
||||
};
|
||||
|
||||
#define REGISTER_PRIMITIVE_EVAL_IMPL(name, primitive, impl) \
|
||||
static auto helper_##name = abstract::RegisterStandardPrimitiveEvalHelper(primitive, impl)
|
||||
#define REGISTER_PRIMITIVE_EVAL_IMPL(name, primitive, impl, infer_value_impl, is_wight_list) \
|
||||
static auto helper_##name = \
|
||||
abstract::RegisterStandardPrimitiveEvalHelper(primitive, impl, infer_value_impl, is_wight_list)
|
||||
} // namespace abstract
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CORE_ABSTRACT_PRIMITIVE_INFER_MAP_H_
|
||||
|
|
|
@ -539,6 +539,40 @@ inline const PrimitivePtr kPrimReduceFusion = std::make_shared<Primitive>("Reduc
|
|||
inline const PrimitivePtr kPrimLayerNormFusion = std::make_shared<Primitive>("LayerNormFusion");
|
||||
inline const PrimitivePtr kPrimDType = std::make_shared<Primitive>("DType");
|
||||
|
||||
// Type introspection
|
||||
inline const PrimitivePtr kPrimTypeOf = std::make_shared<Primitive>("typeof");
|
||||
inline const PrimitivePtr kPrimHasType = std::make_shared<Primitive>("hastype");
|
||||
|
||||
inline const PrimitivePtr kPrimResolve = std::make_shared<Primitive>("resolve");
|
||||
inline const PrimitivePtr kPrimEmbed = std::make_shared<Primitive>("embed");
|
||||
inline const PrimitivePtr kPrimRefToEmbed = std::make_shared<Primitive>("RefToEmbed");
|
||||
inline const PrimitivePtr kPrimCreateInstance = std::make_shared<Primitive>("create_instance");
|
||||
|
||||
// Other miscellaneous
|
||||
inline const PrimitivePtr kPrimGetRefOrigin = std::make_shared<Primitive>("get_ref_origin");
|
||||
inline const PrimitivePtr kPrimInsertGradientOf = std::make_shared<Primitive>("InsertGradientOf");
|
||||
inline const PrimitivePtr kPrimCheckBprop = std::make_shared<Primitive>("CheckBprop");
|
||||
inline const PrimitivePtr kPrimMixedPrecisionCast = std::make_shared<Primitive>("mixed_precision_cast");
|
||||
inline const PrimitivePtr kPrimMakeRecord = std::make_shared<Primitive>("make_record");
|
||||
|
||||
// Structures
|
||||
inline const PrimitivePtr kPrimListMap = std::make_shared<Primitive>("list_map");
|
||||
inline const PrimitivePtr kPrimListReduce = std::make_shared<Primitive>("list_reduce");
|
||||
inline const PrimitivePtr kPrimTupleReversed = std::make_shared<Primitive>("tuple_reversed");
|
||||
inline const PrimitivePtr kPrimReducedShape = std::make_shared<Primitive>("reduced_shape");
|
||||
inline const PrimitivePtr kPrimTupleDiv = std::make_shared<Primitive>("tuple_div");
|
||||
inline const PrimitivePtr kPrimTupleToArray = std::make_shared<Primitive>("tuple_to_array");
|
||||
inline const PrimitivePtr kPrimShapeMul = std::make_shared<Primitive>("shape_mul");
|
||||
inline const PrimitivePtr kPrimTupleEqual = std::make_shared<Primitive>("tuple_equal");
|
||||
inline const PrimitivePtr kPrimListEqual = std::make_shared<Primitive>("list_equal");
|
||||
inline const PrimitivePtr kPrimMakeRange = std::make_shared<Primitive>("make_range");
|
||||
inline const PrimitivePtr kPrimStopGradient = std::make_shared<Primitive>("stop_gradient");
|
||||
inline const PrimitivePtr kPrimStringEqual = std::make_shared<Primitive>("string_equal");
|
||||
inline const PrimitivePtr kPrimStringConcat = std::make_shared<Primitive>("string_concat");
|
||||
inline const PrimitivePtr kPrimDictLen = std::make_shared<Primitive>("dict_len");
|
||||
inline const PrimitivePtr kPrimFakeBprop = std::make_shared<Primitive>("fake_bprop");
|
||||
inline const PrimitivePtr kPrimBroadcastGradientArgs = std::make_shared<Primitive>("BroadcastGradientArgs");
|
||||
|
||||
class DoSignaturePrimitive : public Primitive {
|
||||
public:
|
||||
explicit DoSignaturePrimitive(const std::string &name, const ValuePtr &function)
|
||||
|
|
|
@ -49,6 +49,5 @@ AbstractBasePtr AddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr
|
|||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
||||
InferShape(primitive, input_args)->shape());
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Add, prim::kPrimAdd, AddInfer);
|
||||
REGISTER_PRIMITIVE_C(kNameAdd, Add);
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -42,7 +42,7 @@ AbstractBasePtr ScalarSummaryInfer(const abstract::AnalysisEnginePtr &, const Pr
|
|||
CheckAndConvertUtils::CheckInteger("v rank", v_shape.size(), kLessEqual, 1, prim_name);
|
||||
return std::make_shared<abstract::AbstractTensor>(kInt32, std::make_shared<abstract::Shape>(ShapeVector(1)));
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(ScalarSummary, prim::kPrimScalarSummary, ScalarSummaryInfer);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(ScalarSummary, prim::kPrimScalarSummary, ScalarSummaryInfer, nullptr, true);
|
||||
REGISTER_PRIMITIVE_C(kNameScalarSummary, ScalarSummary);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -42,7 +42,7 @@ AbstractBasePtr TensorSummaryInfer(const abstract::AnalysisEnginePtr &, const Pr
|
|||
CheckAndConvertUtils::CheckInteger("v rank", v_shape.size(), kGreaterEqual, 1, prim_name);
|
||||
return std::make_shared<abstract::AbstractTensor>(kInt32, std::make_shared<abstract::Shape>(ShapeVector(1)));
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(TensorSummary, prim::kPrimTensorSummary, TensorSummaryInfer);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(TensorSummary, prim::kPrimTensorSummary, TensorSummaryInfer, nullptr, true);
|
||||
REGISTER_PRIMITIVE_C(kNameTensorSummary, TensorSummary);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -36,7 +36,7 @@ AbstractBasePtr InferImplAttrTest(const abstract::AnalysisEnginePtr &, const Pri
|
|||
EXPECT_EQ(args_spec_list[1]->isa<abstract::AbstractTuple>(), true);
|
||||
return args_spec_list[0];
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(TestAttr,kPrimAttrConvertTest,InferImplAttrTest);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(TestAttr, kPrimAttrConvertTest, InferImplAttrTest, nullptr, true);
|
||||
AbstractBasePtr InferImplDynamicInputTest(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
EXPECT_EQ(args_spec_list.size(), 3);
|
||||
|
@ -45,7 +45,7 @@ AbstractBasePtr InferImplDynamicInputTest(const abstract::AnalysisEnginePtr &, c
|
|||
auto item = args_spec_list[1]->cast<abstract::AbstractTuplePtr>();
|
||||
return args_spec_list[0];
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(TestDynamicInput,kPrimDynamicInputTest,InferImplDynamicInputTest);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(TestDynamicInput, kPrimDynamicInputTest, InferImplDynamicInputTest, nullptr, true);
|
||||
class TestAttrAndDynamicBackendInfer : public UT::Common {
|
||||
public:
|
||||
TestAttrAndDynamicBackendInfer() {}
|
||||
|
|
Loading…
Reference in New Issue