forked from mindspore-Ecosystem/mindspore
!27045 Fix wrong return value of nn.Jvp and nn.Vjp examples.
Merge pull request !27045 from LiangZhibo/code_docs_grad
This commit is contained in:
commit
0a05868d07
|
@ -91,7 +91,7 @@ class Jvp(Cell):
|
||||||
>>> v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
>>> v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||||
>>> output = Jvp(Net())(x, y, (v, v))
|
>>> output = Jvp(Net())(x, y, (v, v))
|
||||||
>>> print(output[0])
|
>>> print(output[0])
|
||||||
[[2, 10], [20, 68]]
|
[[2, 10], [30, 68]]
|
||||||
>>> print(output[1])
|
>>> print(output[1])
|
||||||
[[4, 13], [28, 49]]
|
[[4, 13], [28, 49]]
|
||||||
"""
|
"""
|
||||||
|
@ -203,7 +203,7 @@ class Vjp(Cell):
|
||||||
>>> v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
>>> v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
|
||||||
>>> output = Vjp(Net())(x, y, v)
|
>>> output = Vjp(Net())(x, y, v)
|
||||||
>>> print(output[0])
|
>>> print(output[0])
|
||||||
[[2, 10], [20, 68]]
|
[[2, 10], [30, 68]]
|
||||||
>>> print(output[1])
|
>>> print(output[1])
|
||||||
([[3, 12], [27, 48]], [[1, 1], [1, 1]])
|
([[3, 12], [27, 48]], [[1, 1], [1, 1]])
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Reference in New Issue