diff --git a/mindspore/python/mindspore/nn/cell.py b/mindspore/python/mindspore/nn/cell.py index 0c7cbb37a83..bdf964182df 100755 --- a/mindspore/python/mindspore/nn/cell.py +++ b/mindspore/python/mindspore/nn/cell.py @@ -2243,6 +2243,8 @@ class GraphCell(Cell): >>> import mindspore as ms >>> import mindspore.nn as nn >>> from mindspore import Tensor + >>> from mindspore import context + >>> context.set_context(mode=context.GRAPH_MODE) >>> >>> net = nn.Conv2d(1, 1, kernel_size=3, weight_init="ones") >>> input = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32)) diff --git a/mindspore/python/mindspore/train/serialization.py b/mindspore/python/mindspore/train/serialization.py index 004c46743d6..57cd4e253ca 100644 --- a/mindspore/python/mindspore/train/serialization.py +++ b/mindspore/python/mindspore/train/serialization.py @@ -412,6 +412,8 @@ def load(file_name, **kwargs): >>> import mindspore as ms >>> import mindspore.nn as nn >>> from mindspore import Tensor + >>> from mindspore import context + >>> context.set_context(mode=context.GRAPH_MODE) >>> >>> net = nn.Conv2d(1, 1, kernel_size=3, weight_init="ones") >>> input_tensor = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))