!21880 Check ref of update parameters

Merge pull request !21880 from chenfei_mindspore/ascend-control-use-vm
This commit is contained in:
i-robot 2021-08-19 03:49:12 +00:00 committed by Gitee
commit 8ec16bbab9
51 changed files with 1089 additions and 263 deletions

View File

@ -1359,8 +1359,12 @@ void KernelGraph::SetOptimizerFlag() {
continue;
}
auto param = real_node->cast<ParameterPtr>();
has_optimizer_ = true;
(void)updated_parameters_.insert(param);
auto abstract = param->abstract();
MS_EXCEPTION_IF_NULL(abstract);
if (abstract->isa<abstract::AbstractRef>()) {
has_optimizer_ = true;
(void)updated_parameters_.insert(param);
}
}
}
}

View File

@ -626,6 +626,7 @@ bool TaskEmitAction(const ResourcePtr &res) {
MS_LOG(INFO) << "Run graph mode with vm.";
bc_ptr->set_is_multi_graph_sink(false);
context_ptr->set_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK, false);
context_ptr->set_param<bool>(MS_CTX_ENABLE_LOOP_SINK, false);
}
}

View File

@ -12,13 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pytest
from mindspore import context
from mindspore import Tensor, nn
from mindspore.ops import composite as C
from mindspore.common import dtype as mstype
grad_all = C.GradOperation(get_all=True)
context.set_context(device_target="Ascend")
class SingleIfNet(nn.Cell):
@ -62,26 +62,38 @@ def control_flow_single_if(input_net, x, y):
context.set_context(mode=context.GRAPH_MODE)
net = input_net()
grad_net = GradNet(net)
graph_forward_res = net(x, y)
forward_net = input_net()
graph_forward_res = forward_net(x, y)
graph_backward_res = grad_net(x, y)
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE)
net = input_net()
grad_net = GradNet(net)
pynative_forward_res = net(x, y)
forward_net = input_net()
pynative_forward_res = forward_net(x, y)
pynative_backward_res = grad_net(x, y)
assert graph_forward_res == pynative_forward_res
assert graph_backward_res == pynative_backward_res
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_single_if():
x = Tensor(2, mstype.int32)
y = Tensor(5, mstype.int32)
control_flow_single_if(SingleIfNet, x, y)
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_single_if_01():
x = Tensor(2, mstype.int32)
y = Tensor(5, mstype.int32)

View File

@ -12,13 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pytest
from mindspore.common import dtype as mstype
from mindspore import nn
from mindspore import Tensor
from mindspore.ops import composite as C
from mindspore import context
context.set_context(mode=context.GRAPH_MODE, save_graphs=True, device_target="Ascend")
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
class ForwardNet(nn.Cell):
@ -41,7 +42,11 @@ class BackwardNet(nn.Cell):
grads = self.grad(self.forward_net)(*inputs)
return grads
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_forward():
c1 = Tensor([0], mstype.int32)
c2 = Tensor([0], mstype.int32)
@ -50,7 +55,11 @@ def test_forward():
output = forward_net(c1, c2)
assert expect == output
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_backward():
c1 = Tensor([0], mstype.int32)
c2 = Tensor([0], mstype.int32)

View File

@ -19,11 +19,17 @@ from mindspore import Tensor, nn
from mindspore.common.parameter import Parameter
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.common import dtype as mstype
grad_all = C.GradOperation(get_all=True)
context.set_context(device_target="Ascend")
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_single_for_01():
class SingleForNet(nn.Cell):
def __init__(self):
@ -72,6 +78,11 @@ def test_single_for_01():
assert graph_backward_res == pynative_backward_res
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_single_for_02():
class SingleForNet(nn.Cell):
def __init__(self):
@ -102,20 +113,29 @@ def test_single_for_02():
context.set_context(mode=context.GRAPH_MODE)
for_net = SingleForNet()
net = GradNet(for_net)
graph_forward_res = for_net(x, y, z)
for_net_forward = SingleForNet()
graph_forward_res = for_net_forward(x, y, z)
graph_backward_res = net(x, y, z)
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE)
for_net = SingleForNet()
net = GradNet(for_net)
pynative_forward_res = for_net(x, y, z)
for_net_forward = SingleForNet()
pynative_forward_res = for_net_forward(x, y, z)
pynative_backward_res = net(x, y, z)
assert graph_forward_res == pynative_forward_res
assert graph_backward_res == pynative_backward_res
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_single_for_03():
class SingleForNet(nn.Cell):
def __init__(self):
@ -157,20 +177,29 @@ def test_single_for_03():
context.set_context(mode=context.GRAPH_MODE)
single_for_net = SingleForNet()
net = GradNet(single_for_net)
graph_forward_res = single_for_net(x, y)
for_net_forward = SingleForNet()
graph_forward_res = for_net_forward(x, y)
graph_backward_res = net(x, y)
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE)
single_for_net = SingleForNet()
net = GradNet(single_for_net)
pynative_forward_res = single_for_net(x, y)
for_net_forward = SingleForNet()
pynative_forward_res = for_net_forward(x, y)
pynative_backward_res = net(x, y)
assert graph_forward_res == pynative_forward_res
assert graph_backward_res == pynative_backward_res
@pytest.mark.skip(reason="not supported side effect")
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_single_for_04():
class SingleForNet(nn.Cell):
def __init__(self):
@ -187,7 +216,7 @@ def test_single_for_04():
def construct(self, x):
self.assign(self.param_a, x + self.param_a)
for _ in range(1):
self.param_b = x - self.param_a
F.assign(self.param_b, x - self.param_a)
return self.param_b
class GradNet(nn.Cell):
@ -204,20 +233,29 @@ def test_single_for_04():
context.set_context(mode=context.GRAPH_MODE)
single_for_net = SingleForNet()
net = GradNet(single_for_net)
graph_forward_res = single_for_net(x)
for_net_forward = SingleForNet()
graph_forward_res = for_net_forward(x)
graph_backward_res = net(x)
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE)
single_for_net = SingleForNet()
net = GradNet(single_for_net)
pynative_forward_res = single_for_net(x)
for_net_forward = SingleForNet()
pynative_forward_res = for_net_forward(x)
pynative_backward_res = net(x)
assert graph_forward_res == pynative_forward_res
assert graph_backward_res == pynative_backward_res
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_single_for_05():
class SingleForNet(nn.Cell):
def __init__(self):
@ -249,14 +287,18 @@ def test_single_for_05():
context.set_context(mode=context.GRAPH_MODE)
single_for_net = SingleForNet()
net = GradNet(single_for_net)
graph_forward_res = single_for_net(x)
for_net_forward = SingleForNet()
graph_forward_res = for_net_forward(x)
graph_backward_res = net(x)
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE)
single_for_net = SingleForNet()
net = GradNet(single_for_net)
pynative_forward_res = single_for_net(x)
for_net_forward = SingleForNet()
pynative_forward_res = for_net_forward(x)
pynative_backward_res = net(x)
assert graph_forward_res == pynative_forward_res

View File

@ -20,7 +20,6 @@ from mindspore.common import dtype as mstype
from mindspore.common.parameter import Parameter
grad_all = C.GradOperation(get_all=True)
context.set_context(device_target="Ascend")
class IfInIfNet(nn.Cell):
@ -111,6 +110,23 @@ class IfInIfNet3(nn.Cell):
return x
# add a while to test if_in_if run with vm.Only should run in ascend.
class IfInIfNet4(nn.Cell):
def __init__(self):
super().__init__()
self.param_a = Parameter(Tensor(5, mstype.int32), name='a')
self.param_b = Parameter(Tensor(4, mstype.int32), name='b')
def construct(self, x):
while x < 1:
x = x + 1
if self.param_a > self.param_b:
out = self.func(x)
else:
out = self.func(self.param_a)
out += self.param_b
return out
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
@ -125,37 +141,65 @@ def control_flow_if_in_if(input_net, x):
context.set_context(mode=context.GRAPH_MODE)
net = input_net()
grad_net = GradNet(net)
graph_forward_res = net(x)
forward_net = input_net()
graph_forward_res = forward_net(x)
graph_backward_res = grad_net(x)
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE)
net = input_net()
grad_net = GradNet(net)
pynative_forward_res = net(x)
forward_net = input_net()
pynative_forward_res = forward_net(x)
pynative_backward_res = grad_net(x)
assert graph_forward_res == pynative_forward_res
assert graph_backward_res == pynative_backward_res
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_if_in_if():
x = Tensor(2, mstype.int32)
control_flow_if_in_if(IfInIfNet, x)
@pytest.mark.skip(reason="not supported side effect")
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_if_in_if_01():
x = Tensor(2, mstype.int32)
control_flow_if_in_if(IfInIfNet1, x)
@pytest.mark.skip(reason="not supported side effect")
@pytest.mark.skip(reason="Ascend compile error in multigraph sink.")
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_if_in_if_02():
x = Tensor(2, mstype.int32)
control_flow_if_in_if(IfInIfNet2, x)
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_if_in_if_03():
x = Tensor(2, mstype.int32)
control_flow_if_in_if(IfInIfNet3, x)
@pytest.mark.skip(reason="Result not correct in ascend vm")
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_if_in_if_04():
x = Tensor(2, mstype.int32)
control_flow_if_in_if(IfInIfNet4, x)

View File

@ -22,7 +22,7 @@ from mindspore import context
from mindspore.ops import functional as F
from mindspore.common.parameter import Parameter
context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target="Ascend")
context.set_context(mode=context.GRAPH_MODE, save_graphs=False)
class ForwardNet(nn.Cell):
@ -56,7 +56,11 @@ class BackwardNet(nn.Cell):
grads = self.grad(self.forward_net)(*inputs)
return grads
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_forward():
# Graph Mode
context.set_context(mode=context.GRAPH_MODE)
@ -72,6 +76,7 @@ def test_forward():
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard

View File

@ -20,7 +20,7 @@ from mindspore import Tensor
from mindspore.ops import composite as C
from mindspore import context
context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target="Ascend")
context.set_context(mode=context.GRAPH_MODE, save_graphs=False)
class ForwardNet(nn.Cell):
@ -79,12 +79,22 @@ class BackwardNetReplaceBreak(nn.Cell):
return grads
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_forward():
context.set_context(mode=context.GRAPH_MODE)
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)
forward_net = ForwardNet(max_cycles=10)
out = forward_net(x, y)
print("forward out:", out)
graph_mode_out = forward_net(x, y)
context.set_context(mode=context.PYNATIVE_MODE)
pynative_forward_net = ForwardNet(max_cycles=10)
pynative_mode_out = pynative_forward_net(x, y)
assert graph_mode_out == pynative_mode_out
# Problem: Exceed function call depth limit 1000.
@ -93,27 +103,58 @@ def test_forward():
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_backward():
context.set_context(mode=context.GRAPH_MODE)
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)
forward_net = ForwardNet(max_cycles=10)
backward_net = BackwardNet(forward_net)
grads = backward_net(x, y)
print("grads:", grads)
graph_grads = backward_net(x, y)
context.set_context(mode=context.PYNATIVE_MODE)
forward_net = ForwardNet(max_cycles=10)
backward_net = BackwardNet(forward_net)
pynative_grads = backward_net(x, y)
assert graph_grads == pynative_grads
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_forward_replace_break():
context.set_context(mode=context.GRAPH_MODE)
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)
forward_net = ForwardNetReplaceBreak(max_cycles=10)
out = forward_net(x, y)
print("forward out:", out)
graph_out = forward_net(x, y)
context.set_context(mode=context.PYNATIVE_MODE)
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)
forward_net = ForwardNetReplaceBreak(max_cycles=10)
pynative_out = forward_net(x, y)
assert graph_out == pynative_out
# Problem: Exceed function call depth limit 1000.
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_backward_replace_break():
context.set_context(mode=context.GRAPH_MODE)
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)
forward_net = ForwardNetReplaceBreak(max_cycles=10)
backward_net = BackwardNetReplaceBreak(forward_net)
grads = backward_net(x, y)
print("grads:", grads)
graph_grads = backward_net(x, y)
context.set_context(mode=context.PYNATIVE_MODE)
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)
forward_net = ForwardNetReplaceBreak(max_cycles=10)
backward_net = BackwardNetReplaceBreak(forward_net)
pynative_grads = backward_net(x, y)
assert graph_grads == pynative_grads

View File

@ -22,7 +22,7 @@ from mindspore import context
from mindspore.common.parameter import Parameter
from mindspore.ops import functional as F
context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target="Ascend")
context.set_context(mode=context.GRAPH_MODE, save_graphs=False)
class ForwardNet(nn.Cell):
@ -52,8 +52,10 @@ class BackwardNet(nn.Cell):
grads = self.grad(self.forward_net)(*inputs)
return grads
@pytest.mark.skip(reason="not supported side effect")
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
def test_forward():
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)
@ -62,13 +64,16 @@ def test_forward():
graph_forward_net = ForwardNet(max_cycles=3)
graph_mode_out = graph_forward_net(x, y)
# Pynative Mode
context.set_context(mode=context.PYNATIVE_MODE)
pynative_forward_net = ForwardNet(max_cycles=3)
pynative_mode_out = pynative_forward_net(x, y)
assert graph_mode_out == pynative_mode_out
# context.set_context(mode=context.PYNATIVE_MODE)
# pynative_forward_net = ForwardNet(max_cycles=3)
# pynative_mode_out = pynative_forward_net(x, y)
expect = (Tensor(np.array(9), mstype.int32), Tensor(np.array(2), mstype.int32))
assert graph_mode_out == expect
@pytest.mark.skip(reason="not supported side effect")
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
def test_backward():
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)
@ -78,8 +83,9 @@ def test_backward():
graph_backward_net = BackwardNet(graph_forward_net)
graph_mode_grads = graph_backward_net(x, y)
# Pynative Mode
context.set_context(mode=context.PYNATIVE_MODE)
pynative_forward_net = ForwardNet(max_cycles=3)
pynative_backward_net = BackwardNet(pynative_forward_net)
pynative_mode_grads = pynative_backward_net(x, y)
assert graph_mode_grads == pynative_mode_grads
# context.set_context(mode=context.PYNATIVE_MODE)
# pynative_forward_net = ForwardNet(max_cycles=3)
# pynative_backward_net = BackwardNet(pynative_forward_net)
# pynative_mode_grads = pynative_backward_net(x, y)
expect = (Tensor(np.array(9), mstype.int32), Tensor(np.array(3), mstype.int32))
assert graph_mode_grads == expect

View File

@ -13,13 +13,14 @@
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
from mindspore.common import dtype as mstype
from mindspore import nn
from mindspore import Tensor
from mindspore.ops import composite as C
from mindspore import context
context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target="Ascend")
context.set_context(mode=context.GRAPH_MODE, save_graphs=False)
class ForwardNet(nn.Cell):
@ -48,18 +49,43 @@ class BackwardNet(nn.Cell):
return grads
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_forward():
context.set_context(mode=context.GRAPH_MODE)
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)
forward_net = ForwardNet(max_cycles=3)
out = forward_net(x, y)
print("forward out:", out)
graph_out = forward_net(x, y)
context.set_context(mode=context.PYNATIVE_MODE)
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)
forward_net = ForwardNet(max_cycles=3)
pynative_out = forward_net(x, y)
assert graph_out == pynative_out
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_backward():
context.set_context(mode=context.GRAPH_MODE)
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)
forward_net = ForwardNet(max_cycles=3)
backward_net = BackwardNet(forward_net)
grads = backward_net(x, y)
print("grads:", grads)
graph_grads = backward_net(x, y)
context.set_context(mode=context.PYNATIVE_MODE)
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)
forward_net = ForwardNet(max_cycles=3)
backward_net = BackwardNet(forward_net)
pynative_grads = backward_net(x, y)
assert graph_grads == pynative_grads

View File

@ -14,6 +14,7 @@
# ============================================================================
import numpy as np
import pytest
from mindspore.common import dtype as mstype
from mindspore import nn
from mindspore import Tensor
@ -22,7 +23,7 @@ from mindspore import context
from mindspore.common.parameter import Parameter
from mindspore.ops import functional as F
context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target="Ascend")
context.set_context(mode=context.GRAPH_MODE, save_graphs=False)
class ForwardNet(nn.Cell):
@ -56,6 +57,11 @@ class BackwardNet(nn.Cell):
return grads
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_forward():
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)
@ -70,6 +76,11 @@ def test_forward():
assert graph_mode_out == pynative_mode_out
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_backward():
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)

View File

@ -14,13 +14,14 @@
# ============================================================================
import numpy as np
import pytest
from mindspore.common import dtype as mstype
from mindspore import nn
from mindspore import Tensor
from mindspore.ops import composite as C
from mindspore import context
context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target="Ascend")
context.set_context(mode=context.GRAPH_MODE, save_graphs=False)
class ForwardNet(nn.Cell):
@ -53,18 +54,43 @@ class BackwardNet(nn.Cell):
return grads
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_forward():
context.set_context(mode=context.GRAPH_MODE)
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)
forward_net = ForwardNet(max_cycles=3)
out = forward_net(x, y)
print("forward out:", out)
graph_out = forward_net(x, y)
context.set_context(mode=context.GRAPH_MODE)
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)
forward_net = ForwardNet(max_cycles=3)
pynative_out = forward_net(x, y)
assert graph_out == pynative_out
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_backward():
context.set_context(mode=context.GRAPH_MODE)
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)
forward_net = ForwardNet(max_cycles=3)
backward_net = BackwardNet(forward_net)
grads = backward_net(x, y)
print("grads:", grads)
graph_grads = backward_net(x, y)
context.set_context(mode=context.PYNATIVE_MODE)
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)
forward_net = ForwardNet(max_cycles=3)
backward_net = BackwardNet(forward_net)
pynative_grads = backward_net(x, y)
assert graph_grads == pynative_grads

View File

@ -14,13 +14,14 @@
# ============================================================================
import numpy as np
import pytest
from mindspore.common import dtype as mstype
from mindspore import nn
from mindspore import Tensor
from mindspore.ops import composite as C
from mindspore import context
context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target="Ascend")
context.set_context(mode=context.GRAPH_MODE, save_graphs=False)
class ForwardNet(nn.Cell):
@ -51,18 +52,43 @@ class BackwardNet(nn.Cell):
return grads
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_forward():
context.set_context(mode=context.GRAPH_MODE)
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)
forward_net = ForwardNet(max_cycles=3)
out = forward_net(x, y)
print("forward out:", out)
graph_out = forward_net(x, y)
context.set_context(mode=context.PYNATIVE_MODE)
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)
forward_net = ForwardNet(max_cycles=3)
pynative_out = forward_net(x, y)
assert graph_out == pynative_out
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_backward():
context.set_context(mode=context.GRAPH_MODE)
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)
forward_net = ForwardNet(max_cycles=3)
backward_net = BackwardNet(forward_net)
grads = backward_net(x, y)
print("grads:", grads)
graph_grads = backward_net(x, y)
context.set_context(mode=context.PYNATIVE_MODE)
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)
forward_net = ForwardNet(max_cycles=3)
backward_net = BackwardNet(forward_net)
pynative_grads = backward_net(x, y)
assert graph_grads == pynative_grads

View File

@ -13,6 +13,7 @@
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
from mindspore import context
from mindspore import Tensor, nn
from mindspore.common.parameter import Parameter
@ -21,9 +22,12 @@ from mindspore.ops import operations as P
from mindspore.common import dtype as mstype
grad_all = C.GradOperation(get_all=True)
context.set_context(device_target="Ascend")
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_for_in_if_01():
class ForInIfNet(nn.Cell):
def __init__(self):
@ -57,20 +61,28 @@ def test_for_in_if_01():
context.set_context(mode=context.GRAPH_MODE)
for_in_if_net = ForInIfNet()
net = GradNet(for_in_if_net)
graph_forward_res = for_in_if_net(x)
forward_net = ForInIfNet()
graph_forward_res = forward_net(x)
graph_backward_res = net(x)
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE)
for_in_if_net = ForInIfNet()
net = GradNet(for_in_if_net)
pynative_forward_res = for_in_if_net(x)
forward_net = ForInIfNet()
pynative_forward_res = forward_net(x)
pynative_backward_res = net(x)
assert graph_forward_res == pynative_forward_res
assert graph_backward_res == pynative_backward_res
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_for_in_if_02():
class ForInIfNet(nn.Cell):
def __init__(self):
@ -108,20 +120,28 @@ def test_for_in_if_02():
context.set_context(mode=context.GRAPH_MODE)
for_in_if_net = ForInIfNet()
net = GradNet(for_in_if_net)
graph_forward_res = for_in_if_net(x)
forward_net = ForInIfNet()
graph_forward_res = forward_net(x)
graph_backward_res = net(x)
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE)
for_in_if_net = ForInIfNet()
net = GradNet(for_in_if_net)
pynative_forward_res = for_in_if_net(x)
forward_net = ForInIfNet()
pynative_forward_res = forward_net(x)
pynative_backward_res = net(x)
assert graph_forward_res == pynative_forward_res
assert graph_backward_res == pynative_backward_res
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_for_in_if_03():
class ForInIfNet(nn.Cell):
def __init__(self):
@ -160,20 +180,29 @@ def test_for_in_if_03():
context.set_context(mode=context.GRAPH_MODE)
for_in_if_net = ForInIfNet()
net = GradNet(for_in_if_net)
graph_forward_res = for_in_if_net(x)
forward_net = ForInIfNet()
graph_forward_res = forward_net(x)
graph_backward_res = net(x)
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE)
for_in_if_net = ForInIfNet()
net = GradNet(for_in_if_net)
pynative_forward_res = for_in_if_net(x)
forward_net = ForInIfNet()
pynative_forward_res = forward_net(x)
pynative_backward_res = net(x)
assert graph_forward_res == pynative_forward_res
assert graph_backward_res == pynative_backward_res
@pytest.mark.skip(reason="Ascend control multi sink result error")
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_for_in_if_04():
class ForInIfNet(nn.Cell):
def __init__(self):
@ -209,20 +238,28 @@ def test_for_in_if_04():
context.set_context(mode=context.GRAPH_MODE)
for_in_if_net = ForInIfNet()
net = GradNet(for_in_if_net)
graph_forward_res = for_in_if_net(x)
forward_net = ForInIfNet()
graph_forward_res = forward_net(x)
graph_backward_res = net(x)
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE)
forward_net = ForInIfNet()
pynative_forward_res = forward_net(x)
for_in_if_net = ForInIfNet()
net = GradNet(for_in_if_net)
pynative_forward_res = for_in_if_net(x)
pynative_backward_res = net(x)
expect_backward_res = net(x)
assert graph_forward_res == pynative_forward_res
assert graph_backward_res == pynative_backward_res
assert graph_backward_res == expect_backward_res
@pytest.mark.skip(reason="Ascend control multi sink result error")
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_for_in_if_05():
class ForInIfNet(nn.Cell):
def __init__(self):
@ -260,15 +297,19 @@ def test_for_in_if_05():
context.set_context(mode=context.GRAPH_MODE)
for_in_if_net = ForInIfNet()
net = GradNet(for_in_if_net)
graph_forward_res = for_in_if_net(x)
forward_net = ForInIfNet()
graph_forward_res = forward_net(x)
graph_backward_res = net(x)
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE)
for_in_if_net = ForInIfNet()
net = GradNet(for_in_if_net)
pynative_forward_res = for_in_if_net(x)
pynative_backward_res = net(x)
for_in_if_net = ForInIfNet()
net = GradNet(for_in_if_net)
expect_backward_res = net(x)
assert graph_forward_res == pynative_forward_res
assert graph_backward_res == pynative_backward_res
assert graph_backward_res == expect_backward_res

View File

@ -22,7 +22,6 @@ from mindspore.ops import operations as P
from mindspore.common import dtype as mstype
grad_all = C.GradOperation(get_all=True)
context.set_context(device_target="Ascend")
@pytest.mark.skip(reason="not supported for in while")
def test_for_in_while_01():
@ -62,16 +61,20 @@ def test_for_in_while_01():
# graph mode
context.set_context(mode=context.GRAPH_MODE)
for_in_while_net = ForInWhileNet()
net = GradNet(for_in_while_net)
graph_forward_res = for_in_while_net(x)
graph_backward_res = net(x)
backward_net = GradNet(for_in_while_net)
forward_net = ForInWhileNet()
graph_forward_res = forward_net(x)
graph_backward_res = backward_net(x)
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE)
for_in_while_net = ForInWhileNet()
net = GradNet(for_in_while_net)
pynative_forward_res = for_in_while_net(x)
pynative_backward_res = net(x)
backward_net = GradNet(for_in_while_net)
forward_net = ForInWhileNet()
pynative_forward_res = forward_net(x)
pynative_backward_res = backward_net(x)
assert graph_forward_res == pynative_forward_res
assert graph_backward_res == pynative_backward_res

View File

@ -22,9 +22,12 @@ from mindspore.ops import operations as P
from mindspore.common import dtype as mstype
grad_all = C.GradOperation(get_all=True)
context.set_context(device_target="GPU")
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_for_in_for_01():
class ForInForNet(nn.Cell):
def __init__(self):
@ -64,14 +67,18 @@ def test_for_in_for_01():
context.set_context(mode=context.GRAPH_MODE)
for_in_for_net = ForInForNet()
net = GradNet(for_in_for_net)
graph_forward_res = for_in_for_net(x)
forward_net = ForInForNet()
graph_forward_res = forward_net(x)
graph_backward_res = net(x)
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE)
for_in_for_net = ForInForNet()
net = GradNet(for_in_for_net)
pynative_forward_res = for_in_for_net(x)
forward_net = ForInForNet()
pynative_forward_res = forward_net(x)
pynative_backward_res = net(x)
assert graph_forward_res == pynative_forward_res
@ -79,6 +86,8 @@ def test_for_in_for_01():
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_for_in_for_02():
class ForInForNet(nn.Cell):
@ -114,14 +123,18 @@ def test_for_in_for_02():
context.set_context(mode=context.GRAPH_MODE)
for_in_for_net = ForInForNet()
net = GradNet(for_in_for_net)
graph_forward_res = for_in_for_net(x)
forward_net = ForInForNet()
graph_forward_res = forward_net(x)
graph_backward_res = net(x)
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE)
for_in_for_net = ForInForNet()
net = GradNet(for_in_for_net)
pynative_forward_res = for_in_for_net(x)
forward_net = ForInForNet()
pynative_forward_res = forward_net(x)
pynative_backward_res = net(x)
assert graph_forward_res == pynative_forward_res

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pytest
from mindspore import context
from mindspore import Tensor, nn
from mindspore.ops import composite as C
@ -19,7 +20,6 @@ from mindspore.common import dtype as mstype
from mindspore.common.parameter import Parameter
grad_all = C.GradOperation(get_all=True)
context.set_context(device_target="Ascend")
class IfAfterIfNet(nn.Cell):
@ -93,6 +93,28 @@ class IfAfterIfNet3(nn.Cell):
return x
# Add a while to run with vm in ascend
class IfAfterIfNet4(nn.Cell):
def __init__(self):
super().__init__()
self.param_a = Parameter(Tensor(5, mstype.int32), name='a')
self.param_b = Parameter(Tensor(4, mstype.int32), name='b')
def construct(self, x, y):
while x < 0:
x = x + 1
out = x * y + self.func(self.param_b)
if self.param_a > self.param_b:
out += 5
return out
def func(self, x):
if self.param_a > self.param_b:
x += 5
self.param_b += 4
return x
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
@ -105,19 +127,19 @@ class GradNet(nn.Cell):
def control_flow_if_after_if(input_net, x, y):
# graph mode
context.set_context(mode=context.GRAPH_MODE)
forward_net = input_net()
net = input_net()
grad_net = GradNet(net)
forward_net = input_net()
graph_forward_res = forward_net(x, y)
graph_backward_res = grad_net(x, y)
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE)
forward_net = input_net()
net = input_net()
grad_net = GradNet(net)
forward_net = input_net()
pynative_forward_res = forward_net(x, y)
pynative_backward_res = grad_net(x, y)
@ -125,25 +147,57 @@ def control_flow_if_after_if(input_net, x, y):
assert graph_backward_res == pynative_backward_res
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_if_after_if():
x = Tensor(2, mstype.int32)
y = Tensor(5, mstype.int32)
control_flow_if_after_if(IfAfterIfNet, x, y)
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_if_after_if_01():
x = Tensor(2, mstype.int32)
y = Tensor(5, mstype.int32)
control_flow_if_after_if(IfAfterIfNet1, x, y)
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_if_after_if_02():
x = Tensor(2, mstype.int32)
y = Tensor(5, mstype.int32)
control_flow_if_after_if(IfAfterIfNet2, x, y)
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
# Now in ascend result is not correct
# @pytest.mark.platform_arm_ascend_training
# @pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_if_after_if_03():
x = Tensor(2, mstype.int32)
y = Tensor(5, mstype.int32)
control_flow_if_after_if(IfAfterIfNet3, x, y)
@pytest.mark.skip(reason="Result is not correct in multigraph sink.")
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_if_after_if_04():
x = Tensor(2, mstype.int32)
y = Tensor(5, mstype.int32)
control_flow_if_after_if(IfAfterIfNet4, x, y)

View File

@ -22,7 +22,7 @@ from mindspore import context
from mindspore.common.parameter import Parameter
from mindspore.ops import functional as F
context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target="Ascend")
context.set_context(mode=context.GRAPH_MODE, save_graphs=False)
class ForwardNet(nn.Cell):
@ -43,7 +43,6 @@ class ForwardNet(nn.Cell):
i = i + 1
if out >= 20:
F.assign(self.weight, out)
self.weight = out
out = out - 20
return out, self.weight
@ -58,7 +57,11 @@ class BackwardNet(nn.Cell):
grads = self.grad(self.forward_net)(*inputs)
return grads
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_forward():
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)
@ -73,8 +76,11 @@ def test_forward():
pynative_mode_out = pynative_forward_net(x, y)
assert graph_mode_out == pynative_mode_out
@pytest.mark.skip(reason="not supported side effect")
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_backward():
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)

View File

@ -13,6 +13,7 @@
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
from mindspore import context
from mindspore import Tensor, nn
from mindspore.common.parameter import Parameter
@ -21,9 +22,12 @@ from mindspore.ops import operations as P
from mindspore.common import dtype as mstype
grad_all = C.GradOperation(get_all=True)
context.set_context(device_target="Ascend")
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_if_after_for_01():
class IfAfterForNet(nn.Cell):
def __init__(self):
@ -64,20 +68,28 @@ def test_if_after_for_01():
context.set_context(mode=context.GRAPH_MODE)
if_after_for_net = IfAfterForNet()
net = GradNet(if_after_for_net)
graph_forward_res = if_after_for_net(x)
forward_net = IfAfterForNet()
graph_forward_res = forward_net(x)
graph_backward_res = net(x)
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE)
if_after_for_net = IfAfterForNet()
net = GradNet(if_after_for_net)
pynative_forward_res = if_after_for_net(x)
forward_net = IfAfterForNet()
pynative_forward_res = forward_net(x)
pynative_backward_res = net(x)
assert graph_forward_res == pynative_forward_res
assert graph_backward_res == pynative_backward_res
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_if_after_for_02():
class IfAfterForNet(nn.Cell):
def __init__(self):
@ -118,14 +130,18 @@ def test_if_after_for_02():
context.set_context(mode=context.GRAPH_MODE)
if_after_for_net = IfAfterForNet()
net = GradNet(if_after_for_net)
graph_forward_res = if_after_for_net(x)
forward_net = IfAfterForNet()
graph_forward_res = forward_net(x)
graph_backward_res = net(x)
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE)
if_after_for_net = IfAfterForNet()
net = GradNet(if_after_for_net)
pynative_forward_res = if_after_for_net(x)
forward_net = IfAfterForNet()
pynative_forward_res = forward_net(x)
pynative_backward_res = net(x)
assert graph_forward_res == pynative_forward_res

View File

@ -20,7 +20,6 @@ from mindspore.common import dtype as mstype
from mindspore.common.parameter import Parameter
grad_all = C.GradOperation(get_all=True)
context.set_context(device_target="GPU")
class IfAfterIfInIfNet(nn.Cell):
@ -133,14 +132,18 @@ def control_flow_if_after_if_in_if(input_net, x):
context.set_context(mode=context.GRAPH_MODE)
net = input_net()
grad_net = GradNet(net)
graph_forward_res = net(x)
forward_net = input_net()
graph_forward_res = forward_net(x)
graph_backward_res = grad_net(x)
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE)
net = input_net()
grad_net = GradNet(net)
pynative_forward_res = net(x)
forward_net = input_net()
pynative_forward_res = forward_net(x)
pynative_backward_res = grad_net(x)
assert graph_forward_res == pynative_forward_res

View File

@ -13,6 +13,7 @@
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
from mindspore.common import dtype as mstype
from mindspore import nn
from mindspore import Tensor
@ -20,7 +21,7 @@ from mindspore.ops import composite as C
from mindspore import context
from mindspore.common.parameter import Parameter
context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target="Ascend")
context.set_context(mode=context.GRAPH_MODE, save_graphs=False)
class ForwardNet(nn.Cell):
@ -56,7 +57,11 @@ class BackwardNet(nn.Cell):
grads = self.grad(self.forward_net)(*inputs)
return grads
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_forward():
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)
@ -70,7 +75,11 @@ def test_forward():
pynative_mode_out = pynative_forward_net(x, y)
assert graph_mode_out == pynative_mode_out
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_backward():
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pytest
from mindspore import context
from mindspore import Tensor, nn
from mindspore.ops import composite as C
@ -19,7 +20,6 @@ from mindspore.common import dtype as mstype
from mindspore.common.parameter import Parameter
grad_all = C.GradOperation(get_all=True)
context.set_context(device_target="Ascend")
class IfAfterIfInForNet(nn.Cell):
@ -124,35 +124,56 @@ def control_flow_if_after_if_in_for(input_net, x):
context.set_context(mode=context.GRAPH_MODE)
net = input_net()
grad_net = GradNet(net)
graph_forward_res = net(x)
forward_net = input_net()
graph_forward_res = forward_net(x)
graph_backward_res = grad_net(x)
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE)
net = input_net()
grad_net = GradNet(net)
pynative_forward_res = net(x)
forward_net = input_net()
pynative_forward_res = forward_net(x)
pynative_backward_res = grad_net(x)
assert graph_forward_res == pynative_forward_res
assert graph_backward_res == pynative_backward_res
@pytest.mark.skip(reason="ME EvalCNode error")
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_if_after_if_in_for():
x = Tensor(2, mstype.int32)
control_flow_if_after_if_in_for(IfAfterIfInForNet, x)
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_if_after_if_in_for_01():
x = Tensor(2, mstype.int32)
control_flow_if_after_if_in_for(IfAfterIfInForNet1, x)
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_if_after_if_in_for_02():
x = Tensor(2, mstype.int32)
control_flow_if_after_if_in_for(IfAfterIfInForNet2, x)
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_if_after_if_in_for_03():
x = Tensor(2, mstype.int32)
control_flow_if_after_if_in_for(IfAfterIfInForNet3, x)

View File

@ -14,6 +14,7 @@
# ============================================================================
import numpy as np
import pytest
from mindspore.common import dtype as mstype
from mindspore import nn
from mindspore import Tensor
@ -21,7 +22,7 @@ from mindspore.ops import composite as C
from mindspore import context
from mindspore.common.parameter import Parameter
context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target="Ascend")
context.set_context(mode=context.GRAPH_MODE, save_graphs=False)
class ForwardNet(nn.Cell):
@ -56,7 +57,11 @@ class BackwardNet(nn.Cell):
grads = self.grad(self.forward_net)(*inputs)
return grads
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_forward():
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)
@ -70,7 +75,11 @@ def test_forward():
pynative_mode_out = pynative_forward_net(x, y)
assert graph_mode_out == pynative_mode_out
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_backward():
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)

View File

@ -19,10 +19,11 @@ from mindspore.common import dtype as mstype
from mindspore import nn
from mindspore import Tensor
from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore import context
from mindspore.common.parameter import Parameter
context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target="GPU")
context.set_context(mode=context.GRAPH_MODE, save_graphs=False)
class ForwardNet(nn.Cell):
@ -37,14 +38,14 @@ class ForwardNet(nn.Cell):
out = self.zero
i = self.i
while x < y:
self.weight = x
F.assign(self.weight, out)
while i < self.max_cycles:
out = x * y + out
i = i + 1
self.weight = i
F.assign(self.weight, i)
x = x + 1
if out < 20:
self.weight = out
F.assign(self.weight, out)
out = out - 20
return out, self.weight
@ -59,7 +60,11 @@ class BackwardNet(nn.Cell):
grads = self.grad(self.forward_net)(*inputs)
return grads
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_forward():
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)
@ -74,7 +79,11 @@ def test_forward():
assert graph_mode_out == pynative_mode_out
@pytest.mark.skip(reason="not supported side effect")
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_backward():
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)
@ -126,6 +135,8 @@ class BackwardNetNoAssign(nn.Cell):
# This test case has a problem of evaluator endless loop.
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_backward_no_assign():
x = Tensor(np.array(1), mstype.int32)

View File

@ -19,10 +19,11 @@ from mindspore.common import dtype as mstype
from mindspore import nn
from mindspore import Tensor
from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore import context
from mindspore.common.parameter import Parameter
context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target="GPU")
context.set_context(mode=context.GRAPH_MODE, save_graphs=False)
class ForwardNet(nn.Cell):
@ -38,9 +39,9 @@ class ForwardNet(nn.Cell):
while x < y:
out = x * y + out
x = x + 1
self.weight = x
F.assign(self.weight, x)
if out > 20:
self.weight = out
F.assign(self.weight, out)
out = out - 20
return out, self.weight
@ -55,7 +56,11 @@ class BackwardNet(nn.Cell):
grads = self.grad(self.forward_net)(*inputs)
return grads
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_forward():
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)
@ -70,7 +75,11 @@ def test_forward():
assert graph_mode_out == pynative_mode_out
@pytest.mark.skip(reason="not supported side effect")
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_backward():
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)
@ -84,6 +93,7 @@ def test_backward():
pynative_forward_net = ForwardNet(max_cycles=3)
pynative_backward_net = BackwardNet(pynative_forward_net)
pynative_mode_grads = pynative_backward_net(x, y)
#expect = (Tensor(np.array(6), mstype.int32), Tensor(np.array(3), mstype.int32))
assert graph_mode_grads == pynative_mode_grads
@ -100,7 +110,7 @@ class ForwardNetNoAssign(nn.Cell):
while x < y:
out = x * y + out
x = x + 1
#self.weight = x
# self.weight = x
if out > 20:
self.weight = out
out = out - 20
@ -119,6 +129,8 @@ class BackwardNetNoAssign(nn.Cell):
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_backward_no_assign():
x = Tensor(np.array(1), mstype.int32)

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pytest
from mindspore import context
from mindspore import Tensor, nn
from mindspore.ops import composite as C
@ -19,8 +20,11 @@ from mindspore.common import dtype as mstype
from mindspore.common.parameter import Parameter
grad_all = C.GradOperation(get_all=True)
context.set_context(device_target="Ascend")
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_if_after_for_in_if():
class IfAfterForInIfNet(nn.Cell):
def __init__(self):
@ -53,14 +57,18 @@ def test_if_after_for_in_if():
context.set_context(mode=context.GRAPH_MODE)
if_after_for_in_if_net = IfAfterForInIfNet()
net = GradNet(if_after_for_in_if_net)
graph_forward_res = if_after_for_in_if_net(x)
forward_net = IfAfterForInIfNet()
graph_forward_res = forward_net(x)
graph_backward_res = net(x)
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE)
if_after_for_in_if_net = IfAfterForInIfNet()
net = GradNet(if_after_for_in_if_net)
pynative_forward_res = if_after_for_in_if_net(x)
forward_net = IfAfterForInIfNet()
pynative_forward_res = forward_net(x)
pynative_backward_res = net(x)
assert graph_forward_res == pynative_forward_res

View File

@ -20,7 +20,6 @@ from mindspore.common import dtype as mstype
from mindspore.common.parameter import Parameter
grad_all = C.GradOperation(get_all=True)
context.set_context(device_target="Ascend")
@pytest.mark.skip(reason="not supported for in while")
def test_if_after_for_in_while():
@ -55,14 +54,18 @@ def test_if_after_for_in_while():
context.set_context(mode=context.GRAPH_MODE)
if_after_for_in_while_net = IfAfterForInWhileNet()
net = GradNet(if_after_for_in_while_net)
graph_forward_res = if_after_for_in_while_net(x)
forward_net = IfAfterForInWhileNet()
graph_forward_res = forward_net(x)
graph_backward_res = net(x)
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE)
if_after_for_in_while_net = IfAfterForInWhileNet()
net = GradNet(if_after_for_in_while_net)
pynative_forward_res = if_after_for_in_while_net(x)
forward_net = IfAfterForInWhileNet()
pynative_forward_res = forward_net(x)
pynative_backward_res = net(x)
assert graph_forward_res == pynative_forward_res

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pytest
from mindspore import context
from mindspore import Tensor, nn
from mindspore.ops import composite as C
@ -19,8 +20,11 @@ from mindspore.common import dtype as mstype
from mindspore.common.parameter import Parameter
grad_all = C.GradOperation(get_all=True)
context.set_context(device_target="Ascend")
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_if_after_for_in_for():
class IfAfterForInForNet(nn.Cell):
def __init__(self):
@ -53,14 +57,18 @@ def test_if_after_for_in_for():
context.set_context(mode=context.GRAPH_MODE)
if_after_for_in_for_net = IfAfterForInForNet()
net = GradNet(if_after_for_in_for_net)
graph_forward_res = if_after_for_in_for_net(x)
forward_net = IfAfterForInForNet()
graph_forward_res = forward_net(x)
graph_backward_res = net(x)
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE)
if_after_for_in_for_net = IfAfterForInForNet()
net = GradNet(if_after_for_in_for_net)
pynative_forward_res = if_after_for_in_for_net(x)
forward_net = IfAfterForInForNet()
pynative_forward_res = forward_net(x)
pynative_backward_res = net(x)
assert graph_forward_res == pynative_forward_res

View File

@ -13,6 +13,7 @@
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
from mindspore.common import dtype as mstype
from mindspore import nn
from mindspore import Tensor
@ -20,7 +21,7 @@ from mindspore.ops import composite as C
from mindspore import context
from mindspore.common.parameter import Parameter
context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target="Ascend")
context.set_context(mode=context.GRAPH_MODE, save_graphs=False)
class ForwardNet(nn.Cell):
@ -54,7 +55,11 @@ class BackwardNet(nn.Cell):
grads = self.grad(self.forward_net)(*inputs)
return grads
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_forward():
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)
@ -68,7 +73,11 @@ def test_forward():
pynative_mode_out = pynative_forward_net(x, y)
assert graph_mode_out == pynative_mode_out
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_backward():
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)

View File

@ -14,13 +14,14 @@
# ============================================================================
import numpy as np
import pytest
from mindspore.common import dtype as mstype
from mindspore import nn
from mindspore import Tensor
from mindspore.ops import composite as C
from mindspore import context
context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target="Ascend")
context.set_context(mode=context.GRAPH_MODE, save_graphs=False)
class ForwardNet(nn.Cell):
@ -52,18 +53,43 @@ class BackwardNet(nn.Cell):
return grads
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_forward():
context.set_context(mode=context.GRAPH_MODE)
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)
forward_net = ForwardNet(max_cycles=3)
out = forward_net(x, y)
print("forward out:", out)
graph_out = forward_net(x, y)
context.set_context(mode=context.PYNATIVE_MODE)
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)
forward_net = ForwardNet(max_cycles=3)
pynative_out = forward_net(x, y)
assert graph_out == pynative_out
@pytest.mark.skip(reason="Ascend kernel compiler error!")
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_backward():
context.set_context(mode=context.GRAPH_MODE)
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)
forward_net = ForwardNet(max_cycles=3)
backward_net = BackwardNet(forward_net)
grads = backward_net(x, y)
print("grads:", grads)
graph_grads = backward_net(x, y)
context.set_context(mode=context.PYNATIVE_MODE)
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)
forward_net = ForwardNet(max_cycles=3)
backward_net = BackwardNet(forward_net)
pynative_grads = backward_net(x, y)
assert graph_grads == pynative_grads

View File

@ -14,13 +14,14 @@
# ============================================================================
import numpy as np
import pytest
from mindspore.common import dtype as mstype
from mindspore import nn
from mindspore import Tensor
from mindspore.ops import composite as C
from mindspore import context
context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target="Ascend")
context.set_context(mode=context.GRAPH_MODE, save_graphs=False)
class ForwardNet(nn.Cell):
@ -54,18 +55,43 @@ class BackwardNet(nn.Cell):
return grads
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_forward():
context.set_context(mode=context.GRAPH_MODE)
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)
forward_net = ForwardNet(max_cycles=3)
out = forward_net(x, y)
print("forward out:", out)
graph_out = forward_net(x, y)
context.set_context(mode=context.PYNATIVE_MODE)
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)
forward_net = ForwardNet(max_cycles=3)
pynative_out = forward_net(x, y)
assert graph_out == pynative_out
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_backward():
context.set_context(mode=context.GRAPH_MODE)
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)
forward_net = ForwardNet(max_cycles=3)
backward_net = BackwardNet(forward_net)
grads = backward_net(x, y)
print("grads:", grads)
graph_grads = backward_net(x, y)
context.set_context(mode=context.PYNATIVE_MODE)
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)
forward_net = ForwardNet(max_cycles=3)
backward_net = BackwardNet(forward_net)
pynative_grads = backward_net(x, y)
assert graph_grads == pynative_grads

View File

@ -13,6 +13,7 @@
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
from mindspore.common import dtype as mstype
from mindspore import nn
from mindspore import Tensor
@ -20,7 +21,7 @@ from mindspore.ops import composite as C
from mindspore import context
from mindspore.common.parameter import Parameter
context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target="Ascend")
context.set_context(mode=context.GRAPH_MODE, save_graphs=False)
class ForwardNet(nn.Cell):
@ -55,7 +56,11 @@ class BackwardNet(nn.Cell):
grads = self.grad(self.forward_net)(*inputs)
return grads
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_forward():
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)
@ -69,7 +74,11 @@ def test_forward():
pynative_mode_out = pynative_forward_net(x, y)
assert graph_mode_out == pynative_mode_out
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_backward():
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)

View File

@ -13,6 +13,7 @@
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
from mindspore.common import dtype as mstype
from mindspore import nn
from mindspore import Tensor
@ -20,7 +21,7 @@ from mindspore.ops import composite as C
from mindspore import context
from mindspore.common.parameter import Parameter
context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target="Ascend")
context.set_context(mode=context.GRAPH_MODE, save_graphs=False)
class ForwardNet(nn.Cell):
@ -58,7 +59,11 @@ class BackwardNet(nn.Cell):
grads = self.grad(self.forward_net)(*inputs)
return grads
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_forward():
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)
@ -72,7 +77,11 @@ def test_forward():
pynative_mode_out = pynative_forward_net(x, y)
assert graph_mode_out == pynative_mode_out
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_backward():
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)

View File

@ -13,6 +13,7 @@
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
from mindspore.common import dtype as mstype
from mindspore import nn
from mindspore import Tensor
@ -20,7 +21,7 @@ from mindspore.ops import composite as C
from mindspore import context
from mindspore.common.parameter import Parameter
context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target="Ascend")
context.set_context(mode=context.GRAPH_MODE, save_graphs=False)
class ForwardNet(nn.Cell):
@ -55,7 +56,11 @@ class BackwardNet(nn.Cell):
grads = self.grad(self.forward_net)(*inputs)
return grads
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_forward():
x = Tensor(np.array(3), mstype.int32)
y = Tensor(np.array(5), mstype.int32)
@ -69,7 +74,11 @@ def test_forward():
pynative_mode_out = pynative_forward_net(x, y)
assert graph_mode_out == pynative_mode_out
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_backward():
x = Tensor(np.array(3), mstype.int32)
y = Tensor(np.array(5), mstype.int32)

View File

@ -14,6 +14,7 @@
# ============================================================================
import numpy as np
import pytest
from mindspore.common import dtype as mstype
from mindspore import nn
from mindspore import Tensor
@ -21,7 +22,7 @@ from mindspore.ops import composite as C
from mindspore import context
from mindspore.common.parameter import Parameter
context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target="Ascend")
context.set_context(mode=context.GRAPH_MODE, save_graphs=False)
class ForwardNet(nn.Cell):
@ -56,7 +57,11 @@ class BackwardNet(nn.Cell):
grads = self.grad(self.forward_net)(*inputs)
return grads
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_forward():
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)
@ -69,7 +74,11 @@ def test_forward():
pynative_mode_out = forward_net(x, y)
assert graph_mode_out == pynative_mode_out
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_backward():
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)

View File

@ -14,13 +14,14 @@
# ============================================================================
import numpy as np
import pytest
from mindspore.common import dtype as mstype
from mindspore import nn
from mindspore import Tensor
from mindspore.ops import composite as C
from mindspore import context
context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target="Ascend")
context.set_context(mode=context.GRAPH_MODE, save_graphs=False)
class ForwardNet(nn.Cell):
@ -57,18 +58,43 @@ class BackwardNet(nn.Cell):
return grads
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_forward():
context.set_context(mode=context.GRAPH_MODE)
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)
forward_net = ForwardNet(max_cycles=3)
out = forward_net(x, y)
print("forward out:", out)
graph_out = forward_net(x, y)
context.set_context(mode=context.PYNATIVE_MODE)
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)
forward_net = ForwardNet(max_cycles=3)
pynative_out = forward_net(x, y)
assert graph_out == pynative_out
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_backward():
context.set_context(mode=context.GRAPH_MODE)
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)
forward_net = ForwardNet(max_cycles=3)
backward_net = BackwardNet(forward_net)
grads = backward_net(x, y)
print("grads:", grads)
graph_grads = backward_net(x, y)
context.set_context(mode=context.PYNATIVE_MODE)
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)
forward_net = ForwardNet(max_cycles=3)
backward_net = BackwardNet(forward_net)
pynative_grads = backward_net(x, y)
assert graph_grads == pynative_grads

View File

@ -14,13 +14,14 @@
# ============================================================================
import numpy as np
import pytest
from mindspore.common import dtype as mstype
from mindspore import nn
from mindspore import Tensor
from mindspore.ops import composite as C
from mindspore import context
context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target="Ascend")
context.set_context(mode=context.GRAPH_MODE, save_graphs=False)
class ForwardNet(nn.Cell):
@ -55,18 +56,43 @@ class BackwardNet(nn.Cell):
return grads
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_forward():
context.set_context(mode=context.GRAPH_MODE)
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)
forward_net = ForwardNet(max_cycles=3)
out = forward_net(x, y)
print("forward out:", out)
graph_out = forward_net(x, y)
context.set_context(mode=context.PYNATIVE_MODE)
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)
forward_net = ForwardNet(max_cycles=3)
pynative_out = forward_net(x, y)
assert graph_out == pynative_out
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_backward():
context.set_context(mode=context.GRAPH_MODE)
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)
forward_net = ForwardNet(max_cycles=3)
backward_net = BackwardNet(forward_net)
grads = backward_net(x, y)
print("grads:", grads)
graph_grads = backward_net(x, y)
context.set_context(mode=context.PYNATIVE_MODE)
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)
forward_net = ForwardNet(max_cycles=3)
backward_net = BackwardNet(forward_net)
pynative_grads = backward_net(x, y)
assert graph_grads == pynative_grads

View File

@ -14,6 +14,7 @@
# ============================================================================
import numpy as np
import pytest
from mindspore.common import dtype as mstype
from mindspore import nn
from mindspore import Tensor
@ -21,7 +22,7 @@ from mindspore.ops import composite as C
from mindspore import context
from mindspore.common.parameter import Parameter
context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target="Ascend")
context.set_context(mode=context.GRAPH_MODE, save_graphs=False)
class ForwardNet(nn.Cell):
@ -53,7 +54,11 @@ class BackwardNet(nn.Cell):
grads = self.grad(self.forward_net)(*inputs)
return grads
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_forward():
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)
@ -67,7 +72,11 @@ def test_forward():
pynative_mode_out = pynative_forward_net(x, y)
assert graph_mode_out == pynative_mode_out
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_backward():
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)

View File

@ -21,7 +21,7 @@ from mindspore import Tensor
from mindspore.ops import composite as C
from mindspore import context
context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target="Ascend")
context.set_context(mode=context.GRAPH_MODE, save_graphs=False)
class ForwardNet(nn.Cell):

View File

@ -14,13 +14,14 @@
# ============================================================================
import numpy as np
import pytest
from mindspore.common import dtype as mstype
from mindspore import nn
from mindspore import Tensor
from mindspore.ops import composite as C
from mindspore import context
context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target="Ascend")
context.set_context(mode=context.GRAPH_MODE, save_graphs=False)
class ForwardNet(nn.Cell):
@ -54,18 +55,43 @@ class BackwardNet(nn.Cell):
return grads
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_forward():
context.set_context(mode=context.GRAPH_MODE)
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)
forward_net = ForwardNet(max_cycles=3)
out = forward_net(x, y)
print("forward out:", out)
graph_out = forward_net(x, y)
context.set_context(mode=context.PYNATIVE_MODE)
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)
forward_net = ForwardNet(max_cycles=3)
pynative_out = forward_net(x, y)
assert graph_out == pynative_out
@pytest.mark.skip(reason="Ascend kernel compiler error!")
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_backward():
context.set_context(mode=context.GRAPH_MODE)
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)
forward_net = ForwardNet(max_cycles=3)
backward_net = BackwardNet(forward_net)
grads = backward_net(x, y)
print("grads:", grads)
graph_grads = backward_net(x, y)
context.set_context(mode=context.PYNATIVE_MODE)
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)
forward_net = ForwardNet(max_cycles=3)
backward_net = BackwardNet(forward_net)
pynative_grads = backward_net(x, y)
assert graph_grads == pynative_grads

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pytest
from mindspore import context
from mindspore import Tensor, nn
from mindspore.ops import composite as C
@ -19,8 +20,11 @@ from mindspore.common import dtype as mstype
from mindspore.common.parameter import Parameter
grad_all = C.GradOperation(get_all=True)
context.set_context(device_target="Ascend")
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_for_after_if():
class ForAfterIfNet(nn.Cell):
def __init__(self):
@ -52,14 +56,18 @@ def test_for_after_if():
context.set_context(mode=context.GRAPH_MODE)
for_after_if_net = ForAfterIfNet()
net = GradNet(for_after_if_net)
graph_forward_res = for_after_if_net(x)
forward_net = ForAfterIfNet()
graph_forward_res = forward_net(x)
graph_backward_res = net(x)
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE)
for_after_if_net = ForAfterIfNet()
net = GradNet(for_after_if_net)
pynative_forward_res = for_after_if_net(x)
forward_net = ForAfterIfNet()
pynative_forward_res = forward_net(x)
pynative_backward_res = net(x)
assert graph_forward_res == pynative_forward_res

View File

@ -14,13 +14,14 @@
# ============================================================================
import numpy as np
import pytest
from mindspore.common import dtype as mstype
from mindspore import nn
from mindspore import Tensor
from mindspore.ops import composite as C
from mindspore import context
context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target="Ascend")
context.set_context(mode=context.GRAPH_MODE, save_graphs=False)
class ForwardNet(nn.Cell):
@ -52,18 +53,43 @@ class BackwardNet(nn.Cell):
return grads
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_forward():
context.set_context(mode=context.GRAPH_MODE)
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)
forward_net = ForwardNet(max_cycles=3)
out = forward_net(x, y)
print("forward out:", out)
graph_out = forward_net(x, y)
context.set_context(mode=context.PYNATIVE_MODE)
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)
forward_net = ForwardNet(max_cycles=3)
pynative_out = forward_net(x, y)
assert graph_out == pynative_out
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_backward():
context.set_context(mode=context.GRAPH_MODE)
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)
forward_net = ForwardNet(max_cycles=3)
backward_net = BackwardNet(forward_net)
grads = backward_net(x, y)
print("grads:", grads)
graph_grads = backward_net(x, y)
context.set_context(mode=context.PYNATIVE_MODE)
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)
forward_net = ForwardNet(max_cycles=3)
backward_net = BackwardNet(forward_net)
pynative_grads = backward_net(x, y)
assert graph_grads == pynative_grads

View File

@ -13,6 +13,7 @@
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
from mindspore import context
from mindspore import Tensor, nn
from mindspore.common.parameter import Parameter
@ -21,8 +22,11 @@ from mindspore.ops import operations as P
from mindspore.common import dtype as mstype
grad_all = C.GradOperation(get_all=True)
context.set_context(device_target="Ascend")
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_for_after_for_01():
class ForAfterForNet(nn.Cell):
def __init__(self):
@ -65,20 +69,28 @@ def test_for_after_for_01():
context.set_context(mode=context.GRAPH_MODE)
for_after_for_net = ForAfterForNet()
net = GradNet(for_after_for_net)
graph_forward_res = for_after_for_net(x)
forward_net = ForAfterForNet()
graph_forward_res = forward_net(x)
graph_backward_res = net(x)
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE)
for_after_for_net = ForAfterForNet()
net = GradNet(for_after_for_net)
pynative_forward_res = for_after_for_net(x)
forward_net = ForAfterForNet()
pynative_forward_res = forward_net(x)
pynative_backward_res = net(x)
assert graph_forward_res == pynative_forward_res
assert graph_backward_res == pynative_backward_res
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_for_after_for_02():
class ForAfterForNet(nn.Cell):
def __init__(self):
@ -118,14 +130,18 @@ def test_for_after_for_02():
context.set_context(mode=context.GRAPH_MODE)
for_after_for_net = ForAfterForNet()
net = GradNet(for_after_for_net)
graph_forward_res = for_after_for_net(x)
forward_net = ForAfterForNet()
graph_forward_res = forward_net(x)
graph_backward_res = net(x)
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE)
for_after_for_net = ForAfterForNet()
net = GradNet(for_after_for_net)
pynative_forward_res = for_after_for_net(x)
forward_net = ForAfterForNet()
pynative_forward_res = forward_net(x)
pynative_backward_res = net(x)
assert graph_forward_res == pynative_forward_res

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pytest
from mindspore import context
from mindspore import Tensor, nn
from mindspore.ops import composite as C
@ -19,8 +20,11 @@ from mindspore.common import dtype as mstype
from mindspore.common.parameter import Parameter
grad_all = C.GradOperation(get_all=True)
context.set_context(device_target="Ascend")
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_for_after_if_in_if():
class ForAfterIfInIfNet(nn.Cell):
def __init__(self):
@ -55,14 +59,18 @@ def test_for_after_if_in_if():
context.set_context(mode=context.GRAPH_MODE)
for_after_if_in_if_net = ForAfterIfInIfNet()
net = GradNet(for_after_if_in_if_net)
graph_forward_res = for_after_if_in_if_net(x)
forward_net = ForAfterIfInIfNet()
graph_forward_res = forward_net(x)
graph_backward_res = net(x)
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE)
for_after_if_in_if_net = ForAfterIfInIfNet()
net = GradNet(for_after_if_in_if_net)
pynative_forward_res = for_after_if_in_if_net(x)
forward_net = ForAfterIfInIfNet()
pynative_forward_res = forward_net(x)
pynative_backward_res = net(x)
assert graph_forward_res == pynative_forward_res

View File

@ -14,13 +14,14 @@
# ============================================================================
import numpy as np
import pytest
from mindspore.common import dtype as mstype
from mindspore import nn
from mindspore import Tensor
from mindspore.ops import composite as C
from mindspore import context
context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target="Ascend")
context.set_context(mode=context.GRAPH_MODE, save_graphs=False)
class ForwardNet(nn.Cell):
@ -56,18 +57,43 @@ class BackwardNet(nn.Cell):
return grads
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_forward():
context.set_context(mode=context.GRAPH_MODE)
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)
forward_net = ForwardNet(max_cycles=3)
out = forward_net(x, y)
print("forward out:", out)
graph_out = forward_net(x, y)
context.set_context(mode=context.PYNATIVE_MODE)
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)
forward_net = ForwardNet(max_cycles=3)
pynative_out = forward_net(x, y)
assert graph_out == pynative_out
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_backward():
context.set_context(mode=context.GRAPH_MODE)
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)
forward_net = ForwardNet(max_cycles=3)
backward_net = BackwardNet(forward_net)
grads = backward_net(x, y)
print("grads:", grads)
graph_grads = backward_net(x, y)
context.set_context(mode=context.PYNATIVE_MODE)
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)
forward_net = ForwardNet(max_cycles=3)
backward_net = BackwardNet(forward_net)
pynative_grads = backward_net(x, y)
assert graph_grads == pynative_grads

View File

@ -13,6 +13,7 @@
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
from mindspore import context
from mindspore import Tensor, nn
from mindspore.common.parameter import Parameter
@ -21,8 +22,11 @@ from mindspore.ops import operations as P
from mindspore.common import dtype as mstype
grad_all = C.GradOperation(get_all=True)
context.set_context(device_target="Ascend")
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_for_after_while_in_if_01():
class ForAfterWhileInIfNet(nn.Cell):
def __init__(self):
@ -78,20 +82,28 @@ def test_for_after_while_in_if_01():
context.set_context(mode=context.GRAPH_MODE)
for_after_while_in_if_net = ForAfterWhileInIfNet()
net = GradNet(for_after_while_in_if_net)
graph_forward_res = for_after_while_in_if_net(x, y)
forward_net = ForAfterWhileInIfNet()
graph_forward_res = forward_net(x, y)
graph_backward_res = net(x, y)
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE)
for_after_while_in_if_net = ForAfterWhileInIfNet()
net = GradNet(for_after_while_in_if_net)
pynative_forward_res = for_after_while_in_if_net(x, y)
forward_net = ForAfterWhileInIfNet()
pynative_forward_res = forward_net(x, y)
pynative_backward_res = net(x, y)
assert graph_forward_res == pynative_forward_res
assert graph_backward_res == pynative_backward_res
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_for_after_while_in_if_02():
class ForAfterWhileInIfNet(nn.Cell):
def __init__(self):
@ -138,14 +150,18 @@ def test_for_after_while_in_if_02():
context.set_context(mode=context.GRAPH_MODE)
for_after_while_in_if_net = ForAfterWhileInIfNet()
net = GradNet(for_after_while_in_if_net)
graph_forward_res = for_after_while_in_if_net(x, y)
forward_net = ForAfterWhileInIfNet()
graph_forward_res = forward_net(x, y)
graph_backward_res = net(x, y)
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE)
for_after_while_in_if_net = ForAfterWhileInIfNet()
net = GradNet(for_after_while_in_if_net)
pynative_forward_res = for_after_while_in_if_net(x, y)
forward_net = ForAfterWhileInIfNet()
pynative_forward_res = forward_net(x, y)
pynative_backward_res = net(x, y)
assert graph_forward_res == pynative_forward_res

View File

@ -14,13 +14,14 @@
# ============================================================================
import numpy as np
import pytest
from mindspore.common import dtype as mstype
from mindspore import nn
from mindspore import Tensor
from mindspore.ops import composite as C
from mindspore import context
context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target="Ascend")
context.set_context(mode=context.GRAPH_MODE, save_graphs=False)
class ForwardNet(nn.Cell):
@ -56,19 +57,42 @@ class BackwardNet(nn.Cell):
grads = self.grad(self.forward_net)(*inputs)
return grads
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_forward():
context.set_context(mode=context.GRAPH_MODE)
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)
forward_net = ForwardNet(max_cycles=3)
out = forward_net(x, y)
print("forward out:", out)
graph_out = forward_net(x, y)
context.set_context(mode=context.PYNATIVE_MODE)
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)
forward_net = ForwardNet(max_cycles=3)
pynative_out = forward_net(x, y)
assert graph_out == pynative_out
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_backward():
context.set_context(mode=context.GRAPH_MODE)
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)
forward_net = ForwardNet(max_cycles=3)
backward_net = BackwardNet(forward_net)
grads = backward_net(x, y)
print("grads:", grads)
graph_grads = backward_net(x, y)
context.set_context(mode=context.PYNATIVE_MODE)
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)
forward_net = ForwardNet(max_cycles=3)
backward_net = BackwardNet(forward_net)
pynative_grads = backward_net(x, y)
assert graph_grads == pynative_grads

View File

@ -13,6 +13,7 @@
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
from mindspore import context
from mindspore import Tensor, nn
from mindspore.common.parameter import Parameter
@ -21,8 +22,11 @@ from mindspore.ops import operations as P
from mindspore.common import dtype as mstype
grad_all = C.GradOperation(get_all=True)
context.set_context(device_target="Ascend")
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_for_after_while_in_for_01():
class ForAfterWhileInForNet(nn.Cell):
def __init__(self):
@ -79,20 +83,28 @@ def test_for_after_while_in_for_01():
context.set_context(mode=context.GRAPH_MODE)
for_after_while_in_for_net = ForAfterWhileInForNet()
net = GradNet(for_after_while_in_for_net)
graph_forward_res = for_after_while_in_for_net(x, y)
forward_net = ForAfterWhileInForNet()
graph_forward_res = forward_net(x, y)
graph_backward_res = net(x, y)
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE)
for_after_while_in_for_net = ForAfterWhileInForNet()
net = GradNet(for_after_while_in_for_net)
pynative_forward_res = for_after_while_in_for_net(x, y)
forward_net = ForAfterWhileInForNet()
pynative_forward_res = forward_net(x, y)
pynative_backward_res = net(x, y)
assert graph_forward_res == pynative_forward_res
assert graph_backward_res == pynative_backward_res
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_for_after_while_in_for_02():
class ForAfterWhileInForNet(nn.Cell):
def __init__(self):
@ -139,14 +151,18 @@ def test_for_after_while_in_for_02():
context.set_context(mode=context.GRAPH_MODE)
for_after_while_in_for_net = ForAfterWhileInForNet()
net = GradNet(for_after_while_in_for_net)
graph_forward_res = for_after_while_in_for_net(x, y)
forward_net = ForAfterWhileInForNet()
graph_forward_res = forward_net(x, y)
graph_backward_res = net(x, y)
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE)
for_after_while_in_for_net = ForAfterWhileInForNet()
net = GradNet(for_after_while_in_for_net)
pynative_forward_res = for_after_while_in_for_net(x, y)
forward_net = ForAfterWhileInForNet()
pynative_forward_res = forward_net(x, y)
pynative_backward_res = net(x, y)
assert graph_forward_res == pynative_forward_res

View File

@ -20,9 +20,13 @@ from mindspore.common import dtype as mstype
from mindspore.common.parameter import Parameter
grad_all = C.GradOperation(get_all=True)
context.set_context(device_target="Ascend")
@pytest.mark.skip(reason="not supported side effect")
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_for_after_for_in_if():
class ForAfterForInIfNet(nn.Cell):
def __init__(self):
@ -56,14 +60,18 @@ def test_for_after_for_in_if():
context.set_context(mode=context.GRAPH_MODE)
for_after_for_in_if_net = ForAfterForInIfNet()
net = GradNet(for_after_for_in_if_net)
graph_forward_res = for_after_for_in_if_net(x)
forward_net = ForAfterForInIfNet()
graph_forward_res = forward_net(x)
graph_backward_res = net(x)
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE)
for_after_for_in_if_net = ForAfterForInIfNet()
net = GradNet(for_after_for_in_if_net)
pynative_forward_res = for_after_for_in_if_net(x)
forward_net = ForAfterForInIfNet()
pynative_forward_res = forward_net(x)
pynative_backward_res = net(x)
assert graph_forward_res == pynative_forward_res

View File

@ -22,7 +22,6 @@ from mindspore.ops import operations as P
from mindspore.common import dtype as mstype
grad_all = C.GradOperation(get_all=True)
context.set_context(device_target="Ascend")
@pytest.mark.skip(reason="not supported for in while")
def test_for_after_for_in_while_01():
class ForAfterForInWhileNet(nn.Cell):
@ -75,14 +74,18 @@ def test_for_after_for_in_while_01():
context.set_context(mode=context.GRAPH_MODE)
for_after_for_in_while_net = ForAfterForInWhileNet()
net = GradNet(for_after_for_in_while_net)
graph_forward_res = for_after_for_in_while_net(x, y)
forward_net = ForAfterForInWhileNet()
graph_forward_res = forward_net(x, y)
graph_backward_res = net(x, y)
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE)
for_after_for_in_while_net = ForAfterForInWhileNet()
net = GradNet(for_after_for_in_while_net)
pynative_forward_res = for_after_for_in_while_net(x, y)
forward_net = ForAfterForInWhileNet()
pynative_forward_res = forward_net(x, y)
pynative_backward_res = net(x, y)
assert graph_forward_res == pynative_forward_res
@ -128,14 +131,18 @@ def test_for_after_for_in_while_02():
context.set_context(mode=context.GRAPH_MODE)
for_after_for_in_while_net = ForAfterForInWhileNet()
net = GradNet(for_after_for_in_while_net)
graph_forward_res = for_after_for_in_while_net(x, y)
forward_net = ForAfterForInWhileNet()
graph_forward_res = forward_net(x, y)
graph_backward_res = net(x, y)
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE)
for_after_for_in_while_net = ForAfterForInWhileNet()
net = GradNet(for_after_for_in_while_net)
pynative_forward_res = for_after_for_in_while_net(x, y)
forward_net = ForAfterForInWhileNet()
pynative_forward_res = forward_net(x, y)
pynative_backward_res = net(x, y)
assert graph_forward_res == pynative_forward_res

View File

@ -13,6 +13,7 @@
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
from mindspore import context
from mindspore import Tensor, nn
from mindspore.common.parameter import Parameter
@ -21,8 +22,12 @@ from mindspore.ops import operations as P
from mindspore.common import dtype as mstype
grad_all = C.GradOperation(get_all=True)
context.set_context(device_target="Ascend")
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_for_after_for_in_for_01():
class ForAfterForInForNet(nn.Cell):
def __init__(self):
@ -70,20 +75,28 @@ def test_for_after_for_in_for_01():
context.set_context(mode=context.GRAPH_MODE)
for_after_for_in_for_net = ForAfterForInForNet()
net = GradNet(for_after_for_in_for_net)
graph_forward_res = for_after_for_in_for_net(x, y)
forward_net = ForAfterForInForNet()
graph_forward_res = forward_net(x, y)
graph_backward_res = net(x, y)
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE)
for_after_for_in_for_net = ForAfterForInForNet()
net = GradNet(for_after_for_in_for_net)
pynative_forward_res = for_after_for_in_for_net(x, y)
forward_net = ForAfterForInForNet()
pynative_forward_res = forward_net(x, y)
pynative_backward_res = net(x, y)
assert graph_forward_res == pynative_forward_res
assert graph_backward_res == pynative_backward_res
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_for_after_for_in_for_02():
class ForAfterForInForNet(nn.Cell):
def __init__(self):
@ -127,14 +140,18 @@ def test_for_after_for_in_for_02():
context.set_context(mode=context.GRAPH_MODE)
for_after_for_in_for_net = ForAfterForInForNet()
net = GradNet(for_after_for_in_for_net)
graph_forward_res = for_after_for_in_for_net(x, y)
forward_net = ForAfterForInForNet()
graph_forward_res = forward_net(x, y)
graph_backward_res = net(x, y)
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE)
for_after_for_in_for_net = ForAfterForInForNet()
net = GradNet(for_after_for_in_for_net)
pynative_forward_res = for_after_for_in_for_net(x, y)
forward_net = ForAfterForInForNet()
pynative_forward_res = forward_net(x, y)
pynative_backward_res = net(x, y)
assert graph_forward_res == pynative_forward_res