forked from mindspore-Ecosystem/mindspore
!8310 fix example of categorical and rnntloss
From: @yanzhenxiang2020 Reviewed-by: @c_34 Signed-off-by: @c_34
This commit is contained in:
commit
d6f6269ff1
|
@ -4154,8 +4154,8 @@ class Meshgrid(PrimitiveWithInfer):
|
|||
|
||||
Args:
|
||||
indexing (str): Either 'xy' or 'ij'. Default: 'xy'.
|
||||
When the indexing argument is set to 'xy' (the default),
|
||||
the broadcasting instructions for the first two dimensions are swapped.
|
||||
When the indexing argument is set to 'xy' (the default), the broadcasting
|
||||
instructions for the first two dimensions are swapped.
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Union[tuple, list]) - A Tuple or list of N 1-D Tensor objects.
|
||||
|
@ -4170,7 +4170,8 @@ class Meshgrid(PrimitiveWithInfer):
|
|||
>>> z = np.array([8, 9, 0, 1, 2]).astype(np.int32)
|
||||
>>> inputs = (x, y, z)
|
||||
>>> meshgrid = ops.Meshgrid(indexing="xy")
|
||||
>>> meshgrid(inputs)
|
||||
>>> output = meshgrid(inputs)
|
||||
>>> print(output)
|
||||
(Tensor(shape=[3, 4, 6], dtype=UInt32, value=
|
||||
[[[1, 1, 1, 1, 1],
|
||||
[2, 2, 2, 2, 2],
|
||||
|
|
|
@ -2261,8 +2261,17 @@ class RNNTLoss(PrimitiveWithInfer):
|
|||
>>> labels = np.array([[1, 2]]).astype(np.int32)
|
||||
>>> input_length = np.array([T] * B).astype(np.int32)
|
||||
>>> label_length = np.array([len(l) for l in labels]).astype(np.int32)
|
||||
>>> rnnt_loss = ops.RNNTLoss(blank_label=blank)
|
||||
>>> rnnt_loss = ops.RNNTLoss(blank_label=0)
|
||||
>>> costs, grads = rnnt_loss(Tensor(acts), Tensor(labels), Tensor(input_length), Tensor(label_length))
|
||||
>>> print(costs)
|
||||
[-3.5036912]
|
||||
>>> print(grads)
|
||||
[[[[-0.35275543 -0.64724463 0. 0. 0. ]
|
||||
[-0.19174816 0. -0.45549652 0. 0. ]
|
||||
[-0.45549664 0. 0. 0. 0. ]]
|
||||
[[0. -0.35275543 0. 0. 0. ]
|
||||
[0. 0. -0.5445037 0. 0. ]
|
||||
[-1.00000002 0. 0. 0. 0. ]]]]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
|
|
|
@ -52,7 +52,7 @@ def test_net_assert():
|
|||
out_expect0 = np.array([0, 0, 0, 1, 1, 0]).reshape(3, 2)
|
||||
out_expect1 = np.array([0, 1, 1])
|
||||
out_expect2 = np.array([2, 2])
|
||||
out_expect3 = np.array([-0.7443749, 0.18251707]).reshape(2, 1)
|
||||
out_expect3 = np.array([-0.7443749, 0.18251707]).astype(np.float32).reshape(2, 1)
|
||||
assert np.array_equal(output[0].asnumpy(), out_expect0)
|
||||
assert np.array_equal(output[1].asnumpy(), out_expect1)
|
||||
assert np.array_equal(output[2].asnumpy(), out_expect2)
|
||||
|
|
Loading…
Reference in New Issue