forked from mindspore-Ecosystem/mindspore
!49641 FillV2 vmap support when shape is a tensor.
Merge pull request !49641 from haozhang/fillv2_vmap
This commit is contained in:
commit
0b3a90cfc8
|
@ -30,6 +30,7 @@ from mindspore.ops.operations import _grad_ops as G
|
|||
from mindspore.ops.operations.array_ops import Fills, UniqueConsecutive, Col2Im, NonZero, IndexFill, \
|
||||
TensorScatterElements
|
||||
from mindspore.ops.operations.random_ops import RandomPoisson
|
||||
from mindspore.ops.operations._inner_ops import DynamicBroadcastTo
|
||||
from mindspore.ops.primitive import Primitive
|
||||
from mindspore.ops._vmap.vmap_base import vmap_rules_getters, vmap_general_preprocess, _bdim_at_front, \
|
||||
_raise_value_error, _vmap_clone_prim, _handle_broadcasting, get_unsupported_dynamic_vmap_rule, _broadcast_by_axis, \
|
||||
|
@ -841,6 +842,11 @@ def get_fill_vmap_rule(prim, axis_size):
|
|||
return vmap_rule
|
||||
|
||||
|
||||
@constexpr
|
||||
def to_tensor_with_type(x, type):
|
||||
return Tensor(x, type)
|
||||
|
||||
|
||||
@vmap_rules_getters.register(P.FillV2)
|
||||
def get_fill_v2_vmap_rule(prim, axis_size):
|
||||
"""VmapRule for `FillV2` operation."""
|
||||
|
@ -848,19 +854,16 @@ def get_fill_v2_vmap_rule(prim, axis_size):
|
|||
prim = Primitive(prim)
|
||||
|
||||
def vmap_rule(shape_bdim, value_bdim):
|
||||
is_all_none, result = vmap_general_preprocess(prim, shape_bdim,
|
||||
value_bdim)
|
||||
is_all_none, result = vmap_general_preprocess(prim, shape_bdim, value_bdim)
|
||||
if is_all_none:
|
||||
return result
|
||||
|
||||
value_shape, shape_dim = shape_bdim
|
||||
if isinstance(value_shape, (Tensor_, Tensor)):
|
||||
_raise_value_error(
|
||||
"For `P.FillV2`, when the `shape` is a tensor, VMAP is not supported!"
|
||||
)
|
||||
if shape_dim is not None:
|
||||
_raise_value_error(
|
||||
"The source axis of `shape` in `P.FillV2` must be None, but got {}."
|
||||
.format(shape_dim))
|
||||
|
||||
value, vdim = value_bdim
|
||||
value_rank = F.rank(value)
|
||||
if value_rank != 1 or vdim != 0:
|
||||
|
@ -869,7 +872,26 @@ def get_fill_v2_vmap_rule(prim, axis_size):
|
|||
"can be rank: 1 with source axis: 0 in vmap scope, but got value rank: "
|
||||
"{} with source axis: {}.".format(value_rank, vdim))
|
||||
value = F.reshape(value, (axis_size,) + (1,) * len(value_shape))
|
||||
out = P.BroadcastTo((axis_size,) + value_shape)(value)
|
||||
|
||||
out = None
|
||||
if isinstance(value_shape, (Tensor_, Tensor)):
|
||||
value_shape_rank = F.rank(value_shape)
|
||||
if value_shape_rank != 1:
|
||||
_raise_value_error(
|
||||
"The `shape` in `P.FillV2` must be 1-D tensor, thus the shape only "
|
||||
"can be rank: 1, but got shape rank: "
|
||||
"{}.".format(value_shape_rank))
|
||||
axis_size_tensor = to_tensor_with_type((axis_size,),
|
||||
F.dtype(value_shape))
|
||||
broad_cast_shape = F.concat((axis_size_tensor, value_shape))
|
||||
out = DynamicBroadcastTo()(value, broad_cast_shape)
|
||||
elif isinstance(value_shape, tuple):
|
||||
out = P.BroadcastTo((axis_size,) + value_shape)(value)
|
||||
else:
|
||||
_raise_value_error(
|
||||
f"For `P.FillV2`, the input `shape` should be Tuple or Tensor, but got `shape`: {value_shape}."
|
||||
)
|
||||
|
||||
return out, 0
|
||||
|
||||
return vmap_rule
|
||||
|
|
|
@ -44,11 +44,11 @@ def dyn_case():
|
|||
assert out.asnumpy().shape == (2, 3)
|
||||
|
||||
|
||||
def vmap_case():
|
||||
def cla_fillv2(shape, value):
|
||||
return P.FillV2()(shape, value)
|
||||
|
||||
def cla_fillv2(shape, value):
|
||||
return P.FillV2()(shape, value)
|
||||
|
||||
def vmap_tuple_case():
|
||||
shape = (2, 2)
|
||||
value = Tensor([1, 2], ms.float32)
|
||||
outputs = vmap(cla_fillv2, in_axes=(None, 0), out_axes=0)(shape, value)
|
||||
|
@ -57,6 +57,15 @@ def vmap_case():
|
|||
assert np.allclose(expect, outputs.asnumpy(), 1.e-4, 1.e-7)
|
||||
|
||||
|
||||
def vmap_tensor_case():
|
||||
shape = Tensor((2, 2), ms.int32)
|
||||
value = Tensor([1, 2], ms.float32)
|
||||
outputs = vmap(cla_fillv2, in_axes=(None, 0), out_axes=0)(shape, value)
|
||||
|
||||
expect = np.array([[[1, 1], [1, 1]], [[2, 2], [2, 2]]]).astype(np.float32)
|
||||
assert np.allclose(expect, outputs.asnumpy(), 1.e-4, 1.e-7)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu
|
||||
@pytest.mark.env_onecard
|
||||
|
@ -68,7 +77,22 @@ def test_fill_v2_dyn():
|
|||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
dyn_case()
|
||||
vmap_case()
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
|
||||
dyn_case()
|
||||
vmap_case()
|
||||
|
||||
|
||||
@pytest.mark.level2
|
||||
@pytest.mark.platform_x86_gpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_fill_v2_vmap():
|
||||
"""
|
||||
Feature: test FillV2 vmap in gpu.
|
||||
Description: inputs is static shape.
|
||||
Expectation: expect correct out result.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
vmap_tuple_case()
|
||||
vmap_tensor_case()
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
|
||||
vmap_tuple_case()
|
||||
vmap_tensor_case()
|
||||
|
|
Loading…
Reference in New Issue