fix onehot

This commit is contained in:
Yang Jiao 2022-08-16 17:22:21 +08:00
parent 1cc940133e
commit 1bda84d34c
2 changed files with 14 additions and 12 deletions

View File

@ -984,10 +984,12 @@ def get_tensor_shape_vmap_rule(prim, axis_size):
def _get_one_hot_vmap_axis(orig_axis, ndim, indices_dim):
"""Find vmap axis for OneHot."""
if orig_axis >= 0 and indices_dim <= orig_axis:
return (orig_axis + 1, indices_dim)
if indices_dim == (ndim - 1) and orig_axis in (-1, (ndim - 1)):
return (ndim - 1, indices_dim + 1)
return (orig_axis, indices_dim + 1)
return orig_axis + 1, indices_dim
if orig_axis == -1:
if indices_dim == (ndim - 1):
return ndim - 1, indices_dim + 1
return orig_axis, indices_dim
return orig_axis, indices_dim + 1
@vmap_rules_getters.register(P.OneHot)

View File

@ -103,14 +103,14 @@ def one_hot_vmap(in_type, value_type):
outputs = vmap(cal_onehot, in_axes=(0, None, None, None), out_axes=0)(indices, depth, on_value, off_value)
expect = np.array([[[0.0, 1.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0, 0.0]],
[[0.0, 0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 1.0, 0.0]],
[[0.0, 0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0, 0.0]],
[[0.0, 0.0, 0.0, 0.0, 1.0],
[0.0, 0.0, 0.0, 0.0, 1.0]],
[[1.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 1.0],
[1.0, 0.0, 0.0, 0.0, 0.0]],
[[0.0, 1.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 1.0],
[1.0, 0.0, 0.0, 0.0, 0.0]]]).astype(in_type)
assert np.allclose(expect, outputs.asnumpy(), 1.e-4, 1.e-7)