forked from mindspore-Ecosystem/mindspore
!41548 Fix graph mode of some ut cases
Merge pull request !41548 from chenfei_mindspore/code-self-check
This commit is contained in:
commit
497ca52445
|
@ -63,6 +63,7 @@ def test_apply_adam_with_amsgrad_compile():
|
|||
_cell_graph_executor.compile(train_network, inputs, label)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
|
||||
_cell_graph_executor.compile(train_network, inputs, label)
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
def test_apply_adam_with_amsgrad_group1():
|
||||
|
@ -95,6 +96,7 @@ def test_apply_adam_with_amsgrad_group1():
|
|||
_cell_graph_executor.compile(train_network, inputs, label)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
|
||||
_cell_graph_executor.compile(train_network, inputs, label)
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
def test_apply_adam_with_amsgrad_group2():
|
||||
|
@ -125,6 +127,7 @@ def test_apply_adam_with_amsgrad_group2():
|
|||
_cell_graph_executor.compile(train_network, inputs, label)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
|
||||
_cell_graph_executor.compile(train_network, inputs, label)
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
class NetWithSparseGatherV2(nn.Cell):
|
||||
|
@ -167,3 +170,4 @@ def test_sparse_apply_adam_with_amsgrad():
|
|||
with pytest.raises(Exception):
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
|
||||
_cell_graph_executor.compile(train_network, indices, label)
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
|
|
@ -20,6 +20,7 @@ import mindspore as ms
|
|||
from mindspore import Tensor
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.nn.optim import Optimizer, SGD, Adam, AdamWeightDecay
|
||||
from mindspore import context
|
||||
|
||||
|
||||
class IterableObjc:
|
||||
|
@ -110,6 +111,7 @@ def test_not_flattened_params():
|
|||
Description: Optimizer with not flattened parameters.
|
||||
Expectation: The Optimizer works as expected.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
p1 = Parameter(Tensor([1], ms.float32), name="p1")
|
||||
p2 = Parameter(Tensor([2], ms.float32), name="p2")
|
||||
p3 = Parameter(Tensor([3], ms.float32), name="p3")
|
||||
|
@ -127,6 +129,7 @@ def test_with_flattened_params():
|
|||
Description: Optimizer with flattened parameters.
|
||||
Expectation: The Optimizer works as expected.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
p1 = Parameter(Tensor([1], ms.float32), name="p1")
|
||||
p2 = Parameter(Tensor([2], ms.float32), name="p2")
|
||||
p3 = Parameter(Tensor([3], ms.float32), name="p3")
|
||||
|
@ -154,6 +157,7 @@ def test_adam_with_flattened_params():
|
|||
Description: Adam optimizer with flattened parameters.
|
||||
Expectation: It is ok to compile the optimizer.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
p1 = Parameter(Tensor([1], ms.float32), name="p1")
|
||||
p2 = Parameter(Tensor([2], ms.float32), name="p2")
|
||||
p3 = Parameter(Tensor([3], ms.float32), name="p3")
|
||||
|
@ -164,8 +168,7 @@ def test_adam_with_flattened_params():
|
|||
g2 = Tensor([0.2], ms.float32)
|
||||
g3 = Tensor([0.3], ms.float32)
|
||||
grads = (g1, g2, g3)
|
||||
with pytest.raises(NotImplementedError):
|
||||
adam(grads)
|
||||
adam(grads)
|
||||
|
||||
|
||||
def test_adam_with_flattened_params_fusion_size():
|
||||
|
@ -174,6 +177,7 @@ def test_adam_with_flattened_params_fusion_size():
|
|||
Description: Adam optimizer with flattened parameters and fusion size.
|
||||
Expectation: It is ok to compile the optimizer.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
p1 = Parameter(Tensor([1], ms.float32), name="p1")
|
||||
p2 = Parameter(Tensor([2], ms.float32), name="p2")
|
||||
p3 = Parameter(Tensor([3], ms.float32), name="p3")
|
||||
|
@ -194,5 +198,4 @@ def test_adam_with_flattened_params_fusion_size():
|
|||
g4 = Tensor([0.4], ms.float32)
|
||||
g5 = Tensor([0.5], ms.float32)
|
||||
grads = (g1, g2, g3, g4, g5)
|
||||
with pytest.raises(NotImplementedError):
|
||||
adam(grads)
|
||||
adam(grads)
|
||||
|
|
Loading…
Reference in New Issue