!27045 Fix wrong return value of nn.Jvp and nn.Vjp examples.

Merge pull request !27045 from LiangZhibo/code_docs_grad
This commit is contained in:
i-robot 2021-12-02 01:01:41 +00:00 committed by Gitee
commit 0a05868d07
1 changed files with 2 additions and 2 deletions

View File

@ -91,7 +91,7 @@ class Jvp(Cell):
>>> v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32)) >>> v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
>>> output = Jvp(Net())(x, y, (v, v)) >>> output = Jvp(Net())(x, y, (v, v))
>>> print(output[0]) >>> print(output[0])
[[2, 10], [20, 68]] [[2, 10], [30, 68]]
>>> print(output[1]) >>> print(output[1])
[[4, 13], [28, 49]] [[4, 13], [28, 49]]
""" """
@ -203,7 +203,7 @@ class Vjp(Cell):
>>> v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32)) >>> v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
>>> output = Vjp(Net())(x, y, v) >>> output = Vjp(Net())(x, y, v)
>>> print(output[0]) >>> print(output[0])
[[2, 10], [20, 68]] [[2, 10], [30, 68]]
>>> print(output[1]) >>> print(output[1])
([[3, 12], [27, 48]], [[1, 1], [1, 1]]) ([[3, 12], [27, 48]], [[1, 1], [1, 1]])
""" """