!40882 Add vmap for StridedSliceGrad

Merge pull request !40882 from OwenSec/master
This commit is contained in:
i-robot 2022-08-27 08:30:55 +00:00 committed by Gitee
commit e523715b7b
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 66 additions and 3 deletions

View File

@ -1820,6 +1820,52 @@ def get_stridedslice_vmap_rule(prim, axis_size):
return vmap_rule
@vmap_rules_getters.register(G.StridedSliceGrad)
def get_stridedslice_grad_vmap_rule(prim, axis_size):
"""VmapRule for `StridedSliceGrad`."""
new_begin_mask = prim.begin_mask * 2 + 1
new_end_mask = prim.end_mask * 2 + 1
new_ellipsis_mask = prim.ellipsis_mask
new_new_axis_mask = prim.new_axis_mask * 2
new_shrink_axis_mask = prim.shrink_axis_mask * 2
batch_stridedslice_grad = G.StridedSliceGrad(new_begin_mask, new_end_mask, new_ellipsis_mask, new_new_axis_mask, \
new_shrink_axis_mask)
@constexpr
def get_new_xshape_begin_end_strided(xshape, begin, end, strided):
new_xshape = (axis_size,) + xshape
new_begin = (0,) + begin
new_end = (axis_size,) + end
new_strided = (1,) + strided
return new_xshape, new_begin, new_end, new_strided
def vmap_rule(dy_bdim, xshape_bdim, begin_bdim, end_bdim, strided_bdim):
is_all_none, result = vmap_general_preprocess(prim, dy_bdim, xshape_bdim, begin_bdim, end_bdim, strided_bdim)
if is_all_none:
return result
dy, dy_dim = dy_bdim
xshape, xshape_dim = xshape_bdim
begin, begin_dim = begin_bdim
end, end_dim = end_bdim
strided, strided_dim = strided_bdim
if any(dim is not None for dim in [xshape_dim, begin_dim, end_dim, strided_dim]):
_raise_value_error("vmap of `StridedSliceGrad` not support `xshape`, `begin`, "
"`end` or `strided` has batch dimension, "
"but got {}, {}, {}, {}".format(xshape_dim, begin_dim, end_dim, strided_dim))
# dy_dim and x_dim are not None, and others are None
dy = _bdim_at_front(dy, dy_dim, axis_size)
new_xshape, new_begin, new_end, new_strided = get_new_xshape_begin_end_strided(xshape, begin, end, strided)
result = batch_stridedslice_grad(dy, new_xshape, new_begin, new_end, new_strided)
return result, 0
return vmap_rule
get_unsupported_dynamic_vmap_rule = vmap_rules_getters.register(NonZero)(get_unsupported_dynamic_vmap_rule)
get_unsupported_dynamic_vmap_rule = vmap_rules_getters.register(P.Unique)(get_unsupported_dynamic_vmap_rule)
get_unsupported_dynamic_vmap_rule = \

View File

@ -23,6 +23,7 @@ from mindspore.common import dtype as mstype
from mindspore.common.api import ms_function
from mindspore.ops import operations as P
from mindspore.ops.operations import _grad_ops as G
from mindspore.ops.functional import vmap
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
@ -60,6 +61,7 @@ class StridedSliceGrad2(nn.Cell):
def construct(self, dy, x):
return self.ssg(dy, self.shape(x), (0, 0, 0), (1, 4, 2), (1, 1, 1))
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@ -71,6 +73,21 @@ def test_slice2():
expect = [[[0., 1.], [2., 3.], [4., 5.], [6., 7.]], [[0., 0.], [0., 0.], [0., 0.], [0., 0.]]]
assert (output.asnumpy() == expect).all()
if __name__ == '__main__':
test_slice()
test_slice2()
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_slice_vmap():
"""
Feature: Test stridedslicegrad CPU vmap.
Description: The inputs are two tensors, x and dy.
Expectation: The output matches to the np benchmark.
"""
x = Tensor(np.array([[[1., 1., 1.], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]],
[[5, 5, 5], [6, 7, 8]]]).astype(np.float32))
dy = P.Stack()([Tensor(np.array([[[5., 1., 5.], [6., 1., 8.]]]).astype(np.float32)) for _ in range(8)])
ssg_vmap = vmap(StridedSliceGrad(), in_axes=(0, None))
output = ssg_vmap(dy, x)
expect = P.Stack()([Tensor(np.array([[[0, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 0]],
[[5, 1, 5], [6, 1, 8]]])) for _ in range(8)]).asnumpy()
assert (output.asnumpy() == expect).all()