gather: add vmap rule for gather operations
Signed-off-by: lvhaoyu <>
This commit is contained in:
@ -21,8 +21,9 @@ from mindspore.ops import functional as F
from mindspore.ops import constexpr
from ..primitive import Primitive
from .._vmap.vmap_base import vmap_rules_getters, vmap_general_preprocess, _bdim_at_front, _raise_value_error, \
_handle_broadcasting, get_unsupported_dynamic_vmap_rule
from ..operations.array_ops import Fills
from ..operations.array_ops import UniqueConsecutive
@ -46,21 +47,6 @@ def get_cast_vmap_rule(prim, axis_size):
return vmap_rule
def _get_transpose_batch_perm(dim, perm, x_rank):
"""Generate batch_perm based on the original perm of transpose operation and dim of the input."""
if dim < 0:
dim = dim + x_rank
batch_perm = (dim,)
for i in perm:
if i < dim:
batch_perm = batch_perm + (i,)
index = i + 1
batch_perm = batch_perm + (index,)
return batch_perm
def _get_prefix_expand_shape(indices_shape):
"""Generate prefix and expand shape by indices shape."""
@ -85,6 +71,20 @@ def get_transpose_vmap_rule(prim, axis_size):
if isinstance(prim, str):
prim = Primitive(prim)
def _get_transpose_batch_perm(dim, perm, x_rank):
"""Generate batch_perm based on the original perm of transpose operation and dim of the input."""
if dim < 0:
dim = dim + x_rank
batch_perm = (dim,)
for i in perm:
if i < dim:
batch_perm = batch_perm + (i,)
index = i + 1
batch_perm = batch_perm + (index,)
return batch_perm
def vmap_rule(x_bdim, perm_bdim):
is_all_none, result = vmap_general_preprocess(prim, x_bdim, perm_bdim)
if is_all_none:
@ -103,24 +103,23 @@ def get_transpose_vmap_rule(prim, axis_size):
return vmap_rule
def _get_tile_shape(input_shape, multiples):
input_ndim = len(input_shape)
multiples_ndim = len(multiples)
input_shape = (input_shape[0],) + (1,) * (multiples_ndim - input_ndim + 1) + input_shape[1:]
multiples = (1,) * (input_ndim - multiples_ndim) + multiples
input_expand_shape = (input_shape[0],) + tuple([j for i in input_shape[1:] for j in [1, i]])
repeat_shape = (input_shape[0],) + tuple([k for pair in zip(multiples[1:], input_shape[1:]) for k in pair])
output_shape = tuple([a * b for a, b in zip(input_shape, multiples)])
return input_expand_shape, repeat_shape, output_shape
def get_tile_vmap_rule(prim, axis_size):
"""VmapRule for `P.Tile` operation."""
if isinstance(prim, str):
prim = Primitive(prim)
def _get_tile_shape(input_shape, multiples):
input_ndim = len(input_shape)
multiples_ndim = len(multiples)
input_shape = (input_shape[0],) + (1,) * (multiples_ndim - input_ndim + 1) + input_shape[1:]
multiples = (1,) * (input_ndim - multiples_ndim) + multiples
input_expand_shape = (input_shape[0],) + tuple([j for i in input_shape[1:] for j in [1, i]])
repeat_shape = (input_shape[0],) + tuple([k for pair in zip(multiples[1:], input_shape[1:]) for k in pair])
output_shape = tuple([a * b for a, b in zip(input_shape, multiples)])
return input_expand_shape, repeat_shape, output_shape
def vmap_rule(input_bdim, multiples_bdim):
is_all_none, result = vmap_general_preprocess(prim, input_bdim, multiples_bdim)
if is_all_none:
@ -653,3 +652,93 @@ def get_masked_fill_vmap_rule(prim, axis_size):
return (out, 0)
return vmap_rule
def get_gather_vmap_rule(prim, axis_size):
"""VmapRule for `Gather` operation. """
if isinstance(prim, str):
prim_name = prim
prim = Primitive(prim)
prim_name =
def process_axis(axis, x_shape_size, has_xdim: bool, has_idim: bool):
if has_xdim and has_idim:
if axis < 0:
axis = x_shape_size - 1 + axis
elif has_xdim:
if axis >= 0:
axis = axis + 1
if axis < 0:
axis = x_shape_size + axis
return axis
def get_x_dst_shape(x_shape, axis):
target_axis_size = x_shape[axis + 1]
x_dst_shape = x_shape[0:axis] + (axis_size * target_axis_size,) + x_shape[axis+2:]
max_axis_size = axis_size * target_axis_size
return target_axis_size, x_dst_shape, max_axis_size
def vmap_rule(x_bdim, indices_bdim, axis_bdim):
is_all_none, result = vmap_general_preprocess(prim, x_bdim, indices_bdim)
if is_all_none:
return result
x, x_dim = x_bdim
indices, indices_dim = indices_bdim
axis, axis_dim = axis_bdim
if axis_dim is not None:
_raise_value_error("The source axis of `axis` in {} must be None, but got {}.".format(prim_name, axis_dim))
x_shape_len = len(x.shape)
if x_dim is not None and indices_dim is None:
x = _bdim_at_front(x, x_dim, axis_size)
axis = process_axis(axis, x_shape_len, True, False)
output = prim(x, indices, axis)
return (output, 0)
if x_dim is None and indices_dim is not None:
indices = _bdim_at_front(indices, indices_dim, axis_size)
axis = process_axis(axis, x_shape_len, False, True)
output = prim(x, indices, axis)
return (output, axis)
x = _bdim_at_front(x, x_dim, axis_size)
indices = _bdim_at_front(indices, indices_dim, axis_size)
axis = process_axis(axis, x_shape_len, True, True)
x = mnp.moveaxis(x, 0, axis)
x_shape = x.shape
target_axis_size, x_dst_shape, max_axis_size = get_x_dst_shape(x_shape, axis)
x = x.reshape(x_dst_shape)
counts_shape = indices.shape
counts = mnp.arange(0, axis_size, 1)
counts = F.mul(counts, target_axis_size)
counts = P.BroadcastTo(counts_shape[1:] + (axis_size,))(counts)
counts = mnp.moveaxis(counts, -1, 0)
indices_out_of_bound = mnp.where(indices > target_axis_size - 1, x=max_axis_size, y=0)
indices = F.add(indices, counts)
indices = F.add(indices, indices_out_of_bound)
output = prim(x, indices, axis)
return (output, axis)
return vmap_rule
get_unsupported_dynamic_vmap_rule = vmap_rules_getters.register(P.Unique)(get_unsupported_dynamic_vmap_rule)
get_unsupported_dynamic_vmap_rule =\
@ -246,3 +246,36 @@ def get_unop_vmap_rule(prim, axis_size):
return (out, dim)
return vmap_rule
def get_unsupported_dynamic_vmap_rule(prim, axis_size):
Vmaprule for the dynamic shape operator whose output shape can not be determined,
which means the output shape of every batch will be different, so the operator can not support vmap.
Forexample, the `Unique` operator, whose output shape is also base on the value of the input.
Otherwise, the other dynamic shape operators need to implement their own vmaprules.
if isinstance(prim, str):
prim_name = prim
prim = Primitive(prim)
prim_name =
def get_all_dims(*params_bdim):
dims = []
for bdim in params_bdim:
_, dim = bdim
return dims
def vmap_rule(*params_bdim):
is_all_none, result = vmap_general_preprocess(prim, params_bdim)
if not is_all_none:
dims = get_all_dims(*params_bdim)
_raise_value_error("For operator {}, all axis should be none, but got {}".format(prim_name, dims))
return result
return vmap_rule
@ -306,27 +306,27 @@ def get_reducer_vmap_rule(prim, axis_size):
return vmap_rule
def _get_index_add_batch_axis(axis, x_dim, x_ndim):
"""get batch_axis for IndexAdd."""
# case1: batch not exists
if x_dim is None:
return axis
# case2: batch exists
if axis < -x_ndim or x_dim >= x_ndim:
raise ValueError("'axis' of 'IndexAdd' should be in range [{}, {}), but got {}".format(-x_ndim, x_ndim, axis))
if axis < 0:
axis = axis + x_ndim
if x_dim > axis:
return axis
return axis + 1
def get_index_add_vmap_rule(prim, axis_size):
"""VmapRule for IndexAdd."""
axis = prim.axis
def _get_index_add_batch_axis(axis, x_dim, x_ndim):
"""get batch_axis for IndexAdd."""
# case1: batch not exists
if x_dim is None:
return axis
# case2: batch exists
if axis < -x_ndim or x_dim >= x_ndim:
raise ValueError("'axis' of 'IndexAdd' should be in range [{}, {}), but got {}"\
.format(-x_ndim, x_ndim, axis))
if axis < 0:
axis = axis + x_ndim
if x_dim > axis:
return axis
return axis + 1
def vmap_rule(x_bdim, indices_bdim, y_bdim, u_monad):
x, x_dim = x_bdim
indices, indices_dim = indices_bdim
@ -24,56 +24,6 @@ from ..primitive import Primitive
from .._vmap.vmap_base import vmap_rules_getters, vmap_general_preprocess, get_unop_vmap_rule, _bdim_at_front
def _get_bias_broadcast_shape(x_shape, bias_shape, bias_dim, data_format):
"""Get the broadcast shape for bias and use it in 'BiasAdd' VmapRule."""
bias_rank = len(bias_shape)
if bias_dim is None and bias_rank == 1:
bias_batch = 1
bias_channel = bias_shape[0]
elif bias_dim is not None and bias_rank == 2:
bias_batch = bias_shape[0]
bias_channel = bias_shape[1]
raise ValueError("The rank of 'bias' in 'BiasAdd' operator is invalid, which is rank: {}"
" with bias_dim: {}.".format(bias_rank, bias_dim))
# The 'Biasadd' operator supports 2-5 dimensions input, and another 'batch' dimension is added to the front in
# vmap scenario.
x_min_rank = 3
x_max_rank = 5
if data_format == "NCDHW":
x_max_rank += 1
x_rank = len(x_shape)
if x_rank < x_min_rank or x_rank > x_max_rank:
raise ValueError("For primitive[BiasAdd] in vmap, the dims of input_x must be in [x_min_rank, {}"
"], but got {}.".format(x_max_rank, x_rank))
if data_format == "NHWC":
# In the 'NHWC' data format ('BN**C' actually), the last dimension is channel axis.
x_channel = x_shape[-1]
if x_channel != bias_channel:
raise ValueError("For 'BiadAdd, bias_channel must be equal to x_channel, "
"but got date format: {}, got bias_channel: {}, "
"x_channel: {}.".format(data_format, bias_channel, x_channel))
if bias_dim is None:
bias_broadcast_shape = (1,) * (x_rank - bias_rank) + (bias_channel,)
bias_broadcast_shape = (bias_batch,) + (1,) * (x_rank - bias_rank) + (bias_channel,)
# In the 'NCHW' or 'NCDHW' data format ('BNC**' actually), the third dimension is channel axis.
x_channel = x_shape[2]
if x_channel != bias_channel:
raise ValueError("For 'BiadAdd, bias_channel must be equal to x_channel, but got date format: "
"{}, got bias_channel: {}, x_channel: {}.".format(data_format, bias_channel, x_channel))
bias_broadcast_shape = (bias_batch, 1, bias_channel)
if x_rank == x_min_rank:
return bias_broadcast_shape
bias_broadcast_shape = bias_broadcast_shape + (1,) * (x_rank - x_min_rank)
return bias_broadcast_shape
def get_bias_add_vmap_rule(prim, axis_size):
"""VmapRule for `BiasAdd` operation."""
@ -84,6 +34,56 @@ def get_bias_add_vmap_rule(prim, axis_size):
data_format = prim.data_format
add_op = P.Add()
def _get_bias_broadcast_shape(x_shape, bias_shape, bias_dim, data_format):
"""Get the broadcast shape for bias and use it in 'BiasAdd' VmapRule."""
bias_rank = len(bias_shape)
if bias_dim is None and bias_rank == 1:
bias_batch = 1
bias_channel = bias_shape[0]
elif bias_dim is not None and bias_rank == 2:
bias_batch = bias_shape[0]
bias_channel = bias_shape[1]
raise ValueError("The rank of 'bias' in 'BiasAdd' operator is invalid, which is rank: {}"
" with bias_dim: {}.".format(bias_rank, bias_dim))
# The 'Biasadd' operator supports 2-5 dimensions input, and another 'batch' dimension is added to the front in
# vmap scenario.
x_min_rank = 3
x_max_rank = 5
if data_format == "NCDHW":
x_max_rank += 1
x_rank = len(x_shape)
if x_rank < x_min_rank or x_rank > x_max_rank:
raise ValueError("For primitive[BiasAdd] in vmap, the dims of input_x must be in [x_min_rank, {}"
"], but got {}.".format(x_max_rank, x_rank))
if data_format == "NHWC":
# In the 'NHWC' data format ('BN**C' actually), the last dimension is channel axis.
x_channel = x_shape[-1]
if x_channel != bias_channel:
raise ValueError("For 'BiadAdd, bias_channel must be equal to x_channel, "
"but got date format: {}, got bias_channel: {}, "
"x_channel: {}.".format(data_format, bias_channel, x_channel))
if bias_dim is None:
bias_broadcast_shape = (1,) * (x_rank - bias_rank) + (bias_channel,)
bias_broadcast_shape = (bias_batch,) + (1,) * (x_rank - bias_rank) + (bias_channel,)
# In the 'NCHW' or 'NCDHW' data format ('BNC**' actually), the third dimension is channel axis.
x_channel = x_shape[2]
if x_channel != bias_channel:
raise ValueError("For 'BiadAdd, bias_channel must be equal to x_channel, but got date format: "
"{}, got bias_channel: {}, x_channel: {}."\
.format(data_format, bias_channel, x_channel))
bias_broadcast_shape = (bias_batch, 1, bias_channel)
if x_rank == x_min_rank:
return bias_broadcast_shape
bias_broadcast_shape = bias_broadcast_shape + (1,) * (x_rank - x_min_rank)
return bias_broadcast_shape
def vmap_rule(input_bdim, bias_bdim):
is_all_none, result = vmap_general_preprocess(prim, input_bdim, bias_bdim)
if is_all_none:
@ -270,7 +270,7 @@ def jet(fn, primals, series):
- **out_primals** (Union[Tensor, list[Tensor]]) - The output of `fn(primals)`.
- **out_series** (Union[Tensor, list[Tensor]]) - The `1` to `i+1`-th order of derivative of output with respect
to the inputs.
to the inputs.
TypeError: If `primals` is not a tensor or tuple of tensors.
@ -304,10 +304,8 @@ def jet(fn, primals, series):
[[2.319777 2.4825778]
[1.1515628 0.4691642]] [[[ 1.2533808 -1.0331168 ]
[-1.1400385 -0.3066662 ]]
[[-1.2748207 -1.8274734 ]
[ 0.966121 0.55551505]]
[[-4.0515366 3.6724353 ]
[ 0.5053504 -0.52061415]]]
@ -369,7 +367,7 @@ def derivative(fn, primals, order):
- **out_primals** (Union[Tensor, list[Tensor]]) - The output of `fn(primals)`.
- **out_series** (Union[Tensor, list[Tensor]]) - The `order`-th order of derivative of output with respect
to the inputs.
to the inputs.
TypeError: If `primals` is not a tensor or tuple of tensors.
@ -20,6 +20,7 @@ from mindspore.ops import operations as P
import mindspore.nn as nn
import mindspore.context as context
from mindspore.common import dtype as mstype
from mindspore.ops.functional import vmap
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
@ -101,7 +102,203 @@ def test_gatherv2_axisN1():
assert np.all(diff < error)
assert np.all(-diff < error)
def cal_vmap_gather(x, indices, axis):
return P.Gather()(x, indices, axis)
def test_gather_vmap_basic():
Feature: gather vmap test on cpu.
Description: test the rightness of vmap gather basic.
Expectation: use vmap rule's result equal to manually batched.
x = Tensor(np.array([[1, 2, 3], [4, 5, 6]]).astype(np.float32))
indices = Tensor(np.array([[0, 1], [1, 2]]).astype(np.int32))
axis = 0
outputs = vmap(cal_vmap_gather, in_axes=(0, 0, None), out_axes=0)(x, indices, axis)
print("output:\n", outputs)
expect = np.array([[1, 2],
[5, 6]]).astype(np.float32)
assert np.allclose(outputs.asnumpy(), expect)
def test_gather_vmap_negative_axis():
Feature: gather vmap test on cpu.
Description: test the rightness of vmap gather when axis is negative.
Expectation: use vmap rule's result equal to manually batched.
x = Tensor(np.array([[[1, 2, 3],
[4, 5, 6]],
[[7, 8, 9],
[10, 11, 12]]]).astype(np.float32))
indices = Tensor(np.array([[0, 1], [1, 0]]).astype(np.int32))
axis = -2
outputs = vmap(cal_vmap_gather, in_axes=(0, 0, None), out_axes=0)(x, indices, axis)
print("output:\n", outputs)
expect = np.array([[[1, 2, 3],
[4, 5, 6]],
[[10, 11, 12],
[7, 8, 9]]]).astype(np.float32)
assert np.allclose(outputs.asnumpy(), expect)
def test_gather_vmap_with_inaxes():
Feature: gather vmap test on cpu.
Description: test the rightness of vmap gather when in_axes is not zero.
Expectation: use vmap rule's result equal to manually batched.
x = Tensor(np.array([[[1, 2, 3],
[4, 5, 6]],
[[7, 8, 9],
[10, 11, 12]]]).astype(np.float32))
x = np.moveaxis(x, 0, 2)
indices = Tensor(np.array([[0, 1], [1, 0]]).astype(np.int32))
indices = np.moveaxis(indices, 0, 1)
axis = 0
outputs = vmap(cal_vmap_gather, in_axes=(2, 1, None), out_axes=0)(x, indices, axis)
print("output:\n", outputs)
expect = np.array([[[1, 2, 3],
[4, 5, 6]],
[[10, 11, 12],
[7, 8, 9]]]).astype(np.float32)
assert np.allclose(outputs.asnumpy(), expect)
def test_gather_vmap_indices_outofbound():
Feature: gather vmap test on cpu.
Description: test the rightness of vmap gather when indices out of bound.
Expectation: use vmap rule's result equal to manually batched.
x = Tensor(np.array([[[1, 2, 3],
[4, 5, 6]],
[[7, 8, 9],
[10, 11, 12]]]).astype(np.float32))
indices = Tensor(np.array([[0, 2], [2, 0]]).astype(np.int32))
axis = 0
outputs = vmap(cal_vmap_gather, in_axes=(0, 0, None), out_axes=0)(x, indices, axis)
print("output:\n", outputs)
expect = np.array([[[1, 2, 3],
[0, 0, 0]],
[[0, 0, 0],
[7, 8, 9]]]).astype(np.float32)
assert np.allclose(outputs.asnumpy(), expect)
def test_gather_vmap_xdim_is_none():
Feature: gather vmap test on cpu.
Description: test the rightness of vmap gather when no xdim.
Expectation: use vmap rule's result equal to manually batched.
x = Tensor(np.array([1, 2, 3]).astype(np.float32))
indices = Tensor(np.array([[0, 1], [2, 0]]).astype(np.int32))
axis = 0
outputs = vmap(cal_vmap_gather, in_axes=(None, 0, None), out_axes=0)(x, indices, axis)
print("output:\n", outputs)
expect = np.array([[1, 2],
[3, 1]]).astype(np.float32)
assert np.allclose(outputs.asnumpy(), expect)
def test_gather_vmap_idim_is_none():
Feature: gather vmap test on cpu.
Description: test the rightness of nested vmap gather.
Expectation: use vmap rule's result equal to manually batched.
x = Tensor(np.array([[1, 2, 3], [4, 5, 6]]).astype(np.float32))
indices = Tensor(np.array([0, 1]).astype(np.int32))
axis = 0
outputs = vmap(cal_vmap_gather, in_axes=(0, None, None), out_axes=0)(x, indices, axis)
print("output:\n", outputs)
expect = np.array([[1, 2],
[4, 5]]).astype(np.float32)
assert np.allclose(outputs.asnumpy(), expect)
def test_gather_vmap_nested():
Feature: gather vmap test on cpu.
Description: test the rightness of nested vmap gather.
Expectation: use vmap rule's result equal to manually batched.
x = Tensor(np.array([[[1, 2, 3],
[4, 5, 6]],
[[7, 8, 9],
[10, 11, 12]]]).astype(np.float32))
indices = Tensor(np.array([[[0, 1],
[1, 0]],
[[0, 1],
[1, 0]]]).astype(np.int32))
axis = 0
outputs = vmap(vmap(cal_vmap_gather, in_axes=(0, 0, None), out_axes=0),
in_axes=(0, 0, None), out_axes=0)(x, indices, axis)
print("output:\n", outputs)
expect = np.array([[[1, 2],
[5, 4]],
[[7, 8],
[11, 10]]]).astype(np.float32)
assert np.allclose(outputs.asnumpy(), expect)
if __name__ == '__main__':
Reference in New Issue