support ms function infer value

This commit is contained in:
lianliguang 2023-01-13 12:01:18 +08:00
parent 90c3520b13
commit 52f946067c
7 changed files with 136 additions and 149 deletions

View File

@ -1132,16 +1132,13 @@ EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, c
}
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
bool need_infer_value = !eval_impl_.IsInWhiteList();
if (need_infer_value == false) {
need_infer_value = ((context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode)) &&
std::all_of(args.begin(), args.end(), [](const AbstractBasePtr &abs) -> bool {
MS_EXCEPTION_IF_NULL(abs);
auto value = abs->BuildValue();
return (value != nullptr && !value->isa<AnyValue>() && !value->isa<None>() &&
!value->isa<Monad>() && !value->isa<FuncGraph>());
});
}
bool need_infer_value = std::all_of(args.begin(), args.end(), [](const AbstractBasePtr &abs) -> bool {
MS_EXCEPTION_IF_NULL(abs);
auto value = abs->BuildValue();
return (value != nullptr && !value->isa<AnyValue>() && !value->isa<None>() && !value->isa<Monad>() &&
!value->isa<FuncGraph>());
});
AbstractBasePtr abs_base = nullptr;
ValuePtr value = nullptr;
prim_->BeginRecordAddAttr();

View File

@ -64,14 +64,11 @@ void ValidateOperation(const AnfNodePtr &node) {
if (prim->name() == "TensorMove") {
return;
}
if (prim->HasPyEvaluator()) {
if (prim->isa<PrimitivePy>()) {
MS_LOG(DEBUG) << "Primitive " << prim->name() << " has python evaluator.";
return;
}
if (prim->prim_type() == PrimType::kPrimTypePyCheck) {
MS_LOG(DEBUG) << "Primitive " << prim->name() << " has python inference checking method.";
return;
}
if (prim->name() == "fake_bprop") {
MS_LOG(EXCEPTION) << "Illegal primitive: " << GetValue<std::string>(prim->GetAttr("info"));
}

View File

@ -141,6 +141,7 @@ PrimShapeDependMap &GetInferDependsMap() {
static const auto &kRandomGamma = prim::kPrimRandomGamma->name();
static const auto &kAffineGrid = prim::kPrimAffineGrid->name();
static const auto &kFillV2 = prim::kPrimFillV2->name();
static const auto &kFill = prim::kPrimFill->name();
static const auto &kFractionalAvgPoolGrad = prim::kPrimFractionalAvgPoolGrad->name();
static const auto &kTranspose = prim::kPrimTranspose->name();
static const auto &kResizeLinear1D = prim::kPrimResizeLinear1D->name();
@ -215,6 +216,7 @@ PrimShapeDependMap &GetInferDependsMap() {
{prim::kPrimSparseTensorDenseMatmul->name(), ShapeSet{2}},
{kSliceGrad, ShapeSet{2, 3}},
{kFillV2, ShapeSet{0}},
{kFill, ShapeSet{0, 2}},
{kRandomCategorical, ShapeSet{1}},
{kRandomGamma, ShapeSet{0, 1}},
{kDynamicBroadcastTo, ShapeSet{1}},

View File

@ -57,9 +57,7 @@ AbstractBasePtr DTypeInfer(const abstract::AnalysisEnginePtr &, const PrimitiveP
}
auto value = DTypeInferValue(primitive, input_args);
MS_EXCEPTION_IF_NULL(value);
auto type = value->cast<TypePtr>();
MS_EXCEPTION_IF_NULL(type);
return abstract::MakeAbstract(std::make_shared<abstract::NoShape>(), type);
return value->ToAbstract();
}
REGISTER_PRIMITIVE_EVAL_IMPL(DType, prim::kPrimDType, DTypeInfer, DTypeInferValue, false);

View File

@ -102,128 +102,131 @@ static tensor::TensorPtr CreateComplexTensor(const TypePtr &type, const std::vec
return tensor;
}
class FillInfer : public abstract::OpInferBase {
public:
BaseShapePtr InferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) const override {
std::vector<size_t> inputsIndex{kIndex0, kIndex1, kIndex2};
if (input_args.size() == kIndex2) {
inputsIndex[kIndex1] = kIndex0;
inputsIndex[kIndex2] = kIndex1;
}
auto prim_name = primitive->name();
if (input_args[inputsIndex[kIndex1]]->isa<abstract::AbstractTuple>()) {
auto out_shape = GetShapeValue(primitive, input_args[inputsIndex[kIndex1]]);
return std::make_shared<abstract::Shape>(out_shape);
}
if (!input_args[inputsIndex[kIndex1]]->isa<abstract::AbstractTensor>()) {
MS_EXCEPTION(TypeError) << "For '" << primitive->name() << "', input[1] must be tensor.";
}
MS_EXCEPTION_IF_NULL(primitive);
const uint32_t kInputDims = 1;
auto input1_shape =
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[inputsIndex[kIndex1]]->BuildShape())[kShape];
if (input1_shape.size() != kInputDims) {
MS_EXCEPTION(TypeError) << "For '" << primitive->name()
<< "', the shape size of 'input1' must be 1, but got: " << input1_shape.size() << ".";
}
auto input2_shape =
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[inputsIndex[kIndex2]]->BuildShape())[kShape];
if (input2_shape.size() > 1 || (input2_shape.size() == 1 && input2_shape[0] > 1)) {
MS_EXCEPTION(TypeError) << "For '" << primitive->name()
<< "', the shape size of 'input2' must be 0, but got: " << input2_shape.size() << ".";
}
auto shape_arg = input_args[inputsIndex[1]];
MS_EXCEPTION_IF_NULL(shape_arg);
auto output_shape = GetShapeValue(primitive, shape_arg);
if (input_args.size() == kIndex2) {
auto input_value = input_args[0]->BuildValue();
output_shape = CheckAndConvertUtils::CheckTensorIntValue("axis", input_value, prim_name);
}
return std::make_shared<abstract::Shape>(output_shape);
BaseShapePtr FillInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
std::vector<size_t> inputsIndex{kIndex0, kIndex1, kIndex2};
if (input_args.size() == kIndex2) {
inputsIndex[kIndex1] = kIndex0;
inputsIndex[kIndex2] = kIndex1;
}
const auto &prim_name = primitive->name();
if (input_args[inputsIndex[kIndex1]]->isa<abstract::AbstractTuple>()) {
auto out_shape = GetShapeValue(primitive, input_args[inputsIndex[kIndex1]]);
return std::make_shared<abstract::Shape>(out_shape);
}
TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override {
std::vector<size_t> inputsIndex{0, 1, 2};
if (input_args.size() == kIndex2) {
inputsIndex[kIndex1] = 0;
inputsIndex[kIndex2] = 1;
}
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
// check
ValuePtr dtype_value;
TypePtr value_dtype;
auto input2_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
auto input2_dtype = input_args[2]->BuildType();
TypePtr input2_element_dtype;
if (input2_dtype->isa<TensorType>()) {
auto tensor_type = input2_dtype->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(tensor_type);
input2_element_dtype = tensor_type->element();
} else {
input2_element_dtype = input2_dtype;
}
if (input2_shape.size() > 1 || (input2_shape.size() == 1 && input2_shape[0] > 1)) {
MS_EXCEPTION(TypeError) << "For '" << prim_name
<< "', the value input only takes scalar or scalar within a tensor!";
}
dtype_value = input_args[0]->BuildValue();
MS_EXCEPTION_IF_NULL(dtype_value);
if (!dtype_value->isa<Type>()) {
MS_EXCEPTION(TypeError)
<< "For '" << prim_name
<< "', the supported data type is ['bool', 'int8', 'int16', 'int32', 'int64', 'uint8', 'uint16','uint32', "
"'uint64','float16', 'float32', 'float64'], but got an invalid dtype!";
}
auto output_dtype = dtype_value->cast<TypePtr>();
const std::set<TypePtr> valid_types = {kBool, kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16,
kUInt32, kUInt64, kFloat16, kFloat32, kFloat64, kComplex64, kComplex128};
CheckAndConvertUtils::CheckSubClass("dtype", input2_element_dtype, valid_types, prim_name);
return CheckAndConvertUtils::CheckSubClass("dtype", output_dtype, valid_types, prim_name);
if (!input_args[inputsIndex[kIndex1]]->isa<abstract::AbstractTensor>()) {
MS_EXCEPTION(TypeError) << "For '" << primitive->name() << "', input[1] must be tensor.";
}
MS_EXCEPTION_IF_NULL(primitive);
const uint32_t kInputDims = 1;
auto input1_shape =
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[inputsIndex[kIndex1]]->BuildShape())[kShape];
if (input1_shape.size() != kInputDims) {
MS_EXCEPTION(TypeError) << "For '" << primitive->name()
<< "', the shape size of 'input1' must be 1, but got: " << input1_shape.size() << ".";
}
auto input2_shape =
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[inputsIndex[kIndex2]]->BuildShape())[kShape];
if (input2_shape.size() > 1 || (input2_shape.size() == 1 && input2_shape[0] > 1)) {
MS_EXCEPTION(TypeError) << "For '" << primitive->name()
<< "', the shape size of 'input2' must be 0, but got: " << input2_shape.size() << ".";
}
ValuePtr InferValue(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) const override {
MS_EXCEPTION_IF_NULL(prim);
const int64_t input_num = 2;
CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, input_num, prim->name());
auto infered_type = InferType(prim, input_args);
MS_EXCEPTION_IF_NULL(infered_type);
auto input_value_ptr = input_args[2]->BuildValue();
auto input_value_type_id = input_args[2]->BuildType()->type_id();
if (input_value_type_id != infered_type->type_id()) {
MS_LOG(WARNING) << "value type is not same as given dtype, value type id is " << input_value_type_id
<< " and given dtype id is " << infered_type->type_id();
}
auto tmp_shape = InferShape(prim, input_args);
MS_EXCEPTION_IF_NULL(tmp_shape);
auto infered_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(tmp_shape)[kShape];
auto input_value_tensor = input_value_ptr->cast<tensor::TensorPtr>();
tensor::TensorPtr infer_result;
if (input_value_type_id == kNumberTypeBool) {
infer_result = CreateValuedTensor<bool>(infered_type, infered_shape, GetValue<bool>(input_value_ptr));
} else if (input_value_type_id == kNumberTypeFloat32) {
infer_result = CreateValuedTensor<float>(infered_type, infered_shape, GetValue<float>(input_value_ptr));
} else if (input_value_type_id == kNumberTypeInt32) {
infer_result = CreateValuedTensor<int32_t>(infered_type, infered_shape, GetValue<int32_t>(input_value_ptr));
} else if (input_value_type_id == kNumberTypeInt64) {
infer_result = CreateValuedTensor<int64_t>(infered_type, infered_shape, GetValue<int64_t>(input_value_ptr));
} else if (input_value_type_id == kNumberTypeComplex64) {
infer_result = CreateComplexTensor<std::complex<float>>(
infered_type, infered_shape, static_cast<std::complex<float> *>(input_value_tensor->data_c())[0]);
} else if (input_value_type_id == kNumberTypeComplex128) {
infer_result = CreateComplexTensor<std::complex<double>>(
infered_type, infered_shape, static_cast<std::complex<double> *>(input_value_tensor->data_c())[0]);
}
return infer_result;
auto shape_arg = input_args[inputsIndex[1]];
MS_EXCEPTION_IF_NULL(shape_arg);
auto output_shape = GetShapeValue(primitive, shape_arg);
if (input_args.size() == kIndex2) {
auto input_value = input_args[0]->BuildValue();
output_shape = CheckAndConvertUtils::CheckTensorIntValue("axis", input_value, prim_name);
}
return std::make_shared<abstract::Shape>(output_shape);
}
std::set<int64_t> GetValueDependArgIndices() const override { return {0, 2}; }
};
TypePtr FillInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
std::vector<size_t> inputsIndex{0, 1, 2};
if (input_args.size() == kIndex2) {
inputsIndex[kIndex1] = 0;
inputsIndex[kIndex2] = 1;
}
MS_EXCEPTION_IF_NULL(primitive);
const auto &prim_name = primitive->name();
// check
ValuePtr dtype_value;
TypePtr value_dtype;
auto input2_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
auto input2_dtype = input_args[2]->BuildType();
TypePtr input2_element_dtype;
if (input2_dtype->isa<TensorType>()) {
auto tensor_type = input2_dtype->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(tensor_type);
input2_element_dtype = tensor_type->element();
} else {
input2_element_dtype = input2_dtype;
}
if (input2_shape.size() > 1 || (input2_shape.size() == 1 && input2_shape[0] > 1)) {
MS_EXCEPTION(TypeError) << "For '" << prim_name
<< "', the value input only takes scalar or scalar within a tensor!";
}
dtype_value = input_args[0]->BuildValue();
MS_EXCEPTION_IF_NULL(dtype_value);
if (!dtype_value->isa<Type>()) {
MS_EXCEPTION(TypeError)
<< "For '" << prim_name
<< "', the supported data type is ['bool', 'int8', 'int16', 'int32', 'int64', 'uint8', 'uint16','uint32', "
"'uint64','float16', 'float32', 'float64'], but got an invalid dtype!";
}
auto output_dtype = dtype_value->cast<TypePtr>();
REGISTER_PRIMITIVE_OP_INFER_IMPL(Fill, prim::kPrimFill, FillInfer, true);
const std::set<TypePtr> valid_types = {kBool, kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16,
kUInt32, kUInt64, kFloat16, kFloat32, kFloat64, kComplex64, kComplex128};
CheckAndConvertUtils::CheckSubClass("dtype", input2_element_dtype, valid_types, prim_name);
return CheckAndConvertUtils::CheckSubClass("dtype", output_dtype, valid_types, prim_name);
}
AbstractBasePtr FillInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
auto type = FillInferType(primitive, input_args);
auto shape = FillInferShape(primitive, input_args);
return MakeAbstract(shape, type);
}
ValuePtr FillInferValue(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
const int64_t input_num = 2;
CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, input_num, prim->name());
auto infered_type = FillInferType(prim, input_args);
MS_EXCEPTION_IF_NULL(infered_type);
auto input_value_ptr = input_args[2]->BuildValue();
auto input_value_type_id = input_args[2]->BuildType()->type_id();
if (input_value_type_id != infered_type->type_id()) {
MS_LOG(WARNING) << "value type is not same as given dtype, value type id is " << input_value_type_id
<< " and given dtype id is " << infered_type->type_id();
}
auto tmp_shape = FillInferShape(prim, input_args);
if (tmp_shape->IsDynamic()) {
return kAnyValue;
}
MS_EXCEPTION_IF_NULL(tmp_shape);
auto infered_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(tmp_shape)[kShape];
auto input_value_tensor = input_value_ptr->cast<tensor::TensorPtr>();
tensor::TensorPtr infer_result;
if (input_value_type_id == kNumberTypeBool) {
infer_result = CreateValuedTensor<bool>(infered_type, infered_shape, GetValue<bool>(input_value_ptr));
} else if (input_value_type_id == kNumberTypeFloat32) {
infer_result = CreateValuedTensor<float>(infered_type, infered_shape, GetValue<float>(input_value_ptr));
} else if (input_value_type_id == kNumberTypeInt32) {
infer_result = CreateValuedTensor<int32_t>(infered_type, infered_shape, GetValue<int32_t>(input_value_ptr));
} else if (input_value_type_id == kNumberTypeInt64) {
infer_result = CreateValuedTensor<int64_t>(infered_type, infered_shape, GetValue<int64_t>(input_value_ptr));
} else if (input_value_type_id == kNumberTypeComplex64) {
infer_result = CreateComplexTensor<std::complex<float>>(
infered_type, infered_shape, static_cast<std::complex<float> *>(input_value_tensor->data_c())[0]);
} else if (input_value_type_id == kNumberTypeComplex128) {
infer_result = CreateComplexTensor<std::complex<double>>(
infered_type, infered_shape, static_cast<std::complex<double> *>(input_value_tensor->data_c())[0]);
}
return infer_result;
}
REGISTER_PRIMITIVE_EVAL_IMPL(Fill, prim::kPrimFill, FillInfer, FillInferValue, false);
} // namespace ops
} // namespace mindspore

View File

@ -1436,7 +1436,7 @@ class MatrixBandPart(Primitive):
self.init_prim_io_names(inputs=['x', 'lower', 'upper'], outputs=['y'])
class Fill(PrimitiveWithCheck):
class Fill(Primitive):
"""
Create a Tensor of the specified shape and fill it with the specified value.
@ -1463,17 +1463,6 @@ class Fill(PrimitiveWithCheck):
"""Initialize Fill"""
self.init_prim_io_names(inputs=['type', 'shape', 'value'], outputs=['y'])
def infer_value(self, dtype, dims, x):
x_nptype = mstype.dtype_to_nptype(dtype)
if dims is not None and None not in dims and x is not None:
if isinstance(dims, Tensor):
dims = dims.asnumpy()
if isinstance(x, Tensor):
x = x.asnumpy()
ret = np.full(dims, x, x_nptype)
return Tensor(ret)
return None
class Fills(Primitive):
"""

View File

@ -145,11 +145,12 @@ def vm_impl_split(self):
def vm_impl_fill(self):
"""Generate vm_impl function for Fill"""
def vm_impl(dims, x):
def vm_impl(dtype, dims, x):
x_nptype = mstype.dtype_to_nptype(dtype)
if isinstance(x, int):
ret = np.full(dims, x, np.int32)
ret = np.full(dims, x, x_nptype)
else:
ret = np.full(dims, x, np.float32)
ret = np.full(dims, x, x_nptype)
return Tensor(ret)
return vm_impl