gather: add vmap rule for gather operations

Signed-off-by: lvhaoyu <lvhaoyu@huawei.com>
This commit is contained in:
lvhaoyu 2022-05-16 19:42:51 +08:00
parent 6cf2a84fa7
commit fadd7da5ef
6 changed files with 415 additions and 98 deletions

View File

@ -21,8 +21,9 @@ from mindspore.ops import functional as F
from mindspore.ops import constexpr
from ..primitive import Primitive
from .._vmap.vmap_base import vmap_rules_getters, vmap_general_preprocess, _bdim_at_front, _raise_value_error, \
_handle_broadcasting
_handle_broadcasting, get_unsupported_dynamic_vmap_rule
from ..operations.array_ops import Fills
from ..operations.array_ops import UniqueConsecutive
@vmap_rules_getters.register("Cast")
@ -46,21 +47,6 @@ def get_cast_vmap_rule(prim, axis_size):
return vmap_rule
@constexpr
def _get_transpose_batch_perm(dim, perm, x_rank):
"""Generate batch_perm based on the original perm of transpose operation and dim of the input."""
if dim < 0:
dim = dim + x_rank
batch_perm = (dim,)
for i in perm:
if i < dim:
batch_perm = batch_perm + (i,)
else:
index = i + 1
batch_perm = batch_perm + (index,)
return batch_perm
@constexpr
def _get_prefix_expand_shape(indices_shape):
"""Generate prefix and expand shape by indices shape."""
@ -85,6 +71,20 @@ def get_transpose_vmap_rule(prim, axis_size):
if isinstance(prim, str):
prim = Primitive(prim)
@constexpr
def _get_transpose_batch_perm(dim, perm, x_rank):
"""Generate batch_perm based on the original perm of transpose operation and dim of the input."""
if dim < 0:
dim = dim + x_rank
batch_perm = (dim,)
for i in perm:
if i < dim:
batch_perm = batch_perm + (i,)
else:
index = i + 1
batch_perm = batch_perm + (index,)
return batch_perm
def vmap_rule(x_bdim, perm_bdim):
is_all_none, result = vmap_general_preprocess(prim, x_bdim, perm_bdim)
if is_all_none:
@ -103,24 +103,23 @@ def get_transpose_vmap_rule(prim, axis_size):
return vmap_rule
@constexpr
def _get_tile_shape(input_shape, multiples):
input_ndim = len(input_shape)
multiples_ndim = len(multiples)
input_shape = (input_shape[0],) + (1,) * (multiples_ndim - input_ndim + 1) + input_shape[1:]
multiples = (1,) * (input_ndim - multiples_ndim) + multiples
input_expand_shape = (input_shape[0],) + tuple([j for i in input_shape[1:] for j in [1, i]])
repeat_shape = (input_shape[0],) + tuple([k for pair in zip(multiples[1:], input_shape[1:]) for k in pair])
output_shape = tuple([a * b for a, b in zip(input_shape, multiples)])
return input_expand_shape, repeat_shape, output_shape
@vmap_rules_getters.register(P.Tile)
def get_tile_vmap_rule(prim, axis_size):
"""VmapRule for `P.Tile` operation."""
if isinstance(prim, str):
prim = Primitive(prim)
@constexpr
def _get_tile_shape(input_shape, multiples):
input_ndim = len(input_shape)
multiples_ndim = len(multiples)
input_shape = (input_shape[0],) + (1,) * (multiples_ndim - input_ndim + 1) + input_shape[1:]
multiples = (1,) * (input_ndim - multiples_ndim) + multiples
input_expand_shape = (input_shape[0],) + tuple([j for i in input_shape[1:] for j in [1, i]])
repeat_shape = (input_shape[0],) + tuple([k for pair in zip(multiples[1:], input_shape[1:]) for k in pair])
output_shape = tuple([a * b for a, b in zip(input_shape, multiples)])
return input_expand_shape, repeat_shape, output_shape
def vmap_rule(input_bdim, multiples_bdim):
is_all_none, result = vmap_general_preprocess(prim, input_bdim, multiples_bdim)
if is_all_none:
@ -653,3 +652,93 @@ def get_masked_fill_vmap_rule(prim, axis_size):
return (out, 0)
return vmap_rule
@vmap_rules_getters.register(P.Gather)
def get_gather_vmap_rule(prim, axis_size):
"""VmapRule for `Gather` operation. """
if isinstance(prim, str):
prim_name = prim
prim = Primitive(prim)
else:
prim_name = prim.name
@constexpr
def process_axis(axis, x_shape_size, has_xdim: bool, has_idim: bool):
if has_xdim and has_idim:
if axis < 0:
axis = x_shape_size - 1 + axis
elif has_xdim:
if axis >= 0:
axis = axis + 1
else:
if axis < 0:
axis = x_shape_size + axis
return axis
@constexpr
def get_x_dst_shape(x_shape, axis):
target_axis_size = x_shape[axis + 1]
x_dst_shape = x_shape[0:axis] + (axis_size * target_axis_size,) + x_shape[axis+2:]
max_axis_size = axis_size * target_axis_size
return target_axis_size, x_dst_shape, max_axis_size
def vmap_rule(x_bdim, indices_bdim, axis_bdim):
is_all_none, result = vmap_general_preprocess(prim, x_bdim, indices_bdim)
if is_all_none:
return result
x, x_dim = x_bdim
indices, indices_dim = indices_bdim
axis, axis_dim = axis_bdim
if axis_dim is not None:
_raise_value_error("The source axis of `axis` in {} must be None, but got {}.".format(prim_name, axis_dim))
x_shape_len = len(x.shape)
if x_dim is not None and indices_dim is None:
x = _bdim_at_front(x, x_dim, axis_size)
axis = process_axis(axis, x_shape_len, True, False)
output = prim(x, indices, axis)
return (output, 0)
if x_dim is None and indices_dim is not None:
indices = _bdim_at_front(indices, indices_dim, axis_size)
axis = process_axis(axis, x_shape_len, False, True)
output = prim(x, indices, axis)
return (output, axis)
x = _bdim_at_front(x, x_dim, axis_size)
indices = _bdim_at_front(indices, indices_dim, axis_size)
axis = process_axis(axis, x_shape_len, True, True)
x = mnp.moveaxis(x, 0, axis)
x_shape = x.shape
target_axis_size, x_dst_shape, max_axis_size = get_x_dst_shape(x_shape, axis)
x = x.reshape(x_dst_shape)
counts_shape = indices.shape
counts = mnp.arange(0, axis_size, 1)
counts = F.mul(counts, target_axis_size)
counts = P.BroadcastTo(counts_shape[1:] + (axis_size,))(counts)
counts = mnp.moveaxis(counts, -1, 0)
indices_out_of_bound = mnp.where(indices > target_axis_size - 1, x=max_axis_size, y=0)
indices = F.add(indices, counts)
indices = F.add(indices, indices_out_of_bound)
output = prim(x, indices, axis)
return (output, axis)
return vmap_rule
get_unsupported_dynamic_vmap_rule = vmap_rules_getters.register(P.Unique)(get_unsupported_dynamic_vmap_rule)
get_unsupported_dynamic_vmap_rule =\
vmap_rules_getters.register(UniqueConsecutive)(get_unsupported_dynamic_vmap_rule)

View File

@ -246,3 +246,36 @@ def get_unop_vmap_rule(prim, axis_size):
return (out, dim)
return vmap_rule
def get_unsupported_dynamic_vmap_rule(prim, axis_size):
"""
Vmaprule for the dynamic shape operator whose output shape can not be determined,
which means the output shape of every batch will be different, so the operator can not support vmap.
Forexample, the `Unique` operator, whose output shape is also base on the value of the input.
Otherwise, the other dynamic shape operators need to implement their own vmaprules.
"""
if isinstance(prim, str):
prim_name = prim
prim = Primitive(prim)
else:
prim_name = prim.name
def get_all_dims(*params_bdim):
dims = []
for bdim in params_bdim:
_, dim = bdim
dims.append(dim)
return dims
def vmap_rule(*params_bdim):
is_all_none, result = vmap_general_preprocess(prim, params_bdim)
if not is_all_none:
dims = get_all_dims(*params_bdim)
_raise_value_error("For operator {}, all axis should be none, but got {}".format(prim_name, dims))
return result
return vmap_rule

View File

@ -306,27 +306,27 @@ def get_reducer_vmap_rule(prim, axis_size):
return vmap_rule
@constexpr
def _get_index_add_batch_axis(axis, x_dim, x_ndim):
"""get batch_axis for IndexAdd."""
# case1: batch not exists
if x_dim is None:
return axis
# case2: batch exists
if axis < -x_ndim or x_dim >= x_ndim:
raise ValueError("'axis' of 'IndexAdd' should be in range [{}, {}), but got {}".format(-x_ndim, x_ndim, axis))
if axis < 0:
axis = axis + x_ndim
if x_dim > axis:
return axis
return axis + 1
@vmap_rules_getters.register(P.IndexAdd)
def get_index_add_vmap_rule(prim, axis_size):
"""VmapRule for IndexAdd."""
axis = prim.axis
@constexpr
def _get_index_add_batch_axis(axis, x_dim, x_ndim):
"""get batch_axis for IndexAdd."""
# case1: batch not exists
if x_dim is None:
return axis
# case2: batch exists
if axis < -x_ndim or x_dim >= x_ndim:
raise ValueError("'axis' of 'IndexAdd' should be in range [{}, {}), but got {}"\
.format(-x_ndim, x_ndim, axis))
if axis < 0:
axis = axis + x_ndim
if x_dim > axis:
return axis
return axis + 1
def vmap_rule(x_bdim, indices_bdim, y_bdim, u_monad):
x, x_dim = x_bdim
indices, indices_dim = indices_bdim

View File

@ -24,56 +24,6 @@ from ..primitive import Primitive
from .._vmap.vmap_base import vmap_rules_getters, vmap_general_preprocess, get_unop_vmap_rule, _bdim_at_front
@constexpr
def _get_bias_broadcast_shape(x_shape, bias_shape, bias_dim, data_format):
"""Get the broadcast shape for bias and use it in 'BiasAdd' VmapRule."""
bias_rank = len(bias_shape)
if bias_dim is None and bias_rank == 1:
bias_batch = 1
bias_channel = bias_shape[0]
elif bias_dim is not None and bias_rank == 2:
bias_batch = bias_shape[0]
bias_channel = bias_shape[1]
else:
raise ValueError("The rank of 'bias' in 'BiasAdd' operator is invalid, which is rank: {}"
" with bias_dim: {}.".format(bias_rank, bias_dim))
# The 'Biasadd' operator supports 2-5 dimensions input, and another 'batch' dimension is added to the front in
# vmap scenario.
x_min_rank = 3
x_max_rank = 5
if data_format == "NCDHW":
x_max_rank += 1
x_rank = len(x_shape)
if x_rank < x_min_rank or x_rank > x_max_rank:
raise ValueError("For primitive[BiasAdd] in vmap, the dims of input_x must be in [x_min_rank, {}"
"], but got {}.".format(x_max_rank, x_rank))
if data_format == "NHWC":
# In the 'NHWC' data format ('BN**C' actually), the last dimension is channel axis.
x_channel = x_shape[-1]
if x_channel != bias_channel:
raise ValueError("For 'BiadAdd, bias_channel must be equal to x_channel, "
"but got date format: {}, got bias_channel: {}, "
"x_channel: {}.".format(data_format, bias_channel, x_channel))
if bias_dim is None:
bias_broadcast_shape = (1,) * (x_rank - bias_rank) + (bias_channel,)
else:
bias_broadcast_shape = (bias_batch,) + (1,) * (x_rank - bias_rank) + (bias_channel,)
else:
# In the 'NCHW' or 'NCDHW' data format ('BNC**' actually), the third dimension is channel axis.
x_channel = x_shape[2]
if x_channel != bias_channel:
raise ValueError("For 'BiadAdd, bias_channel must be equal to x_channel, but got date format: "
"{}, got bias_channel: {}, x_channel: {}.".format(data_format, bias_channel, x_channel))
bias_broadcast_shape = (bias_batch, 1, bias_channel)
if x_rank == x_min_rank:
return bias_broadcast_shape
bias_broadcast_shape = bias_broadcast_shape + (1,) * (x_rank - x_min_rank)
return bias_broadcast_shape
@vmap_rules_getters.register(P.BiasAdd)
def get_bias_add_vmap_rule(prim, axis_size):
"""VmapRule for `BiasAdd` operation."""
@ -84,6 +34,56 @@ def get_bias_add_vmap_rule(prim, axis_size):
data_format = prim.data_format
add_op = P.Add()
@constexpr
def _get_bias_broadcast_shape(x_shape, bias_shape, bias_dim, data_format):
"""Get the broadcast shape for bias and use it in 'BiasAdd' VmapRule."""
bias_rank = len(bias_shape)
if bias_dim is None and bias_rank == 1:
bias_batch = 1
bias_channel = bias_shape[0]
elif bias_dim is not None and bias_rank == 2:
bias_batch = bias_shape[0]
bias_channel = bias_shape[1]
else:
raise ValueError("The rank of 'bias' in 'BiasAdd' operator is invalid, which is rank: {}"
" with bias_dim: {}.".format(bias_rank, bias_dim))
# The 'Biasadd' operator supports 2-5 dimensions input, and another 'batch' dimension is added to the front in
# vmap scenario.
x_min_rank = 3
x_max_rank = 5
if data_format == "NCDHW":
x_max_rank += 1
x_rank = len(x_shape)
if x_rank < x_min_rank or x_rank > x_max_rank:
raise ValueError("For primitive[BiasAdd] in vmap, the dims of input_x must be in [x_min_rank, {}"
"], but got {}.".format(x_max_rank, x_rank))
if data_format == "NHWC":
# In the 'NHWC' data format ('BN**C' actually), the last dimension is channel axis.
x_channel = x_shape[-1]
if x_channel != bias_channel:
raise ValueError("For 'BiadAdd, bias_channel must be equal to x_channel, "
"but got date format: {}, got bias_channel: {}, "
"x_channel: {}.".format(data_format, bias_channel, x_channel))
if bias_dim is None:
bias_broadcast_shape = (1,) * (x_rank - bias_rank) + (bias_channel,)
else:
bias_broadcast_shape = (bias_batch,) + (1,) * (x_rank - bias_rank) + (bias_channel,)
else:
# In the 'NCHW' or 'NCDHW' data format ('BNC**' actually), the third dimension is channel axis.
x_channel = x_shape[2]
if x_channel != bias_channel:
raise ValueError("For 'BiadAdd, bias_channel must be equal to x_channel, but got date format: "
"{}, got bias_channel: {}, x_channel: {}."\
.format(data_format, bias_channel, x_channel))
bias_broadcast_shape = (bias_batch, 1, bias_channel)
if x_rank == x_min_rank:
return bias_broadcast_shape
bias_broadcast_shape = bias_broadcast_shape + (1,) * (x_rank - x_min_rank)
return bias_broadcast_shape
def vmap_rule(input_bdim, bias_bdim):
is_all_none, result = vmap_general_preprocess(prim, input_bdim, bias_bdim)
if is_all_none:

View File

@ -270,7 +270,7 @@ def jet(fn, primals, series):
- **out_primals** (Union[Tensor, list[Tensor]]) - The output of `fn(primals)`.
- **out_series** (Union[Tensor, list[Tensor]]) - The `1` to `i+1`-th order of derivative of output with respect
to the inputs.
to the inputs.
Raises:
TypeError: If `primals` is not a tensor or tuple of tensors.
@ -304,10 +304,8 @@ def jet(fn, primals, series):
[[2.319777 2.4825778]
[1.1515628 0.4691642]] [[[ 1.2533808 -1.0331168 ]
[-1.1400385 -0.3066662 ]]
[[-1.2748207 -1.8274734 ]
[ 0.966121 0.55551505]]
[[-4.0515366 3.6724353 ]
[ 0.5053504 -0.52061415]]]
"""
@ -369,7 +367,7 @@ def derivative(fn, primals, order):
- **out_primals** (Union[Tensor, list[Tensor]]) - The output of `fn(primals)`.
- **out_series** (Union[Tensor, list[Tensor]]) - The `order`-th order of derivative of output with respect
to the inputs.
to the inputs.
Raises:
TypeError: If `primals` is not a tensor or tuple of tensors.

View File

@ -20,6 +20,7 @@ from mindspore.ops import operations as P
import mindspore.nn as nn
import mindspore.context as context
from mindspore.common import dtype as mstype
from mindspore.ops.functional import vmap
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
@ -101,7 +102,203 @@ def test_gatherv2_axisN1():
assert np.all(diff < error)
assert np.all(-diff < error)
def cal_vmap_gather(x, indices, axis):
return P.Gather()(x, indices, axis)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_gather_vmap_basic():
"""
Feature: gather vmap test on cpu.
Description: test the rightness of vmap gather basic.
Expectation: use vmap rule's result equal to manually batched.
"""
x = Tensor(np.array([[1, 2, 3], [4, 5, 6]]).astype(np.float32))
indices = Tensor(np.array([[0, 1], [1, 2]]).astype(np.int32))
axis = 0
outputs = vmap(cal_vmap_gather, in_axes=(0, 0, None), out_axes=0)(x, indices, axis)
print("output:\n", outputs)
expect = np.array([[1, 2],
[5, 6]]).astype(np.float32)
assert np.allclose(outputs.asnumpy(), expect)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_gather_vmap_negative_axis():
"""
Feature: gather vmap test on cpu.
Description: test the rightness of vmap gather when axis is negative.
Expectation: use vmap rule's result equal to manually batched.
"""
x = Tensor(np.array([[[1, 2, 3],
[4, 5, 6]],
[[7, 8, 9],
[10, 11, 12]]]).astype(np.float32))
indices = Tensor(np.array([[0, 1], [1, 0]]).astype(np.int32))
axis = -2
outputs = vmap(cal_vmap_gather, in_axes=(0, 0, None), out_axes=0)(x, indices, axis)
print("output:\n", outputs)
expect = np.array([[[1, 2, 3],
[4, 5, 6]],
[[10, 11, 12],
[7, 8, 9]]]).astype(np.float32)
assert np.allclose(outputs.asnumpy(), expect)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_gather_vmap_with_inaxes():
"""
Feature: gather vmap test on cpu.
Description: test the rightness of vmap gather when in_axes is not zero.
Expectation: use vmap rule's result equal to manually batched.
"""
x = Tensor(np.array([[[1, 2, 3],
[4, 5, 6]],
[[7, 8, 9],
[10, 11, 12]]]).astype(np.float32))
x = np.moveaxis(x, 0, 2)
indices = Tensor(np.array([[0, 1], [1, 0]]).astype(np.int32))
indices = np.moveaxis(indices, 0, 1)
axis = 0
outputs = vmap(cal_vmap_gather, in_axes=(2, 1, None), out_axes=0)(x, indices, axis)
print("output:\n", outputs)
expect = np.array([[[1, 2, 3],
[4, 5, 6]],
[[10, 11, 12],
[7, 8, 9]]]).astype(np.float32)
assert np.allclose(outputs.asnumpy(), expect)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_gather_vmap_indices_outofbound():
"""
Feature: gather vmap test on cpu.
Description: test the rightness of vmap gather when indices out of bound.
Expectation: use vmap rule's result equal to manually batched.
"""
x = Tensor(np.array([[[1, 2, 3],
[4, 5, 6]],
[[7, 8, 9],
[10, 11, 12]]]).astype(np.float32))
indices = Tensor(np.array([[0, 2], [2, 0]]).astype(np.int32))
axis = 0
outputs = vmap(cal_vmap_gather, in_axes=(0, 0, None), out_axes=0)(x, indices, axis)
print("output:\n", outputs)
expect = np.array([[[1, 2, 3],
[0, 0, 0]],
[[0, 0, 0],
[7, 8, 9]]]).astype(np.float32)
assert np.allclose(outputs.asnumpy(), expect)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_gather_vmap_xdim_is_none():
"""
Feature: gather vmap test on cpu.
Description: test the rightness of vmap gather when no xdim.
Expectation: use vmap rule's result equal to manually batched.
"""
x = Tensor(np.array([1, 2, 3]).astype(np.float32))
indices = Tensor(np.array([[0, 1], [2, 0]]).astype(np.int32))
axis = 0
outputs = vmap(cal_vmap_gather, in_axes=(None, 0, None), out_axes=0)(x, indices, axis)
print("output:\n", outputs)
expect = np.array([[1, 2],
[3, 1]]).astype(np.float32)
assert np.allclose(outputs.asnumpy(), expect)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_gather_vmap_idim_is_none():
"""
Feature: gather vmap test on cpu.
Description: test the rightness of nested vmap gather.
Expectation: use vmap rule's result equal to manually batched.
"""
x = Tensor(np.array([[1, 2, 3], [4, 5, 6]]).astype(np.float32))
indices = Tensor(np.array([0, 1]).astype(np.int32))
axis = 0
outputs = vmap(cal_vmap_gather, in_axes=(0, None, None), out_axes=0)(x, indices, axis)
print("output:\n", outputs)
expect = np.array([[1, 2],
[4, 5]]).astype(np.float32)
assert np.allclose(outputs.asnumpy(), expect)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_gather_vmap_nested():
"""
Feature: gather vmap test on cpu.
Description: test the rightness of nested vmap gather.
Expectation: use vmap rule's result equal to manually batched.
"""
x = Tensor(np.array([[[1, 2, 3],
[4, 5, 6]],
[[7, 8, 9],
[10, 11, 12]]]).astype(np.float32))
indices = Tensor(np.array([[[0, 1],
[1, 0]],
[[0, 1],
[1, 0]]]).astype(np.int32))
axis = 0
outputs = vmap(vmap(cal_vmap_gather, in_axes=(0, 0, None), out_axes=0),
in_axes=(0, 0, None), out_axes=0)(x, indices, axis)
print("output:\n", outputs)
expect = np.array([[[1, 2],
[5, 4]],
[[7, 8],
[11, 10]]]).astype(np.float32)
assert np.allclose(outputs.asnumpy(), expect)
if __name__ == '__main__':
test_gatherv2_axis0()
test_gatherv2_axis1()
test_gatherv2_axisN1()
test_gather_vmap_basic()
test_gather_vmap_negative_axis()
test_gather_vmap_with_inaxes()
test_gather_vmap_indices_outofbound()
test_gather_vmap_xdim_is_none()
test_gather_vmap_idim_is_none()
test_gather_vmap_nested()