diff --git a/mindspore/python/mindspore/ops/_vmap/vmap_array_ops.py b/mindspore/python/mindspore/ops/_vmap/vmap_array_ops.py index 9b0e77889fc..c106601766c 100644 --- a/mindspore/python/mindspore/ops/_vmap/vmap_array_ops.py +++ b/mindspore/python/mindspore/ops/_vmap/vmap_array_ops.py @@ -984,10 +984,12 @@ def get_tensor_shape_vmap_rule(prim, axis_size): def _get_one_hot_vmap_axis(orig_axis, ndim, indices_dim): """Find vmap axis for OneHot.""" if orig_axis >= 0 and indices_dim <= orig_axis: - return (orig_axis + 1, indices_dim) - if indices_dim == (ndim - 1) and orig_axis in (-1, (ndim - 1)): - return (ndim - 1, indices_dim + 1) - return (orig_axis, indices_dim + 1) + return orig_axis + 1, indices_dim + if orig_axis == -1: + if indices_dim == (ndim - 1): + return ndim - 1, indices_dim + 1 + return orig_axis, indices_dim + return orig_axis, indices_dim + 1 @vmap_rules_getters.register(P.OneHot) diff --git a/tests/st/ops/ascend/test_one_hot.py b/tests/st/ops/ascend/test_one_hot.py index 5b49ad290d9..2b180a40107 100644 --- a/tests/st/ops/ascend/test_one_hot.py +++ b/tests/st/ops/ascend/test_one_hot.py @@ -103,14 +103,14 @@ def one_hot_vmap(in_type, value_type): outputs = vmap(cal_onehot, in_axes=(0, None, None, None), out_axes=0)(indices, depth, on_value, off_value) expect = np.array([[[0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 1.0, 0.0]], - [[0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 0.0, 0.0]], - [[0.0, 0.0, 0.0, 0.0, 1.0], - [0.0, 0.0, 0.0, 0.0, 1.0]], - [[1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0], + [1.0, 0.0, 0.0, 0.0, 0.0]], + [[0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0], [1.0, 0.0, 0.0, 0.0, 0.0]]]).astype(in_type) assert np.allclose(expect, outputs.asnumpy(), 1.e-4, 1.e-7)