!19759 fix example result of graphcell

Merge pull request !19759 from wangnan39/code_docs_fix_result_of_graphcell_example
This commit is contained in:
i-robot 2021-07-09 02:02:44 +00:00 committed by Gitee
commit b45a455c70
2 changed files with 9 additions and 5 deletions

View File

@ -1405,12 +1405,16 @@ class GraphCell(Cell):
>>> from mindspore import Tensor
>>> from mindspore.train import export, load
>>>
>>> net = nn.Conv2d(1, 1, kernel_size=3)
>>> net = nn.Conv2d(1, 1, kernel_size=3, weight_init="ones")
>>> input = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
>>> export(net, input, file_name="net", file_format="MINDIR")
>>> graph = load("net.mindir")
>>> net = nn.GraphCell(graph)
>>> output = net(input)
>>> print(output)
[[[[4. 6. 4.]
[6. 9. 6.]
[4. 6. 4.]]]]
"""
def __init__(self, graph):
super(GraphCell, self).__init__(auto_prefix=True)

View File

@ -331,16 +331,16 @@ def load(file_name, **kwargs):
>>> from mindspore import Tensor
>>> from mindspore.train import export, load
>>>
>>> net = nn.Conv2d(1, 1, kernel_size=3)
>>> net = nn.Conv2d(1, 1, kernel_size=3, weight_init="ones")
>>> input = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
>>> export(net, input, file_name="net", file_format="MINDIR")
>>> graph = load("net.mindir")
>>> net = nn.GraphCell(graph)
>>> output = net(input)
>>> print(output)
[[[[0.03204346 0.04455566 0.03509521]
[0.02406311 0.04125977 0.02404785]
[0.02018738 0.0292511 0.00889587]]]]
[[[[4. 6. 4.]
[6. 9. 6.]
[4. 6. 4.]]]]
"""
if not isinstance(file_name, str):
raise ValueError("The file name must be string.")