!14745 refactor infer value

From: @lianliguang
Reviewed-by: @ginfung
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-04-13 11:27:08 +08:00 committed by Gitee
commit 3068d66d63
18 changed files with 113 additions and 84 deletions

View File

@ -923,14 +923,16 @@ AbstractBasePtr CppInferShape(const PrimitivePtr &prim, const AbstractBasePtrLis
auto ret = prim_eval_implement_map.find(prim);
if (ret != prim_eval_implement_map.end()) {
// fing infer function in the front infer map and restore input abastract form dynamic inputs and reg attr
MS_EXCEPTION_IF_NULL(ret->second.infer_shape_dtype_impl_);
auto infer_spec_list = RectifyAbstract(prim, args_spec_list);
return ret->second.impl_(nullptr, prim, infer_spec_list);
return ret->second.infer_shape_dtype_impl_(nullptr, prim, infer_spec_list);
} else {
// if the infer function has been not founded in the front infer map find it in the backend infer map instead
auto &prim_backend_eval_impl_map = abstract::GetPrimitiveToBackendEvalImplMap();
auto ret_backend = prim_backend_eval_impl_map.find(prim);
if (ret_backend != prim_backend_eval_impl_map.end()) {
return ret_backend->second.impl_(nullptr, prim, args_spec_list);
MS_EXCEPTION_IF_NULL(ret_backend->second.infer_shape_dtype_impl_);
return ret_backend->second.infer_shape_dtype_impl_(nullptr, prim, args_spec_list);
}
}
MS_LOG(EXCEPTION) << "Get infer shape function failed, primitive name:" << prim->name()

View File

@ -639,26 +639,26 @@ AbstractBasePtr InferImplMakeRecord(const AnalysisEnginePtr &, const PrimitivePt
return std::make_shared<AbstractClass>(cls->tag(), abs_attributes, cls->methods());
}
REGISTER_PRIMITIVE_EVAL_IMPL(TypeOf, prim::kPrimTypeOf, InferImplTypeof, nullptr, false);
REGISTER_PRIMITIVE_EVAL_IMPL(HasType, prim::kPrimHasType, InferImplHasType, nullptr, false);
REGISTER_PRIMITIVE_EVAL_IMPL(MakeRecord, prim::kPrimMakeRecord, InferImplMakeRecord, nullptr, false);
REGISTER_PRIMITIVE_EVAL_IMPL(ListMap, prim::kPrimListMap, InferImplListMap, nullptr, false);
REGISTER_PRIMITIVE_EVAL_IMPL(ListReduce, prim::kPrimListReduce, InferImplListReduce, nullptr, false);
REGISTER_PRIMITIVE_EVAL_IMPL(TupleReversed, prim::kPrimTupleReversed, InferImplTupleReversed, nullptr, false);
REGISTER_PRIMITIVE_EVAL_IMPL(ReducedShape, prim::kPrimReducedShape, InferImplReduceShape, nullptr, false);
REGISTER_PRIMITIVE_EVAL_IMPL(TupleDiv, prim::kPrimTupleDiv, InferImplTupleDiv, nullptr, false);
REGISTER_PRIMITIVE_EVAL_IMPL(TupleToArray, prim::kPrimTupleToArray, InferImplTuple2Array, nullptr, false);
REGISTER_PRIMITIVE_EVAL_IMPL(ShapeMul, prim::kPrimShapeMul, InferImplShapeMul, nullptr, false);
REGISTER_PRIMITIVE_EVAL_IMPL(TupleEqual, prim::kPrimTupleEqual, InferImplTupleEqual, nullptr, false);
REGISTER_PRIMITIVE_EVAL_IMPL(ListEqual, prim::kPrimListEqual, InferImplListEqual, nullptr, false);
REGISTER_PRIMITIVE_EVAL_IMPL(MakeRange, prim::kPrimMakeRange, InferImplMakeRange, nullptr, false);
REGISTER_PRIMITIVE_EVAL_IMPL(StopGradient, prim::kPrimStopGradient, InferImplStopGradient, nullptr, false);
REGISTER_PRIMITIVE_EVAL_IMPL(StringEqual, prim::kPrimStringEqual, InferImplStringEqual, nullptr, false);
REGISTER_PRIMITIVE_EVAL_IMPL(StringConcat, prim::kPrimStringConcat, InferImplStringConcat, nullptr, false);
REGISTER_PRIMITIVE_EVAL_IMPL(DictLen, prim::kPrimDictLen, InferImplDictLen, nullptr, false);
REGISTER_PRIMITIVE_EVAL_IMPL(FakeBprop, prim::kPrimFakeBprop, InferImplFakeBprop, nullptr, false);
REGISTER_PRIMITIVE_EVAL_IMPL(J, prim::kPrimJ, InferImplJ, nullptr, false);
REGISTER_PRIMITIVE_EVAL_IMPL(BroadcastGradientArgs, prim::kPrimBroadcastGradientArgs, InferImplBroadcastGradientArgs,
nullptr, false);
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(TypeOf, prim::kPrimTypeOf, InferImplTypeof, nullptr);
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(HasType, prim::kPrimHasType, InferImplHasType, nullptr);
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(MakeRecord, prim::kPrimMakeRecord, InferImplMakeRecord, nullptr);
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(ListMap, prim::kPrimListMap, InferImplListMap, nullptr);
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(ListReduce, prim::kPrimListReduce, InferImplListReduce, nullptr);
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(TupleReversed, prim::kPrimTupleReversed, InferImplTupleReversed, nullptr);
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(ReducedShape, prim::kPrimReducedShape, InferImplReduceShape, nullptr);
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(TupleDiv, prim::kPrimTupleDiv, InferImplTupleDiv, nullptr);
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(TupleToArray, prim::kPrimTupleToArray, InferImplTuple2Array, nullptr);
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(ShapeMul, prim::kPrimShapeMul, InferImplShapeMul, nullptr);
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(TupleEqual, prim::kPrimTupleEqual, InferImplTupleEqual, nullptr);
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(ListEqual, prim::kPrimListEqual, InferImplListEqual, nullptr);
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(MakeRange, prim::kPrimMakeRange, InferImplMakeRange, nullptr);
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(StopGradient, prim::kPrimStopGradient, InferImplStopGradient, nullptr);
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(StringEqual, prim::kPrimStringEqual, InferImplStringEqual, nullptr);
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(StringConcat, prim::kPrimStringConcat, InferImplStringConcat, nullptr);
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(DictLen, prim::kPrimDictLen, InferImplDictLen, nullptr);
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(FakeBprop, prim::kPrimFakeBprop, InferImplFakeBprop, nullptr);
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(J, prim::kPrimJ, InferImplJ, nullptr);
REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(BroadcastGradientArgs, prim::kPrimBroadcastGradientArgs,
InferImplBroadcastGradientArgs, nullptr);
} // namespace abstract
} // namespace mindspore

View File

@ -59,7 +59,9 @@ AbstractBasePtr InferImplFakeBprop(const AnalysisEnginePtr &, const PrimitivePtr
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplMakeRecord(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
#define REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(name, primitive, infer_impl, infer_value_impl) \
static auto helper_##name = \
abstract::RegisterStandardPrimitiveEvalHelper(primitive, infer_impl, infer_value_impl, false);
} // namespace abstract
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_OPERATE_OPS_FRONT_INFER_FUNCTION_H_

View File

@ -530,28 +530,17 @@ AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dic
}
} // end anonymous namespace
EvalResultPtr StandardPrimEvaluator::EvalPyCheckPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) {
EvalResultPtr StandardPrimEvaluator::RunPyInferValue(const AnalysisEnginePtr &engine, const AbstractBasePtr &abs_base,
const AbstractBasePtrList &args) {
auto prim_py = dyn_cast<PrimitivePy>(prim_);
if (prim_py == nullptr) {
MS_LOG(EXCEPTION) << "The primitive with type 'kPrimTypePyInferCheck' should be a python primitive.";
}
// Call checking method '__check__' for subclass of 'PrimitiveWithCheck'
// Call checking method 'infer_value' for python primitive
MS_LOG(DEBUG) << "Begin input args checking for: " << prim_py->ToString();
auto py_args = PreparePyInputs(prim_py, args);
prim_py->RunCheck(py_args);
prim_->BeginRecordAddAttr();
AbstractBasePtr abs_base = eval_impl_(engine, prim_, args);
prim_->EndRecordAddAttr();
auto added_attrs = prim_->evaluate_added_attrs();
if (!py::hasattr(prim_py->GetPyObj(), PY_PRIM_METHOD_INFER_VALUE)) {
return std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>(added_attrs));
}
// Call method 'infer_value' for primitive with this method for constant propagation
py::tuple py_vals(py_args.size());
auto added_attrs = prim_->evaluate_added_attrs();
for (size_t i = 0; i < py_args.size(); ++i) {
py_vals[i] = py_args[i][ATTR_VALUE];
}
@ -559,7 +548,6 @@ EvalResultPtr StandardPrimEvaluator::EvalPyCheckPrim(const AnalysisEnginePtr &en
if (py::isinstance<py::none>(py_ret)) {
return std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>(added_attrs));
}
// Convert pyobject to Value, then to AbstractValue
ValuePtr converted_ret = nullptr;
TypePtr dtype = abs_base->BuildType();
@ -577,6 +565,28 @@ EvalResultPtr StandardPrimEvaluator::EvalPyCheckPrim(const AnalysisEnginePtr &en
return std::make_shared<EvalResult>(res_spec, std::make_shared<AttrValueMap>(added_attrs));
}
EvalResultPtr StandardPrimEvaluator::EvalPyCheckPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) {
auto prim_py = dyn_cast<PrimitivePy>(prim_);
if (prim_py == nullptr) {
MS_LOG(EXCEPTION) << "The primitive with type 'kPrimTypePyInferCheck' should be a python primitive.";
}
// Call checking method '__check__' for subclass of 'PrimitiveWithCheck'
MS_LOG(DEBUG) << "Begin input args checking for: " << prim_py->ToString();
auto py_args = PreparePyInputs(prim_py, args);
prim_py->RunCheck(py_args);
prim_->BeginRecordAddAttr();
AbstractBasePtr abs_base = eval_impl_.infer_shape_dtype_impl_(engine, prim_, args);
prim_->EndRecordAddAttr();
auto added_attrs = prim_->evaluate_added_attrs();
if (!py::hasattr(prim_py->GetPyObj(), PY_PRIM_METHOD_INFER_VALUE)) {
return std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>(added_attrs));
}
// Call method 'infer_value' for primitive with this method for constant propagation
return RunPyInferValue(engine, abs_base, args);
}
EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) {
if (prims_to_skip_undetermined_infer.find(prim_->name()) == prims_to_skip_undetermined_infer.end()) {
auto ret_abstract = AbstractEval(args);
@ -589,11 +599,19 @@ EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, c
if (prim_->prim_type() == PrimType::kPrimTypePyInferCheck) {
return EvalPyCheckPrim(engine, args);
}
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
bool need_infer_value = eval_impl_.in_white_list_ || (context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode);
prim_->BeginRecordAddAttr();
AbstractBasePtr abs_base = eval_impl_(engine, prim_, args);
AbstractBasePtr abs_base = eval_impl_.infer_shape_dtype_impl_(engine, prim_, args);
prim_->EndRecordAddAttr();
auto added_attrs = prim_->evaluate_added_attrs();
if (need_infer_value) {
if (eval_impl_.infer_value_func_ != nullptr) {
auto value = eval_impl_.infer_value_func_(prim_, args, abs_base);
abs_base->set_value(value);
}
}
auto eval_result = std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>(added_attrs));
return eval_result;
}
@ -617,7 +635,6 @@ EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const Abs
auto added_attrs = prim_py_->evaluate_added_attrs();
MS_LOG(DEBUG) << "Output type is " << (std::string)py::str(output);
auto res_spec = PyInferRes2Abstract(prim_py_, output);
MS_LOG(DEBUG) << "Python InferTensor result spec: " << res_spec->ToString() << ".";
auto infer_result = std::make_shared<EvalResult>(res_spec, std::make_shared<AttrValueMap>(added_attrs));
(*evaluator_cache_map_)[args] = infer_result;
@ -689,7 +706,7 @@ ValuePtr UniformPrimEvaluator::RunImpl(const ValuePtrList &args) const {
// Primitive implementation
// static function start
namespace {
EvaluatorPtr InitStandardPrimEvaluator(PrimitivePtr primitive, const StandardPrimitiveEvalImpl eval_impl) {
EvaluatorPtr InitStandardPrimEvaluator(PrimitivePtr primitive, const StandardPrimitiveImplReg eval_impl) {
EvaluatorPtr prim_evaluator = std::make_shared<StandardPrimEvaluator>(primitive, eval_impl);
return prim_evaluator;
}
@ -1279,7 +1296,7 @@ void InitPrimEvaluatorConstructors() {
PrimEvaluatorMap &constructor = PrimEvaluatorConstructors;
for (const auto &iter : GetPrimitiveToEvalImplMap()) {
constructor[iter.first] = InitStandardPrimEvaluator(iter.first, iter.second.impl_);
constructor[iter.first] = InitStandardPrimEvaluator(iter.first, iter.second);
}
for (const auto &iter : GetUniformPrimitiveToImplMap()) {

View File

@ -32,7 +32,7 @@ namespace mindspore {
namespace abstract {
class StandardPrimEvaluator : public TrivialPrimEvaluator {
public:
StandardPrimEvaluator(const PrimitivePtr primitive, StandardPrimitiveEvalImpl eval_impl)
StandardPrimEvaluator(const PrimitivePtr &primitive, const StandardPrimitiveImplReg &eval_impl)
: TrivialPrimEvaluator("StandardPrimEvaluator"), prim_(primitive), eval_impl_(eval_impl) {}
~StandardPrimEvaluator() override = default;
MS_DECLARE_PARENT(StandardPrimEvaluator, TrivialPrimEvaluator);
@ -43,9 +43,10 @@ class StandardPrimEvaluator : public TrivialPrimEvaluator {
private:
EvalResultPtr EvalPyCheckPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args);
EvalResultPtr RunPyInferValue(const AnalysisEnginePtr &engine, const AbstractBasePtr &abs_base,
const AbstractBasePtrList &args);
PrimitivePtr prim_;
const StandardPrimitiveEvalImpl eval_impl_;
const StandardPrimitiveImplReg eval_impl_;
};
using StandardPrimEvaluatorPtr = std::shared_ptr<StandardPrimEvaluator>;

View File

@ -358,8 +358,8 @@ EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr
}
// find prim infer function in the prim function map return a standard evaluator
StandardPrimitiveEvalImpl eval_impl = GetPrimitiveInferImpl(prim);
if (eval_impl != nullptr) {
auto eval_impl = GetPrimitiveInferImpl(prim);
if (eval_impl.infer_shape_dtype_impl_ != nullptr) {
return std::make_shared<StandardPrimEvaluator>(prim, eval_impl);
}

View File

@ -213,13 +213,13 @@ PrimitiveEvalImplMap &GetPrimitiveToBackendEvalImplMap() {
return prim_backend_eval_implement_map;
}
StandardPrimitiveEvalImpl GetPrimitiveInferImpl(const PrimitivePtr &primitive) {
StandardPrimitiveImplReg GetPrimitiveInferImpl(const PrimitivePtr &primitive) {
MS_EXCEPTION_IF_NULL(primitive);
auto iter = GetPrimitiveToEvalImplMap().find(primitive);
if (iter == GetPrimitiveToEvalImplMap().end()) {
return nullptr;
return {nullptr, nullptr, false};
}
return iter->second.impl_;
return iter->second;
}
void RegisterStandardPrimitiveImpl(const PrimitivePtr &primitive, const StandardPrimitiveImplReg &impl_reg) {

View File

@ -19,21 +19,24 @@
#define MINDSPORE_CORE_ABSTRACT_PRIMITIVE_INFER_MAP_H_
#include <unordered_map>
#include <vector>
#include <memory>
#include "ir/primitive.h"
#include "ops/primitive_c.h"
#include "base/core_ops.h"
#include "abstract/abstract_value.h"
#include "ir/anf.h"
namespace mindspore {
namespace abstract {
using StandardPrimitiveEvalImpl = AbstractBasePtr (*)(const abstract::AnalysisEnginePtr &, const PrimitivePtr &,
const AbstractBasePtrList &);
using InferShapeAndTypeImpl = AbstractBasePtr (*)(const abstract::AnalysisEnginePtr &, const PrimitivePtr &,
const AbstractBasePtrList &);
using InferValueEvalImpl = ValuePtr (*)(const PrimitivePtr &, const AbstractBasePtrList &, const AbstractBasePtr &);
struct StandardPrimitiveImplReg {
StandardPrimitiveEvalImpl impl_; // Implement function of Primitive
InferValueEvalImpl infer_value_func_; // infer value of primitive
// true means this primitive can be executed by vm backend else will be constant folded by frontend
InferShapeAndTypeImpl infer_shape_dtype_impl_; // Implement function of Primitive
InferValueEvalImpl infer_value_func_; // infer value of primitive
// in_white_list_ is true means this primitive can be executed by vm backend
// else will be optimized by frontend
bool in_white_list_;
};
@ -44,7 +47,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap();
PrimitiveEvalImplMap &GetPrimitiveToBackendEvalImplMap();
StandardPrimitiveEvalImpl GetPrimitiveInferImpl(const PrimitivePtr &primitive);
StandardPrimitiveImplReg GetPrimitiveInferImpl(const PrimitivePtr &primitive);
std::vector<int64_t> GetDependsFormMap(const CNodePtr &cnode);
@ -52,17 +55,22 @@ void RegisterStandardPrimitiveImpl(const PrimitivePtr &primitive, const Standard
class RegisterStandardPrimitiveEvalHelper {
public:
RegisterStandardPrimitiveEvalHelper(const PrimitivePtr &primitive, const StandardPrimitiveEvalImpl &impl,
RegisterStandardPrimitiveEvalHelper(const PrimitivePtr &primitive, const InferShapeAndTypeImpl &infer_impl,
const InferValueEvalImpl &infer_value_impl, const bool is_wight_list = true) {
const StandardPrimitiveImplReg impl_reg{impl, infer_value_impl, is_wight_list};
const StandardPrimitiveImplReg impl_reg{infer_impl, infer_value_impl, is_wight_list};
RegisterStandardPrimitiveImpl(primitive, impl_reg);
}
~RegisterStandardPrimitiveEvalHelper() = default;
};
#define REGISTER_PRIMITIVE_EVAL_IMPL(name, primitive, impl, infer_value_impl, is_wight_list) \
static auto helper_##name = \
abstract::RegisterStandardPrimitiveEvalHelper(primitive, impl, infer_value_impl, is_wight_list)
#define REGISTER_PRIMITIVE_EVAL_IMPL(name, primitive, infer_impl, infer_value_impl, is_wight_list) \
static auto helper_##name = \
abstract::RegisterStandardPrimitiveEvalHelper(primitive, infer_impl, infer_value_impl, is_wight_list); \
std::shared_ptr<ops::PrimitiveC> GetDefaultPrimC##name() { \
auto out = std::make_shared<name>(); \
return out; \
} \
ops::OpPrimCRegisterHelper primc_gen_##name(#name, GetDefaultPrimC##name);
} // namespace abstract
} // namespace mindspore
#endif // MINDSPORE_CORE_ABSTRACT_PRIMITIVE_INFER_MAP_H_

View File

@ -70,6 +70,5 @@ AbstractBasePtr GatherDInfer(const abstract::AnalysisEnginePtr &, const Primitiv
return abs;
}
REGISTER_PRIMITIVE_EVAL_IMPL(GatherD, prim::kPrimGatherD, GatherDInfer, nullptr, false);
REGISTER_PRIMITIVE_C(kNameGatherD, GatherD);
} // namespace ops
} // namespace mindspore

View File

@ -32,7 +32,7 @@ AbstractBasePtr PrimitiveC::Infer(const AbstractBasePtrList &abstract_list) {
if (iter == infer_map.end()) {
MS_EXCEPTION(NotExistsError) << "Cannot find the " << this->name() << "infer function in the infer map!";
}
auto infer_function = iter->second.impl_;
auto infer_function = iter->second.infer_shape_dtype_impl_;
return infer_function(nullptr, shared_from_base<Primitive>(), abstract_list);
}

View File

@ -51,6 +51,5 @@ AbstractBasePtr ScalarSummaryInfer(const abstract::AnalysisEnginePtr &, const Pr
return std::make_shared<abstract::AbstractTensor>(kInt32, ScalarSummaryInferShape(primitive, input_args));
}
REGISTER_PRIMITIVE_EVAL_IMPL(ScalarSummary, prim::kPrimScalarSummary, ScalarSummaryInfer, nullptr, true);
REGISTER_PRIMITIVE_C(kNameScalarSummary, ScalarSummary);
} // namespace ops
} // namespace mindspore

View File

@ -81,6 +81,5 @@ AbstractBasePtr SoftmaxInfer(const abstract::AnalysisEnginePtr &, const Primitiv
SoftMaxInferShape(primitive, input_args)->shape());
}
REGISTER_PRIMITIVE_EVAL_IMPL(Softmax, prim::kPrimSoftmax, SoftmaxInfer, nullptr, true);
REGISTER_PRIMITIVE_C(kNameSoftmax, Softmax);
} // namespace ops
} // namespace mindspore

View File

@ -51,6 +51,5 @@ AbstractBasePtr TensorSummaryInfer(const abstract::AnalysisEnginePtr &, const Pr
return std::make_shared<abstract::AbstractTensor>(kInt32, TensorSummaryInferShape(primitive, input_args));
}
REGISTER_PRIMITIVE_EVAL_IMPL(TensorSummary, prim::kPrimTensorSummary, TensorSummaryInfer, nullptr, true);
REGISTER_PRIMITIVE_C(kNameTensorSummary, TensorSummary);
} // namespace ops
} // namespace mindspore

View File

@ -66,10 +66,8 @@ AbstractBasePtr ZerosInfer(const abstract::AnalysisEnginePtr &, const PrimitiveP
MS_EXCEPTION_IF_NULL(primitive);
auto abs = std::make_shared<abstract::AbstractTensor>(ZerosInferType(primitive, input_args),
ZerosInferShape(primitive, input_args));
abs->set_value(ZerosInferValue(primitive, input_args, abs));
return abs;
}
REGISTER_PRIMITIVE_EVAL_IMPL(Zeros, prim::kPrimZeros, ZerosInfer, ZerosInferValue, false);
REGISTER_PRIMITIVE_C(kNameZeros, Zeros);
} // namespace ops
} // namespace mindspore

View File

@ -16,6 +16,7 @@
#ifndef MINDSPORE_CORE_UTILS_TENSOR_CONSTRUCT_UTILS_H_
#define MINDSPORE_CORE_UTILS_TENSOR_CONSTRUCT_UTILS_H_
#include <vector>
#include <algorithm>
#include "ir/tensor.h"
namespace mindspore {
template <typename T>
@ -23,10 +24,7 @@ void SetTensorData(void *data, T num, size_t data_length) {
MS_EXCEPTION_IF_NULL(data);
auto tensor_data = reinterpret_cast<T *>(data);
MS_EXCEPTION_IF_NULL(tensor_data);
for (size_t index = 0; index < data_length; ++index) {
*tensor_data = num;
++tensor_data;
}
std::fill(tensor_data, tensor_data + data_length, num);
}
class TensorConstructUtils {
public:

View File

@ -86,10 +86,10 @@ TEST_F(TestAbstract, TestParseDataClass) {
AbstractBasePtrList args_list = {abs_scalar, abstract_x, abstract_y};
StandardPrimitiveEvalImpl eval_impl = GetPrimitiveInferImpl(prim::kPrimMakeRecord);
ASSERT_TRUE(nullptr != eval_impl);
auto eval_impl = GetPrimitiveInferImpl(prim::kPrimMakeRecord);
ASSERT_TRUE(nullptr != eval_impl.infer_shape_dtype_impl_);
AbstractBasePtr new_cls = eval_impl(nullptr, prim::kPrimMakeRecord, args_list);
AbstractBasePtr new_cls = eval_impl.infer_shape_dtype_impl_(nullptr, prim::kPrimMakeRecord, args_list);
ASSERT_TRUE(nullptr != new_cls);
}

View File

@ -47,8 +47,8 @@ AbstractBasePtr InferImplScalarAddStub(const AnalysisEnginePtr &engine, const Pr
}
EvaluatorPtr InitPrimitiveScalarAddEvaluatorStub() {
EvaluatorPtr PrimitiveScalarAddEvaluator =
std::make_shared<StandardPrimEvaluator>(prim::kPrimScalarAdd, InferImplScalarAddStub);
EvaluatorPtr PrimitiveScalarAddEvaluator = std::make_shared<StandardPrimEvaluator>(
prim::kPrimScalarAdd, StandardPrimitiveImplReg{InferImplScalarAddStub, nullptr, true});
return PrimitiveScalarAddEvaluator;
}
@ -63,8 +63,8 @@ AbstractBasePtr InferImplReturnStub(const AnalysisEnginePtr &engine, const Primi
}
EvaluatorPtr InitPrimitiveReturnEvaluatorStub() {
EvaluatorPtr PrimitiveReturnEvaluator =
std::make_shared<StandardPrimEvaluator>(prim::kPrimReturn, InferImplReturnStub);
EvaluatorPtr PrimitiveReturnEvaluator = std::make_shared<StandardPrimEvaluator>(
prim::kPrimReturn, StandardPrimitiveImplReg{InferImplReturnStub, nullptr, true});
return PrimitiveReturnEvaluator;
}
@ -396,7 +396,6 @@ TEST_F(TestInferUniform, test_inferred_scalar_add) {
ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeInt64);
}
class TestEvalOnePrim : public UT::Common {
public:
TestEvalOnePrim() : getPyFun("gtest_input.pipeline.infer.infer_test", true), engine_(nullptr) {}
@ -435,7 +434,7 @@ class TestGraphEval : public UT::Common {
UT::PyFuncGraphFetcher getPyFun;
};
void TestGraphEval::SetUp() { engine_ = SetupAnalysisEngine(); }
void TestGraphEval::SetUp() { engine_ = SetupAnalysisEngine(); }
void TestGraphEval::TearDown() {
// destroy resource

View File

@ -25,6 +25,14 @@
#include "common/common_test.h"
namespace mindspore {
namespace opt {
class TestAttr : public ops::PrimitiveC {
public:
TestAttr() : PrimitiveC("") {}
};
class TestDynamicInput : public ops::PrimitiveC {
public:
TestDynamicInput() : PrimitiveC("") {}
};
constexpr auto kAttrConvertTestName = "attr_convert_test";
constexpr auto kDynamicInputTestName = "dynamic_input_test";
inline const PrimitivePtr kPrimAttrConvertTest = std::make_shared<Primitive>(kAttrConvertTestName);