!15431 change infer value

From: @lianliguang
Reviewed-by: @zh_qh,@ginfung
Signed-off-by: @zh_qh
This commit is contained in:
mindspore-ci-bot 2021-04-21 18:00:13 +08:00 committed by Gitee
commit 2e43434221
11 changed files with 50 additions and 47 deletions

View File

@ -924,16 +924,16 @@ AbstractBasePtr CppInferShape(const PrimitivePtr &prim, const AbstractBasePtrLis
auto ret = prim_eval_implement_map.find(prim);
if (ret != prim_eval_implement_map.end()) {
// fing infer function in the front infer map and restore input abastract form dynamic inputs and reg attr
MS_EXCEPTION_IF_NULL(ret->second.infer_shape_dtype_impl_);
MS_EXCEPTION_IF_NULL(ret->second.infer_shape_impl_);
auto infer_spec_list = RectifyAbstract(prim, args_spec_list);
return ret->second.infer_shape_dtype_impl_(nullptr, prim, infer_spec_list);
return ret->second.infer_shape_impl_(nullptr, prim, infer_spec_list);
} else {
// if the infer function has been not founded in the front infer map find it in the backend infer map instead
auto &prim_backend_eval_impl_map = abstract::GetPrimitiveToBackendEvalImplMap();
auto ret_backend = prim_backend_eval_impl_map.find(prim);
if (ret_backend != prim_backend_eval_impl_map.end()) {
MS_EXCEPTION_IF_NULL(ret_backend->second.infer_shape_dtype_impl_);
return ret_backend->second.infer_shape_dtype_impl_(nullptr, prim, args_spec_list);
MS_EXCEPTION_IF_NULL(ret_backend->second.infer_shape_impl_);
return ret_backend->second.infer_shape_impl_(nullptr, prim, args_spec_list);
}
}
MS_LOG(EXCEPTION) << "Get infer shape function failed, primitive name:" << prim->name()

View File

@ -576,7 +576,7 @@ EvalResultPtr StandardPrimEvaluator::EvalPyCheckPrim(const AnalysisEnginePtr &en
prim_py->RunCheck(py_args);
prim_->BeginRecordAddAttr();
AbstractBasePtr abs_base = eval_impl_.infer_shape_dtype_impl_(engine, prim_, args);
AbstractBasePtr abs_base = eval_impl_.infer_shape_impl_(engine, prim_, args);
prim_->EndRecordAddAttr();
auto added_attrs = prim_->evaluate_added_attrs();
@ -602,16 +602,21 @@ EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, c
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
bool need_infer_value = eval_impl_.in_white_list_ || (context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode);
AbstractBasePtr abs_base = nullptr;
ValuePtr value = nullptr;
prim_->BeginRecordAddAttr();
AbstractBasePtr abs_base = eval_impl_.infer_shape_dtype_impl_(engine, prim_, args);
prim_->EndRecordAddAttr();
auto added_attrs = prim_->evaluate_added_attrs();
if (need_infer_value) {
if (eval_impl_.infer_value_func_ != nullptr) {
auto value = eval_impl_.infer_value_func_(prim_, args, abs_base);
abs_base->set_value(value);
if (need_infer_value && eval_impl_.infer_value_impl_ != nullptr) {
value = eval_impl_.infer_value_impl_(prim_, args);
if (value != nullptr) {
abs_base = value->ToAbstract();
prim_->EndRecordAddAttr();
auto added_attrs = prim_->evaluate_added_attrs();
return std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>(added_attrs));
}
}
abs_base = eval_impl_.infer_shape_impl_(engine, prim_, args);
prim_->EndRecordAddAttr();
auto added_attrs = prim_->evaluate_added_attrs();
auto eval_result = std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>(added_attrs));
return eval_result;
}

View File

@ -359,7 +359,7 @@ EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr
// find prim infer function in the prim function map return a standard evaluator
auto eval_impl = GetPrimitiveInferImpl(prim);
if (eval_impl.infer_shape_dtype_impl_ != nullptr) {
if (eval_impl.infer_shape_impl_ != nullptr) {
return std::make_shared<StandardPrimEvaluator>(prim, eval_impl);
}

View File

@ -28,13 +28,13 @@
namespace mindspore {
namespace abstract {
using InferShapeAndTypeImpl = AbstractBasePtr (*)(const abstract::AnalysisEnginePtr &, const PrimitivePtr &,
const AbstractBasePtrList &);
using InferValueEvalImpl = ValuePtr (*)(const PrimitivePtr &, const AbstractBasePtrList &, const AbstractBasePtr &);
using InferShapeImpl = AbstractBasePtr (*)(const abstract::AnalysisEnginePtr &, const PrimitivePtr &,
const AbstractBasePtrList &);
using InferValueImpl = ValuePtr (*)(const PrimitivePtr &, const AbstractBasePtrList &);
struct StandardPrimitiveImplReg {
InferShapeAndTypeImpl infer_shape_dtype_impl_; // Implement function of Primitive
InferValueEvalImpl infer_value_func_; // infer value of primitive
InferShapeImpl infer_shape_impl_; // infer shape and type for ops
InferValueImpl infer_value_impl_; // infer value for ops
// in_white_list_ is true means this primitive can be executed by vm backend
// else will be optimized by frontend
bool in_white_list_;
@ -55,8 +55,8 @@ void RegisterStandardPrimitiveImpl(const PrimitivePtr &primitive, const Standard
class RegisterStandardPrimitiveEvalHelper {
public:
RegisterStandardPrimitiveEvalHelper(const PrimitivePtr &primitive, const InferShapeAndTypeImpl &infer_impl,
const InferValueEvalImpl &infer_value_impl, const bool is_wight_list = true) {
RegisterStandardPrimitiveEvalHelper(const PrimitivePtr &primitive, const InferShapeImpl &infer_impl,
const InferValueImpl &infer_value_impl, const bool is_wight_list = true) {
const StandardPrimitiveImplReg impl_reg{infer_impl, infer_value_impl, is_wight_list};
RegisterStandardPrimitiveImpl(primitive, impl_reg);
}

View File

@ -27,8 +27,7 @@
namespace mindspore {
namespace ops {
ValuePtr DTypeInferValue(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args,
const AbstractBasePtr &infer) {
ValuePtr DTypeInferValue(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name();
CheckAndConvertUtils::CheckInteger("dtype infer", input_args.size(), kEqual, 1, op_name);
@ -41,7 +40,7 @@ ValuePtr DTypeInferValue(const PrimitivePtr &primitive, const std::vector<Abstra
AbstractBasePtr DTypeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
auto value = DTypeInferValue(primitive, input_args, nullptr);
auto value = DTypeInferValue(primitive, input_args);
MS_EXCEPTION_IF_NULL(value);
auto type = value->cast<TypePtr>();
MS_EXCEPTION_IF_NULL(type);

View File

@ -32,7 +32,7 @@ AbstractBasePtr PrimitiveC::Infer(const AbstractBasePtrList &abstract_list) {
if (iter == infer_map.end()) {
MS_EXCEPTION(NotExistsError) << "Cannot find the " << this->name() << "infer function in the infer map!";
}
auto infer_function = iter->second.infer_shape_dtype_impl_;
auto infer_function = iter->second.infer_shape_impl_;
return infer_function(nullptr, shared_from_base<Primitive>(), abstract_list);
}

View File

@ -45,8 +45,7 @@ AbstractBasePtr ShapeInfer(const abstract::AnalysisEnginePtr &, const PrimitiveP
return abs;
}
ValuePtr ShapeInferValue(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args,
const AbstractBasePtr &infer) {
ValuePtr ShapeInferValue(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name();
CheckAndConvertUtils::CheckInteger("shape infer", input_args.size(), kEqual, 1, op_name);

View File

@ -49,17 +49,6 @@ TypePtr ZerosInferType(const PrimitivePtr &prim, const std::vector<AbstractBaseP
kUInt16, kUInt32, kUInt64, kFloat16, kFloat32, kFloat64};
return CheckAndConvertUtils::CheckSubClass("dtype", output_type, valid_types, prim_name);
}
ValuePtr ZerosInferValue(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args,
const abstract::AbstractBasePtr &abs) {
MS_EXCEPTION_IF_NULL(prim);
// check
auto out_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(abs->BuildShape())[kShape];
auto out_type = abs->BuildType();
MS_EXCEPTION_IF_NULL(out_type);
return TensorConstructUtils::CreateZerosTensor(out_type, out_shape);
}
} // namespace
AbstractBasePtr ZerosInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
@ -67,6 +56,17 @@ AbstractBasePtr ZerosInfer(const abstract::AnalysisEnginePtr &, const PrimitiveP
ZerosInferShape(primitive, input_args));
return abs;
}
ValuePtr ZerosInferValue(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
auto abs = ZerosInfer(nullptr, prim, input_args);
// check
auto out_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(abs->BuildShape())[kShape];
auto out_type = abs->BuildType();
MS_EXCEPTION_IF_NULL(out_type);
return TensorConstructUtils::CreateZerosTensor(out_type, out_shape);
}
} // namespace
REGISTER_PRIMITIVE_EVAL_IMPL(Zeros, prim::kPrimZeros, ZerosInfer, ZerosInferValue, false);
} // namespace ops
} // namespace mindspore

View File

@ -17,7 +17,7 @@
#include <vector>
#include <memory>
namespace mindspore {
tensor::TensorPtr TensorConstructUtils::CreateZerosTensor(const TypePtr type_ptr, const std::vector<int64_t> &shape) {
tensor::TensorPtr TensorConstructUtils::CreateZerosTensor(const TypePtr &type_ptr, const std::vector<int64_t> &shape) {
MS_EXCEPTION_IF_NULL(type_ptr);
auto type_id = ExtractTypeId(type_ptr);
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, shape);
@ -30,7 +30,7 @@ tensor::TensorPtr TensorConstructUtils::CreateZerosTensor(const TypePtr type_ptr
return tensor;
}
tensor::TensorPtr TensorConstructUtils::CreateOnesTensor(const TypePtr type_ptr, const std::vector<int64_t> &shape) {
tensor::TensorPtr TensorConstructUtils::CreateOnesTensor(const TypePtr &type_ptr, const std::vector<int64_t> &shape) {
MS_EXCEPTION_IF_NULL(type_ptr);
auto type_id = ExtractTypeId(type_ptr);
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, shape);
@ -43,7 +43,7 @@ tensor::TensorPtr TensorConstructUtils::CreateOnesTensor(const TypePtr type_ptr,
return tensor;
}
tensor::TensorPtr TensorConstructUtils::CreateTensor(const TypePtr type_ptr, const std::vector<int64_t> &shape,
tensor::TensorPtr TensorConstructUtils::CreateTensor(const TypePtr &type_ptr, const std::vector<int64_t> &shape,
void *data) {
MS_EXCEPTION_IF_NULL(type_ptr);
auto type_id = ExtractTypeId(type_ptr);
@ -51,7 +51,7 @@ tensor::TensorPtr TensorConstructUtils::CreateTensor(const TypePtr type_ptr, con
return tensor;
}
TypeId TensorConstructUtils::ExtractTypeId(const TypePtr type_ptr) {
TypeId TensorConstructUtils::ExtractTypeId(const TypePtr &type_ptr) {
MS_EXCEPTION_IF_NULL(type_ptr);
TypeId type_id;
if (type_ptr->isa<TensorType>()) {

View File

@ -28,12 +28,12 @@ void SetTensorData(void *data, T num, size_t data_length) {
}
class TensorConstructUtils {
public:
static tensor::TensorPtr CreateZerosTensor(const TypePtr type, const std::vector<int64_t> &shape);
static tensor::TensorPtr CreateOnesTensor(const TypePtr type, const std::vector<int64_t> &shape);
static tensor::TensorPtr CreateTensor(const TypePtr type, const std::vector<int64_t> &shape, void *data);
static tensor::TensorPtr CreateZerosTensor(const TypePtr &type, const std::vector<int64_t> &shape);
static tensor::TensorPtr CreateOnesTensor(const TypePtr &type, const std::vector<int64_t> &shape);
static tensor::TensorPtr CreateTensor(const TypePtr &type, const std::vector<int64_t> &shape, void *data);
private:
static TypeId ExtractTypeId(const TypePtr type);
static TypeId ExtractTypeId(const TypePtr &type);
};
} // namespace mindspore
#endif // MINDSPORE_CORE_UTILS_TENSOR_CONSTRUCT_UTILS_H_

View File

@ -87,9 +87,9 @@ TEST_F(TestAbstract, TestParseDataClass) {
AbstractBasePtrList args_list = {abs_scalar, abstract_x, abstract_y};
auto eval_impl = GetPrimitiveInferImpl(prim::kPrimMakeRecord);
ASSERT_TRUE(nullptr != eval_impl.infer_shape_dtype_impl_);
ASSERT_TRUE(nullptr != eval_impl.infer_shape_impl_);
AbstractBasePtr new_cls = eval_impl.infer_shape_dtype_impl_(nullptr, prim::kPrimMakeRecord, args_list);
AbstractBasePtr new_cls = eval_impl.infer_shape_impl_(nullptr, prim::kPrimMakeRecord, args_list);
ASSERT_TRUE(nullptr != new_cls);
}