forked from mindspore-Ecosystem/mindspore
!15532 add control flow testcases
From: @huangbingjian Reviewed-by: @ginfung,@zh_qh Signed-off-by: @zh_qh
This commit is contained in:
commit
09d41736e1
|
@ -20,41 +20,69 @@ from mindspore.common import dtype as mstype
|
|||
grad_all = C.GradOperation(get_all=True)
|
||||
context.set_context(device_target="Ascend")
|
||||
|
||||
def test_single_if():
|
||||
class SingleIfNet(nn.Cell):
|
||||
def construct(self, x, y):
|
||||
x += 1
|
||||
if x < y:
|
||||
y += x
|
||||
else:
|
||||
y -= x
|
||||
y += 5
|
||||
return y
|
||||
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNet, self).__init__()
|
||||
self.net = net
|
||||
class SingleIfNet(nn.Cell):
|
||||
def construct(self, x, y):
|
||||
x += 1
|
||||
if x < y:
|
||||
y += x
|
||||
else:
|
||||
y -= x
|
||||
y += 5
|
||||
return y
|
||||
|
||||
def construct(self, *inputs):
|
||||
return grad_all(self.net)(*inputs)
|
||||
|
||||
x = Tensor(2, mstype.int32)
|
||||
y = Tensor(5, mstype.int32)
|
||||
class SingleIfNet1(nn.Cell):
|
||||
def construct(self, x, y):
|
||||
x += 1
|
||||
out = self.func(x, y)
|
||||
out *= 2
|
||||
return out
|
||||
|
||||
def func(self, x, y):
|
||||
if x < y:
|
||||
y += x
|
||||
else:
|
||||
y -= x
|
||||
y += 5
|
||||
return y
|
||||
|
||||
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNet, self).__init__()
|
||||
self.net = net
|
||||
|
||||
def construct(self, *inputs):
|
||||
return grad_all(self.net)(*inputs)
|
||||
|
||||
|
||||
def control_flow_single_if(input_net, x, y):
|
||||
# graph mode
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
if_net = SingleIfNet()
|
||||
net = GradNet(if_net)
|
||||
graph_forward_res = if_net(x, y)
|
||||
graph_backward_res = net(x, y)
|
||||
net = input_net()
|
||||
grad_net = GradNet(net)
|
||||
graph_forward_res = net(x, y)
|
||||
graph_backward_res = grad_net(x, y)
|
||||
|
||||
# pynative mode
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
if_net = SingleIfNet()
|
||||
net = GradNet(if_net)
|
||||
pynative_forward_res = if_net(x, y)
|
||||
pynative_backward_res = net(x, y)
|
||||
net = input_net()
|
||||
grad_net = GradNet(net)
|
||||
pynative_forward_res = net(x, y)
|
||||
pynative_backward_res = grad_net(x, y)
|
||||
|
||||
assert graph_forward_res == pynative_forward_res
|
||||
assert graph_backward_res == pynative_backward_res
|
||||
|
||||
|
||||
def test_single_if():
|
||||
x = Tensor(2, mstype.int32)
|
||||
y = Tensor(5, mstype.int32)
|
||||
control_flow_single_if(SingleIfNet, x, y)
|
||||
|
||||
|
||||
def test_single_if_01():
|
||||
x = Tensor(2, mstype.int32)
|
||||
y = Tensor(5, mstype.int32)
|
||||
control_flow_single_if(SingleIfNet1, x, y)
|
||||
|
|
|
@ -21,44 +21,138 @@ from mindspore.common.parameter import Parameter
|
|||
grad_all = C.GradOperation(get_all=True)
|
||||
context.set_context(device_target="Ascend")
|
||||
|
||||
def test_if_in_if():
|
||||
class IfInIfNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.param_a = Parameter(Tensor(5, mstype.int32), name='a')
|
||||
self.param_b = Parameter(Tensor(4, mstype.int32), name='b')
|
||||
|
||||
def construct(self, x):
|
||||
if self.param_a > self.param_b:
|
||||
x += 10
|
||||
if x > self.param_a:
|
||||
self.param_b += 1
|
||||
x += self.param_a
|
||||
return x
|
||||
class IfInIfNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.param_a = Parameter(Tensor(5, mstype.int32), name='a')
|
||||
self.param_b = Parameter(Tensor(4, mstype.int32), name='b')
|
||||
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNet, self).__init__()
|
||||
self.net = net
|
||||
def construct(self, x):
|
||||
if self.param_a > self.param_b:
|
||||
x += 10
|
||||
if x > self.param_a:
|
||||
self.param_b += 1
|
||||
x += self.param_a
|
||||
return x
|
||||
|
||||
def construct(self, *inputs):
|
||||
return grad_all(self.net)(*inputs)
|
||||
|
||||
x = Tensor(2, mstype.int32)
|
||||
class IfInIfNet1(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.param_a = Parameter(Tensor(5, mstype.int32), name='a')
|
||||
self.param_b = Parameter(Tensor(4, mstype.int32), name='b')
|
||||
|
||||
def construct(self, x):
|
||||
if self.param_a > self.param_b:
|
||||
out = self.func(x)
|
||||
else:
|
||||
out = self.func(self.param_a)
|
||||
out += self.param_b
|
||||
return out
|
||||
|
||||
def func(self, x):
|
||||
x += 10
|
||||
if x > self.param_a:
|
||||
self.param_b += 1
|
||||
x += self.param_a
|
||||
return x
|
||||
|
||||
|
||||
class IfInIfNet2(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.param_a = Parameter(Tensor(5, mstype.int32), name='a')
|
||||
self.param_b = Parameter(Tensor(4, mstype.int32), name='b')
|
||||
|
||||
def construct(self, x):
|
||||
if self.check(self.param_a, self.param_b):
|
||||
out = self.func(x)
|
||||
else:
|
||||
out = x
|
||||
out += self.param_b
|
||||
return out
|
||||
|
||||
def func(self, x):
|
||||
x += 10
|
||||
if x > self.param_a:
|
||||
self.param_b += 1
|
||||
x += self.param_a
|
||||
return x
|
||||
|
||||
def check(self, x, y):
|
||||
if x < y:
|
||||
self.param_b += 1
|
||||
return True
|
||||
self.param_b -= 1
|
||||
return False
|
||||
|
||||
|
||||
class IfInIfNet3(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.param_a = Parameter(Tensor(5, mstype.int32), name='a')
|
||||
self.param_b = Parameter(Tensor(4, mstype.int32), name='b')
|
||||
|
||||
def construct(self, x):
|
||||
if self.func(x) > self.param_a:
|
||||
out = x
|
||||
else:
|
||||
out = self.param_a
|
||||
out += self.param_b
|
||||
return out
|
||||
|
||||
def func(self, x):
|
||||
x += 10
|
||||
if x > self.param_a:
|
||||
self.param_b += 1
|
||||
x += self.param_a
|
||||
return x
|
||||
|
||||
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNet, self).__init__()
|
||||
self.net = net
|
||||
|
||||
def construct(self, *inputs):
|
||||
return grad_all(self.net)(*inputs)
|
||||
|
||||
|
||||
def control_flow_if_in_if(input_net, x):
|
||||
# graph mode
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
if_in_if_net = IfInIfNet()
|
||||
net = GradNet(if_in_if_net)
|
||||
graph_forward_res = if_in_if_net(x)
|
||||
graph_backward_res = net(x)
|
||||
net = input_net()
|
||||
grad_net = GradNet(net)
|
||||
graph_forward_res = net(x)
|
||||
graph_backward_res = grad_net(x)
|
||||
|
||||
# pynative mode
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
if_in_if_net = IfInIfNet()
|
||||
net = GradNet(if_in_if_net)
|
||||
pynative_forward_res = if_in_if_net(x)
|
||||
pynative_backward_res = net(x)
|
||||
net = input_net()
|
||||
grad_net = GradNet(net)
|
||||
pynative_forward_res = net(x)
|
||||
pynative_backward_res = grad_net(x)
|
||||
|
||||
assert graph_forward_res == pynative_forward_res
|
||||
assert graph_backward_res == pynative_backward_res
|
||||
|
||||
|
||||
def test_if_in_if():
|
||||
x = Tensor(2, mstype.int32)
|
||||
control_flow_if_in_if(IfInIfNet, x)
|
||||
|
||||
|
||||
def test_if_in_if_01():
|
||||
x = Tensor(2, mstype.int32)
|
||||
control_flow_if_in_if(IfInIfNet1, x)
|
||||
|
||||
|
||||
def test_if_in_if_02():
|
||||
x = Tensor(2, mstype.int32)
|
||||
control_flow_if_in_if(IfInIfNet2, x)
|
||||
|
||||
|
||||
def test_if_in_if_03():
|
||||
x = Tensor(2, mstype.int32)
|
||||
control_flow_if_in_if(IfInIfNet3, x)
|
||||
|
|
|
@ -21,45 +21,100 @@ from mindspore.common.parameter import Parameter
|
|||
grad_all = C.GradOperation(get_all=True)
|
||||
context.set_context(device_target="Ascend")
|
||||
|
||||
def test_if_after_if():
|
||||
class IfAfterIfNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.param_a = Parameter(Tensor(5, mstype.int32), name='a')
|
||||
self.param_b = Parameter(Tensor(4, mstype.int32), name='b')
|
||||
|
||||
def construct(self, x):
|
||||
out = x + self.param_b
|
||||
if self.param_a > self.param_b:
|
||||
x += 5
|
||||
self.param_b += 4
|
||||
if x < self.param_b:
|
||||
out += self.param_b
|
||||
return out
|
||||
class IfAfterIfNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.param_a = Parameter(Tensor(5, mstype.int32), name='a')
|
||||
self.param_b = Parameter(Tensor(4, mstype.int32), name='b')
|
||||
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNet, self).__init__()
|
||||
self.net = net
|
||||
def construct(self, x, y):
|
||||
out = y
|
||||
if self.param_a > self.param_b:
|
||||
x += 5
|
||||
self.param_b += 4
|
||||
if x < self.param_b:
|
||||
out += self.param_b
|
||||
return out
|
||||
|
||||
def construct(self, *inputs):
|
||||
return grad_all(self.net)(*inputs)
|
||||
|
||||
x = Tensor(2, mstype.int32)
|
||||
class IfAfterIfNet1(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.param_a = Parameter(Tensor(5, mstype.int32), name='a')
|
||||
self.param_b = Parameter(Tensor(4, mstype.int32), name='b')
|
||||
|
||||
def construct(self, x, y):
|
||||
out = y
|
||||
x = self.func(x)
|
||||
if x < self.param_b:
|
||||
out += self.param_b
|
||||
return out
|
||||
|
||||
def func(self, x):
|
||||
if self.param_a > self.param_b:
|
||||
x += 5
|
||||
self.param_b += 4
|
||||
return x
|
||||
|
||||
|
||||
class IfAfterIfNet2(nn.Cell):
|
||||
def construct(self, x, y):
|
||||
x += 1
|
||||
out = self.func(x, y)
|
||||
if out > 10:
|
||||
out += 5
|
||||
return out
|
||||
|
||||
def func(self, x, y):
|
||||
if x < y:
|
||||
y += x
|
||||
else:
|
||||
y -= x
|
||||
return y
|
||||
|
||||
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNet, self).__init__()
|
||||
self.net = net
|
||||
|
||||
def construct(self, *inputs):
|
||||
return grad_all(self.net)(*inputs)
|
||||
|
||||
|
||||
def control_flow_if_after_if(input_net, x, y):
|
||||
# graph mode
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
if_after_if_net = IfAfterIfNet()
|
||||
net = GradNet(if_after_if_net)
|
||||
graph_forward_res = if_after_if_net(x)
|
||||
graph_backward_res = net(x)
|
||||
net = input_net()
|
||||
grad_net = GradNet(net)
|
||||
graph_forward_res = net(x, y)
|
||||
graph_backward_res = grad_net(x, y)
|
||||
|
||||
# pynative mode
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
if_after_if_net = IfAfterIfNet()
|
||||
net = GradNet(if_after_if_net)
|
||||
pynative_forward_res = if_after_if_net(x)
|
||||
pynative_backward_res = net(x)
|
||||
net = input_net()
|
||||
grad_net = GradNet(net)
|
||||
pynative_forward_res = net(x, y)
|
||||
pynative_backward_res = grad_net(x, y)
|
||||
|
||||
assert graph_forward_res == pynative_forward_res
|
||||
assert graph_backward_res == pynative_backward_res
|
||||
|
||||
|
||||
def test_if_after_if():
|
||||
x = Tensor(2, mstype.int32)
|
||||
y = Tensor(5, mstype.int32)
|
||||
control_flow_if_after_if(IfAfterIfNet, x, y)
|
||||
|
||||
|
||||
def test_if_after_if_01():
|
||||
x = Tensor(2, mstype.int32)
|
||||
y = Tensor(5, mstype.int32)
|
||||
control_flow_if_after_if(IfAfterIfNet1, x, y)
|
||||
|
||||
|
||||
def test_if_after_if_02():
|
||||
x = Tensor(2, mstype.int32)
|
||||
y = Tensor(5, mstype.int32)
|
||||
control_flow_if_after_if(IfAfterIfNet2, x, y)
|
||||
|
|
|
@ -21,47 +21,146 @@ from mindspore.common.parameter import Parameter
|
|||
grad_all = C.GradOperation(get_all=True)
|
||||
context.set_context(device_target="Ascend")
|
||||
|
||||
def test_if_after_if_in_if():
|
||||
class IfAfterIfInIfNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.param_a = Parameter(Tensor(5, mstype.int32), name='a')
|
||||
self.param_b = Parameter(Tensor(4, mstype.int32), name='b')
|
||||
|
||||
def construct(self, x):
|
||||
out = x + self.param_b
|
||||
if self.param_a > self.param_b:
|
||||
x += 5
|
||||
if x > self.param_a:
|
||||
self.param_b += 1
|
||||
self.param_b += 3
|
||||
if x < self.param_b:
|
||||
out += self.param_b
|
||||
return out
|
||||
class IfAfterIfInIfNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.param_a = Parameter(Tensor(5, mstype.int32), name='a')
|
||||
self.param_b = Parameter(Tensor(4, mstype.int32), name='b')
|
||||
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNet, self).__init__()
|
||||
self.net = net
|
||||
def construct(self, x):
|
||||
out = x + self.param_b
|
||||
if self.param_a > self.param_b:
|
||||
x += 5
|
||||
if x > self.param_a:
|
||||
self.param_b += 1
|
||||
self.param_b += 3
|
||||
if x < self.param_b:
|
||||
out += self.param_b
|
||||
return out
|
||||
|
||||
def construct(self, *inputs):
|
||||
return grad_all(self.net)(*inputs)
|
||||
|
||||
x = Tensor(2, mstype.int32)
|
||||
class IfAfterIfInIfNet1(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.param_a = Parameter(Tensor(5, mstype.int32), name='a')
|
||||
self.param_b = Parameter(Tensor(4, mstype.int32), name='b')
|
||||
|
||||
def construct(self, x):
|
||||
out = x + self.param_b
|
||||
x = self.func(x)
|
||||
if x < self.param_b:
|
||||
out += self.param_b
|
||||
return out
|
||||
|
||||
def func(self, x):
|
||||
if self.param_a > self.param_b:
|
||||
x += 5
|
||||
if x > self.param_a:
|
||||
self.param_b += 1
|
||||
self.param_b += 3
|
||||
return x
|
||||
|
||||
|
||||
class IfAfterIfInIfNet2(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.param_a = Parameter(Tensor(5, mstype.int32), name='a')
|
||||
self.param_b = Parameter(Tensor(4, mstype.int32), name='b')
|
||||
|
||||
def construct(self, x):
|
||||
out = x + self.param_b
|
||||
x = self.func(x)
|
||||
if x < self.param_b:
|
||||
out += self.param_b
|
||||
return out
|
||||
|
||||
def func(self, x):
|
||||
if self.subfunc(x):
|
||||
x += 5
|
||||
self.param_b += 3
|
||||
return x
|
||||
|
||||
def subfunc(self, x):
|
||||
if x > self.param_a:
|
||||
self.param_b += 1
|
||||
return True
|
||||
self.param_b -= 1
|
||||
return False
|
||||
|
||||
|
||||
class IfAfterIfInIfNet3(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.param_a = Parameter(Tensor(5, mstype.int32), name='a')
|
||||
self.param_b = Parameter(Tensor(4, mstype.int32), name='b')
|
||||
|
||||
def construct(self, x):
|
||||
out = x + self.param_b
|
||||
x = self.func(x)
|
||||
if x < self.param_b:
|
||||
out += self.param_b
|
||||
return out
|
||||
|
||||
def func(self, x):
|
||||
if self.subfunc(x) > self.param_a:
|
||||
x += 5
|
||||
self.param_b += 3
|
||||
return x
|
||||
|
||||
def subfunc(self, x):
|
||||
if x > self.param_a:
|
||||
x -= self.param_b
|
||||
self.param_b += 1
|
||||
else:
|
||||
x += self.param_b
|
||||
self.param_b -= 1
|
||||
return x
|
||||
|
||||
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNet, self).__init__()
|
||||
self.net = net
|
||||
|
||||
def construct(self, *inputs):
|
||||
return grad_all(self.net)(*inputs)
|
||||
|
||||
|
||||
def control_flow_if_after_if_in_if(input_net, x):
|
||||
# graph mode
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
if_after_if_in_if_net = IfAfterIfInIfNet()
|
||||
net = GradNet(if_after_if_in_if_net)
|
||||
graph_forward_res = if_after_if_in_if_net(x)
|
||||
graph_backward_res = net(x)
|
||||
net = input_net()
|
||||
grad_net = GradNet(net)
|
||||
graph_forward_res = net(x)
|
||||
graph_backward_res = grad_net(x)
|
||||
|
||||
# pynative mode
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
if_after_if_in_if_net = IfAfterIfInIfNet()
|
||||
net = GradNet(if_after_if_in_if_net)
|
||||
pynative_forward_res = if_after_if_in_if_net(x)
|
||||
pynative_backward_res = net(x)
|
||||
net = input_net()
|
||||
grad_net = GradNet(net)
|
||||
pynative_forward_res = net(x)
|
||||
pynative_backward_res = grad_net(x)
|
||||
|
||||
assert graph_forward_res == pynative_forward_res
|
||||
assert graph_backward_res == pynative_backward_res
|
||||
|
||||
|
||||
def test_if_after_if_in_if():
|
||||
x = Tensor(2, mstype.int32)
|
||||
control_flow_if_after_if_in_if(IfAfterIfInIfNet, x)
|
||||
|
||||
|
||||
def test_if_after_if_in_if_01():
|
||||
x = Tensor(2, mstype.int32)
|
||||
control_flow_if_after_if_in_if(IfAfterIfInIfNet1, x)
|
||||
|
||||
|
||||
def test_if_after_if_in_if_02():
|
||||
x = Tensor(2, mstype.int32)
|
||||
control_flow_if_after_if_in_if(IfAfterIfInIfNet2, x)
|
||||
|
||||
|
||||
def test_if_after_if_in_if_03():
|
||||
x = Tensor(2, mstype.int32)
|
||||
control_flow_if_after_if_in_if(IfAfterIfInIfNet3, x)
|
||||
|
|
|
@ -21,46 +21,138 @@ from mindspore.common.parameter import Parameter
|
|||
grad_all = C.GradOperation(get_all=True)
|
||||
context.set_context(device_target="Ascend")
|
||||
|
||||
def test_if_after_if_in_for():
|
||||
class IfAfterIfInForNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.param_a = Parameter(Tensor(5, mstype.int32), name='a')
|
||||
self.param_b = Parameter(Tensor(4, mstype.int32), name='b')
|
||||
|
||||
def construct(self, x):
|
||||
out = x + self.param_b
|
||||
for _ in range(4):
|
||||
if out <= 20:
|
||||
out += self.param_a
|
||||
self.param_b += 3
|
||||
if x < self.param_b:
|
||||
out -= self.param_b
|
||||
return out
|
||||
class IfAfterIfInForNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.param_a = Parameter(Tensor(5, mstype.int32), name='a')
|
||||
self.param_b = Parameter(Tensor(4, mstype.int32), name='b')
|
||||
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNet, self).__init__()
|
||||
self.net = net
|
||||
def construct(self, x):
|
||||
out = x + self.param_b
|
||||
for _ in range(4):
|
||||
if out <= 20:
|
||||
out += self.param_a
|
||||
self.param_b += 3
|
||||
if x < self.param_b:
|
||||
out -= self.param_b
|
||||
return out
|
||||
|
||||
def construct(self, *inputs):
|
||||
return grad_all(self.net)(*inputs)
|
||||
|
||||
x = Tensor(2, mstype.int32)
|
||||
class IfAfterIfInForNet1(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.param_a = Parameter(Tensor(5, mstype.int32), name='a')
|
||||
self.param_b = Parameter(Tensor(4, mstype.int32), name='b')
|
||||
|
||||
def construct(self, x):
|
||||
out = self.func(x)
|
||||
if x < self.param_b:
|
||||
out -= self.param_b
|
||||
return out
|
||||
|
||||
def func(self, x):
|
||||
out = x + self.param_b
|
||||
for _ in range(4):
|
||||
if out <= 20:
|
||||
out += self.param_a
|
||||
self.param_b += 3
|
||||
return out
|
||||
|
||||
|
||||
class IfAfterIfInForNet2(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.param_a = Parameter(Tensor(5, mstype.int32), name='a')
|
||||
self.param_b = Parameter(Tensor(4, mstype.int32), name='b')
|
||||
|
||||
def construct(self, x):
|
||||
out = self.func(x)
|
||||
if x < self.param_b:
|
||||
out -= self.param_b
|
||||
return out
|
||||
|
||||
def func(self, x):
|
||||
out = x + self.param_b
|
||||
for _ in range(4):
|
||||
out = self.subfunc(out)
|
||||
self.param_b += 3
|
||||
return out
|
||||
|
||||
def subfunc(self, x):
|
||||
if x <= 20:
|
||||
x += self.param_a
|
||||
return x
|
||||
|
||||
|
||||
class IfAfterIfInForNet3(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.param_a = Parameter(Tensor(5, mstype.int32), name='a')
|
||||
self.param_b = Parameter(Tensor(4, mstype.int32), name='b')
|
||||
|
||||
def construct(self, x):
|
||||
out = self.func(x)
|
||||
if x < self.param_b:
|
||||
out -= self.param_b
|
||||
return out
|
||||
|
||||
def func(self, x):
|
||||
out = x + self.param_b
|
||||
for _ in range(3):
|
||||
out += self.subfunc(x)
|
||||
self.param_b += 3
|
||||
return out
|
||||
|
||||
def subfunc(self, x):
|
||||
if x > 10:
|
||||
return self.param_a
|
||||
return self.param_b
|
||||
|
||||
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNet, self).__init__()
|
||||
self.net = net
|
||||
|
||||
def construct(self, *inputs):
|
||||
return grad_all(self.net)(*inputs)
|
||||
|
||||
|
||||
def control_flow_if_after_if_in_for(input_net, x):
|
||||
# graph mode
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
if_after_if_in_for_net = IfAfterIfInForNet()
|
||||
net = GradNet(if_after_if_in_for_net)
|
||||
graph_forward_res = if_after_if_in_for_net(x)
|
||||
graph_backward_res = net(x)
|
||||
net = input_net()
|
||||
grad_net = GradNet(net)
|
||||
graph_forward_res = net(x)
|
||||
graph_backward_res = grad_net(x)
|
||||
|
||||
# pynative mode
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
if_after_if_in_for_net = IfAfterIfInForNet()
|
||||
net = GradNet(if_after_if_in_for_net)
|
||||
pynative_forward_res = if_after_if_in_for_net(x)
|
||||
pynative_backward_res = net(x)
|
||||
net = input_net()
|
||||
grad_net = GradNet(net)
|
||||
pynative_forward_res = net(x)
|
||||
pynative_backward_res = grad_net(x)
|
||||
|
||||
assert graph_forward_res == pynative_forward_res
|
||||
assert graph_backward_res == pynative_backward_res
|
||||
|
||||
|
||||
def test_if_after_if_in_for():
|
||||
x = Tensor(2, mstype.int32)
|
||||
control_flow_if_after_if_in_for(IfAfterIfInForNet, x)
|
||||
|
||||
|
||||
def test_if_after_if_in_for_01():
|
||||
x = Tensor(2, mstype.int32)
|
||||
control_flow_if_after_if_in_for(IfAfterIfInForNet1, x)
|
||||
|
||||
|
||||
def test_if_after_if_in_for_02():
|
||||
x = Tensor(2, mstype.int32)
|
||||
control_flow_if_after_if_in_for(IfAfterIfInForNet2, x)
|
||||
|
||||
|
||||
def test_if_after_if_in_for_03():
|
||||
x = Tensor(2, mstype.int32)
|
||||
control_flow_if_after_if_in_for(IfAfterIfInForNet3, x)
|
||||
|
|
Loading…
Reference in New Issue