!19759 fix example result of graphcell
Merge pull request !19759 from wangnan39/code_docs_fix_result_of_graphcell_example
This commit is contained in:
commit
b45a455c70
|
@ -1405,12 +1405,16 @@ class GraphCell(Cell):
|
|||
>>> from mindspore import Tensor
|
||||
>>> from mindspore.train import export, load
|
||||
>>>
|
||||
>>> net = nn.Conv2d(1, 1, kernel_size=3)
|
||||
>>> net = nn.Conv2d(1, 1, kernel_size=3, weight_init="ones")
|
||||
>>> input = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
|
||||
>>> export(net, input, file_name="net", file_format="MINDIR")
|
||||
>>> graph = load("net.mindir")
|
||||
>>> net = nn.GraphCell(graph)
|
||||
>>> output = net(input)
|
||||
>>> print(output)
|
||||
[[[[4. 6. 4.]
|
||||
[6. 9. 6.]
|
||||
[4. 6. 4.]]]]
|
||||
"""
|
||||
def __init__(self, graph):
|
||||
super(GraphCell, self).__init__(auto_prefix=True)
|
||||
|
|
|
@ -331,16 +331,16 @@ def load(file_name, **kwargs):
|
|||
>>> from mindspore import Tensor
|
||||
>>> from mindspore.train import export, load
|
||||
>>>
|
||||
>>> net = nn.Conv2d(1, 1, kernel_size=3)
|
||||
>>> net = nn.Conv2d(1, 1, kernel_size=3, weight_init="ones")
|
||||
>>> input = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
|
||||
>>> export(net, input, file_name="net", file_format="MINDIR")
|
||||
>>> graph = load("net.mindir")
|
||||
>>> net = nn.GraphCell(graph)
|
||||
>>> output = net(input)
|
||||
>>> print(output)
|
||||
[[[[0.03204346 0.04455566 0.03509521]
|
||||
[0.02406311 0.04125977 0.02404785]
|
||||
[0.02018738 0.0292511 0.00889587]]]]
|
||||
[[[[4. 6. 4.]
|
||||
[6. 9. 6.]
|
||||
[4. 6. 4.]]]]
|
||||
"""
|
||||
if not isinstance(file_name, str):
|
||||
raise ValueError("The file name must be string.")
|
||||
|
|
Loading…
Reference in New Issue