forked from mindspore-Ecosystem/mindspore
!14745 refactor infer value
From: @lianliguang Reviewed-by: @ginfung Signed-off-by:
This commit is contained in:
commit
3068d66d63
|
@ -923,14 +923,16 @@ AbstractBasePtr CppInferShape(const PrimitivePtr &prim, const AbstractBasePtrLis
|
|||
auto ret = prim_eval_implement_map.find(prim);
|
||||
if (ret != prim_eval_implement_map.end()) {
|
||||
// fing infer function in the front infer map and restore input abastract form dynamic inputs and reg attr
|
||||
MS_EXCEPTION_IF_NULL(ret->second.infer_shape_dtype_impl_);
|
||||
auto infer_spec_list = RectifyAbstract(prim, args_spec_list);
|
||||
return ret->second.impl_(nullptr, prim, infer_spec_list);
|
||||
return ret->second.infer_shape_dtype_impl_(nullptr, prim, infer_spec_list);
|
||||
} else {
|
||||
// if the infer function has been not founded in the front infer map find it in the backend infer map instead
|
||||
auto &prim_backend_eval_impl_map = abstract::GetPrimitiveToBackendEvalImplMap();
|
||||
auto ret_backend = prim_backend_eval_impl_map.find(prim);
|
||||
if (ret_backend != prim_backend_eval_impl_map.end()) {
|
||||
return ret_backend->second.impl_(nullptr, prim, args_spec_list);
|
||||
MS_EXCEPTION_IF_NULL(ret_backend->second.infer_shape_dtype_impl_);
|
||||
return ret_backend->second.infer_shape_dtype_impl_(nullptr, prim, args_spec_list);
|
||||
}
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "Get infer shape function failed, primitive name:" << prim->name()
|
||||
|
|
|
@ -639,26 +639,26 @@ AbstractBasePtr InferImplMakeRecord(const AnalysisEnginePtr &, const PrimitivePt
|
|||
|
||||
return std::make_shared<AbstractClass>(cls->tag(), abs_attributes, cls->methods());
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(TypeOf, prim::kPrimTypeOf, InferImplTypeof, nullptr, false);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(HasType, prim::kPrimHasType, InferImplHasType, nullptr, false);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(MakeRecord, prim::kPrimMakeRecord, InferImplMakeRecord, nullptr, false);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(ListMap, prim::kPrimListMap, InferImplListMap, nullptr, false);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(ListReduce, prim::kPrimListReduce, InferImplListReduce, nullptr, false);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(TupleReversed, prim::kPrimTupleReversed, InferImplTupleReversed, nullptr, false);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(ReducedShape, prim::kPrimReducedShape, InferImplReduceShape, nullptr, false);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(TupleDiv, prim::kPrimTupleDiv, InferImplTupleDiv, nullptr, false);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(TupleToArray, prim::kPrimTupleToArray, InferImplTuple2Array, nullptr, false);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(ShapeMul, prim::kPrimShapeMul, InferImplShapeMul, nullptr, false);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(TupleEqual, prim::kPrimTupleEqual, InferImplTupleEqual, nullptr, false);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(ListEqual, prim::kPrimListEqual, InferImplListEqual, nullptr, false);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(MakeRange, prim::kPrimMakeRange, InferImplMakeRange, nullptr, false);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(StopGradient, prim::kPrimStopGradient, InferImplStopGradient, nullptr, false);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(StringEqual, prim::kPrimStringEqual, InferImplStringEqual, nullptr, false);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(StringConcat, prim::kPrimStringConcat, InferImplStringConcat, nullptr, false);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(DictLen, prim::kPrimDictLen, InferImplDictLen, nullptr, false);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(FakeBprop, prim::kPrimFakeBprop, InferImplFakeBprop, nullptr, false);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(J, prim::kPrimJ, InferImplJ, nullptr, false);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(BroadcastGradientArgs, prim::kPrimBroadcastGradientArgs, InferImplBroadcastGradientArgs,
|
||||
nullptr, false);
|
||||
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(TypeOf, prim::kPrimTypeOf, InferImplTypeof, nullptr);
|
||||
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(HasType, prim::kPrimHasType, InferImplHasType, nullptr);
|
||||
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(MakeRecord, prim::kPrimMakeRecord, InferImplMakeRecord, nullptr);
|
||||
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(ListMap, prim::kPrimListMap, InferImplListMap, nullptr);
|
||||
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(ListReduce, prim::kPrimListReduce, InferImplListReduce, nullptr);
|
||||
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(TupleReversed, prim::kPrimTupleReversed, InferImplTupleReversed, nullptr);
|
||||
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(ReducedShape, prim::kPrimReducedShape, InferImplReduceShape, nullptr);
|
||||
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(TupleDiv, prim::kPrimTupleDiv, InferImplTupleDiv, nullptr);
|
||||
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(TupleToArray, prim::kPrimTupleToArray, InferImplTuple2Array, nullptr);
|
||||
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(ShapeMul, prim::kPrimShapeMul, InferImplShapeMul, nullptr);
|
||||
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(TupleEqual, prim::kPrimTupleEqual, InferImplTupleEqual, nullptr);
|
||||
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(ListEqual, prim::kPrimListEqual, InferImplListEqual, nullptr);
|
||||
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(MakeRange, prim::kPrimMakeRange, InferImplMakeRange, nullptr);
|
||||
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(StopGradient, prim::kPrimStopGradient, InferImplStopGradient, nullptr);
|
||||
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(StringEqual, prim::kPrimStringEqual, InferImplStringEqual, nullptr);
|
||||
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(StringConcat, prim::kPrimStringConcat, InferImplStringConcat, nullptr);
|
||||
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(DictLen, prim::kPrimDictLen, InferImplDictLen, nullptr);
|
||||
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(FakeBprop, prim::kPrimFakeBprop, InferImplFakeBprop, nullptr);
|
||||
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(J, prim::kPrimJ, InferImplJ, nullptr);
|
||||
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(BroadcastGradientArgs, prim::kPrimBroadcastGradientArgs,
|
||||
InferImplBroadcastGradientArgs, nullptr);
|
||||
} // namespace abstract
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -59,7 +59,9 @@ AbstractBasePtr InferImplFakeBprop(const AnalysisEnginePtr &, const PrimitivePtr
|
|||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplMakeRecord(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
#define REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(name, primitive, infer_impl, infer_value_impl) \
|
||||
static auto helper_##name = \
|
||||
abstract::RegisterStandardPrimitiveEvalHelper(primitive, infer_impl, infer_value_impl, false);
|
||||
} // namespace abstract
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_OPERATE_OPS_FRONT_INFER_FUNCTION_H_
|
||||
|
|
|
@ -530,28 +530,17 @@ AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dic
|
|||
}
|
||||
} // end anonymous namespace
|
||||
|
||||
EvalResultPtr StandardPrimEvaluator::EvalPyCheckPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) {
|
||||
EvalResultPtr StandardPrimEvaluator::RunPyInferValue(const AnalysisEnginePtr &engine, const AbstractBasePtr &abs_base,
|
||||
const AbstractBasePtrList &args) {
|
||||
auto prim_py = dyn_cast<PrimitivePy>(prim_);
|
||||
if (prim_py == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "The primitive with type 'kPrimTypePyInferCheck' should be a python primitive.";
|
||||
}
|
||||
|
||||
// Call checking method '__check__' for subclass of 'PrimitiveWithCheck'
|
||||
// Call checking method 'infer_value' for python primitive
|
||||
MS_LOG(DEBUG) << "Begin input args checking for: " << prim_py->ToString();
|
||||
auto py_args = PreparePyInputs(prim_py, args);
|
||||
prim_py->RunCheck(py_args);
|
||||
|
||||
prim_->BeginRecordAddAttr();
|
||||
AbstractBasePtr abs_base = eval_impl_(engine, prim_, args);
|
||||
prim_->EndRecordAddAttr();
|
||||
auto added_attrs = prim_->evaluate_added_attrs();
|
||||
|
||||
if (!py::hasattr(prim_py->GetPyObj(), PY_PRIM_METHOD_INFER_VALUE)) {
|
||||
return std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>(added_attrs));
|
||||
}
|
||||
|
||||
// Call method 'infer_value' for primitive with this method for constant propagation
|
||||
py::tuple py_vals(py_args.size());
|
||||
auto added_attrs = prim_->evaluate_added_attrs();
|
||||
for (size_t i = 0; i < py_args.size(); ++i) {
|
||||
py_vals[i] = py_args[i][ATTR_VALUE];
|
||||
}
|
||||
|
@ -559,7 +548,6 @@ EvalResultPtr StandardPrimEvaluator::EvalPyCheckPrim(const AnalysisEnginePtr &en
|
|||
if (py::isinstance<py::none>(py_ret)) {
|
||||
return std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>(added_attrs));
|
||||
}
|
||||
|
||||
// Convert pyobject to Value, then to AbstractValue
|
||||
ValuePtr converted_ret = nullptr;
|
||||
TypePtr dtype = abs_base->BuildType();
|
||||
|
@ -577,6 +565,28 @@ EvalResultPtr StandardPrimEvaluator::EvalPyCheckPrim(const AnalysisEnginePtr &en
|
|||
return std::make_shared<EvalResult>(res_spec, std::make_shared<AttrValueMap>(added_attrs));
|
||||
}
|
||||
|
||||
EvalResultPtr StandardPrimEvaluator::EvalPyCheckPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) {
|
||||
auto prim_py = dyn_cast<PrimitivePy>(prim_);
|
||||
if (prim_py == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "The primitive with type 'kPrimTypePyInferCheck' should be a python primitive.";
|
||||
}
|
||||
// Call checking method '__check__' for subclass of 'PrimitiveWithCheck'
|
||||
MS_LOG(DEBUG) << "Begin input args checking for: " << prim_py->ToString();
|
||||
auto py_args = PreparePyInputs(prim_py, args);
|
||||
prim_py->RunCheck(py_args);
|
||||
|
||||
prim_->BeginRecordAddAttr();
|
||||
AbstractBasePtr abs_base = eval_impl_.infer_shape_dtype_impl_(engine, prim_, args);
|
||||
prim_->EndRecordAddAttr();
|
||||
auto added_attrs = prim_->evaluate_added_attrs();
|
||||
|
||||
if (!py::hasattr(prim_py->GetPyObj(), PY_PRIM_METHOD_INFER_VALUE)) {
|
||||
return std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>(added_attrs));
|
||||
}
|
||||
// Call method 'infer_value' for primitive with this method for constant propagation
|
||||
return RunPyInferValue(engine, abs_base, args);
|
||||
}
|
||||
|
||||
EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) {
|
||||
if (prims_to_skip_undetermined_infer.find(prim_->name()) == prims_to_skip_undetermined_infer.end()) {
|
||||
auto ret_abstract = AbstractEval(args);
|
||||
|
@ -589,11 +599,19 @@ EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, c
|
|||
if (prim_->prim_type() == PrimType::kPrimTypePyInferCheck) {
|
||||
return EvalPyCheckPrim(engine, args);
|
||||
}
|
||||
|
||||
auto context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
bool need_infer_value = eval_impl_.in_white_list_ || (context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode);
|
||||
prim_->BeginRecordAddAttr();
|
||||
AbstractBasePtr abs_base = eval_impl_(engine, prim_, args);
|
||||
AbstractBasePtr abs_base = eval_impl_.infer_shape_dtype_impl_(engine, prim_, args);
|
||||
prim_->EndRecordAddAttr();
|
||||
auto added_attrs = prim_->evaluate_added_attrs();
|
||||
if (need_infer_value) {
|
||||
if (eval_impl_.infer_value_func_ != nullptr) {
|
||||
auto value = eval_impl_.infer_value_func_(prim_, args, abs_base);
|
||||
abs_base->set_value(value);
|
||||
}
|
||||
}
|
||||
auto eval_result = std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>(added_attrs));
|
||||
return eval_result;
|
||||
}
|
||||
|
@ -617,7 +635,6 @@ EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const Abs
|
|||
auto added_attrs = prim_py_->evaluate_added_attrs();
|
||||
MS_LOG(DEBUG) << "Output type is " << (std::string)py::str(output);
|
||||
auto res_spec = PyInferRes2Abstract(prim_py_, output);
|
||||
|
||||
MS_LOG(DEBUG) << "Python InferTensor result spec: " << res_spec->ToString() << ".";
|
||||
auto infer_result = std::make_shared<EvalResult>(res_spec, std::make_shared<AttrValueMap>(added_attrs));
|
||||
(*evaluator_cache_map_)[args] = infer_result;
|
||||
|
@ -689,7 +706,7 @@ ValuePtr UniformPrimEvaluator::RunImpl(const ValuePtrList &args) const {
|
|||
// Primitive implementation
|
||||
// static function start
|
||||
namespace {
|
||||
EvaluatorPtr InitStandardPrimEvaluator(PrimitivePtr primitive, const StandardPrimitiveEvalImpl eval_impl) {
|
||||
EvaluatorPtr InitStandardPrimEvaluator(PrimitivePtr primitive, const StandardPrimitiveImplReg eval_impl) {
|
||||
EvaluatorPtr prim_evaluator = std::make_shared<StandardPrimEvaluator>(primitive, eval_impl);
|
||||
return prim_evaluator;
|
||||
}
|
||||
|
@ -1279,7 +1296,7 @@ void InitPrimEvaluatorConstructors() {
|
|||
PrimEvaluatorMap &constructor = PrimEvaluatorConstructors;
|
||||
|
||||
for (const auto &iter : GetPrimitiveToEvalImplMap()) {
|
||||
constructor[iter.first] = InitStandardPrimEvaluator(iter.first, iter.second.impl_);
|
||||
constructor[iter.first] = InitStandardPrimEvaluator(iter.first, iter.second);
|
||||
}
|
||||
|
||||
for (const auto &iter : GetUniformPrimitiveToImplMap()) {
|
||||
|
|
|
@ -32,7 +32,7 @@ namespace mindspore {
|
|||
namespace abstract {
|
||||
class StandardPrimEvaluator : public TrivialPrimEvaluator {
|
||||
public:
|
||||
StandardPrimEvaluator(const PrimitivePtr primitive, StandardPrimitiveEvalImpl eval_impl)
|
||||
StandardPrimEvaluator(const PrimitivePtr &primitive, const StandardPrimitiveImplReg &eval_impl)
|
||||
: TrivialPrimEvaluator("StandardPrimEvaluator"), prim_(primitive), eval_impl_(eval_impl) {}
|
||||
~StandardPrimEvaluator() override = default;
|
||||
MS_DECLARE_PARENT(StandardPrimEvaluator, TrivialPrimEvaluator);
|
||||
|
@ -43,9 +43,10 @@ class StandardPrimEvaluator : public TrivialPrimEvaluator {
|
|||
|
||||
private:
|
||||
EvalResultPtr EvalPyCheckPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args);
|
||||
|
||||
EvalResultPtr RunPyInferValue(const AnalysisEnginePtr &engine, const AbstractBasePtr &abs_base,
|
||||
const AbstractBasePtrList &args);
|
||||
PrimitivePtr prim_;
|
||||
const StandardPrimitiveEvalImpl eval_impl_;
|
||||
const StandardPrimitiveImplReg eval_impl_;
|
||||
};
|
||||
|
||||
using StandardPrimEvaluatorPtr = std::shared_ptr<StandardPrimEvaluator>;
|
||||
|
|
|
@ -358,8 +358,8 @@ EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr
|
|||
}
|
||||
|
||||
// find prim infer function in the prim function map return a standard evaluator
|
||||
StandardPrimitiveEvalImpl eval_impl = GetPrimitiveInferImpl(prim);
|
||||
if (eval_impl != nullptr) {
|
||||
auto eval_impl = GetPrimitiveInferImpl(prim);
|
||||
if (eval_impl.infer_shape_dtype_impl_ != nullptr) {
|
||||
return std::make_shared<StandardPrimEvaluator>(prim, eval_impl);
|
||||
}
|
||||
|
||||
|
|
|
@ -213,13 +213,13 @@ PrimitiveEvalImplMap &GetPrimitiveToBackendEvalImplMap() {
|
|||
return prim_backend_eval_implement_map;
|
||||
}
|
||||
|
||||
StandardPrimitiveEvalImpl GetPrimitiveInferImpl(const PrimitivePtr &primitive) {
|
||||
StandardPrimitiveImplReg GetPrimitiveInferImpl(const PrimitivePtr &primitive) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto iter = GetPrimitiveToEvalImplMap().find(primitive);
|
||||
if (iter == GetPrimitiveToEvalImplMap().end()) {
|
||||
return nullptr;
|
||||
return {nullptr, nullptr, false};
|
||||
}
|
||||
return iter->second.impl_;
|
||||
return iter->second;
|
||||
}
|
||||
|
||||
void RegisterStandardPrimitiveImpl(const PrimitivePtr &primitive, const StandardPrimitiveImplReg &impl_reg) {
|
||||
|
|
|
@ -19,21 +19,24 @@
|
|||
#define MINDSPORE_CORE_ABSTRACT_PRIMITIVE_INFER_MAP_H_
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "ir/primitive.h"
|
||||
#include "ops/primitive_c.h"
|
||||
#include "base/core_ops.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "ir/anf.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace abstract {
|
||||
using StandardPrimitiveEvalImpl = AbstractBasePtr (*)(const abstract::AnalysisEnginePtr &, const PrimitivePtr &,
|
||||
const AbstractBasePtrList &);
|
||||
using InferShapeAndTypeImpl = AbstractBasePtr (*)(const abstract::AnalysisEnginePtr &, const PrimitivePtr &,
|
||||
const AbstractBasePtrList &);
|
||||
using InferValueEvalImpl = ValuePtr (*)(const PrimitivePtr &, const AbstractBasePtrList &, const AbstractBasePtr &);
|
||||
|
||||
struct StandardPrimitiveImplReg {
|
||||
StandardPrimitiveEvalImpl impl_; // Implement function of Primitive
|
||||
InferValueEvalImpl infer_value_func_; // infer value of primitive
|
||||
// true means this primitive can be executed by vm backend else will be constant folded by frontend
|
||||
InferShapeAndTypeImpl infer_shape_dtype_impl_; // Implement function of Primitive
|
||||
InferValueEvalImpl infer_value_func_; // infer value of primitive
|
||||
// in_white_list_ is true means this primitive can be executed by vm backend
|
||||
// else will be optimized by frontend
|
||||
bool in_white_list_;
|
||||
};
|
||||
|
||||
|
@ -44,7 +47,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap();
|
|||
|
||||
PrimitiveEvalImplMap &GetPrimitiveToBackendEvalImplMap();
|
||||
|
||||
StandardPrimitiveEvalImpl GetPrimitiveInferImpl(const PrimitivePtr &primitive);
|
||||
StandardPrimitiveImplReg GetPrimitiveInferImpl(const PrimitivePtr &primitive);
|
||||
|
||||
std::vector<int64_t> GetDependsFormMap(const CNodePtr &cnode);
|
||||
|
||||
|
@ -52,17 +55,22 @@ void RegisterStandardPrimitiveImpl(const PrimitivePtr &primitive, const Standard
|
|||
|
||||
class RegisterStandardPrimitiveEvalHelper {
|
||||
public:
|
||||
RegisterStandardPrimitiveEvalHelper(const PrimitivePtr &primitive, const StandardPrimitiveEvalImpl &impl,
|
||||
RegisterStandardPrimitiveEvalHelper(const PrimitivePtr &primitive, const InferShapeAndTypeImpl &infer_impl,
|
||||
const InferValueEvalImpl &infer_value_impl, const bool is_wight_list = true) {
|
||||
const StandardPrimitiveImplReg impl_reg{impl, infer_value_impl, is_wight_list};
|
||||
const StandardPrimitiveImplReg impl_reg{infer_impl, infer_value_impl, is_wight_list};
|
||||
RegisterStandardPrimitiveImpl(primitive, impl_reg);
|
||||
}
|
||||
~RegisterStandardPrimitiveEvalHelper() = default;
|
||||
};
|
||||
|
||||
#define REGISTER_PRIMITIVE_EVAL_IMPL(name, primitive, impl, infer_value_impl, is_wight_list) \
|
||||
static auto helper_##name = \
|
||||
abstract::RegisterStandardPrimitiveEvalHelper(primitive, impl, infer_value_impl, is_wight_list)
|
||||
#define REGISTER_PRIMITIVE_EVAL_IMPL(name, primitive, infer_impl, infer_value_impl, is_wight_list) \
|
||||
static auto helper_##name = \
|
||||
abstract::RegisterStandardPrimitiveEvalHelper(primitive, infer_impl, infer_value_impl, is_wight_list); \
|
||||
std::shared_ptr<ops::PrimitiveC> GetDefaultPrimC##name() { \
|
||||
auto out = std::make_shared<name>(); \
|
||||
return out; \
|
||||
} \
|
||||
ops::OpPrimCRegisterHelper primc_gen_##name(#name, GetDefaultPrimC##name);
|
||||
} // namespace abstract
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CORE_ABSTRACT_PRIMITIVE_INFER_MAP_H_
|
||||
|
|
|
@ -70,6 +70,5 @@ AbstractBasePtr GatherDInfer(const abstract::AnalysisEnginePtr &, const Primitiv
|
|||
return abs;
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(GatherD, prim::kPrimGatherD, GatherDInfer, nullptr, false);
|
||||
REGISTER_PRIMITIVE_C(kNameGatherD, GatherD);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -32,7 +32,7 @@ AbstractBasePtr PrimitiveC::Infer(const AbstractBasePtrList &abstract_list) {
|
|||
if (iter == infer_map.end()) {
|
||||
MS_EXCEPTION(NotExistsError) << "Cannot find the " << this->name() << "infer function in the infer map!";
|
||||
}
|
||||
auto infer_function = iter->second.impl_;
|
||||
auto infer_function = iter->second.infer_shape_dtype_impl_;
|
||||
return infer_function(nullptr, shared_from_base<Primitive>(), abstract_list);
|
||||
}
|
||||
|
||||
|
|
|
@ -51,6 +51,5 @@ AbstractBasePtr ScalarSummaryInfer(const abstract::AnalysisEnginePtr &, const Pr
|
|||
return std::make_shared<abstract::AbstractTensor>(kInt32, ScalarSummaryInferShape(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(ScalarSummary, prim::kPrimScalarSummary, ScalarSummaryInfer, nullptr, true);
|
||||
REGISTER_PRIMITIVE_C(kNameScalarSummary, ScalarSummary);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -81,6 +81,5 @@ AbstractBasePtr SoftmaxInfer(const abstract::AnalysisEnginePtr &, const Primitiv
|
|||
SoftMaxInferShape(primitive, input_args)->shape());
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Softmax, prim::kPrimSoftmax, SoftmaxInfer, nullptr, true);
|
||||
REGISTER_PRIMITIVE_C(kNameSoftmax, Softmax);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -51,6 +51,5 @@ AbstractBasePtr TensorSummaryInfer(const abstract::AnalysisEnginePtr &, const Pr
|
|||
return std::make_shared<abstract::AbstractTensor>(kInt32, TensorSummaryInferShape(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(TensorSummary, prim::kPrimTensorSummary, TensorSummaryInfer, nullptr, true);
|
||||
REGISTER_PRIMITIVE_C(kNameTensorSummary, TensorSummary);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -66,10 +66,8 @@ AbstractBasePtr ZerosInfer(const abstract::AnalysisEnginePtr &, const PrimitiveP
|
|||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto abs = std::make_shared<abstract::AbstractTensor>(ZerosInferType(primitive, input_args),
|
||||
ZerosInferShape(primitive, input_args));
|
||||
abs->set_value(ZerosInferValue(primitive, input_args, abs));
|
||||
return abs;
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Zeros, prim::kPrimZeros, ZerosInfer, ZerosInferValue, false);
|
||||
REGISTER_PRIMITIVE_C(kNameZeros, Zeros);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
#ifndef MINDSPORE_CORE_UTILS_TENSOR_CONSTRUCT_UTILS_H_
|
||||
#define MINDSPORE_CORE_UTILS_TENSOR_CONSTRUCT_UTILS_H_
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include "ir/tensor.h"
|
||||
namespace mindspore {
|
||||
template <typename T>
|
||||
|
@ -23,10 +24,7 @@ void SetTensorData(void *data, T num, size_t data_length) {
|
|||
MS_EXCEPTION_IF_NULL(data);
|
||||
auto tensor_data = reinterpret_cast<T *>(data);
|
||||
MS_EXCEPTION_IF_NULL(tensor_data);
|
||||
for (size_t index = 0; index < data_length; ++index) {
|
||||
*tensor_data = num;
|
||||
++tensor_data;
|
||||
}
|
||||
std::fill(tensor_data, tensor_data + data_length, num);
|
||||
}
|
||||
class TensorConstructUtils {
|
||||
public:
|
||||
|
|
|
@ -86,10 +86,10 @@ TEST_F(TestAbstract, TestParseDataClass) {
|
|||
|
||||
AbstractBasePtrList args_list = {abs_scalar, abstract_x, abstract_y};
|
||||
|
||||
StandardPrimitiveEvalImpl eval_impl = GetPrimitiveInferImpl(prim::kPrimMakeRecord);
|
||||
ASSERT_TRUE(nullptr != eval_impl);
|
||||
auto eval_impl = GetPrimitiveInferImpl(prim::kPrimMakeRecord);
|
||||
ASSERT_TRUE(nullptr != eval_impl.infer_shape_dtype_impl_);
|
||||
|
||||
AbstractBasePtr new_cls = eval_impl(nullptr, prim::kPrimMakeRecord, args_list);
|
||||
AbstractBasePtr new_cls = eval_impl.infer_shape_dtype_impl_(nullptr, prim::kPrimMakeRecord, args_list);
|
||||
ASSERT_TRUE(nullptr != new_cls);
|
||||
}
|
||||
|
||||
|
|
|
@ -47,8 +47,8 @@ AbstractBasePtr InferImplScalarAddStub(const AnalysisEnginePtr &engine, const Pr
|
|||
}
|
||||
|
||||
EvaluatorPtr InitPrimitiveScalarAddEvaluatorStub() {
|
||||
EvaluatorPtr PrimitiveScalarAddEvaluator =
|
||||
std::make_shared<StandardPrimEvaluator>(prim::kPrimScalarAdd, InferImplScalarAddStub);
|
||||
EvaluatorPtr PrimitiveScalarAddEvaluator = std::make_shared<StandardPrimEvaluator>(
|
||||
prim::kPrimScalarAdd, StandardPrimitiveImplReg{InferImplScalarAddStub, nullptr, true});
|
||||
return PrimitiveScalarAddEvaluator;
|
||||
}
|
||||
|
||||
|
@ -63,8 +63,8 @@ AbstractBasePtr InferImplReturnStub(const AnalysisEnginePtr &engine, const Primi
|
|||
}
|
||||
|
||||
EvaluatorPtr InitPrimitiveReturnEvaluatorStub() {
|
||||
EvaluatorPtr PrimitiveReturnEvaluator =
|
||||
std::make_shared<StandardPrimEvaluator>(prim::kPrimReturn, InferImplReturnStub);
|
||||
EvaluatorPtr PrimitiveReturnEvaluator = std::make_shared<StandardPrimEvaluator>(
|
||||
prim::kPrimReturn, StandardPrimitiveImplReg{InferImplReturnStub, nullptr, true});
|
||||
return PrimitiveReturnEvaluator;
|
||||
}
|
||||
|
||||
|
@ -396,7 +396,6 @@ TEST_F(TestInferUniform, test_inferred_scalar_add) {
|
|||
ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeInt64);
|
||||
}
|
||||
|
||||
|
||||
class TestEvalOnePrim : public UT::Common {
|
||||
public:
|
||||
TestEvalOnePrim() : getPyFun("gtest_input.pipeline.infer.infer_test", true), engine_(nullptr) {}
|
||||
|
@ -435,7 +434,7 @@ class TestGraphEval : public UT::Common {
|
|||
UT::PyFuncGraphFetcher getPyFun;
|
||||
};
|
||||
|
||||
void TestGraphEval::SetUp() { engine_ = SetupAnalysisEngine(); }
|
||||
void TestGraphEval::SetUp() { engine_ = SetupAnalysisEngine(); }
|
||||
|
||||
void TestGraphEval::TearDown() {
|
||||
// destroy resource
|
||||
|
|
|
@ -25,6 +25,14 @@
|
|||
#include "common/common_test.h"
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class TestAttr : public ops::PrimitiveC {
|
||||
public:
|
||||
TestAttr() : PrimitiveC("") {}
|
||||
};
|
||||
class TestDynamicInput : public ops::PrimitiveC {
|
||||
public:
|
||||
TestDynamicInput() : PrimitiveC("") {}
|
||||
};
|
||||
constexpr auto kAttrConvertTestName = "attr_convert_test";
|
||||
constexpr auto kDynamicInputTestName = "dynamic_input_test";
|
||||
inline const PrimitivePtr kPrimAttrConvertTest = std::make_shared<Primitive>(kAttrConvertTestName);
|
||||
|
|
Loading…
Reference in New Issue