forked from mindspore-Ecosystem/mindspore
parent
0189631634
commit
7481797ad2
|
@ -48,8 +48,6 @@ int FillCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::ve
|
|||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &) {
|
||||
if (auto ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
|
||||
MS_LOG(WARNING) << "For '" << kernel_name_ << "' "
|
||||
<< "walk out FillCpuKernelMod::Resize with unknown shape.";
|
||||
return ret;
|
||||
}
|
||||
return KRET_OK;
|
||||
|
|
|
@ -102,127 +102,123 @@ static tensor::TensorPtr CreateComplexTensor(const TypePtr &type, const std::vec
|
|||
return tensor;
|
||||
}
|
||||
|
||||
BaseShapePtr FillInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
std::vector<size_t> inputsIndex{kIndex0, kIndex1, kIndex2};
|
||||
if (input_args.size() == kIndex2) {
|
||||
inputsIndex[kIndex1] = kIndex0;
|
||||
inputsIndex[kIndex2] = kIndex1;
|
||||
}
|
||||
const auto &prim_name = primitive->name();
|
||||
if (input_args[inputsIndex[kIndex1]]->isa<abstract::AbstractTuple>()) {
|
||||
auto out_shape = GetShapeValue(primitive, input_args[inputsIndex[kIndex1]]);
|
||||
return std::make_shared<abstract::Shape>(out_shape);
|
||||
class FillInfer : public abstract::OpInferBase {
|
||||
public:
|
||||
BaseShapePtr InferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) const override {
|
||||
std::vector<size_t> inputsIndex{kIndex0, kIndex1, kIndex2};
|
||||
if (input_args.size() == kIndex2) {
|
||||
inputsIndex[kIndex1] = kIndex0;
|
||||
inputsIndex[kIndex2] = kIndex1;
|
||||
}
|
||||
auto prim_name = primitive->name();
|
||||
if (input_args[inputsIndex[kIndex1]]->isa<abstract::AbstractTuple>()) {
|
||||
auto out_shape = GetShapeValue(primitive, input_args[inputsIndex[kIndex1]]);
|
||||
return std::make_shared<abstract::Shape>(out_shape);
|
||||
}
|
||||
|
||||
if (!input_args[inputsIndex[kIndex1]]->isa<abstract::AbstractTensor>()) {
|
||||
MS_EXCEPTION(TypeError) << "For '" << prim_name << "', input[1] must be tensor.";
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const uint32_t kInputDims = 1;
|
||||
auto shape_arg = input_args[inputsIndex[1]];
|
||||
MS_EXCEPTION_IF_NULL(shape_arg);
|
||||
if (!IsValueKnown(shape_arg->BuildValue()) && shape_arg->isa<abstract::AbstractTensor>()) {
|
||||
auto abs_tensor = shape_arg->cast<abstract::AbstractTensorPtr>();
|
||||
auto abs_tensor_shape = abs_tensor->shape()->shape();
|
||||
if (abs_tensor_shape.size() != kInputDims) {
|
||||
MS_EXCEPTION(TypeError) << "For '" << prim_name
|
||||
<< "', the shape size of 'input1' must be 1, but got: " << abs_tensor_shape.size()
|
||||
<< ".";
|
||||
}
|
||||
}
|
||||
auto input2_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[inputsIndex[kIndex2]]->BuildShape())[kShape];
|
||||
if (input2_shape.size() > 1 || (input2_shape.size() == 1 && input2_shape[0] > 1)) {
|
||||
MS_EXCEPTION(TypeError) << "For '" << prim_name
|
||||
<< "', the shape size of 'input2' must be 0, but got: " << input2_shape.size() << ".";
|
||||
}
|
||||
|
||||
auto output_shape = GetShapeValue(primitive, shape_arg);
|
||||
return std::make_shared<abstract::Shape>(output_shape);
|
||||
}
|
||||
|
||||
if (!input_args[inputsIndex[kIndex1]]->isa<abstract::AbstractTensor>()) {
|
||||
MS_EXCEPTION(TypeError) << "For '" << primitive->name() << "', input[1] must be tensor.";
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const uint32_t kInputDims = 1;
|
||||
auto input1_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[inputsIndex[kIndex1]]->BuildShape())[kShape];
|
||||
if (input1_shape.size() != kInputDims) {
|
||||
MS_EXCEPTION(TypeError) << "For '" << primitive->name()
|
||||
<< "', the shape size of 'input1' must be 1, but got: " << input1_shape.size() << ".";
|
||||
}
|
||||
auto input2_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[inputsIndex[kIndex2]]->BuildShape())[kShape];
|
||||
if (input2_shape.size() > 1 || (input2_shape.size() == 1 && input2_shape[0] > 1)) {
|
||||
MS_EXCEPTION(TypeError) << "For '" << primitive->name()
|
||||
<< "', the shape size of 'input2' must be 0, but got: " << input2_shape.size() << ".";
|
||||
TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override {
|
||||
std::vector<size_t> inputsIndex{0, 1, 2};
|
||||
if (input_args.size() == kIndex2) {
|
||||
inputsIndex[kIndex1] = 0;
|
||||
inputsIndex[kIndex2] = 1;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
// check
|
||||
ValuePtr dtype_value;
|
||||
TypePtr value_dtype;
|
||||
auto input2_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
|
||||
auto input2_dtype = input_args[2]->BuildType();
|
||||
TypePtr input2_element_dtype;
|
||||
if (input2_dtype->isa<TensorType>()) {
|
||||
auto tensor_type = input2_dtype->cast<TensorTypePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tensor_type);
|
||||
input2_element_dtype = tensor_type->element();
|
||||
} else {
|
||||
input2_element_dtype = input2_dtype;
|
||||
}
|
||||
if (input2_shape.size() > 1 || (input2_shape.size() == 1 && input2_shape[0] > 1)) {
|
||||
MS_EXCEPTION(TypeError) << "For '" << prim_name
|
||||
<< "', the value input only takes scalar or scalar within a tensor!";
|
||||
}
|
||||
dtype_value = input_args[0]->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(dtype_value);
|
||||
if (!dtype_value->isa<Type>()) {
|
||||
MS_EXCEPTION(TypeError)
|
||||
<< "For '" << prim_name
|
||||
<< "', the supported data type is ['bool', 'int8', 'int16', 'int32', 'int64', 'uint8', 'uint16','uint32', "
|
||||
"'uint64','float16', 'float32', 'float64'], but got an invalid dtype!";
|
||||
}
|
||||
auto output_dtype = dtype_value->cast<TypePtr>();
|
||||
|
||||
const std::set<TypePtr> valid_types = {kBool, kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16,
|
||||
kUInt32, kUInt64, kFloat16, kFloat32, kFloat64, kComplex64, kComplex128};
|
||||
CheckAndConvertUtils::CheckSubClass("dtype", input2_element_dtype, valid_types, prim_name);
|
||||
return CheckAndConvertUtils::CheckSubClass("dtype", output_dtype, valid_types, prim_name);
|
||||
}
|
||||
|
||||
auto shape_arg = input_args[inputsIndex[1]];
|
||||
MS_EXCEPTION_IF_NULL(shape_arg);
|
||||
auto output_shape = GetShapeValue(primitive, shape_arg);
|
||||
if (input_args.size() == kIndex2) {
|
||||
auto input_value = input_args[0]->BuildValue();
|
||||
output_shape = CheckAndConvertUtils::CheckTensorIntValue("axis", input_value, prim_name);
|
||||
ValuePtr InferValue(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) const override {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
const int64_t input_num = 2;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, input_num, prim->name());
|
||||
auto infered_type = InferType(prim, input_args);
|
||||
MS_EXCEPTION_IF_NULL(infered_type);
|
||||
auto input_value_ptr = input_args[2]->BuildValue();
|
||||
auto input_value_type_id = input_args[2]->BuildType()->type_id();
|
||||
auto tmp_shape = InferShape(prim, input_args);
|
||||
MS_EXCEPTION_IF_NULL(tmp_shape);
|
||||
auto infered_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(tmp_shape)[kShape];
|
||||
auto input_value_tensor = input_value_ptr->cast<tensor::TensorPtr>();
|
||||
tensor::TensorPtr infer_result;
|
||||
if (input_value_type_id == kNumberTypeBool) {
|
||||
infer_result = CreateValuedTensor<bool>(infered_type, infered_shape, GetValue<bool>(input_value_ptr));
|
||||
} else if (input_value_type_id == kNumberTypeFloat32) {
|
||||
infer_result = CreateValuedTensor<float>(infered_type, infered_shape, GetValue<float>(input_value_ptr));
|
||||
} else if (input_value_type_id == kNumberTypeInt32) {
|
||||
infer_result = CreateValuedTensor<int32_t>(infered_type, infered_shape, GetValue<int32_t>(input_value_ptr));
|
||||
} else if (input_value_type_id == kNumberTypeInt64) {
|
||||
infer_result = CreateValuedTensor<int64_t>(infered_type, infered_shape, GetValue<int64_t>(input_value_ptr));
|
||||
} else if (input_value_type_id == kNumberTypeComplex64) {
|
||||
infer_result = CreateComplexTensor<std::complex<float>>(
|
||||
infered_type, infered_shape, static_cast<std::complex<float> *>(input_value_tensor->data_c())[0]);
|
||||
} else if (input_value_type_id == kNumberTypeComplex128) {
|
||||
infer_result = CreateComplexTensor<std::complex<double>>(
|
||||
infered_type, infered_shape, static_cast<std::complex<double> *>(input_value_tensor->data_c())[0]);
|
||||
}
|
||||
return infer_result;
|
||||
}
|
||||
return std::make_shared<abstract::Shape>(output_shape);
|
||||
}
|
||||
|
||||
TypePtr FillInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
std::vector<size_t> inputsIndex{0, 1, 2};
|
||||
if (input_args.size() == kIndex2) {
|
||||
inputsIndex[kIndex1] = 0;
|
||||
inputsIndex[kIndex2] = 1;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const auto &prim_name = primitive->name();
|
||||
// check
|
||||
ValuePtr dtype_value;
|
||||
TypePtr value_dtype;
|
||||
auto input2_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
|
||||
auto input2_dtype = input_args[2]->BuildType();
|
||||
TypePtr input2_element_dtype;
|
||||
if (input2_dtype->isa<TensorType>()) {
|
||||
auto tensor_type = input2_dtype->cast<TensorTypePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tensor_type);
|
||||
input2_element_dtype = tensor_type->element();
|
||||
} else {
|
||||
input2_element_dtype = input2_dtype;
|
||||
}
|
||||
if (input2_shape.size() > 1 || (input2_shape.size() == 1 && input2_shape[0] > 1)) {
|
||||
MS_EXCEPTION(TypeError) << "For '" << prim_name
|
||||
<< "', the value input only takes scalar or scalar within a tensor!";
|
||||
}
|
||||
dtype_value = input_args[0]->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(dtype_value);
|
||||
if (!dtype_value->isa<Type>()) {
|
||||
MS_EXCEPTION(TypeError)
|
||||
<< "For '" << prim_name
|
||||
<< "', the supported data type is ['bool', 'int8', 'int16', 'int32', 'int64', 'uint8', 'uint16','uint32', "
|
||||
"'uint64','float16', 'float32', 'float64'], but got an invalid dtype!";
|
||||
}
|
||||
auto output_dtype = dtype_value->cast<TypePtr>();
|
||||
std::set<int64_t> GetValueDependArgIndices() const override { return {0, 2}; }
|
||||
};
|
||||
|
||||
const std::set<TypePtr> valid_types = {kBool, kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16,
|
||||
kUInt32, kUInt64, kFloat16, kFloat32, kFloat64, kComplex64, kComplex128};
|
||||
CheckAndConvertUtils::CheckSubClass("dtype", input2_element_dtype, valid_types, prim_name);
|
||||
return CheckAndConvertUtils::CheckSubClass("dtype", output_dtype, valid_types, prim_name);
|
||||
}
|
||||
|
||||
AbstractBasePtr FillInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto type = FillInferType(primitive, input_args);
|
||||
auto shape = FillInferShape(primitive, input_args);
|
||||
return MakeAbstract(shape, type);
|
||||
}
|
||||
|
||||
ValuePtr FillInferValue(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
const int64_t input_num = 2;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, input_num, prim->name());
|
||||
auto infered_type = FillInferType(prim, input_args);
|
||||
MS_EXCEPTION_IF_NULL(infered_type);
|
||||
auto input_value_ptr = input_args[2]->BuildValue();
|
||||
auto input_value_type_id = input_args[2]->BuildType()->type_id();
|
||||
auto tmp_shape = FillInferShape(prim, input_args);
|
||||
if (tmp_shape->IsDynamic()) {
|
||||
return kAnyValue;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(tmp_shape);
|
||||
auto infered_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(tmp_shape)[kShape];
|
||||
auto input_value_tensor = input_value_ptr->cast<tensor::TensorPtr>();
|
||||
tensor::TensorPtr infer_result;
|
||||
if (input_value_type_id == kNumberTypeBool) {
|
||||
infer_result = CreateValuedTensor<bool>(infered_type, infered_shape, GetValue<bool>(input_value_ptr));
|
||||
} else if (input_value_type_id == kNumberTypeFloat32) {
|
||||
infer_result = CreateValuedTensor<float>(infered_type, infered_shape, GetValue<float>(input_value_ptr));
|
||||
} else if (input_value_type_id == kNumberTypeInt32) {
|
||||
infer_result = CreateValuedTensor<int32_t>(infered_type, infered_shape, GetValue<int32_t>(input_value_ptr));
|
||||
} else if (input_value_type_id == kNumberTypeInt64) {
|
||||
infer_result = CreateValuedTensor<int64_t>(infered_type, infered_shape, GetValue<int64_t>(input_value_ptr));
|
||||
} else if (input_value_type_id == kNumberTypeComplex64) {
|
||||
infer_result = CreateComplexTensor<std::complex<float>>(
|
||||
infered_type, infered_shape, static_cast<std::complex<float> *>(input_value_tensor->data_c())[0]);
|
||||
} else if (input_value_type_id == kNumberTypeComplex128) {
|
||||
infer_result = CreateComplexTensor<std::complex<double>>(
|
||||
infered_type, infered_shape, static_cast<std::complex<double> *>(input_value_tensor->data_c())[0]);
|
||||
}
|
||||
return infer_result;
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Fill, prim::kPrimFill, FillInfer, FillInferValue, false);
|
||||
REGISTER_PRIMITIVE_OP_INFER_IMPL(Fill, prim::kPrimFill, FillInfer, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1438,7 +1438,7 @@ class MatrixBandPart(Primitive):
|
|||
self.init_prim_io_names(inputs=['x', 'lower', 'upper'], outputs=['y'])
|
||||
|
||||
|
||||
class Fill(Primitive):
|
||||
class Fill(PrimitiveWithCheck):
|
||||
"""
|
||||
Create a Tensor of the specified shape and fill it with the specified value.
|
||||
|
||||
|
@ -1465,6 +1465,34 @@ class Fill(Primitive):
|
|||
"""Initialize Fill"""
|
||||
self.init_prim_io_names(inputs=['type', 'shape', 'value'], outputs=['y'])
|
||||
|
||||
def __call__(self, dtype, dims, x):
|
||||
if dtype not in mstype.all_types and dtype not in [mstype.uint16, mstype.uint32, mstype.uint64]:
|
||||
raise TypeError(
|
||||
f"For \'{self.name}\', the supported data type is ['bool', 'int8', 'int16', 'int32', 'int64', 'uint8', "
|
||||
"'uint16', 'uint32', 'uint64','float16', 'float32', 'float64'], but got an invalid dtype!.")
|
||||
x_nptype = mstype.dtype_to_nptype(dtype)
|
||||
if not isinstance(dims, Tensor) and not isinstance(dims, tuple):
|
||||
raise TypeError(f"For \'{self.name}\', input[1] must be tensor.")
|
||||
if not isinstance(x, Tensor) and not isinstance(x, float) and not isinstance(x, int):
|
||||
raise TypeError(f"For \'{self.name}\', the value input only takes scalar or scalar within a tensor!.")
|
||||
if isinstance(dims, Tensor):
|
||||
dims = dims.asnumpy()
|
||||
if isinstance(x, Tensor):
|
||||
x = x.asnumpy()
|
||||
ret = np.full(dims, x, x_nptype)
|
||||
return Tensor(ret)
|
||||
|
||||
def infer_value(self, dtype, dims, x):
|
||||
x_nptype = mstype.dtype_to_nptype(dtype)
|
||||
if dims is not None and None not in dims and x is not None:
|
||||
if isinstance(dims, Tensor):
|
||||
dims = dims.asnumpy()
|
||||
if isinstance(x, Tensor):
|
||||
x = x.asnumpy()
|
||||
ret = np.full(dims, x, x_nptype)
|
||||
return Tensor(ret)
|
||||
return None
|
||||
|
||||
|
||||
class Fills(Primitive):
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue