forked from mindspore-Ecosystem/mindspore
!40882 Add vmap for StridedSliceGrad
Merge pull request !40882 from OwenSec/master
This commit is contained in:
commit
e523715b7b
|
@ -1820,6 +1820,52 @@ def get_stridedslice_vmap_rule(prim, axis_size):
|
|||
return vmap_rule
|
||||
|
||||
|
||||
@vmap_rules_getters.register(G.StridedSliceGrad)
|
||||
def get_stridedslice_grad_vmap_rule(prim, axis_size):
|
||||
"""VmapRule for `StridedSliceGrad`."""
|
||||
new_begin_mask = prim.begin_mask * 2 + 1
|
||||
new_end_mask = prim.end_mask * 2 + 1
|
||||
new_ellipsis_mask = prim.ellipsis_mask
|
||||
new_new_axis_mask = prim.new_axis_mask * 2
|
||||
new_shrink_axis_mask = prim.shrink_axis_mask * 2
|
||||
batch_stridedslice_grad = G.StridedSliceGrad(new_begin_mask, new_end_mask, new_ellipsis_mask, new_new_axis_mask, \
|
||||
new_shrink_axis_mask)
|
||||
|
||||
@constexpr
|
||||
def get_new_xshape_begin_end_strided(xshape, begin, end, strided):
|
||||
new_xshape = (axis_size,) + xshape
|
||||
new_begin = (0,) + begin
|
||||
new_end = (axis_size,) + end
|
||||
new_strided = (1,) + strided
|
||||
return new_xshape, new_begin, new_end, new_strided
|
||||
|
||||
def vmap_rule(dy_bdim, xshape_bdim, begin_bdim, end_bdim, strided_bdim):
|
||||
is_all_none, result = vmap_general_preprocess(prim, dy_bdim, xshape_bdim, begin_bdim, end_bdim, strided_bdim)
|
||||
if is_all_none:
|
||||
return result
|
||||
|
||||
dy, dy_dim = dy_bdim
|
||||
xshape, xshape_dim = xshape_bdim
|
||||
begin, begin_dim = begin_bdim
|
||||
end, end_dim = end_bdim
|
||||
strided, strided_dim = strided_bdim
|
||||
|
||||
if any(dim is not None for dim in [xshape_dim, begin_dim, end_dim, strided_dim]):
|
||||
_raise_value_error("vmap of `StridedSliceGrad` not support `xshape`, `begin`, "
|
||||
"`end` or `strided` has batch dimension, "
|
||||
"but got {}, {}, {}, {}".format(xshape_dim, begin_dim, end_dim, strided_dim))
|
||||
|
||||
# dy_dim and x_dim are not None, and others are None
|
||||
dy = _bdim_at_front(dy, dy_dim, axis_size)
|
||||
|
||||
new_xshape, new_begin, new_end, new_strided = get_new_xshape_begin_end_strided(xshape, begin, end, strided)
|
||||
|
||||
result = batch_stridedslice_grad(dy, new_xshape, new_begin, new_end, new_strided)
|
||||
return result, 0
|
||||
|
||||
return vmap_rule
|
||||
|
||||
|
||||
get_unsupported_dynamic_vmap_rule = vmap_rules_getters.register(NonZero)(get_unsupported_dynamic_vmap_rule)
|
||||
get_unsupported_dynamic_vmap_rule = vmap_rules_getters.register(P.Unique)(get_unsupported_dynamic_vmap_rule)
|
||||
get_unsupported_dynamic_vmap_rule = \
|
||||
|
|
|
@ -23,6 +23,7 @@ from mindspore.common import dtype as mstype
|
|||
from mindspore.common.api import ms_function
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops.operations import _grad_ops as G
|
||||
from mindspore.ops.functional import vmap
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
|
||||
|
||||
|
@ -60,6 +61,7 @@ class StridedSliceGrad2(nn.Cell):
|
|||
def construct(self, dy, x):
|
||||
return self.ssg(dy, self.shape(x), (0, 0, 0), (1, 4, 2), (1, 1, 1))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
|
@ -71,6 +73,21 @@ def test_slice2():
|
|||
expect = [[[0., 1.], [2., 3.], [4., 5.], [6., 7.]], [[0., 0.], [0., 0.], [0., 0.], [0., 0.]]]
|
||||
assert (output.asnumpy() == expect).all()
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_slice()
|
||||
test_slice2()
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_slice_vmap():
|
||||
"""
|
||||
Feature: Test stridedslicegrad CPU vmap.
|
||||
Description: The inputs are two tensors, x and dy.
|
||||
Expectation: The output matches to the np benchmark.
|
||||
"""
|
||||
x = Tensor(np.array([[[1., 1., 1.], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]],
|
||||
[[5, 5, 5], [6, 7, 8]]]).astype(np.float32))
|
||||
dy = P.Stack()([Tensor(np.array([[[5., 1., 5.], [6., 1., 8.]]]).astype(np.float32)) for _ in range(8)])
|
||||
ssg_vmap = vmap(StridedSliceGrad(), in_axes=(0, None))
|
||||
output = ssg_vmap(dy, x)
|
||||
expect = P.Stack()([Tensor(np.array([[[0, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 0]],
|
||||
[[5, 1, 5], [6, 1, 8]]])) for _ in range(8)]).asnumpy()
|
||||
assert (output.asnumpy() == expect).all()
|
||||
|
|
Loading…
Reference in New Issue