!49641 FillV2 vmap support when shape is a tensor.

Merge pull request !49641 from haozhang/fillv2_vmap
This commit is contained in:
i-robot 2023-03-06 08:29:21 +00:00 committed by Gitee
commit 0b3a90cfc8
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 58 additions and 12 deletions

View File

@ -30,6 +30,7 @@ from mindspore.ops.operations import _grad_ops as G
from mindspore.ops.operations.array_ops import Fills, UniqueConsecutive, Col2Im, NonZero, IndexFill, \
TensorScatterElements
from mindspore.ops.operations.random_ops import RandomPoisson
from mindspore.ops.operations._inner_ops import DynamicBroadcastTo
from mindspore.ops.primitive import Primitive
from mindspore.ops._vmap.vmap_base import vmap_rules_getters, vmap_general_preprocess, _bdim_at_front, \
_raise_value_error, _vmap_clone_prim, _handle_broadcasting, get_unsupported_dynamic_vmap_rule, _broadcast_by_axis, \
@ -841,6 +842,11 @@ def get_fill_vmap_rule(prim, axis_size):
return vmap_rule
@constexpr
def to_tensor_with_type(x, type):
return Tensor(x, type)
@vmap_rules_getters.register(P.FillV2)
def get_fill_v2_vmap_rule(prim, axis_size):
"""VmapRule for `FillV2` operation."""
@ -848,19 +854,16 @@ def get_fill_v2_vmap_rule(prim, axis_size):
prim = Primitive(prim)
def vmap_rule(shape_bdim, value_bdim):
is_all_none, result = vmap_general_preprocess(prim, shape_bdim,
value_bdim)
is_all_none, result = vmap_general_preprocess(prim, shape_bdim, value_bdim)
if is_all_none:
return result
value_shape, shape_dim = shape_bdim
if isinstance(value_shape, (Tensor_, Tensor)):
_raise_value_error(
"For `P.FillV2`, when the `shape` is a tensor, VMAP is not supported!"
)
if shape_dim is not None:
_raise_value_error(
"The source axis of `shape` in `P.FillV2` must be None, but got {}."
.format(shape_dim))
value, vdim = value_bdim
value_rank = F.rank(value)
if value_rank != 1 or vdim != 0:
@ -869,7 +872,26 @@ def get_fill_v2_vmap_rule(prim, axis_size):
"can be rank: 1 with source axis: 0 in vmap scope, but got value rank: "
"{} with source axis: {}.".format(value_rank, vdim))
value = F.reshape(value, (axis_size,) + (1,) * len(value_shape))
out = P.BroadcastTo((axis_size,) + value_shape)(value)
out = None
if isinstance(value_shape, (Tensor_, Tensor)):
value_shape_rank = F.rank(value_shape)
if value_shape_rank != 1:
_raise_value_error(
"The `shape` in `P.FillV2` must be 1-D tensor, thus the shape only "
"can be rank: 1, but got shape rank: "
"{}.".format(value_shape_rank))
axis_size_tensor = to_tensor_with_type((axis_size,),
F.dtype(value_shape))
broad_cast_shape = F.concat((axis_size_tensor, value_shape))
out = DynamicBroadcastTo()(value, broad_cast_shape)
elif isinstance(value_shape, tuple):
out = P.BroadcastTo((axis_size,) + value_shape)(value)
else:
_raise_value_error(
f"For `P.FillV2`, the input `shape` should be Tuple or Tensor, but got `shape`: {value_shape}."
)
return out, 0
return vmap_rule

View File

@ -44,11 +44,11 @@ def dyn_case():
assert out.asnumpy().shape == (2, 3)
def vmap_case():
def cla_fillv2(shape, value):
return P.FillV2()(shape, value)
def cla_fillv2(shape, value):
return P.FillV2()(shape, value)
def vmap_tuple_case():
shape = (2, 2)
value = Tensor([1, 2], ms.float32)
outputs = vmap(cla_fillv2, in_axes=(None, 0), out_axes=0)(shape, value)
@ -57,6 +57,15 @@ def vmap_case():
assert np.allclose(expect, outputs.asnumpy(), 1.e-4, 1.e-7)
def vmap_tensor_case():
shape = Tensor((2, 2), ms.int32)
value = Tensor([1, 2], ms.float32)
outputs = vmap(cla_fillv2, in_axes=(None, 0), out_axes=0)(shape, value)
expect = np.array([[[1, 1], [1, 1]], [[2, 2], [2, 2]]]).astype(np.float32)
assert np.allclose(expect, outputs.asnumpy(), 1.e-4, 1.e-7)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu
@pytest.mark.env_onecard
@ -68,7 +77,22 @@ def test_fill_v2_dyn():
"""
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
dyn_case()
vmap_case()
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
dyn_case()
vmap_case()
@pytest.mark.level2
@pytest.mark.platform_x86_gpu
@pytest.mark.env_onecard
def test_fill_v2_vmap():
"""
Feature: test FillV2 vmap in gpu.
Description: inputs is static shape.
Expectation: expect correct out result.
"""
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
vmap_tuple_case()
vmap_tensor_case()
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
vmap_tuple_case()
vmap_tensor_case()