fix example of categorical and rnntloss

This commit is contained in:
yanzhenxiang2020 2020-11-06 17:51:51 +08:00
parent 1080e1d7a7
commit dca109c9a5
3 changed files with 15 additions and 5 deletions

View File

@ -4151,8 +4151,8 @@ class Meshgrid(PrimitiveWithInfer):
Args: Args:
indexing (str): Either 'xy' or 'ij'. Default: 'xy'. indexing (str): Either 'xy' or 'ij'. Default: 'xy'.
When the indexing argument is set to 'xy' (the default), When the indexing argument is set to 'xy' (the default), the broadcasting
the broadcasting instructions for the first two dimensions are swapped. instructions for the first two dimensions are swapped.
Inputs: Inputs:
- **input_x** (Union[tuple, list]) - A Tuple or list of N 1-D Tensor objects. - **input_x** (Union[tuple, list]) - A Tuple or list of N 1-D Tensor objects.
@ -4167,7 +4167,8 @@ class Meshgrid(PrimitiveWithInfer):
>>> z = np.array([8, 9, 0, 1, 2]).astype(np.int32) >>> z = np.array([8, 9, 0, 1, 2]).astype(np.int32)
>>> inputs = (x, y, z) >>> inputs = (x, y, z)
>>> meshgrid = ops.Meshgrid(indexing="xy") >>> meshgrid = ops.Meshgrid(indexing="xy")
>>> meshgrid(inputs) >>> output = meshgrid(inputs)
>>> print(output)
(Tensor(shape=[3, 4, 6], dtype=UInt32, value= (Tensor(shape=[3, 4, 6], dtype=UInt32, value=
[[[1, 1, 1, 1, 1], [[[1, 1, 1, 1, 1],
[2, 2, 2, 2, 2], [2, 2, 2, 2, 2],

View File

@ -2261,8 +2261,17 @@ class RNNTLoss(PrimitiveWithInfer):
>>> labels = np.array([[1, 2]]).astype(np.int32) >>> labels = np.array([[1, 2]]).astype(np.int32)
>>> input_length = np.array([T] * B).astype(np.int32) >>> input_length = np.array([T] * B).astype(np.int32)
>>> label_length = np.array([len(l) for l in labels]).astype(np.int32) >>> label_length = np.array([len(l) for l in labels]).astype(np.int32)
>>> rnnt_loss = ops.RNNTLoss(blank_label=blank) >>> rnnt_loss = ops.RNNTLoss(blank_label=0)
>>> costs, grads = rnnt_loss(Tensor(acts), Tensor(labels), Tensor(input_length), Tensor(label_length)) >>> costs, grads = rnnt_loss(Tensor(acts), Tensor(labels), Tensor(input_length), Tensor(label_length))
>>> print(costs)
[-3.5036912]
>>> print(grads)
[[[[-0.35275543 -0.64724463 0. 0. 0. ]
[-0.19174816 0. -0.45549652 0. 0. ]
[-0.45549664 0. 0. 0. 0. ]]
[[0. -0.35275543 0. 0. 0. ]
[0. 0. -0.5445037 0. 0. ]
[-1.00000002 0. 0. 0. 0. ]]]]
""" """
@prim_attr_register @prim_attr_register

View File

@ -52,7 +52,7 @@ def test_net_assert():
out_expect0 = np.array([0, 0, 0, 1, 1, 0]).reshape(3, 2) out_expect0 = np.array([0, 0, 0, 1, 1, 0]).reshape(3, 2)
out_expect1 = np.array([0, 1, 1]) out_expect1 = np.array([0, 1, 1])
out_expect2 = np.array([2, 2]) out_expect2 = np.array([2, 2])
out_expect3 = np.array([-0.7443749, 0.18251707]).reshape(2, 1) out_expect3 = np.array([-0.7443749, 0.18251707]).astype(np.float32).reshape(2, 1)
assert np.array_equal(output[0].asnumpy(), out_expect0) assert np.array_equal(output[0].asnumpy(), out_expect0)
assert np.array_equal(output[1].asnumpy(), out_expect1) assert np.array_equal(output[1].asnumpy(), out_expect1)
assert np.array_equal(output[2].asnumpy(), out_expect2) assert np.array_equal(output[2].asnumpy(), out_expect2)