diff --git a/mindspore/python/mindspore/ops/_vmap/vmap_array_ops.py b/mindspore/python/mindspore/ops/_vmap/vmap_array_ops.py index 330c3a456a2..1f0e10bd115 100644 --- a/mindspore/python/mindspore/ops/_vmap/vmap_array_ops.py +++ b/mindspore/python/mindspore/ops/_vmap/vmap_array_ops.py @@ -30,6 +30,7 @@ from mindspore.ops.operations import _grad_ops as G from mindspore.ops.operations.array_ops import Fills, UniqueConsecutive, Col2Im, NonZero, IndexFill, \ TensorScatterElements from mindspore.ops.operations.random_ops import RandomPoisson +from mindspore.ops.operations._inner_ops import DynamicBroadcastTo from mindspore.ops.primitive import Primitive from mindspore.ops._vmap.vmap_base import vmap_rules_getters, vmap_general_preprocess, _bdim_at_front, \ _raise_value_error, _vmap_clone_prim, _handle_broadcasting, get_unsupported_dynamic_vmap_rule, _broadcast_by_axis, \ @@ -841,6 +842,11 @@ def get_fill_vmap_rule(prim, axis_size): return vmap_rule +@constexpr +def to_tensor_with_type(x, type): + return Tensor(x, type) + + @vmap_rules_getters.register(P.FillV2) def get_fill_v2_vmap_rule(prim, axis_size): """VmapRule for `FillV2` operation.""" @@ -848,19 +854,16 @@ def get_fill_v2_vmap_rule(prim, axis_size): prim = Primitive(prim) def vmap_rule(shape_bdim, value_bdim): - is_all_none, result = vmap_general_preprocess(prim, shape_bdim, - value_bdim) + is_all_none, result = vmap_general_preprocess(prim, shape_bdim, value_bdim) if is_all_none: return result + value_shape, shape_dim = shape_bdim - if isinstance(value_shape, (Tensor_, Tensor)): - _raise_value_error( - "For `P.FillV2`, when the `shape` is a tensor, VMAP is not supported!" - ) if shape_dim is not None: _raise_value_error( "The source axis of `shape` in `P.FillV2` must be None, but got {}." .format(shape_dim)) + value, vdim = value_bdim value_rank = F.rank(value) if value_rank != 1 or vdim != 0: @@ -869,7 +872,26 @@ def get_fill_v2_vmap_rule(prim, axis_size): "can be rank: 1 with source axis: 0 in vmap scope, but got value rank: " "{} with source axis: {}.".format(value_rank, vdim)) value = F.reshape(value, (axis_size,) + (1,) * len(value_shape)) - out = P.BroadcastTo((axis_size,) + value_shape)(value) + + out = None + if isinstance(value_shape, (Tensor_, Tensor)): + value_shape_rank = F.rank(value_shape) + if value_shape_rank != 1: + _raise_value_error( + "The `shape` in `P.FillV2` must be 1-D tensor, thus the shape only " + "can be rank: 1, but got shape rank: " + "{}.".format(value_shape_rank)) + axis_size_tensor = to_tensor_with_type((axis_size,), + F.dtype(value_shape)) + broad_cast_shape = F.concat((axis_size_tensor, value_shape)) + out = DynamicBroadcastTo()(value, broad_cast_shape) + elif isinstance(value_shape, tuple): + out = P.BroadcastTo((axis_size,) + value_shape)(value) + else: + _raise_value_error( + f"For `P.FillV2`, the input `shape` should be Tuple or Tensor, but got `shape`: {value_shape}." + ) + return out, 0 return vmap_rule diff --git a/tests/st/ops/gpu/test_fill_v2_op.py b/tests/st/ops/gpu/test_fill_v2_op.py index a4e9c2d8ce5..be955ce78fc 100644 --- a/tests/st/ops/gpu/test_fill_v2_op.py +++ b/tests/st/ops/gpu/test_fill_v2_op.py @@ -44,11 +44,11 @@ def dyn_case(): assert out.asnumpy().shape == (2, 3) -def vmap_case(): +def cla_fillv2(shape, value): + return P.FillV2()(shape, value) - def cla_fillv2(shape, value): - return P.FillV2()(shape, value) +def vmap_tuple_case(): shape = (2, 2) value = Tensor([1, 2], ms.float32) outputs = vmap(cla_fillv2, in_axes=(None, 0), out_axes=0)(shape, value) @@ -57,6 +57,15 @@ def vmap_case(): assert np.allclose(expect, outputs.asnumpy(), 1.e-4, 1.e-7) +def vmap_tensor_case(): + shape = Tensor((2, 2), ms.int32) + value = Tensor([1, 2], ms.float32) + outputs = vmap(cla_fillv2, in_axes=(None, 0), out_axes=0)(shape, value) + + expect = np.array([[[1, 1], [1, 1]], [[2, 2], [2, 2]]]).astype(np.float32) + assert np.allclose(expect, outputs.asnumpy(), 1.e-4, 1.e-7) + + @pytest.mark.level0 @pytest.mark.platform_x86_gpu @pytest.mark.env_onecard @@ -68,7 +77,22 @@ def test_fill_v2_dyn(): """ context.set_context(mode=context.GRAPH_MODE, device_target='GPU') dyn_case() - vmap_case() context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') dyn_case() - vmap_case() + + +@pytest.mark.level2 +@pytest.mark.platform_x86_gpu +@pytest.mark.env_onecard +def test_fill_v2_vmap(): + """ + Feature: test FillV2 vmap in gpu. + Description: inputs is static shape. + Expectation: expect correct out result. + """ + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + vmap_tuple_case() + vmap_tensor_case() + context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') + vmap_tuple_case() + vmap_tensor_case()