diff --git a/docs/api/api_python/mindspore.ops.primitive.rst b/docs/api/api_python/mindspore.ops.primitive.rst index 01ab0cb724a..580b04986fb 100644 --- a/docs/api/api_python/mindspore.ops.primitive.rst +++ b/docs/api/api_python/mindspore.ops.primitive.rst @@ -493,6 +493,7 @@ Array操作 mindspore.ops.Expand mindspore.ops.ExpandDims mindspore.ops.FFTWithSize + mindspore.ops.FillV2 mindspore.ops.FloatStatus mindspore.ops.FillDiagonal mindspore.ops.Gather diff --git a/docs/api/api_python/ops/mindspore.ops.FillV2.rst b/docs/api/api_python/ops/mindspore.ops.FillV2.rst new file mode 100644 index 00000000000..0e96b2bcc8a --- /dev/null +++ b/docs/api/api_python/ops/mindspore.ops.FillV2.rst @@ -0,0 +1,22 @@ +mindspore.ops.FillV2 +==================== + +.. py:class:: mindspore.ops.FillV2 + + 创建一个Tensor,其shape由 `shape` 指定,其值则由 `value` 进行填充。 + + 输入: + - **shape** (Tensor) - 1-D Tensor,指定了输出Tensor的shape。 + 其数据类型必须是int32或int64。 + - **value** (Tensor) - 一个标量Tensor,其值用于填充输出Tensor。 + `value` 必须是0-D的,且其数据类型必须是以下之一: + bool、int8、int16、int32、int64、uint8、uint16、uint32、uint64、float16、float32、float64。 + + 输出: + - **y** (Tensor) - Tensor,其shape和值如上所述。 + + 异常: + - **ValueError** - 如果 `shape` 不是1-D Tensor。 + - **TypeError** - 如果 `shape` 的数据类型不是int32或者int64。 + - **ValueError** - 如果 `value` 不是0-D Tensor。 + - **ValueError** - 如果输出元素的数量多于1000000。 diff --git a/docs/api/api_python_en/mindspore.ops.primitive.rst b/docs/api/api_python_en/mindspore.ops.primitive.rst index 6788c988f76..92874951085 100644 --- a/docs/api/api_python_en/mindspore.ops.primitive.rst +++ b/docs/api/api_python_en/mindspore.ops.primitive.rst @@ -492,6 +492,7 @@ Array Operation mindspore.ops.Expand mindspore.ops.ExpandDims mindspore.ops.FFTWithSize + mindspore.ops.FillV2 mindspore.ops.FloatStatus mindspore.ops.FillDiagonal mindspore.ops.Gather diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/reg_ascend_vm_op_adaptation_info.h b/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/reg_ascend_vm_op_adaptation_info.h index bed7ae1962e..3b3b3bc0aae 100644 --- a/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/reg_ascend_vm_op_adaptation_info.h +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/reg_ascend_vm_op_adaptation_info.h @@ -268,7 +268,13 @@ REG_ASCEND_VM_OP_ADAPTATION_INFO(kEuclideanNormOpName) REG_ASCEND_VM_OP_ADAPTATION_INFO(kExpandDimsOpName).set_target_op_name(kExpandDimsOpName).set_input_attr_info(1, "int"); -REG_ASCEND_VM_OP_ADAPTATION_INFO(kFillOpName).set_backend_op_name(kFillDOpName); +REG_ASCEND_VM_OP_ADAPTATION_INFO(kFillOpName).set_target_op_name(kFillDOpName); + +REG_ASCEND_VM_OP_ADAPTATION_INFO(kFillV2OpName) + .set_backend_op_name(kFillOpName) + .set_target_op_name(kFillDOpName) + .set_need_tbe_check_supported(true) + .set_input_attr_info(0, "listInt"); // In hisi code, first check dynamic impl in GatherV2 REG_ASCEND_VM_OP_ADAPTATION_INFO(kGatherOpName) diff --git a/mindspore/core/ops/fill_v2.cc b/mindspore/core/ops/fill_v2.cc index 8d3461a2e66..c6ae06603f0 100644 --- a/mindspore/core/ops/fill_v2.cc +++ b/mindspore/core/ops/fill_v2.cc @@ -33,8 +33,6 @@ namespace { abstract::ShapePtr FillV2InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); auto prim_name = primitive->name(); - auto input1_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; - auto input2_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; auto max_length_ptr = primitive->GetAttr("max_length"); MS_EXCEPTION_IF_NULL(max_length_ptr); @@ -42,26 +40,41 @@ abstract::ShapePtr FillV2InferShape(const PrimitivePtr &primitive, const std::ve const int64_t kDimOne = 1; const int64_t kDimZero = 0; - CheckAndConvertUtils::CheckInteger("rank of shape", SizeToLong(input1_shape.size()), kEqual, kDimOne, prim_name); + auto input2_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; + if (!IsDynamic(input2_shape)) { CheckAndConvertUtils::CheckInteger("rank of value", SizeToLong(input2_shape.size()), kEqual, kDimZero, prim_name); } - if (input_args[kInputIndex0]->isa() && - input_args[kInputIndex0]->BuildValue()->isa()) { - auto value_ptr = input_args[kInputIndex0]->BuildValue(); - MS_EXCEPTION_IF_NULL(value_ptr); - auto output_shape = CheckAndConvertUtils::CheckTensorIntValue("shape", value_ptr, prim_name); - for (size_t i = 0; i < output_shape.size(); ++i) { - CheckAndConvertUtils::CheckInteger("the " + std::to_string(i) + "th dimension of input shape", output_shape[i], - kGreaterThan, kDimZero, prim_name); - } - CheckAndConvertUtils::CheckInteger("the number of elements of output", SizeToLong(SizeOf(output_shape)), kLessEqual, - max_length, prim_name); - return std::make_shared(output_shape); - } else { - return std::make_shared(std::vector{-2}); + auto input1_type = input_args[kInputIndex0]->BuildType(); + auto value_ptr = input_args[kInputIndex0]->BuildValue(); + MS_EXCEPTION_IF_NULL(value_ptr); + + if (!IsValueKnown(value_ptr)) { + return std::make_shared(ShapeVector{abstract::Shape::kShapeRankAny}); } + + ShapeVector output_shape{}; + if (input1_type->isa()) { + auto input1_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; + CheckAndConvertUtils::CheckInteger("rank of shape", SizeToLong(input1_shape.size()), kEqual, kDimOne, prim_name); + output_shape = CheckAndConvertUtils::CheckTensorIntValue("shape", value_ptr, prim_name); + } else if (IsIdentidityOrSubclass(input1_type, kTuple)) { + output_shape = CheckAndConvertUtils::CheckTupleInt("shape", value_ptr, prim_name); + } else { + MS_EXCEPTION(TypeError) << "For primitive[" << prim_name << "], the `shape` " + << " must be a tuple or tensor with all Int elements, but got " << value_ptr->type_name() + << "."; + } + + for (size_t i = 0; i < output_shape.size(); ++i) { + CheckAndConvertUtils::CheckInteger("the " + std::to_string(i) + "th dimension of input shape", output_shape[i], + kGreaterThan, kDimZero, prim_name); + } + CheckAndConvertUtils::CheckInteger("the number of elements of output", SizeToLong(SizeOf(output_shape)), kLessEqual, + max_length, prim_name); + + return std::make_shared(output_shape); } TypePtr FillV2InferType(const PrimitivePtr &primitive, const std::vector &input_args) { @@ -71,8 +84,10 @@ TypePtr FillV2InferType(const PrimitivePtr &primitive, const std::vectorBuildType(); // Check the data type of the first input - const std::set input1_valid_types = {kInt32, kInt64}; - (void)CheckAndConvertUtils::CheckTensorTypeValid("input1 datatype", input1_type, input1_valid_types, prim_name); + if (input1_type->isa()) { + const std::set input1_valid_types = {kInt32, kInt64}; + (void)CheckAndConvertUtils::CheckTensorTypeValid("input1 datatype", input1_type, input1_valid_types, prim_name); + } // Check the data type of the second input and infer the data type of the output from the second input (void)CheckAndConvertUtils::CheckTensorTypeValid("output datatype", input2_type, common_valid_types_with_complex_and_bool, prim_name); diff --git a/mindspore/python/mindspore/ops/_grad_experimental/grad_array_ops.py b/mindspore/python/mindspore/ops/_grad_experimental/grad_array_ops.py index 5429bb8e5bb..67629df99b5 100644 --- a/mindspore/python/mindspore/ops/_grad_experimental/grad_array_ops.py +++ b/mindspore/python/mindspore/ops/_grad_experimental/grad_array_ops.py @@ -57,6 +57,26 @@ from mindspore.ops.operations import _grad_ops as G from mindspore import context +@bprop_getters.register(P.FillV2) +def get_bprop_fill_v2(self): + """Generate bprop for FillV2""" + sum_op = P.ReduceSum() + cast_op = P.Cast() + + def bprop(shape, value, out, dout): + dout_type = F.dtype(dout) + type_list = [mstype.int8, mstype.int16, mstype.int32, mstype.int64, + mstype.uint8, mstype.uint16, mstype.uint32, mstype.uint64, mstype.float16] + if dout_type in type_list: + dout = cast_op(dout, mstype.float32) + if dout_type == mstype.float64: + dout = cast_op(dout, mstype.float32) + dvalue = sum_op(dout) + return zeros_like(shape), cast_op(dvalue, dout_type) + + return bprop + + @bprop_getters.register(StridedSliceV2) def get_bprop_strided_slice_v2(self): """Generate bprop for StridedSliceV2""" diff --git a/mindspore/python/mindspore/ops/_grad_experimental/grad_inner_ops.py b/mindspore/python/mindspore/ops/_grad_experimental/grad_inner_ops.py index fefb52b7df8..c54f29154f2 100644 --- a/mindspore/python/mindspore/ops/_grad_experimental/grad_inner_ops.py +++ b/mindspore/python/mindspore/ops/_grad_experimental/grad_inner_ops.py @@ -23,7 +23,6 @@ from mindspore.ops.operations import _grad_ops as G from mindspore.ops import functional as F from mindspore.ops import operations as P from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like -from mindspore.common import dtype as mstype @bprop_getters.register(inner.TensorCopySlices) @@ -82,26 +81,6 @@ def get_bprop_parallel_resize_bilinear(self): return bprop -@bprop_getters.register(inner.FillV2) -def get_bprop_fill_v2(self): - """Generate bprop for FillV2""" - sum_op = P.ReduceSum() - cast_op = P.Cast() - - def bprop(shape, value, out, dout): - dout_type = F.dtype(dout) - type_list = [mstype.int8, mstype.int16, mstype.int32, mstype.int64, - mstype.uint8, mstype.uint16, mstype.uint32, mstype.uint64, mstype.float16] - if dout_type in type_list: - dout = cast_op(dout, mstype.float32) - if dout_type == mstype.float64: - dout = cast_op(dout, mstype.float32) - dvalue = sum_op(dout) - return zeros_like(shape), cast_op(dvalue, dout_type) - - return bprop - - @bprop_getters.register(inner.ConvertToDynamic) def get_bprop_gpu_convert_to_dynamic_rank(self): """Get backprop for ConvertToDynamic.""" diff --git a/mindspore/python/mindspore/ops/operations/__init__.py b/mindspore/python/mindspore/ops/operations/__init__.py index 44865ae58bd..727e9d1db50 100644 --- a/mindspore/python/mindspore/ops/operations/__init__.py +++ b/mindspore/python/mindspore/ops/operations/__init__.py @@ -21,7 +21,7 @@ A collection of operators to build neural networks or to compute functions. from ._embedding_cache_ops import (CacheSwapTable, UpdateCache, MapCacheIdx, SubAndFilter, MapUniform, DynamicAssign, PadAndShift) -from ._inner_ops import (FillV2, MatmulDDS, DSDMatmul, Cummin, ExtractImagePatches) +from ._inner_ops import (MatmulDDS, DSDMatmul, Cummin, ExtractImagePatches) from ._quant_ops import * from ._thor_ops import (CusBatchMatMul, CusCholeskyTrsm, CusFusedAbsMax1, CusImg2Col, CusMatMulCubeDenseLeft, CusMatMulCubeFraczRightMul, CusMatMulCube, CusMatrixCombine, CusTranspose02314, @@ -48,7 +48,7 @@ from .array_ops import (ArgMaxWithValue, ArgMinWithValue, Argmax, Argmin, BatchT MatrixDiagPartV3, MatrixDiagV3, MatrixSetDiagV3, NonZero, Expand, Col2Im, ConjugateTranspose, FillDiagonal, Fills, ResizeNearestNeighborV2, RightShift, ScatterAddWithAxis, ScatterNdMul, SegmentMean, SegmentProd, SegmentSum, SegmentMax, SegmentMin, Tril, Triu, - UniqueConsecutive, UnravelIndex) + UniqueConsecutive, UnravelIndex, FillV2) from .comm_ops import (AllGather, AllReduce, NeighborExchange, NeighborExchangeV2, AlltoAll, _AllSwap, ReduceScatter, Broadcast, _MirrorOperator, _MirrorMiniStepOperator, _MiniStepAllGather, ReduceOp, _VirtualDataset, diff --git a/mindspore/python/mindspore/ops/operations/_inner_ops.py b/mindspore/python/mindspore/ops/operations/_inner_ops.py index 3bf457c9ae9..d8798a051df 100755 --- a/mindspore/python/mindspore/ops/operations/_inner_ops.py +++ b/mindspore/python/mindspore/ops/operations/_inner_ops.py @@ -50,50 +50,6 @@ string_mul = Primitive("string_mul") string_getitem = Primitive("string_getitem") -class FillV2(Primitive): - """ - Creates a tensor filled with a scalar value. - - Creates a tensor with shape described by the first argument and fills it with values in the second argument. - - Inputs: - - **shape** (tensor) - The specified shape of output tensor. The shape of the input1 must be 1D and - the data type of the input1 must be int32 or int64. - - **value** (tensor) - Value to fill the returned tensor. The shape of the input2 must be 0D and - the data type of the input2 must be one of the following types: - bool, int8, int16, int32, int64, uint8, uint16, uint32, uint64, float16, float32, float64. - - Outputs: - A tensor, has the same type and shape as input value. - - Raises: - ValueError: If `shape` is not a 1D tensor. - TypeError: If the data type of `shape` is not int32 or int64. - ValueError: If `value` is not a 0D tensor. - ValueError: If the number of output elements is greater than 1000000. - - Supported Platforms: - ``CPU`` - - Examples: - >>> fillV2 = ops.FillV2() - >>> output = fillV2(Tensor([2, 3], mindspore.int32), Tensor(1, mindspore.float32)) - >>> print(output) - [[1. 1. 1.] - [1. 1. 1.]] - >>> output = fillV2(Tensor([3, 3], mindspore.int64), Tensor(0, mindspore.int32)) - >>> print(output) - [[0 0 0] - [0 0 0] - [0 0 0]] - """ - - @prim_attr_register - def __init__(self): - """Initialize FillV2""" - self.add_prim_attr("max_length", 1000000) - - class ExtractImagePatches(Primitive): r""" Extracts patches from images. diff --git a/mindspore/python/mindspore/ops/operations/array_ops.py b/mindspore/python/mindspore/ops/operations/array_ops.py index 578fd2267df..92afa9684b4 100755 --- a/mindspore/python/mindspore/ops/operations/array_ops.py +++ b/mindspore/python/mindspore/ops/operations/array_ops.py @@ -1501,6 +1501,49 @@ class Fills(Primitive): self.init_prim_io_names(inputs=['x', 'value'], outputs=['y']) +class FillV2(Primitive): + """ + Creates a tensor with shape described by `shape` and fills it with values in `value` . + + Inputs: + - **shape** (tensor) - 1-D Tensor, the specified shape of output tensor. + It's dtype must be int32 or int64. + - **value** (tensor) - A scalar tensor, the value to fill the output tensor. + The shape of `value` must be 0D and it's dtype must be one of the following types: + bool, int8, int16, int32, int64, uint8, uint16, uint32, uint64, float16, float32, float64. + + Outputs: + - **y** (Tensor) - A tensor, it's shape and value are described above. + + Raises: + ValueError: If `shape` is not a 1-D tensor. + TypeError: If the data type of `shape` is not int32 or int64. + ValueError: If `value` is not a 0-D tensor. + ValueError: If the number of output elements is greater than 1000000. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> fillV2 = ops.FillV2() + >>> output = fillV2(Tensor([2, 3], mindspore.int32), Tensor(1, mindspore.float32)) + >>> print(output) + [[1. 1. 1.] + [1. 1. 1.]] + >>> output = fillV2(Tensor([3, 3], mindspore.int64), Tensor(0, mindspore.int32)) + >>> print(output) + [[0 0 0] + [0 0 0] + [0 0 0]] + """ + + @prim_attr_register + def __init__(self): + """Initialize FillV2""" + self.add_prim_attr("max_length", 1000000) + self.init_prim_io_names(inputs=['shape', 'value'], outputs=['y']) + + class Ones(Primitive): r""" Creates a tensor filled with value ones. diff --git a/tests/st/ops/gpu/test_fill_v2_op.py b/tests/st/ops/gpu/test_fill_v2_op.py index 499087d0384..96c93d2a4b3 100644 --- a/tests/st/ops/gpu/test_fill_v2_op.py +++ b/tests/st/ops/gpu/test_fill_v2_op.py @@ -15,14 +15,14 @@ import pytest import mindspore as ms from mindspore import context, nn, Tensor -from mindspore.ops.operations import _inner_ops as inner +from mindspore.ops.operations import array_ops as op class Net(nn.Cell): def __init__(self): super(Net, self).__init__() - self.op = inner.FillV2() + self.op = op.FillV2() def construct(self, shape, value): return self.op(shape, value) diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index d39c429060d..1484e8c432c 100644 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -4694,7 +4694,7 @@ test_case_other_ops = [ 'desc_inputs': [Tensor(np.array([[[[0.5, 0.5, 0.5]]]], np.float32))], 'desc_bprop': [Tensor(np.array([[[[0., 0., 0.5]]]], np.float32))]}), ('FillV2', { - 'block': inner.FillV2(), + 'block': P.FillV2(), 'desc_inputs': [Tensor([2, 3], mstype.int32), Tensor(1, mstype.float32)], 'desc_bprop': [Tensor([[1, 1, 1], [1, 1, 1]], mstype.float32)]}),