forked from mindspore-Ecosystem/mindspore
!15431 change infer value
From: @lianliguang Reviewed-by: @zh_qh,@ginfung Signed-off-by: @zh_qh
This commit is contained in:
commit
2e43434221
|
@ -924,16 +924,16 @@ AbstractBasePtr CppInferShape(const PrimitivePtr &prim, const AbstractBasePtrLis
|
|||
auto ret = prim_eval_implement_map.find(prim);
|
||||
if (ret != prim_eval_implement_map.end()) {
|
||||
// fing infer function in the front infer map and restore input abastract form dynamic inputs and reg attr
|
||||
MS_EXCEPTION_IF_NULL(ret->second.infer_shape_dtype_impl_);
|
||||
MS_EXCEPTION_IF_NULL(ret->second.infer_shape_impl_);
|
||||
auto infer_spec_list = RectifyAbstract(prim, args_spec_list);
|
||||
return ret->second.infer_shape_dtype_impl_(nullptr, prim, infer_spec_list);
|
||||
return ret->second.infer_shape_impl_(nullptr, prim, infer_spec_list);
|
||||
} else {
|
||||
// if the infer function has been not founded in the front infer map find it in the backend infer map instead
|
||||
auto &prim_backend_eval_impl_map = abstract::GetPrimitiveToBackendEvalImplMap();
|
||||
auto ret_backend = prim_backend_eval_impl_map.find(prim);
|
||||
if (ret_backend != prim_backend_eval_impl_map.end()) {
|
||||
MS_EXCEPTION_IF_NULL(ret_backend->second.infer_shape_dtype_impl_);
|
||||
return ret_backend->second.infer_shape_dtype_impl_(nullptr, prim, args_spec_list);
|
||||
MS_EXCEPTION_IF_NULL(ret_backend->second.infer_shape_impl_);
|
||||
return ret_backend->second.infer_shape_impl_(nullptr, prim, args_spec_list);
|
||||
}
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "Get infer shape function failed, primitive name:" << prim->name()
|
||||
|
|
|
@ -576,7 +576,7 @@ EvalResultPtr StandardPrimEvaluator::EvalPyCheckPrim(const AnalysisEnginePtr &en
|
|||
prim_py->RunCheck(py_args);
|
||||
|
||||
prim_->BeginRecordAddAttr();
|
||||
AbstractBasePtr abs_base = eval_impl_.infer_shape_dtype_impl_(engine, prim_, args);
|
||||
AbstractBasePtr abs_base = eval_impl_.infer_shape_impl_(engine, prim_, args);
|
||||
prim_->EndRecordAddAttr();
|
||||
auto added_attrs = prim_->evaluate_added_attrs();
|
||||
|
||||
|
@ -602,16 +602,21 @@ EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, c
|
|||
auto context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
bool need_infer_value = eval_impl_.in_white_list_ || (context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode);
|
||||
AbstractBasePtr abs_base = nullptr;
|
||||
ValuePtr value = nullptr;
|
||||
prim_->BeginRecordAddAttr();
|
||||
AbstractBasePtr abs_base = eval_impl_.infer_shape_dtype_impl_(engine, prim_, args);
|
||||
prim_->EndRecordAddAttr();
|
||||
auto added_attrs = prim_->evaluate_added_attrs();
|
||||
if (need_infer_value) {
|
||||
if (eval_impl_.infer_value_func_ != nullptr) {
|
||||
auto value = eval_impl_.infer_value_func_(prim_, args, abs_base);
|
||||
abs_base->set_value(value);
|
||||
if (need_infer_value && eval_impl_.infer_value_impl_ != nullptr) {
|
||||
value = eval_impl_.infer_value_impl_(prim_, args);
|
||||
if (value != nullptr) {
|
||||
abs_base = value->ToAbstract();
|
||||
prim_->EndRecordAddAttr();
|
||||
auto added_attrs = prim_->evaluate_added_attrs();
|
||||
return std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>(added_attrs));
|
||||
}
|
||||
}
|
||||
abs_base = eval_impl_.infer_shape_impl_(engine, prim_, args);
|
||||
prim_->EndRecordAddAttr();
|
||||
auto added_attrs = prim_->evaluate_added_attrs();
|
||||
auto eval_result = std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>(added_attrs));
|
||||
return eval_result;
|
||||
}
|
||||
|
|
|
@ -359,7 +359,7 @@ EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr
|
|||
|
||||
// find prim infer function in the prim function map return a standard evaluator
|
||||
auto eval_impl = GetPrimitiveInferImpl(prim);
|
||||
if (eval_impl.infer_shape_dtype_impl_ != nullptr) {
|
||||
if (eval_impl.infer_shape_impl_ != nullptr) {
|
||||
return std::make_shared<StandardPrimEvaluator>(prim, eval_impl);
|
||||
}
|
||||
|
||||
|
|
|
@ -28,13 +28,13 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace abstract {
|
||||
using InferShapeAndTypeImpl = AbstractBasePtr (*)(const abstract::AnalysisEnginePtr &, const PrimitivePtr &,
|
||||
const AbstractBasePtrList &);
|
||||
using InferValueEvalImpl = ValuePtr (*)(const PrimitivePtr &, const AbstractBasePtrList &, const AbstractBasePtr &);
|
||||
using InferShapeImpl = AbstractBasePtr (*)(const abstract::AnalysisEnginePtr &, const PrimitivePtr &,
|
||||
const AbstractBasePtrList &);
|
||||
using InferValueImpl = ValuePtr (*)(const PrimitivePtr &, const AbstractBasePtrList &);
|
||||
|
||||
struct StandardPrimitiveImplReg {
|
||||
InferShapeAndTypeImpl infer_shape_dtype_impl_; // Implement function of Primitive
|
||||
InferValueEvalImpl infer_value_func_; // infer value of primitive
|
||||
InferShapeImpl infer_shape_impl_; // infer shape and type for ops
|
||||
InferValueImpl infer_value_impl_; // infer value for ops
|
||||
// in_white_list_ is true means this primitive can be executed by vm backend
|
||||
// else will be optimized by frontend
|
||||
bool in_white_list_;
|
||||
|
@ -55,8 +55,8 @@ void RegisterStandardPrimitiveImpl(const PrimitivePtr &primitive, const Standard
|
|||
|
||||
class RegisterStandardPrimitiveEvalHelper {
|
||||
public:
|
||||
RegisterStandardPrimitiveEvalHelper(const PrimitivePtr &primitive, const InferShapeAndTypeImpl &infer_impl,
|
||||
const InferValueEvalImpl &infer_value_impl, const bool is_wight_list = true) {
|
||||
RegisterStandardPrimitiveEvalHelper(const PrimitivePtr &primitive, const InferShapeImpl &infer_impl,
|
||||
const InferValueImpl &infer_value_impl, const bool is_wight_list = true) {
|
||||
const StandardPrimitiveImplReg impl_reg{infer_impl, infer_value_impl, is_wight_list};
|
||||
RegisterStandardPrimitiveImpl(primitive, impl_reg);
|
||||
}
|
||||
|
|
|
@ -27,8 +27,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
ValuePtr DTypeInferValue(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args,
|
||||
const AbstractBasePtr &infer) {
|
||||
ValuePtr DTypeInferValue(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInteger("dtype infer", input_args.size(), kEqual, 1, op_name);
|
||||
|
@ -41,7 +40,7 @@ ValuePtr DTypeInferValue(const PrimitivePtr &primitive, const std::vector<Abstra
|
|||
|
||||
AbstractBasePtr DTypeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto value = DTypeInferValue(primitive, input_args, nullptr);
|
||||
auto value = DTypeInferValue(primitive, input_args);
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
auto type = value->cast<TypePtr>();
|
||||
MS_EXCEPTION_IF_NULL(type);
|
||||
|
|
|
@ -32,7 +32,7 @@ AbstractBasePtr PrimitiveC::Infer(const AbstractBasePtrList &abstract_list) {
|
|||
if (iter == infer_map.end()) {
|
||||
MS_EXCEPTION(NotExistsError) << "Cannot find the " << this->name() << "infer function in the infer map!";
|
||||
}
|
||||
auto infer_function = iter->second.infer_shape_dtype_impl_;
|
||||
auto infer_function = iter->second.infer_shape_impl_;
|
||||
return infer_function(nullptr, shared_from_base<Primitive>(), abstract_list);
|
||||
}
|
||||
|
||||
|
|
|
@ -45,8 +45,7 @@ AbstractBasePtr ShapeInfer(const abstract::AnalysisEnginePtr &, const PrimitiveP
|
|||
return abs;
|
||||
}
|
||||
|
||||
ValuePtr ShapeInferValue(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args,
|
||||
const AbstractBasePtr &infer) {
|
||||
ValuePtr ShapeInferValue(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInteger("shape infer", input_args.size(), kEqual, 1, op_name);
|
||||
|
|
|
@ -49,17 +49,6 @@ TypePtr ZerosInferType(const PrimitivePtr &prim, const std::vector<AbstractBaseP
|
|||
kUInt16, kUInt32, kUInt64, kFloat16, kFloat32, kFloat64};
|
||||
return CheckAndConvertUtils::CheckSubClass("dtype", output_type, valid_types, prim_name);
|
||||
}
|
||||
ValuePtr ZerosInferValue(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args,
|
||||
const abstract::AbstractBasePtr &abs) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
// check
|
||||
auto out_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(abs->BuildShape())[kShape];
|
||||
auto out_type = abs->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(out_type);
|
||||
return TensorConstructUtils::CreateZerosTensor(out_type, out_shape);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
AbstractBasePtr ZerosInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
|
@ -67,6 +56,17 @@ AbstractBasePtr ZerosInfer(const abstract::AnalysisEnginePtr &, const PrimitiveP
|
|||
ZerosInferShape(primitive, input_args));
|
||||
return abs;
|
||||
}
|
||||
|
||||
ValuePtr ZerosInferValue(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto abs = ZerosInfer(nullptr, prim, input_args);
|
||||
// check
|
||||
auto out_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(abs->BuildShape())[kShape];
|
||||
auto out_type = abs->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(out_type);
|
||||
return TensorConstructUtils::CreateZerosTensor(out_type, out_shape);
|
||||
}
|
||||
} // namespace
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Zeros, prim::kPrimZeros, ZerosInfer, ZerosInferValue, false);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
#include <vector>
|
||||
#include <memory>
|
||||
namespace mindspore {
|
||||
tensor::TensorPtr TensorConstructUtils::CreateZerosTensor(const TypePtr type_ptr, const std::vector<int64_t> &shape) {
|
||||
tensor::TensorPtr TensorConstructUtils::CreateZerosTensor(const TypePtr &type_ptr, const std::vector<int64_t> &shape) {
|
||||
MS_EXCEPTION_IF_NULL(type_ptr);
|
||||
auto type_id = ExtractTypeId(type_ptr);
|
||||
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, shape);
|
||||
|
@ -30,7 +30,7 @@ tensor::TensorPtr TensorConstructUtils::CreateZerosTensor(const TypePtr type_ptr
|
|||
return tensor;
|
||||
}
|
||||
|
||||
tensor::TensorPtr TensorConstructUtils::CreateOnesTensor(const TypePtr type_ptr, const std::vector<int64_t> &shape) {
|
||||
tensor::TensorPtr TensorConstructUtils::CreateOnesTensor(const TypePtr &type_ptr, const std::vector<int64_t> &shape) {
|
||||
MS_EXCEPTION_IF_NULL(type_ptr);
|
||||
auto type_id = ExtractTypeId(type_ptr);
|
||||
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, shape);
|
||||
|
@ -43,7 +43,7 @@ tensor::TensorPtr TensorConstructUtils::CreateOnesTensor(const TypePtr type_ptr,
|
|||
return tensor;
|
||||
}
|
||||
|
||||
tensor::TensorPtr TensorConstructUtils::CreateTensor(const TypePtr type_ptr, const std::vector<int64_t> &shape,
|
||||
tensor::TensorPtr TensorConstructUtils::CreateTensor(const TypePtr &type_ptr, const std::vector<int64_t> &shape,
|
||||
void *data) {
|
||||
MS_EXCEPTION_IF_NULL(type_ptr);
|
||||
auto type_id = ExtractTypeId(type_ptr);
|
||||
|
@ -51,7 +51,7 @@ tensor::TensorPtr TensorConstructUtils::CreateTensor(const TypePtr type_ptr, con
|
|||
return tensor;
|
||||
}
|
||||
|
||||
TypeId TensorConstructUtils::ExtractTypeId(const TypePtr type_ptr) {
|
||||
TypeId TensorConstructUtils::ExtractTypeId(const TypePtr &type_ptr) {
|
||||
MS_EXCEPTION_IF_NULL(type_ptr);
|
||||
TypeId type_id;
|
||||
if (type_ptr->isa<TensorType>()) {
|
||||
|
|
|
@ -28,12 +28,12 @@ void SetTensorData(void *data, T num, size_t data_length) {
|
|||
}
|
||||
class TensorConstructUtils {
|
||||
public:
|
||||
static tensor::TensorPtr CreateZerosTensor(const TypePtr type, const std::vector<int64_t> &shape);
|
||||
static tensor::TensorPtr CreateOnesTensor(const TypePtr type, const std::vector<int64_t> &shape);
|
||||
static tensor::TensorPtr CreateTensor(const TypePtr type, const std::vector<int64_t> &shape, void *data);
|
||||
static tensor::TensorPtr CreateZerosTensor(const TypePtr &type, const std::vector<int64_t> &shape);
|
||||
static tensor::TensorPtr CreateOnesTensor(const TypePtr &type, const std::vector<int64_t> &shape);
|
||||
static tensor::TensorPtr CreateTensor(const TypePtr &type, const std::vector<int64_t> &shape, void *data);
|
||||
|
||||
private:
|
||||
static TypeId ExtractTypeId(const TypePtr type);
|
||||
static TypeId ExtractTypeId(const TypePtr &type);
|
||||
};
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CORE_UTILS_TENSOR_CONSTRUCT_UTILS_H_
|
||||
|
|
|
@ -87,9 +87,9 @@ TEST_F(TestAbstract, TestParseDataClass) {
|
|||
AbstractBasePtrList args_list = {abs_scalar, abstract_x, abstract_y};
|
||||
|
||||
auto eval_impl = GetPrimitiveInferImpl(prim::kPrimMakeRecord);
|
||||
ASSERT_TRUE(nullptr != eval_impl.infer_shape_dtype_impl_);
|
||||
ASSERT_TRUE(nullptr != eval_impl.infer_shape_impl_);
|
||||
|
||||
AbstractBasePtr new_cls = eval_impl.infer_shape_dtype_impl_(nullptr, prim::kPrimMakeRecord, args_list);
|
||||
AbstractBasePtr new_cls = eval_impl.infer_shape_impl_(nullptr, prim::kPrimMakeRecord, args_list);
|
||||
ASSERT_TRUE(nullptr != new_cls);
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue