!38894 Fix OneHot Vmap

Merge pull request !38894 from jiaoy1224/onehot
This commit is contained in:
i-robot 2022-07-28 08:42:23 +00:00 committed by Gitee
commit dcd74aacb0
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 13 additions and 13 deletions

View File

@ -904,10 +904,10 @@ def get_tensor_shape_vmap_rule(prim, axis_size):
def _get_one_hot_vmap_axis(orig_axis, ndim, indices_dim):
"""Find vmap axis for OneHot."""
if orig_axis >= 0 and indices_dim <= orig_axis:
return orig_axis + 1
return (orig_axis + 1, indices_dim)
if indices_dim == (ndim - 1) and orig_axis in (-1, (ndim - 1)):
return ndim - 1
return orig_axis
return (ndim - 1, indices_dim + 1)
return (orig_axis, indices_dim + 1)
@vmap_rules_getters.register(P.OneHot)
@ -943,10 +943,10 @@ def get_one_hot_vmap_rule(prim, axis_size):
_raise_value_error(
"The source axis of `off_value` in {} must be None, but got {}.".format(prim_name, off_value_dim))
ndim = F.rank(indices)
new_axis = _get_one_hot_vmap_axis(axis, ndim, indices_dim)
new_axis, new_bd = _get_one_hot_vmap_axis(axis, ndim, indices_dim)
out = P.OneHot(new_axis)(indices, depth, on_value, off_value)
return (out, indices_dim)
return (out, new_bd)
return vmap_rule

View File

@ -103,14 +103,14 @@ def one_hot_vmap(in_type, value_type):
outputs = vmap(cal_onehot, in_axes=(0, None, None, None), out_axes=0)(indices, depth, on_value, off_value)
expect = np.array([[[0.0, 1.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 1.0],
[1.0, 0.0, 0.0, 0.0, 0.0]],
[[0.0, 1.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 1.0],
[0.0, 1.0, 0.0, 0.0, 0.0]],
[[0.0, 0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 1.0, 0.0]],
[[0.0, 0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0, 0.0]],
[[0.0, 0.0, 0.0, 0.0, 1.0],
[0.0, 0.0, 0.0, 0.0, 1.0]],
[[1.0, 0.0, 0.0, 0.0, 0.0],
[1.0, 0.0, 0.0, 0.0, 0.0]]]).astype(in_type)
assert np.allclose(expect, outputs.asnumpy(), 1.e-4, 1.e-7)