forked from mindspore-Ecosystem/mindspore
fix onehot
This commit is contained in:
parent
1cc940133e
commit
1bda84d34c
|
@ -984,10 +984,12 @@ def get_tensor_shape_vmap_rule(prim, axis_size):
|
|||
def _get_one_hot_vmap_axis(orig_axis, ndim, indices_dim):
|
||||
"""Find vmap axis for OneHot."""
|
||||
if orig_axis >= 0 and indices_dim <= orig_axis:
|
||||
return (orig_axis + 1, indices_dim)
|
||||
if indices_dim == (ndim - 1) and orig_axis in (-1, (ndim - 1)):
|
||||
return (ndim - 1, indices_dim + 1)
|
||||
return (orig_axis, indices_dim + 1)
|
||||
return orig_axis + 1, indices_dim
|
||||
if orig_axis == -1:
|
||||
if indices_dim == (ndim - 1):
|
||||
return ndim - 1, indices_dim + 1
|
||||
return orig_axis, indices_dim
|
||||
return orig_axis, indices_dim + 1
|
||||
|
||||
|
||||
@vmap_rules_getters.register(P.OneHot)
|
||||
|
|
|
@ -103,14 +103,14 @@ def one_hot_vmap(in_type, value_type):
|
|||
outputs = vmap(cal_onehot, in_axes=(0, None, None, None), out_axes=0)(indices, depth, on_value, off_value)
|
||||
|
||||
expect = np.array([[[0.0, 1.0, 0.0, 0.0, 0.0],
|
||||
[0.0, 1.0, 0.0, 0.0, 0.0]],
|
||||
[[0.0, 0.0, 0.0, 1.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 1.0, 0.0]],
|
||||
[[0.0, 0.0, 1.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 1.0, 0.0, 0.0]],
|
||||
[[0.0, 0.0, 0.0, 0.0, 1.0],
|
||||
[0.0, 0.0, 0.0, 0.0, 1.0]],
|
||||
[[1.0, 0.0, 0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 1.0, 0.0],
|
||||
[0.0, 0.0, 1.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0, 1.0],
|
||||
[1.0, 0.0, 0.0, 0.0, 0.0]],
|
||||
[[0.0, 1.0, 0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 1.0, 0.0],
|
||||
[0.0, 0.0, 1.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0, 1.0],
|
||||
[1.0, 0.0, 0.0, 0.0, 0.0]]]).astype(in_type)
|
||||
assert np.allclose(expect, outputs.asnumpy(), 1.e-4, 1.e-7)
|
||||
|
||||
|
|
Loading…
Reference in New Issue