gather: add vmap rule for gather operations
Signed-off-by: lvhaoyu <lvhaoyu@huawei.com>
This commit is contained in:
parent
6cf2a84fa7
commit
fadd7da5ef
|
@ -21,8 +21,9 @@ from mindspore.ops import functional as F
|
|||
from mindspore.ops import constexpr
|
||||
from ..primitive import Primitive
|
||||
from .._vmap.vmap_base import vmap_rules_getters, vmap_general_preprocess, _bdim_at_front, _raise_value_error, \
|
||||
_handle_broadcasting
|
||||
_handle_broadcasting, get_unsupported_dynamic_vmap_rule
|
||||
from ..operations.array_ops import Fills
|
||||
from ..operations.array_ops import UniqueConsecutive
|
||||
|
||||
|
||||
@vmap_rules_getters.register("Cast")
|
||||
|
@ -46,21 +47,6 @@ def get_cast_vmap_rule(prim, axis_size):
|
|||
return vmap_rule
|
||||
|
||||
|
||||
@constexpr
|
||||
def _get_transpose_batch_perm(dim, perm, x_rank):
|
||||
"""Generate batch_perm based on the original perm of transpose operation and dim of the input."""
|
||||
if dim < 0:
|
||||
dim = dim + x_rank
|
||||
batch_perm = (dim,)
|
||||
for i in perm:
|
||||
if i < dim:
|
||||
batch_perm = batch_perm + (i,)
|
||||
else:
|
||||
index = i + 1
|
||||
batch_perm = batch_perm + (index,)
|
||||
return batch_perm
|
||||
|
||||
|
||||
@constexpr
|
||||
def _get_prefix_expand_shape(indices_shape):
|
||||
"""Generate prefix and expand shape by indices shape."""
|
||||
|
@ -85,6 +71,20 @@ def get_transpose_vmap_rule(prim, axis_size):
|
|||
if isinstance(prim, str):
|
||||
prim = Primitive(prim)
|
||||
|
||||
@constexpr
|
||||
def _get_transpose_batch_perm(dim, perm, x_rank):
|
||||
"""Generate batch_perm based on the original perm of transpose operation and dim of the input."""
|
||||
if dim < 0:
|
||||
dim = dim + x_rank
|
||||
batch_perm = (dim,)
|
||||
for i in perm:
|
||||
if i < dim:
|
||||
batch_perm = batch_perm + (i,)
|
||||
else:
|
||||
index = i + 1
|
||||
batch_perm = batch_perm + (index,)
|
||||
return batch_perm
|
||||
|
||||
def vmap_rule(x_bdim, perm_bdim):
|
||||
is_all_none, result = vmap_general_preprocess(prim, x_bdim, perm_bdim)
|
||||
if is_all_none:
|
||||
|
@ -103,24 +103,23 @@ def get_transpose_vmap_rule(prim, axis_size):
|
|||
return vmap_rule
|
||||
|
||||
|
||||
@constexpr
|
||||
def _get_tile_shape(input_shape, multiples):
|
||||
input_ndim = len(input_shape)
|
||||
multiples_ndim = len(multiples)
|
||||
input_shape = (input_shape[0],) + (1,) * (multiples_ndim - input_ndim + 1) + input_shape[1:]
|
||||
multiples = (1,) * (input_ndim - multiples_ndim) + multiples
|
||||
input_expand_shape = (input_shape[0],) + tuple([j for i in input_shape[1:] for j in [1, i]])
|
||||
repeat_shape = (input_shape[0],) + tuple([k for pair in zip(multiples[1:], input_shape[1:]) for k in pair])
|
||||
output_shape = tuple([a * b for a, b in zip(input_shape, multiples)])
|
||||
return input_expand_shape, repeat_shape, output_shape
|
||||
|
||||
|
||||
@vmap_rules_getters.register(P.Tile)
|
||||
def get_tile_vmap_rule(prim, axis_size):
|
||||
"""VmapRule for `P.Tile` operation."""
|
||||
if isinstance(prim, str):
|
||||
prim = Primitive(prim)
|
||||
|
||||
@constexpr
|
||||
def _get_tile_shape(input_shape, multiples):
|
||||
input_ndim = len(input_shape)
|
||||
multiples_ndim = len(multiples)
|
||||
input_shape = (input_shape[0],) + (1,) * (multiples_ndim - input_ndim + 1) + input_shape[1:]
|
||||
multiples = (1,) * (input_ndim - multiples_ndim) + multiples
|
||||
input_expand_shape = (input_shape[0],) + tuple([j for i in input_shape[1:] for j in [1, i]])
|
||||
repeat_shape = (input_shape[0],) + tuple([k for pair in zip(multiples[1:], input_shape[1:]) for k in pair])
|
||||
output_shape = tuple([a * b for a, b in zip(input_shape, multiples)])
|
||||
return input_expand_shape, repeat_shape, output_shape
|
||||
|
||||
def vmap_rule(input_bdim, multiples_bdim):
|
||||
is_all_none, result = vmap_general_preprocess(prim, input_bdim, multiples_bdim)
|
||||
if is_all_none:
|
||||
|
@ -653,3 +652,93 @@ def get_masked_fill_vmap_rule(prim, axis_size):
|
|||
return (out, 0)
|
||||
|
||||
return vmap_rule
|
||||
|
||||
|
||||
@vmap_rules_getters.register(P.Gather)
|
||||
def get_gather_vmap_rule(prim, axis_size):
|
||||
"""VmapRule for `Gather` operation. """
|
||||
if isinstance(prim, str):
|
||||
prim_name = prim
|
||||
prim = Primitive(prim)
|
||||
else:
|
||||
prim_name = prim.name
|
||||
|
||||
@constexpr
|
||||
def process_axis(axis, x_shape_size, has_xdim: bool, has_idim: bool):
|
||||
if has_xdim and has_idim:
|
||||
if axis < 0:
|
||||
axis = x_shape_size - 1 + axis
|
||||
elif has_xdim:
|
||||
if axis >= 0:
|
||||
axis = axis + 1
|
||||
else:
|
||||
if axis < 0:
|
||||
axis = x_shape_size + axis
|
||||
|
||||
return axis
|
||||
|
||||
@constexpr
|
||||
def get_x_dst_shape(x_shape, axis):
|
||||
target_axis_size = x_shape[axis + 1]
|
||||
x_dst_shape = x_shape[0:axis] + (axis_size * target_axis_size,) + x_shape[axis+2:]
|
||||
max_axis_size = axis_size * target_axis_size
|
||||
|
||||
return target_axis_size, x_dst_shape, max_axis_size
|
||||
|
||||
def vmap_rule(x_bdim, indices_bdim, axis_bdim):
|
||||
is_all_none, result = vmap_general_preprocess(prim, x_bdim, indices_bdim)
|
||||
if is_all_none:
|
||||
return result
|
||||
|
||||
x, x_dim = x_bdim
|
||||
indices, indices_dim = indices_bdim
|
||||
axis, axis_dim = axis_bdim
|
||||
|
||||
if axis_dim is not None:
|
||||
_raise_value_error("The source axis of `axis` in {} must be None, but got {}.".format(prim_name, axis_dim))
|
||||
|
||||
x_shape_len = len(x.shape)
|
||||
|
||||
if x_dim is not None and indices_dim is None:
|
||||
x = _bdim_at_front(x, x_dim, axis_size)
|
||||
axis = process_axis(axis, x_shape_len, True, False)
|
||||
output = prim(x, indices, axis)
|
||||
return (output, 0)
|
||||
|
||||
if x_dim is None and indices_dim is not None:
|
||||
indices = _bdim_at_front(indices, indices_dim, axis_size)
|
||||
axis = process_axis(axis, x_shape_len, False, True)
|
||||
output = prim(x, indices, axis)
|
||||
return (output, axis)
|
||||
|
||||
x = _bdim_at_front(x, x_dim, axis_size)
|
||||
indices = _bdim_at_front(indices, indices_dim, axis_size)
|
||||
|
||||
axis = process_axis(axis, x_shape_len, True, True)
|
||||
|
||||
x = mnp.moveaxis(x, 0, axis)
|
||||
|
||||
x_shape = x.shape
|
||||
target_axis_size, x_dst_shape, max_axis_size = get_x_dst_shape(x_shape, axis)
|
||||
|
||||
x = x.reshape(x_dst_shape)
|
||||
|
||||
counts_shape = indices.shape
|
||||
counts = mnp.arange(0, axis_size, 1)
|
||||
counts = F.mul(counts, target_axis_size)
|
||||
counts = P.BroadcastTo(counts_shape[1:] + (axis_size,))(counts)
|
||||
counts = mnp.moveaxis(counts, -1, 0)
|
||||
|
||||
indices_out_of_bound = mnp.where(indices > target_axis_size - 1, x=max_axis_size, y=0)
|
||||
|
||||
indices = F.add(indices, counts)
|
||||
indices = F.add(indices, indices_out_of_bound)
|
||||
|
||||
output = prim(x, indices, axis)
|
||||
|
||||
return (output, axis)
|
||||
return vmap_rule
|
||||
|
||||
get_unsupported_dynamic_vmap_rule = vmap_rules_getters.register(P.Unique)(get_unsupported_dynamic_vmap_rule)
|
||||
get_unsupported_dynamic_vmap_rule =\
|
||||
vmap_rules_getters.register(UniqueConsecutive)(get_unsupported_dynamic_vmap_rule)
|
||||
|
|
|
@ -246,3 +246,36 @@ def get_unop_vmap_rule(prim, axis_size):
|
|||
return (out, dim)
|
||||
|
||||
return vmap_rule
|
||||
|
||||
|
||||
def get_unsupported_dynamic_vmap_rule(prim, axis_size):
|
||||
"""
|
||||
Vmaprule for the dynamic shape operator whose output shape can not be determined,
|
||||
which means the output shape of every batch will be different, so the operator can not support vmap.
|
||||
Forexample, the `Unique` operator, whose output shape is also base on the value of the input.
|
||||
Otherwise, the other dynamic shape operators need to implement their own vmaprules.
|
||||
"""
|
||||
if isinstance(prim, str):
|
||||
prim_name = prim
|
||||
prim = Primitive(prim)
|
||||
else:
|
||||
prim_name = prim.name
|
||||
|
||||
def get_all_dims(*params_bdim):
|
||||
dims = []
|
||||
for bdim in params_bdim:
|
||||
_, dim = bdim
|
||||
dims.append(dim)
|
||||
|
||||
return dims
|
||||
|
||||
def vmap_rule(*params_bdim):
|
||||
is_all_none, result = vmap_general_preprocess(prim, params_bdim)
|
||||
|
||||
if not is_all_none:
|
||||
dims = get_all_dims(*params_bdim)
|
||||
_raise_value_error("For operator {}, all axis should be none, but got {}".format(prim_name, dims))
|
||||
|
||||
return result
|
||||
|
||||
return vmap_rule
|
||||
|
|
|
@ -306,27 +306,27 @@ def get_reducer_vmap_rule(prim, axis_size):
|
|||
return vmap_rule
|
||||
|
||||
|
||||
@constexpr
|
||||
def _get_index_add_batch_axis(axis, x_dim, x_ndim):
|
||||
"""get batch_axis for IndexAdd."""
|
||||
# case1: batch not exists
|
||||
if x_dim is None:
|
||||
return axis
|
||||
# case2: batch exists
|
||||
if axis < -x_ndim or x_dim >= x_ndim:
|
||||
raise ValueError("'axis' of 'IndexAdd' should be in range [{}, {}), but got {}".format(-x_ndim, x_ndim, axis))
|
||||
if axis < 0:
|
||||
axis = axis + x_ndim
|
||||
if x_dim > axis:
|
||||
return axis
|
||||
return axis + 1
|
||||
|
||||
|
||||
@vmap_rules_getters.register(P.IndexAdd)
|
||||
def get_index_add_vmap_rule(prim, axis_size):
|
||||
"""VmapRule for IndexAdd."""
|
||||
axis = prim.axis
|
||||
|
||||
@constexpr
|
||||
def _get_index_add_batch_axis(axis, x_dim, x_ndim):
|
||||
"""get batch_axis for IndexAdd."""
|
||||
# case1: batch not exists
|
||||
if x_dim is None:
|
||||
return axis
|
||||
# case2: batch exists
|
||||
if axis < -x_ndim or x_dim >= x_ndim:
|
||||
raise ValueError("'axis' of 'IndexAdd' should be in range [{}, {}), but got {}"\
|
||||
.format(-x_ndim, x_ndim, axis))
|
||||
if axis < 0:
|
||||
axis = axis + x_ndim
|
||||
if x_dim > axis:
|
||||
return axis
|
||||
return axis + 1
|
||||
|
||||
def vmap_rule(x_bdim, indices_bdim, y_bdim, u_monad):
|
||||
x, x_dim = x_bdim
|
||||
indices, indices_dim = indices_bdim
|
||||
|
|
|
@ -24,56 +24,6 @@ from ..primitive import Primitive
|
|||
from .._vmap.vmap_base import vmap_rules_getters, vmap_general_preprocess, get_unop_vmap_rule, _bdim_at_front
|
||||
|
||||
|
||||
@constexpr
|
||||
def _get_bias_broadcast_shape(x_shape, bias_shape, bias_dim, data_format):
|
||||
"""Get the broadcast shape for bias and use it in 'BiasAdd' VmapRule."""
|
||||
bias_rank = len(bias_shape)
|
||||
if bias_dim is None and bias_rank == 1:
|
||||
bias_batch = 1
|
||||
bias_channel = bias_shape[0]
|
||||
elif bias_dim is not None and bias_rank == 2:
|
||||
bias_batch = bias_shape[0]
|
||||
bias_channel = bias_shape[1]
|
||||
else:
|
||||
raise ValueError("The rank of 'bias' in 'BiasAdd' operator is invalid, which is rank: {}"
|
||||
" with bias_dim: {}.".format(bias_rank, bias_dim))
|
||||
|
||||
# The 'Biasadd' operator supports 2-5 dimensions input, and another 'batch' dimension is added to the front in
|
||||
# vmap scenario.
|
||||
x_min_rank = 3
|
||||
x_max_rank = 5
|
||||
if data_format == "NCDHW":
|
||||
x_max_rank += 1
|
||||
x_rank = len(x_shape)
|
||||
|
||||
if x_rank < x_min_rank or x_rank > x_max_rank:
|
||||
raise ValueError("For primitive[BiasAdd] in vmap, the dims of input_x must be in [x_min_rank, {}"
|
||||
"], but got {}.".format(x_max_rank, x_rank))
|
||||
|
||||
if data_format == "NHWC":
|
||||
# In the 'NHWC' data format ('BN**C' actually), the last dimension is channel axis.
|
||||
x_channel = x_shape[-1]
|
||||
if x_channel != bias_channel:
|
||||
raise ValueError("For 'BiadAdd, bias_channel must be equal to x_channel, "
|
||||
"but got date format: {}, got bias_channel: {}, "
|
||||
"x_channel: {}.".format(data_format, bias_channel, x_channel))
|
||||
if bias_dim is None:
|
||||
bias_broadcast_shape = (1,) * (x_rank - bias_rank) + (bias_channel,)
|
||||
else:
|
||||
bias_broadcast_shape = (bias_batch,) + (1,) * (x_rank - bias_rank) + (bias_channel,)
|
||||
else:
|
||||
# In the 'NCHW' or 'NCDHW' data format ('BNC**' actually), the third dimension is channel axis.
|
||||
x_channel = x_shape[2]
|
||||
if x_channel != bias_channel:
|
||||
raise ValueError("For 'BiadAdd, bias_channel must be equal to x_channel, but got date format: "
|
||||
"{}, got bias_channel: {}, x_channel: {}.".format(data_format, bias_channel, x_channel))
|
||||
bias_broadcast_shape = (bias_batch, 1, bias_channel)
|
||||
if x_rank == x_min_rank:
|
||||
return bias_broadcast_shape
|
||||
bias_broadcast_shape = bias_broadcast_shape + (1,) * (x_rank - x_min_rank)
|
||||
return bias_broadcast_shape
|
||||
|
||||
|
||||
@vmap_rules_getters.register(P.BiasAdd)
|
||||
def get_bias_add_vmap_rule(prim, axis_size):
|
||||
"""VmapRule for `BiasAdd` operation."""
|
||||
|
@ -84,6 +34,56 @@ def get_bias_add_vmap_rule(prim, axis_size):
|
|||
data_format = prim.data_format
|
||||
add_op = P.Add()
|
||||
|
||||
@constexpr
|
||||
def _get_bias_broadcast_shape(x_shape, bias_shape, bias_dim, data_format):
|
||||
"""Get the broadcast shape for bias and use it in 'BiasAdd' VmapRule."""
|
||||
bias_rank = len(bias_shape)
|
||||
if bias_dim is None and bias_rank == 1:
|
||||
bias_batch = 1
|
||||
bias_channel = bias_shape[0]
|
||||
elif bias_dim is not None and bias_rank == 2:
|
||||
bias_batch = bias_shape[0]
|
||||
bias_channel = bias_shape[1]
|
||||
else:
|
||||
raise ValueError("The rank of 'bias' in 'BiasAdd' operator is invalid, which is rank: {}"
|
||||
" with bias_dim: {}.".format(bias_rank, bias_dim))
|
||||
|
||||
# The 'Biasadd' operator supports 2-5 dimensions input, and another 'batch' dimension is added to the front in
|
||||
# vmap scenario.
|
||||
x_min_rank = 3
|
||||
x_max_rank = 5
|
||||
if data_format == "NCDHW":
|
||||
x_max_rank += 1
|
||||
x_rank = len(x_shape)
|
||||
|
||||
if x_rank < x_min_rank or x_rank > x_max_rank:
|
||||
raise ValueError("For primitive[BiasAdd] in vmap, the dims of input_x must be in [x_min_rank, {}"
|
||||
"], but got {}.".format(x_max_rank, x_rank))
|
||||
|
||||
if data_format == "NHWC":
|
||||
# In the 'NHWC' data format ('BN**C' actually), the last dimension is channel axis.
|
||||
x_channel = x_shape[-1]
|
||||
if x_channel != bias_channel:
|
||||
raise ValueError("For 'BiadAdd, bias_channel must be equal to x_channel, "
|
||||
"but got date format: {}, got bias_channel: {}, "
|
||||
"x_channel: {}.".format(data_format, bias_channel, x_channel))
|
||||
if bias_dim is None:
|
||||
bias_broadcast_shape = (1,) * (x_rank - bias_rank) + (bias_channel,)
|
||||
else:
|
||||
bias_broadcast_shape = (bias_batch,) + (1,) * (x_rank - bias_rank) + (bias_channel,)
|
||||
else:
|
||||
# In the 'NCHW' or 'NCDHW' data format ('BNC**' actually), the third dimension is channel axis.
|
||||
x_channel = x_shape[2]
|
||||
if x_channel != bias_channel:
|
||||
raise ValueError("For 'BiadAdd, bias_channel must be equal to x_channel, but got date format: "
|
||||
"{}, got bias_channel: {}, x_channel: {}."\
|
||||
.format(data_format, bias_channel, x_channel))
|
||||
bias_broadcast_shape = (bias_batch, 1, bias_channel)
|
||||
if x_rank == x_min_rank:
|
||||
return bias_broadcast_shape
|
||||
bias_broadcast_shape = bias_broadcast_shape + (1,) * (x_rank - x_min_rank)
|
||||
return bias_broadcast_shape
|
||||
|
||||
def vmap_rule(input_bdim, bias_bdim):
|
||||
is_all_none, result = vmap_general_preprocess(prim, input_bdim, bias_bdim)
|
||||
if is_all_none:
|
||||
|
|
|
@ -270,7 +270,7 @@ def jet(fn, primals, series):
|
|||
|
||||
- **out_primals** (Union[Tensor, list[Tensor]]) - The output of `fn(primals)`.
|
||||
- **out_series** (Union[Tensor, list[Tensor]]) - The `1` to `i+1`-th order of derivative of output with respect
|
||||
to the inputs.
|
||||
to the inputs.
|
||||
|
||||
Raises:
|
||||
TypeError: If `primals` is not a tensor or tuple of tensors.
|
||||
|
@ -304,10 +304,8 @@ def jet(fn, primals, series):
|
|||
[[2.319777 2.4825778]
|
||||
[1.1515628 0.4691642]] [[[ 1.2533808 -1.0331168 ]
|
||||
[-1.1400385 -0.3066662 ]]
|
||||
|
||||
[[-1.2748207 -1.8274734 ]
|
||||
[ 0.966121 0.55551505]]
|
||||
|
||||
[[-4.0515366 3.6724353 ]
|
||||
[ 0.5053504 -0.52061415]]]
|
||||
"""
|
||||
|
@ -369,7 +367,7 @@ def derivative(fn, primals, order):
|
|||
|
||||
- **out_primals** (Union[Tensor, list[Tensor]]) - The output of `fn(primals)`.
|
||||
- **out_series** (Union[Tensor, list[Tensor]]) - The `order`-th order of derivative of output with respect
|
||||
to the inputs.
|
||||
to the inputs.
|
||||
|
||||
Raises:
|
||||
TypeError: If `primals` is not a tensor or tuple of tensors.
|
||||
|
|
|
@ -20,6 +20,7 @@ from mindspore.ops import operations as P
|
|||
import mindspore.nn as nn
|
||||
import mindspore.context as context
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.ops.functional import vmap
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
|
||||
|
||||
|
@ -101,7 +102,203 @@ def test_gatherv2_axisN1():
|
|||
assert np.all(diff < error)
|
||||
assert np.all(-diff < error)
|
||||
|
||||
|
||||
def cal_vmap_gather(x, indices, axis):
|
||||
return P.Gather()(x, indices, axis)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_gather_vmap_basic():
|
||||
"""
|
||||
Feature: gather vmap test on cpu.
|
||||
Description: test the rightness of vmap gather basic.
|
||||
Expectation: use vmap rule's result equal to manually batched.
|
||||
"""
|
||||
|
||||
x = Tensor(np.array([[1, 2, 3], [4, 5, 6]]).astype(np.float32))
|
||||
indices = Tensor(np.array([[0, 1], [1, 2]]).astype(np.int32))
|
||||
axis = 0
|
||||
|
||||
outputs = vmap(cal_vmap_gather, in_axes=(0, 0, None), out_axes=0)(x, indices, axis)
|
||||
print("output:\n", outputs)
|
||||
|
||||
expect = np.array([[1, 2],
|
||||
[5, 6]]).astype(np.float32)
|
||||
assert np.allclose(outputs.asnumpy(), expect)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_gather_vmap_negative_axis():
|
||||
"""
|
||||
Feature: gather vmap test on cpu.
|
||||
Description: test the rightness of vmap gather when axis is negative.
|
||||
Expectation: use vmap rule's result equal to manually batched.
|
||||
"""
|
||||
|
||||
x = Tensor(np.array([[[1, 2, 3],
|
||||
[4, 5, 6]],
|
||||
[[7, 8, 9],
|
||||
[10, 11, 12]]]).astype(np.float32))
|
||||
indices = Tensor(np.array([[0, 1], [1, 0]]).astype(np.int32))
|
||||
axis = -2
|
||||
|
||||
outputs = vmap(cal_vmap_gather, in_axes=(0, 0, None), out_axes=0)(x, indices, axis)
|
||||
print("output:\n", outputs)
|
||||
|
||||
expect = np.array([[[1, 2, 3],
|
||||
[4, 5, 6]],
|
||||
[[10, 11, 12],
|
||||
[7, 8, 9]]]).astype(np.float32)
|
||||
assert np.allclose(outputs.asnumpy(), expect)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_gather_vmap_with_inaxes():
|
||||
"""
|
||||
Feature: gather vmap test on cpu.
|
||||
Description: test the rightness of vmap gather when in_axes is not zero.
|
||||
Expectation: use vmap rule's result equal to manually batched.
|
||||
"""
|
||||
|
||||
x = Tensor(np.array([[[1, 2, 3],
|
||||
[4, 5, 6]],
|
||||
[[7, 8, 9],
|
||||
[10, 11, 12]]]).astype(np.float32))
|
||||
|
||||
x = np.moveaxis(x, 0, 2)
|
||||
indices = Tensor(np.array([[0, 1], [1, 0]]).astype(np.int32))
|
||||
indices = np.moveaxis(indices, 0, 1)
|
||||
axis = 0
|
||||
|
||||
outputs = vmap(cal_vmap_gather, in_axes=(2, 1, None), out_axes=0)(x, indices, axis)
|
||||
print("output:\n", outputs)
|
||||
|
||||
expect = np.array([[[1, 2, 3],
|
||||
[4, 5, 6]],
|
||||
[[10, 11, 12],
|
||||
[7, 8, 9]]]).astype(np.float32)
|
||||
assert np.allclose(outputs.asnumpy(), expect)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_gather_vmap_indices_outofbound():
|
||||
"""
|
||||
Feature: gather vmap test on cpu.
|
||||
Description: test the rightness of vmap gather when indices out of bound.
|
||||
Expectation: use vmap rule's result equal to manually batched.
|
||||
"""
|
||||
|
||||
x = Tensor(np.array([[[1, 2, 3],
|
||||
[4, 5, 6]],
|
||||
[[7, 8, 9],
|
||||
[10, 11, 12]]]).astype(np.float32))
|
||||
|
||||
indices = Tensor(np.array([[0, 2], [2, 0]]).astype(np.int32))
|
||||
axis = 0
|
||||
|
||||
outputs = vmap(cal_vmap_gather, in_axes=(0, 0, None), out_axes=0)(x, indices, axis)
|
||||
print("output:\n", outputs)
|
||||
|
||||
expect = np.array([[[1, 2, 3],
|
||||
[0, 0, 0]],
|
||||
[[0, 0, 0],
|
||||
[7, 8, 9]]]).astype(np.float32)
|
||||
assert np.allclose(outputs.asnumpy(), expect)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_gather_vmap_xdim_is_none():
|
||||
"""
|
||||
Feature: gather vmap test on cpu.
|
||||
Description: test the rightness of vmap gather when no xdim.
|
||||
Expectation: use vmap rule's result equal to manually batched.
|
||||
"""
|
||||
|
||||
x = Tensor(np.array([1, 2, 3]).astype(np.float32))
|
||||
|
||||
indices = Tensor(np.array([[0, 1], [2, 0]]).astype(np.int32))
|
||||
axis = 0
|
||||
|
||||
outputs = vmap(cal_vmap_gather, in_axes=(None, 0, None), out_axes=0)(x, indices, axis)
|
||||
print("output:\n", outputs)
|
||||
|
||||
expect = np.array([[1, 2],
|
||||
[3, 1]]).astype(np.float32)
|
||||
assert np.allclose(outputs.asnumpy(), expect)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_gather_vmap_idim_is_none():
|
||||
"""
|
||||
Feature: gather vmap test on cpu.
|
||||
Description: test the rightness of nested vmap gather.
|
||||
Expectation: use vmap rule's result equal to manually batched.
|
||||
"""
|
||||
|
||||
x = Tensor(np.array([[1, 2, 3], [4, 5, 6]]).astype(np.float32))
|
||||
|
||||
indices = Tensor(np.array([0, 1]).astype(np.int32))
|
||||
axis = 0
|
||||
|
||||
outputs = vmap(cal_vmap_gather, in_axes=(0, None, None), out_axes=0)(x, indices, axis)
|
||||
print("output:\n", outputs)
|
||||
|
||||
expect = np.array([[1, 2],
|
||||
[4, 5]]).astype(np.float32)
|
||||
assert np.allclose(outputs.asnumpy(), expect)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_gather_vmap_nested():
|
||||
"""
|
||||
Feature: gather vmap test on cpu.
|
||||
Description: test the rightness of nested vmap gather.
|
||||
Expectation: use vmap rule's result equal to manually batched.
|
||||
"""
|
||||
|
||||
x = Tensor(np.array([[[1, 2, 3],
|
||||
[4, 5, 6]],
|
||||
[[7, 8, 9],
|
||||
[10, 11, 12]]]).astype(np.float32))
|
||||
|
||||
indices = Tensor(np.array([[[0, 1],
|
||||
[1, 0]],
|
||||
[[0, 1],
|
||||
[1, 0]]]).astype(np.int32))
|
||||
axis = 0
|
||||
|
||||
outputs = vmap(vmap(cal_vmap_gather, in_axes=(0, 0, None), out_axes=0),
|
||||
in_axes=(0, 0, None), out_axes=0)(x, indices, axis)
|
||||
print("output:\n", outputs)
|
||||
|
||||
expect = np.array([[[1, 2],
|
||||
[5, 4]],
|
||||
[[7, 8],
|
||||
[11, 10]]]).astype(np.float32)
|
||||
assert np.allclose(outputs.asnumpy(), expect)
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_gatherv2_axis0()
|
||||
test_gatherv2_axis1()
|
||||
test_gatherv2_axisN1()
|
||||
test_gather_vmap_basic()
|
||||
test_gather_vmap_negative_axis()
|
||||
test_gather_vmap_with_inaxes()
|
||||
test_gather_vmap_indices_outofbound()
|
||||
test_gather_vmap_xdim_is_none()
|
||||
test_gather_vmap_idim_is_none()
|
||||
test_gather_vmap_nested()
|
||||
|
|
Loading…
Reference in New Issue