!8310 fix example of categorical and rnntloss

From: @yanzhenxiang2020
Reviewed-by: @c_34
Signed-off-by: @c_34
This commit is contained in:
mindspore-ci-bot 2020-11-26 16:28:15 +08:00 committed by Gitee
commit d6f6269ff1
3 changed files with 15 additions and 5 deletions

View File

@ -4154,8 +4154,8 @@ class Meshgrid(PrimitiveWithInfer):
Args:
indexing (str): Either 'xy' or 'ij'. Default: 'xy'.
When the indexing argument is set to 'xy' (the default),
the broadcasting instructions for the first two dimensions are swapped.
When the indexing argument is set to 'xy' (the default), the broadcasting
instructions for the first two dimensions are swapped.
Inputs:
- **input_x** (Union[tuple, list]) - A Tuple or list of N 1-D Tensor objects.
@ -4170,7 +4170,8 @@ class Meshgrid(PrimitiveWithInfer):
>>> z = np.array([8, 9, 0, 1, 2]).astype(np.int32)
>>> inputs = (x, y, z)
>>> meshgrid = ops.Meshgrid(indexing="xy")
>>> meshgrid(inputs)
>>> output = meshgrid(inputs)
>>> print(output)
(Tensor(shape=[3, 4, 6], dtype=UInt32, value=
[[[1, 1, 1, 1, 1],
[2, 2, 2, 2, 2],

View File

@ -2261,8 +2261,17 @@ class RNNTLoss(PrimitiveWithInfer):
>>> labels = np.array([[1, 2]]).astype(np.int32)
>>> input_length = np.array([T] * B).astype(np.int32)
>>> label_length = np.array([len(l) for l in labels]).astype(np.int32)
>>> rnnt_loss = ops.RNNTLoss(blank_label=blank)
>>> rnnt_loss = ops.RNNTLoss(blank_label=0)
>>> costs, grads = rnnt_loss(Tensor(acts), Tensor(labels), Tensor(input_length), Tensor(label_length))
>>> print(costs)
[-3.5036912]
>>> print(grads)
[[[[-0.35275543 -0.64724463 0. 0. 0. ]
[-0.19174816 0. -0.45549652 0. 0. ]
[-0.45549664 0. 0. 0. 0. ]]
[[0. -0.35275543 0. 0. 0. ]
[0. 0. -0.5445037 0. 0. ]
[-1.00000002 0. 0. 0. 0. ]]]]
"""
@prim_attr_register

View File

@ -52,7 +52,7 @@ def test_net_assert():
out_expect0 = np.array([0, 0, 0, 1, 1, 0]).reshape(3, 2)
out_expect1 = np.array([0, 1, 1])
out_expect2 = np.array([2, 2])
out_expect3 = np.array([-0.7443749, 0.18251707]).reshape(2, 1)
out_expect3 = np.array([-0.7443749, 0.18251707]).astype(np.float32).reshape(2, 1)
assert np.array_equal(output[0].asnumpy(), out_expect0)
assert np.array_equal(output[1].asnumpy(), out_expect1)
assert np.array_equal(output[2].asnumpy(), out_expect2)