forked from mindspore-Ecosystem/mindspore
commit
dcd74aacb0
|
@ -904,10 +904,10 @@ def get_tensor_shape_vmap_rule(prim, axis_size):
|
|||
def _get_one_hot_vmap_axis(orig_axis, ndim, indices_dim):
|
||||
"""Find vmap axis for OneHot."""
|
||||
if orig_axis >= 0 and indices_dim <= orig_axis:
|
||||
return orig_axis + 1
|
||||
return (orig_axis + 1, indices_dim)
|
||||
if indices_dim == (ndim - 1) and orig_axis in (-1, (ndim - 1)):
|
||||
return ndim - 1
|
||||
return orig_axis
|
||||
return (ndim - 1, indices_dim + 1)
|
||||
return (orig_axis, indices_dim + 1)
|
||||
|
||||
|
||||
@vmap_rules_getters.register(P.OneHot)
|
||||
|
@ -943,10 +943,10 @@ def get_one_hot_vmap_rule(prim, axis_size):
|
|||
_raise_value_error(
|
||||
"The source axis of `off_value` in {} must be None, but got {}.".format(prim_name, off_value_dim))
|
||||
ndim = F.rank(indices)
|
||||
new_axis = _get_one_hot_vmap_axis(axis, ndim, indices_dim)
|
||||
new_axis, new_bd = _get_one_hot_vmap_axis(axis, ndim, indices_dim)
|
||||
out = P.OneHot(new_axis)(indices, depth, on_value, off_value)
|
||||
|
||||
return (out, indices_dim)
|
||||
return (out, new_bd)
|
||||
|
||||
return vmap_rule
|
||||
|
||||
|
|
|
@ -103,14 +103,14 @@ def one_hot_vmap(in_type, value_type):
|
|||
outputs = vmap(cal_onehot, in_axes=(0, None, None, None), out_axes=0)(indices, depth, on_value, off_value)
|
||||
|
||||
expect = np.array([[[0.0, 1.0, 0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 1.0, 0.0],
|
||||
[0.0, 0.0, 1.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0, 1.0],
|
||||
[1.0, 0.0, 0.0, 0.0, 0.0]],
|
||||
[[0.0, 1.0, 0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 1.0, 0.0],
|
||||
[0.0, 0.0, 1.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0, 1.0],
|
||||
[0.0, 1.0, 0.0, 0.0, 0.0]],
|
||||
[[0.0, 0.0, 0.0, 1.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 1.0, 0.0]],
|
||||
[[0.0, 0.0, 1.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 1.0, 0.0, 0.0]],
|
||||
[[0.0, 0.0, 0.0, 0.0, 1.0],
|
||||
[0.0, 0.0, 0.0, 0.0, 1.0]],
|
||||
[[1.0, 0.0, 0.0, 0.0, 0.0],
|
||||
[1.0, 0.0, 0.0, 0.0, 0.0]]]).astype(in_type)
|
||||
assert np.allclose(expect, outputs.asnumpy(), 1.e-4, 1.e-7)
|
||||
|
||||
|
|
Loading…
Reference in New Issue