!41548 Fix graph mode of some ut cases

Merge pull request !41548 from chenfei_mindspore/code-self-check
This commit is contained in:
i-robot 2022-09-06 10:57:59 +00:00 committed by Gitee
commit 497ca52445
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 11 additions and 4 deletions

View File

@ -63,6 +63,7 @@ def test_apply_adam_with_amsgrad_compile():
_cell_graph_executor.compile(train_network, inputs, label) _cell_graph_executor.compile(train_network, inputs, label)
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU") context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
_cell_graph_executor.compile(train_network, inputs, label) _cell_graph_executor.compile(train_network, inputs, label)
context.set_context(mode=context.GRAPH_MODE)
def test_apply_adam_with_amsgrad_group1(): def test_apply_adam_with_amsgrad_group1():
@ -95,6 +96,7 @@ def test_apply_adam_with_amsgrad_group1():
_cell_graph_executor.compile(train_network, inputs, label) _cell_graph_executor.compile(train_network, inputs, label)
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU") context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
_cell_graph_executor.compile(train_network, inputs, label) _cell_graph_executor.compile(train_network, inputs, label)
context.set_context(mode=context.GRAPH_MODE)
def test_apply_adam_with_amsgrad_group2(): def test_apply_adam_with_amsgrad_group2():
@ -125,6 +127,7 @@ def test_apply_adam_with_amsgrad_group2():
_cell_graph_executor.compile(train_network, inputs, label) _cell_graph_executor.compile(train_network, inputs, label)
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU") context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
_cell_graph_executor.compile(train_network, inputs, label) _cell_graph_executor.compile(train_network, inputs, label)
context.set_context(mode=context.GRAPH_MODE)
class NetWithSparseGatherV2(nn.Cell): class NetWithSparseGatherV2(nn.Cell):
@ -167,3 +170,4 @@ def test_sparse_apply_adam_with_amsgrad():
with pytest.raises(Exception): with pytest.raises(Exception):
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU") context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
_cell_graph_executor.compile(train_network, indices, label) _cell_graph_executor.compile(train_network, indices, label)
context.set_context(mode=context.GRAPH_MODE)

View File

@ -20,6 +20,7 @@ import mindspore as ms
from mindspore import Tensor from mindspore import Tensor
from mindspore.common.parameter import Parameter from mindspore.common.parameter import Parameter
from mindspore.nn.optim import Optimizer, SGD, Adam, AdamWeightDecay from mindspore.nn.optim import Optimizer, SGD, Adam, AdamWeightDecay
from mindspore import context
class IterableObjc: class IterableObjc:
@ -110,6 +111,7 @@ def test_not_flattened_params():
Description: Optimizer with not flattened parameters. Description: Optimizer with not flattened parameters.
Expectation: The Optimizer works as expected. Expectation: The Optimizer works as expected.
""" """
context.set_context(mode=context.GRAPH_MODE)
p1 = Parameter(Tensor([1], ms.float32), name="p1") p1 = Parameter(Tensor([1], ms.float32), name="p1")
p2 = Parameter(Tensor([2], ms.float32), name="p2") p2 = Parameter(Tensor([2], ms.float32), name="p2")
p3 = Parameter(Tensor([3], ms.float32), name="p3") p3 = Parameter(Tensor([3], ms.float32), name="p3")
@ -127,6 +129,7 @@ def test_with_flattened_params():
Description: Optimizer with flattened parameters. Description: Optimizer with flattened parameters.
Expectation: The Optimizer works as expected. Expectation: The Optimizer works as expected.
""" """
context.set_context(mode=context.GRAPH_MODE)
p1 = Parameter(Tensor([1], ms.float32), name="p1") p1 = Parameter(Tensor([1], ms.float32), name="p1")
p2 = Parameter(Tensor([2], ms.float32), name="p2") p2 = Parameter(Tensor([2], ms.float32), name="p2")
p3 = Parameter(Tensor([3], ms.float32), name="p3") p3 = Parameter(Tensor([3], ms.float32), name="p3")
@ -154,6 +157,7 @@ def test_adam_with_flattened_params():
Description: Adam optimizer with flattened parameters. Description: Adam optimizer with flattened parameters.
Expectation: It is ok to compile the optimizer. Expectation: It is ok to compile the optimizer.
""" """
context.set_context(mode=context.GRAPH_MODE)
p1 = Parameter(Tensor([1], ms.float32), name="p1") p1 = Parameter(Tensor([1], ms.float32), name="p1")
p2 = Parameter(Tensor([2], ms.float32), name="p2") p2 = Parameter(Tensor([2], ms.float32), name="p2")
p3 = Parameter(Tensor([3], ms.float32), name="p3") p3 = Parameter(Tensor([3], ms.float32), name="p3")
@ -164,7 +168,6 @@ def test_adam_with_flattened_params():
g2 = Tensor([0.2], ms.float32) g2 = Tensor([0.2], ms.float32)
g3 = Tensor([0.3], ms.float32) g3 = Tensor([0.3], ms.float32)
grads = (g1, g2, g3) grads = (g1, g2, g3)
with pytest.raises(NotImplementedError):
adam(grads) adam(grads)
@ -174,6 +177,7 @@ def test_adam_with_flattened_params_fusion_size():
Description: Adam optimizer with flattened parameters and fusion size. Description: Adam optimizer with flattened parameters and fusion size.
Expectation: It is ok to compile the optimizer. Expectation: It is ok to compile the optimizer.
""" """
context.set_context(mode=context.GRAPH_MODE)
p1 = Parameter(Tensor([1], ms.float32), name="p1") p1 = Parameter(Tensor([1], ms.float32), name="p1")
p2 = Parameter(Tensor([2], ms.float32), name="p2") p2 = Parameter(Tensor([2], ms.float32), name="p2")
p3 = Parameter(Tensor([3], ms.float32), name="p3") p3 = Parameter(Tensor([3], ms.float32), name="p3")
@ -194,5 +198,4 @@ def test_adam_with_flattened_params_fusion_size():
g4 = Tensor([0.4], ms.float32) g4 = Tensor([0.4], ms.float32)
g5 = Tensor([0.5], ms.float32) g5 = Tensor([0.5], ms.float32)
grads = (g1, g2, g3, g4, g5) grads = (g1, g2, g3, g4, g5)
with pytest.raises(NotImplementedError):
adam(grads) adam(grads)