Revert "support ms function infer value"

This reverts commit 52f94606
This commit is contained in:
lei-yuanzhe 2023-02-16 16:16:29 +08:00
parent 0189631634
commit 7481797ad2
3 changed files with 140 additions and 118 deletions

View File

@ -48,8 +48,6 @@ int FillCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::ve
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &) {
if (auto ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
MS_LOG(WARNING) << "For '" << kernel_name_ << "' "
<< "walk out FillCpuKernelMod::Resize with unknown shape.";
return ret;
}
return KRET_OK;

View File

@ -102,127 +102,123 @@ static tensor::TensorPtr CreateComplexTensor(const TypePtr &type, const std::vec
return tensor;
}
BaseShapePtr FillInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
std::vector<size_t> inputsIndex{kIndex0, kIndex1, kIndex2};
if (input_args.size() == kIndex2) {
inputsIndex[kIndex1] = kIndex0;
inputsIndex[kIndex2] = kIndex1;
}
const auto &prim_name = primitive->name();
if (input_args[inputsIndex[kIndex1]]->isa<abstract::AbstractTuple>()) {
auto out_shape = GetShapeValue(primitive, input_args[inputsIndex[kIndex1]]);
return std::make_shared<abstract::Shape>(out_shape);
class FillInfer : public abstract::OpInferBase {
public:
BaseShapePtr InferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) const override {
std::vector<size_t> inputsIndex{kIndex0, kIndex1, kIndex2};
if (input_args.size() == kIndex2) {
inputsIndex[kIndex1] = kIndex0;
inputsIndex[kIndex2] = kIndex1;
}
auto prim_name = primitive->name();
if (input_args[inputsIndex[kIndex1]]->isa<abstract::AbstractTuple>()) {
auto out_shape = GetShapeValue(primitive, input_args[inputsIndex[kIndex1]]);
return std::make_shared<abstract::Shape>(out_shape);
}
if (!input_args[inputsIndex[kIndex1]]->isa<abstract::AbstractTensor>()) {
MS_EXCEPTION(TypeError) << "For '" << prim_name << "', input[1] must be tensor.";
}
MS_EXCEPTION_IF_NULL(primitive);
const uint32_t kInputDims = 1;
auto shape_arg = input_args[inputsIndex[1]];
MS_EXCEPTION_IF_NULL(shape_arg);
if (!IsValueKnown(shape_arg->BuildValue()) && shape_arg->isa<abstract::AbstractTensor>()) {
auto abs_tensor = shape_arg->cast<abstract::AbstractTensorPtr>();
auto abs_tensor_shape = abs_tensor->shape()->shape();
if (abs_tensor_shape.size() != kInputDims) {
MS_EXCEPTION(TypeError) << "For '" << prim_name
<< "', the shape size of 'input1' must be 1, but got: " << abs_tensor_shape.size()
<< ".";
}
}
auto input2_shape =
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[inputsIndex[kIndex2]]->BuildShape())[kShape];
if (input2_shape.size() > 1 || (input2_shape.size() == 1 && input2_shape[0] > 1)) {
MS_EXCEPTION(TypeError) << "For '" << prim_name
<< "', the shape size of 'input2' must be 0, but got: " << input2_shape.size() << ".";
}
auto output_shape = GetShapeValue(primitive, shape_arg);
return std::make_shared<abstract::Shape>(output_shape);
}
if (!input_args[inputsIndex[kIndex1]]->isa<abstract::AbstractTensor>()) {
MS_EXCEPTION(TypeError) << "For '" << primitive->name() << "', input[1] must be tensor.";
}
MS_EXCEPTION_IF_NULL(primitive);
const uint32_t kInputDims = 1;
auto input1_shape =
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[inputsIndex[kIndex1]]->BuildShape())[kShape];
if (input1_shape.size() != kInputDims) {
MS_EXCEPTION(TypeError) << "For '" << primitive->name()
<< "', the shape size of 'input1' must be 1, but got: " << input1_shape.size() << ".";
}
auto input2_shape =
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[inputsIndex[kIndex2]]->BuildShape())[kShape];
if (input2_shape.size() > 1 || (input2_shape.size() == 1 && input2_shape[0] > 1)) {
MS_EXCEPTION(TypeError) << "For '" << primitive->name()
<< "', the shape size of 'input2' must be 0, but got: " << input2_shape.size() << ".";
TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override {
std::vector<size_t> inputsIndex{0, 1, 2};
if (input_args.size() == kIndex2) {
inputsIndex[kIndex1] = 0;
inputsIndex[kIndex2] = 1;
}
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
// check
ValuePtr dtype_value;
TypePtr value_dtype;
auto input2_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
auto input2_dtype = input_args[2]->BuildType();
TypePtr input2_element_dtype;
if (input2_dtype->isa<TensorType>()) {
auto tensor_type = input2_dtype->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(tensor_type);
input2_element_dtype = tensor_type->element();
} else {
input2_element_dtype = input2_dtype;
}
if (input2_shape.size() > 1 || (input2_shape.size() == 1 && input2_shape[0] > 1)) {
MS_EXCEPTION(TypeError) << "For '" << prim_name
<< "', the value input only takes scalar or scalar within a tensor!";
}
dtype_value = input_args[0]->BuildValue();
MS_EXCEPTION_IF_NULL(dtype_value);
if (!dtype_value->isa<Type>()) {
MS_EXCEPTION(TypeError)
<< "For '" << prim_name
<< "', the supported data type is ['bool', 'int8', 'int16', 'int32', 'int64', 'uint8', 'uint16','uint32', "
"'uint64','float16', 'float32', 'float64'], but got an invalid dtype!";
}
auto output_dtype = dtype_value->cast<TypePtr>();
const std::set<TypePtr> valid_types = {kBool, kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16,
kUInt32, kUInt64, kFloat16, kFloat32, kFloat64, kComplex64, kComplex128};
CheckAndConvertUtils::CheckSubClass("dtype", input2_element_dtype, valid_types, prim_name);
return CheckAndConvertUtils::CheckSubClass("dtype", output_dtype, valid_types, prim_name);
}
auto shape_arg = input_args[inputsIndex[1]];
MS_EXCEPTION_IF_NULL(shape_arg);
auto output_shape = GetShapeValue(primitive, shape_arg);
if (input_args.size() == kIndex2) {
auto input_value = input_args[0]->BuildValue();
output_shape = CheckAndConvertUtils::CheckTensorIntValue("axis", input_value, prim_name);
ValuePtr InferValue(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) const override {
MS_EXCEPTION_IF_NULL(prim);
const int64_t input_num = 2;
CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, input_num, prim->name());
auto infered_type = InferType(prim, input_args);
MS_EXCEPTION_IF_NULL(infered_type);
auto input_value_ptr = input_args[2]->BuildValue();
auto input_value_type_id = input_args[2]->BuildType()->type_id();
auto tmp_shape = InferShape(prim, input_args);
MS_EXCEPTION_IF_NULL(tmp_shape);
auto infered_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(tmp_shape)[kShape];
auto input_value_tensor = input_value_ptr->cast<tensor::TensorPtr>();
tensor::TensorPtr infer_result;
if (input_value_type_id == kNumberTypeBool) {
infer_result = CreateValuedTensor<bool>(infered_type, infered_shape, GetValue<bool>(input_value_ptr));
} else if (input_value_type_id == kNumberTypeFloat32) {
infer_result = CreateValuedTensor<float>(infered_type, infered_shape, GetValue<float>(input_value_ptr));
} else if (input_value_type_id == kNumberTypeInt32) {
infer_result = CreateValuedTensor<int32_t>(infered_type, infered_shape, GetValue<int32_t>(input_value_ptr));
} else if (input_value_type_id == kNumberTypeInt64) {
infer_result = CreateValuedTensor<int64_t>(infered_type, infered_shape, GetValue<int64_t>(input_value_ptr));
} else if (input_value_type_id == kNumberTypeComplex64) {
infer_result = CreateComplexTensor<std::complex<float>>(
infered_type, infered_shape, static_cast<std::complex<float> *>(input_value_tensor->data_c())[0]);
} else if (input_value_type_id == kNumberTypeComplex128) {
infer_result = CreateComplexTensor<std::complex<double>>(
infered_type, infered_shape, static_cast<std::complex<double> *>(input_value_tensor->data_c())[0]);
}
return infer_result;
}
return std::make_shared<abstract::Shape>(output_shape);
}
TypePtr FillInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
std::vector<size_t> inputsIndex{0, 1, 2};
if (input_args.size() == kIndex2) {
inputsIndex[kIndex1] = 0;
inputsIndex[kIndex2] = 1;
}
MS_EXCEPTION_IF_NULL(primitive);
const auto &prim_name = primitive->name();
// check
ValuePtr dtype_value;
TypePtr value_dtype;
auto input2_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
auto input2_dtype = input_args[2]->BuildType();
TypePtr input2_element_dtype;
if (input2_dtype->isa<TensorType>()) {
auto tensor_type = input2_dtype->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(tensor_type);
input2_element_dtype = tensor_type->element();
} else {
input2_element_dtype = input2_dtype;
}
if (input2_shape.size() > 1 || (input2_shape.size() == 1 && input2_shape[0] > 1)) {
MS_EXCEPTION(TypeError) << "For '" << prim_name
<< "', the value input only takes scalar or scalar within a tensor!";
}
dtype_value = input_args[0]->BuildValue();
MS_EXCEPTION_IF_NULL(dtype_value);
if (!dtype_value->isa<Type>()) {
MS_EXCEPTION(TypeError)
<< "For '" << prim_name
<< "', the supported data type is ['bool', 'int8', 'int16', 'int32', 'int64', 'uint8', 'uint16','uint32', "
"'uint64','float16', 'float32', 'float64'], but got an invalid dtype!";
}
auto output_dtype = dtype_value->cast<TypePtr>();
std::set<int64_t> GetValueDependArgIndices() const override { return {0, 2}; }
};
const std::set<TypePtr> valid_types = {kBool, kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16,
kUInt32, kUInt64, kFloat16, kFloat32, kFloat64, kComplex64, kComplex128};
CheckAndConvertUtils::CheckSubClass("dtype", input2_element_dtype, valid_types, prim_name);
return CheckAndConvertUtils::CheckSubClass("dtype", output_dtype, valid_types, prim_name);
}
AbstractBasePtr FillInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
auto type = FillInferType(primitive, input_args);
auto shape = FillInferShape(primitive, input_args);
return MakeAbstract(shape, type);
}
ValuePtr FillInferValue(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
const int64_t input_num = 2;
CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, input_num, prim->name());
auto infered_type = FillInferType(prim, input_args);
MS_EXCEPTION_IF_NULL(infered_type);
auto input_value_ptr = input_args[2]->BuildValue();
auto input_value_type_id = input_args[2]->BuildType()->type_id();
auto tmp_shape = FillInferShape(prim, input_args);
if (tmp_shape->IsDynamic()) {
return kAnyValue;
}
MS_EXCEPTION_IF_NULL(tmp_shape);
auto infered_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(tmp_shape)[kShape];
auto input_value_tensor = input_value_ptr->cast<tensor::TensorPtr>();
tensor::TensorPtr infer_result;
if (input_value_type_id == kNumberTypeBool) {
infer_result = CreateValuedTensor<bool>(infered_type, infered_shape, GetValue<bool>(input_value_ptr));
} else if (input_value_type_id == kNumberTypeFloat32) {
infer_result = CreateValuedTensor<float>(infered_type, infered_shape, GetValue<float>(input_value_ptr));
} else if (input_value_type_id == kNumberTypeInt32) {
infer_result = CreateValuedTensor<int32_t>(infered_type, infered_shape, GetValue<int32_t>(input_value_ptr));
} else if (input_value_type_id == kNumberTypeInt64) {
infer_result = CreateValuedTensor<int64_t>(infered_type, infered_shape, GetValue<int64_t>(input_value_ptr));
} else if (input_value_type_id == kNumberTypeComplex64) {
infer_result = CreateComplexTensor<std::complex<float>>(
infered_type, infered_shape, static_cast<std::complex<float> *>(input_value_tensor->data_c())[0]);
} else if (input_value_type_id == kNumberTypeComplex128) {
infer_result = CreateComplexTensor<std::complex<double>>(
infered_type, infered_shape, static_cast<std::complex<double> *>(input_value_tensor->data_c())[0]);
}
return infer_result;
}
REGISTER_PRIMITIVE_EVAL_IMPL(Fill, prim::kPrimFill, FillInfer, FillInferValue, false);
REGISTER_PRIMITIVE_OP_INFER_IMPL(Fill, prim::kPrimFill, FillInfer, true);
} // namespace ops
} // namespace mindspore

View File

@ -1438,7 +1438,7 @@ class MatrixBandPart(Primitive):
self.init_prim_io_names(inputs=['x', 'lower', 'upper'], outputs=['y'])
class Fill(Primitive):
class Fill(PrimitiveWithCheck):
"""
Create a Tensor of the specified shape and fill it with the specified value.
@ -1465,6 +1465,34 @@ class Fill(Primitive):
"""Initialize Fill"""
self.init_prim_io_names(inputs=['type', 'shape', 'value'], outputs=['y'])
def __call__(self, dtype, dims, x):
if dtype not in mstype.all_types and dtype not in [mstype.uint16, mstype.uint32, mstype.uint64]:
raise TypeError(
f"For \'{self.name}\', the supported data type is ['bool', 'int8', 'int16', 'int32', 'int64', 'uint8', "
"'uint16', 'uint32', 'uint64','float16', 'float32', 'float64'], but got an invalid dtype!.")
x_nptype = mstype.dtype_to_nptype(dtype)
if not isinstance(dims, Tensor) and not isinstance(dims, tuple):
raise TypeError(f"For \'{self.name}\', input[1] must be tensor.")
if not isinstance(x, Tensor) and not isinstance(x, float) and not isinstance(x, int):
raise TypeError(f"For \'{self.name}\', the value input only takes scalar or scalar within a tensor!.")
if isinstance(dims, Tensor):
dims = dims.asnumpy()
if isinstance(x, Tensor):
x = x.asnumpy()
ret = np.full(dims, x, x_nptype)
return Tensor(ret)
def infer_value(self, dtype, dims, x):
x_nptype = mstype.dtype_to_nptype(dtype)
if dims is not None and None not in dims and x is not None:
if isinstance(dims, Tensor):
dims = dims.asnumpy()
if isinstance(x, Tensor):
x = x.asnumpy()
ret = np.full(dims, x, x_nptype)
return Tensor(ret)
return None
class Fills(Primitive):
"""