forked from mindspore-Ecosystem/mindspore
!14888 add testcases related to if, fix codedex check
From: @huangbingjian Reviewed-by: @zh_qh,@ginfung Signed-off-by: @zh_qh
This commit is contained in:
commit
1d2e37be88
|
@ -103,8 +103,8 @@ static bool isTraversable(const AnfNodePtr &node) {
|
|||
return false;
|
||||
}
|
||||
|
||||
static inline AnfNodePtr DoTransform(const OptimizerPtr &optimizer, const AnfNodePtr &node,
|
||||
const SubstitutionPtr &substitution) {
|
||||
static AnfNodePtr DoTransform(const OptimizerPtr &optimizer, const AnfNodePtr &node,
|
||||
const SubstitutionPtr &substitution) {
|
||||
auto manager = optimizer->manager();
|
||||
bool is_match = substitution->predicate_(node);
|
||||
if (is_match) {
|
||||
|
@ -126,8 +126,8 @@ static inline AnfNodePtr DoTransform(const OptimizerPtr &optimizer, const AnfNod
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
static inline void UpdateTransformingList(const OptimizerPtr &optimizer, const AnfNodePtr &node,
|
||||
std::deque<AnfNodePtr> *todo, bool change, size_t seen) {
|
||||
static void UpdateTransformingList(const OptimizerPtr &optimizer, const AnfNodePtr &node, std::deque<AnfNodePtr> *todo,
|
||||
bool change, size_t seen) {
|
||||
if (IsValueNode<FuncGraph>(node)) {
|
||||
(*todo).emplace_back(GetValueNode<FuncGraphPtr>(node)->output());
|
||||
}
|
||||
|
@ -238,6 +238,23 @@ bool SubstitutionList::ApplySubstitutionToIR(const OptimizerPtr &optimizer, cons
|
|||
return changes;
|
||||
}
|
||||
|
||||
void SubstitutionList::DisplayStatusOfSubstitution(const std::unordered_map<std::string, std::vector<bool>> &status,
|
||||
const OptimizerPtr &optimizer, size_t space) const {
|
||||
std::stringstream ss;
|
||||
ss << std::endl
|
||||
<< "Pass: " << optimizer->name() << "(" << optimizer->CurPass_.counter << ")_" << optimizer->CurPass_.name
|
||||
<< std::endl;
|
||||
for (size_t i = 0; i < list_.size(); i++) {
|
||||
auto name = list_[i]->name_;
|
||||
ss << std::left << std::setw(space + 4) << name << "\t";
|
||||
for (auto change : status.at(name + std::to_string(i))) {
|
||||
ss << change << " ";
|
||||
}
|
||||
ss << std::endl;
|
||||
}
|
||||
MS_LOG(DEBUG) << ss.str();
|
||||
}
|
||||
|
||||
bool SubstitutionList::ApplySubstitutionsToIR(const OptimizerPtr &optimizer, const FuncGraphPtr &func_graph) const {
|
||||
// Add for substitution status counting
|
||||
size_t space = 0;
|
||||
|
@ -282,19 +299,7 @@ bool SubstitutionList::ApplySubstitutionsToIR(const OptimizerPtr &optimizer, con
|
|||
|
||||
// Display the status of each substitution
|
||||
if (optimizer->is_on_debug_) {
|
||||
std::stringstream ss;
|
||||
ss << std::endl
|
||||
<< "Pass: " << optimizer->name() << "(" << optimizer->CurPass_.counter << ")_" << optimizer->CurPass_.name
|
||||
<< std::endl;
|
||||
for (size_t i = 0; i < list_.size(); i++) {
|
||||
auto name = list_[i]->name_;
|
||||
ss << std::left << std::setw(space + 4) << name << "\t";
|
||||
for (auto change : status[name + std::to_string(i)]) {
|
||||
ss << change << " ";
|
||||
}
|
||||
ss << std::endl;
|
||||
}
|
||||
MS_LOG(DEBUG) << ss.str();
|
||||
DisplayStatusOfSubstitution(status, optimizer, space);
|
||||
}
|
||||
return changes;
|
||||
}
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "ir/anf.h"
|
||||
#include "ir/func_graph.h"
|
||||
|
@ -74,6 +75,8 @@ class SubstitutionList {
|
|||
bool ApplyIRToSubstitutions(const OptimizerPtr &optimizer, const FuncGraphPtr &func_graph) const;
|
||||
bool ApplySubstitutionToIR(const OptimizerPtr &optimizer, const AnfNodePtr &node, const SubstitutionPtr &sub) const;
|
||||
bool ApplySubstitutionsToIR(const OptimizerPtr &optimizer, const FuncGraphPtr &func_graph) const;
|
||||
void DisplayStatusOfSubstitution(const std::unordered_map<std::string, std::vector<bool>> &status,
|
||||
const OptimizerPtr &optimizer, size_t space) const;
|
||||
|
||||
std::vector<SubstitutionPtr> list_;
|
||||
// a flag to mark this list of Substitution can only be executed only once
|
||||
|
|
|
@ -0,0 +1,60 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
from mindspore import context
|
||||
from mindspore import Tensor, nn
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
grad_all = C.GradOperation(get_all=True)
|
||||
context.set_context(device_target="Ascend")
|
||||
|
||||
def test_signle_if():
|
||||
class SignleIfNet(nn.Cell):
|
||||
def construct(self, x, y):
|
||||
x += 1
|
||||
if x < y:
|
||||
y += x
|
||||
else:
|
||||
y -= x
|
||||
y += 5
|
||||
return y
|
||||
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNet, self).__init__()
|
||||
self.net = net
|
||||
|
||||
def construct(self, *inputs):
|
||||
return grad_all(self.net)(*inputs)
|
||||
|
||||
x = Tensor(2, mstype.int32)
|
||||
y = Tensor(5, mstype.int32)
|
||||
|
||||
# graph mode
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
if_net = SignleIfNet()
|
||||
net = GradNet(if_net)
|
||||
graph_forward_res = if_net(x, y)
|
||||
graph_backward_res = net(x, y)
|
||||
|
||||
# pynative mode
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
if_net = SignleIfNet()
|
||||
net = GradNet(if_net)
|
||||
pynative_forward_res = if_net(x, y)
|
||||
pynative_backward_res = net(x, y)
|
||||
|
||||
assert graph_forward_res == pynative_forward_res
|
||||
assert graph_backward_res == pynative_backward_res
|
|
@ -0,0 +1,64 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
from mindspore import context
|
||||
from mindspore import Tensor, nn
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.parameter import Parameter
|
||||
|
||||
grad_all = C.GradOperation(get_all=True)
|
||||
context.set_context(device_target="Ascend")
|
||||
|
||||
def test_if_in_if():
|
||||
class IfInIfNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.param_a = Parameter(Tensor(5, mstype.int32), name='a')
|
||||
self.param_b = Parameter(Tensor(4, mstype.int32), name='b')
|
||||
|
||||
def construct(self, x):
|
||||
if self.param_a > self.param_b:
|
||||
x += 10
|
||||
if x > self.param_a:
|
||||
self.param_b += 1
|
||||
x += self.param_a
|
||||
return x
|
||||
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNet, self).__init__()
|
||||
self.net = net
|
||||
|
||||
def construct(self, *inputs):
|
||||
return grad_all(self.net)(*inputs)
|
||||
|
||||
x = Tensor(2, mstype.int32)
|
||||
|
||||
# graph mode
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
if_in_if_net = IfInIfNet()
|
||||
net = GradNet(if_in_if_net)
|
||||
graph_forward_res = if_in_if_net(x)
|
||||
graph_backward_res = net(x)
|
||||
|
||||
# pynative mode
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
if_in_if_net = IfInIfNet()
|
||||
net = GradNet(if_in_if_net)
|
||||
pynative_forward_res = if_in_if_net(x)
|
||||
pynative_backward_res = net(x)
|
||||
|
||||
assert graph_forward_res == pynative_forward_res
|
||||
assert graph_backward_res == pynative_backward_res
|
|
@ -0,0 +1,65 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
from mindspore import context
|
||||
from mindspore import Tensor, nn
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.parameter import Parameter
|
||||
|
||||
grad_all = C.GradOperation(get_all=True)
|
||||
context.set_context(device_target="Ascend")
|
||||
|
||||
def test_if_after_if():
|
||||
class IfAfterIfNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.param_a = Parameter(Tensor(5, mstype.int32), name='a')
|
||||
self.param_b = Parameter(Tensor(4, mstype.int32), name='b')
|
||||
|
||||
def construct(self, x):
|
||||
out = x + self.param_b
|
||||
if self.param_a > self.param_b:
|
||||
x += 5
|
||||
self.param_b += 4
|
||||
if x < self.param_b:
|
||||
out += self.param_b
|
||||
return out
|
||||
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNet, self).__init__()
|
||||
self.net = net
|
||||
|
||||
def construct(self, *inputs):
|
||||
return grad_all(self.net)(*inputs)
|
||||
|
||||
x = Tensor(2, mstype.int32)
|
||||
|
||||
# graph mode
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
if_after_if_net = IfAfterIfNet()
|
||||
net = GradNet(if_after_if_net)
|
||||
graph_forward_res = if_after_if_net(x)
|
||||
graph_backward_res = net(x)
|
||||
|
||||
# pynative mode
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
if_after_if_net = IfAfterIfNet()
|
||||
net = GradNet(if_after_if_net)
|
||||
pynative_forward_res = if_after_if_net(x)
|
||||
pynative_backward_res = net(x)
|
||||
|
||||
assert graph_forward_res == pynative_forward_res
|
||||
assert graph_backward_res == pynative_backward_res
|
|
@ -0,0 +1,67 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
from mindspore import context
|
||||
from mindspore import Tensor, nn
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.parameter import Parameter
|
||||
|
||||
grad_all = C.GradOperation(get_all=True)
|
||||
context.set_context(device_target="Ascend")
|
||||
|
||||
def test_if_after_if_in_if():
|
||||
class IfAfterIfInIfNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.param_a = Parameter(Tensor(5, mstype.int32), name='a')
|
||||
self.param_b = Parameter(Tensor(4, mstype.int32), name='b')
|
||||
|
||||
def construct(self, x):
|
||||
out = x + self.param_b
|
||||
if self.param_a > self.param_b:
|
||||
x += 5
|
||||
if x > self.param_a:
|
||||
self.param_b += 1
|
||||
self.param_b += 3
|
||||
if x < self.param_b:
|
||||
out += self.param_b
|
||||
return out
|
||||
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNet, self).__init__()
|
||||
self.net = net
|
||||
|
||||
def construct(self, *inputs):
|
||||
return grad_all(self.net)(*inputs)
|
||||
|
||||
x = Tensor(2, mstype.int32)
|
||||
|
||||
# graph mode
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
if_after_if_in_if_net = IfAfterIfInIfNet()
|
||||
net = GradNet(if_after_if_in_if_net)
|
||||
graph_forward_res = if_after_if_in_if_net(x)
|
||||
graph_backward_res = net(x)
|
||||
|
||||
# pynative mode
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
if_after_if_in_if_net = IfAfterIfInIfNet()
|
||||
net = GradNet(if_after_if_in_if_net)
|
||||
pynative_forward_res = if_after_if_in_if_net(x)
|
||||
pynative_backward_res = net(x)
|
||||
|
||||
assert graph_forward_res == pynative_forward_res
|
||||
assert graph_backward_res == pynative_backward_res
|
|
@ -0,0 +1,66 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
from mindspore import context
|
||||
from mindspore import Tensor, nn
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.parameter import Parameter
|
||||
|
||||
grad_all = C.GradOperation(get_all=True)
|
||||
context.set_context(device_target="Ascend")
|
||||
|
||||
def test_if_after_if_in_for():
|
||||
class IfAfterIfInForNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.param_a = Parameter(Tensor(5, mstype.int32), name='a')
|
||||
self.param_b = Parameter(Tensor(4, mstype.int32), name='b')
|
||||
|
||||
def construct(self, x):
|
||||
out = x + self.param_b
|
||||
for _ in range(4):
|
||||
if out <= 20:
|
||||
out += self.param_a
|
||||
self.param_b += 3
|
||||
if x < self.param_b:
|
||||
out -= self.param_b
|
||||
return out
|
||||
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNet, self).__init__()
|
||||
self.net = net
|
||||
|
||||
def construct(self, *inputs):
|
||||
return grad_all(self.net)(*inputs)
|
||||
|
||||
x = Tensor(2, mstype.int32)
|
||||
|
||||
# graph mode
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
if_after_if_in_for_net = IfAfterIfInForNet()
|
||||
net = GradNet(if_after_if_in_for_net)
|
||||
graph_forward_res = if_after_if_in_for_net(x)
|
||||
graph_backward_res = net(x)
|
||||
|
||||
# pynative mode
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
if_after_if_in_for_net = IfAfterIfInForNet()
|
||||
net = GradNet(if_after_if_in_for_net)
|
||||
pynative_forward_res = if_after_if_in_for_net(x)
|
||||
pynative_backward_res = net(x)
|
||||
|
||||
assert graph_forward_res == pynative_forward_res
|
||||
assert graph_backward_res == pynative_backward_res
|
|
@ -0,0 +1,67 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
from mindspore import context
|
||||
from mindspore import Tensor, nn
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.parameter import Parameter
|
||||
|
||||
grad_all = C.GradOperation(get_all=True)
|
||||
context.set_context(device_target="Ascend")
|
||||
|
||||
def test_if_after_for_in_if():
|
||||
class IfAfterForInIfNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.param_a = Parameter(Tensor(5, mstype.int32), name='a')
|
||||
self.param_b = Parameter(Tensor(4, mstype.int32), name='b')
|
||||
|
||||
def construct(self, x):
|
||||
out = x + self.param_a
|
||||
if self.param_a > self.param_b:
|
||||
for _ in range(4):
|
||||
self.param_a += 1
|
||||
self.param_b -= 3
|
||||
self.param_b += 15
|
||||
if x < self.param_b:
|
||||
out -= self.param_b
|
||||
return out
|
||||
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNet, self).__init__()
|
||||
self.net = net
|
||||
|
||||
def construct(self, *inputs):
|
||||
return grad_all(self.net)(*inputs)
|
||||
|
||||
x = Tensor(2, mstype.int32)
|
||||
|
||||
# graph mode
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
if_after_for_in_if_net = IfAfterForInIfNet()
|
||||
net = GradNet(if_after_for_in_if_net)
|
||||
graph_forward_res = if_after_for_in_if_net(x)
|
||||
graph_backward_res = net(x)
|
||||
|
||||
# pynative mode
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
if_after_for_in_if_net = IfAfterForInIfNet()
|
||||
net = GradNet(if_after_for_in_if_net)
|
||||
pynative_forward_res = if_after_for_in_if_net(x)
|
||||
pynative_backward_res = net(x)
|
||||
|
||||
assert graph_forward_res == pynative_forward_res
|
||||
assert graph_backward_res == pynative_backward_res
|
|
@ -0,0 +1,67 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
from mindspore import context
|
||||
from mindspore import Tensor, nn
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.parameter import Parameter
|
||||
|
||||
grad_all = C.GradOperation(get_all=True)
|
||||
context.set_context(device_target="Ascend")
|
||||
|
||||
def test_if_after_for_in_while():
|
||||
class IfAfterForInWhileNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.param_a = Parameter(Tensor(5, mstype.int32), name='a')
|
||||
self.param_b = Parameter(Tensor(2, mstype.int32), name='b')
|
||||
|
||||
def construct(self, x):
|
||||
out = x + self.param_a
|
||||
while self.param_a > self.param_b:
|
||||
self.param_b += 1
|
||||
for _ in range(4):
|
||||
self.param_a += 3
|
||||
self.param_a -= 40
|
||||
if x > self.param_a:
|
||||
out += self.param_a * 10
|
||||
return out
|
||||
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNet, self).__init__()
|
||||
self.net = net
|
||||
|
||||
def construct(self, *inputs):
|
||||
return grad_all(self.net)(*inputs)
|
||||
|
||||
x = Tensor(2, mstype.int32)
|
||||
|
||||
# graph mode
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
if_after_for_in_while_net = IfAfterForInWhileNet()
|
||||
net = GradNet(if_after_for_in_while_net)
|
||||
graph_forward_res = if_after_for_in_while_net(x)
|
||||
graph_backward_res = net(x)
|
||||
|
||||
# pynative mode
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
if_after_for_in_while_net = IfAfterForInWhileNet()
|
||||
net = GradNet(if_after_for_in_while_net)
|
||||
pynative_forward_res = if_after_for_in_while_net(x)
|
||||
pynative_backward_res = net(x)
|
||||
|
||||
assert graph_forward_res == pynative_forward_res
|
||||
assert graph_backward_res == pynative_backward_res
|
|
@ -0,0 +1,67 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
from mindspore import context
|
||||
from mindspore import Tensor, nn
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.parameter import Parameter
|
||||
|
||||
grad_all = C.GradOperation(get_all=True)
|
||||
context.set_context(device_target="Ascend")
|
||||
|
||||
def test_if_after_for_in_for():
|
||||
class IfAfterForInForNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.param_a = Parameter(Tensor(5, mstype.int32), name='a')
|
||||
self.param_b = Parameter(Tensor(2, mstype.int32), name='b')
|
||||
|
||||
def construct(self, x):
|
||||
out = x + self.param_a
|
||||
for _ in range(0, 10):
|
||||
x *= 2
|
||||
for _ in range(0, 5):
|
||||
self.param_a += 1
|
||||
x += self.param_b
|
||||
if self.param_a > self.param_b:
|
||||
out += x
|
||||
return out
|
||||
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNet, self).__init__()
|
||||
self.net = net
|
||||
|
||||
def construct(self, *inputs):
|
||||
return grad_all(self.net)(*inputs)
|
||||
|
||||
x = Tensor(2, mstype.int32)
|
||||
|
||||
# graph mode
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
if_after_for_in_for_net = IfAfterForInForNet()
|
||||
net = GradNet(if_after_for_in_for_net)
|
||||
graph_forward_res = if_after_for_in_for_net(x)
|
||||
graph_backward_res = net(x)
|
||||
|
||||
# pynative mode
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
if_after_for_in_for_net = IfAfterForInForNet()
|
||||
net = GradNet(if_after_for_in_for_net)
|
||||
pynative_forward_res = if_after_for_in_for_net(x)
|
||||
pynative_backward_res = net(x)
|
||||
|
||||
assert graph_forward_res == pynative_forward_res
|
||||
assert graph_backward_res == pynative_backward_res
|
|
@ -0,0 +1,66 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
from mindspore import context
|
||||
from mindspore import Tensor, nn
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.parameter import Parameter
|
||||
|
||||
grad_all = C.GradOperation(get_all=True)
|
||||
context.set_context(device_target="Ascend")
|
||||
|
||||
def test_for_after_if():
|
||||
class ForAfterIfNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.param_a = Parameter(Tensor(5, mstype.int32), name='a')
|
||||
self.param_b = Parameter(Tensor(4, mstype.int32), name='b')
|
||||
|
||||
def construct(self, x):
|
||||
out = self.param_a
|
||||
if self.param_a > self.param_b:
|
||||
x += 3
|
||||
self.param_b += 1
|
||||
for _ in range(0, 5):
|
||||
x += self.param_b
|
||||
out *= x
|
||||
return out
|
||||
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNet, self).__init__()
|
||||
self.net = net
|
||||
|
||||
def construct(self, *inputs):
|
||||
return grad_all(self.net)(*inputs)
|
||||
|
||||
x = Tensor(2, mstype.int32)
|
||||
|
||||
# graph mode
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
for_after_if_net = ForAfterIfNet()
|
||||
net = GradNet(for_after_if_net)
|
||||
graph_forward_res = for_after_if_net(x)
|
||||
graph_backward_res = net(x)
|
||||
|
||||
# pynative mode
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
for_after_if_net = ForAfterIfNet()
|
||||
net = GradNet(for_after_if_net)
|
||||
pynative_forward_res = for_after_if_net(x)
|
||||
pynative_backward_res = net(x)
|
||||
|
||||
assert graph_forward_res == pynative_forward_res
|
||||
assert graph_backward_res == pynative_backward_res
|
|
@ -0,0 +1,69 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
from mindspore import context
|
||||
from mindspore import Tensor, nn
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.parameter import Parameter
|
||||
|
||||
grad_all = C.GradOperation(get_all=True)
|
||||
context.set_context(device_target="Ascend")
|
||||
|
||||
def test_for_after_if_in_if():
|
||||
class ForAfterIfInIfNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.param_a = Parameter(Tensor(5, mstype.int32), name='a')
|
||||
self.param_b = Parameter(Tensor(4, mstype.int32), name='b')
|
||||
|
||||
def construct(self, x):
|
||||
out = self.param_a
|
||||
if self.param_a > self.param_b:
|
||||
x += 3
|
||||
if x > self.param_a:
|
||||
self.param_b += 4
|
||||
x += self.param_a
|
||||
self.param_b += 2
|
||||
for _ in range(0, 5):
|
||||
x += self.param_b
|
||||
out *= x
|
||||
return out
|
||||
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNet, self).__init__()
|
||||
self.net = net
|
||||
|
||||
def construct(self, *inputs):
|
||||
return grad_all(self.net)(*inputs)
|
||||
|
||||
x = Tensor(5, mstype.int32)
|
||||
|
||||
# graph mode
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
for_after_if_in_if_net = ForAfterIfInIfNet()
|
||||
net = GradNet(for_after_if_in_if_net)
|
||||
graph_forward_res = for_after_if_in_if_net(x)
|
||||
graph_backward_res = net(x)
|
||||
|
||||
# pynative mode
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
for_after_if_in_if_net = ForAfterIfInIfNet()
|
||||
net = GradNet(for_after_if_in_if_net)
|
||||
pynative_forward_res = for_after_if_in_if_net(x)
|
||||
pynative_backward_res = net(x)
|
||||
|
||||
assert graph_forward_res == pynative_forward_res
|
||||
assert graph_backward_res == pynative_backward_res
|
|
@ -0,0 +1,68 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
from mindspore import context
|
||||
from mindspore import Tensor, nn
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.parameter import Parameter
|
||||
|
||||
grad_all = C.GradOperation(get_all=True)
|
||||
context.set_context(device_target="Ascend")
|
||||
|
||||
def test_for_after_for_in_if():
|
||||
class ForAfterForInIfNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.param_a = Parameter(Tensor(5, mstype.int32), name='a')
|
||||
self.param_b = Parameter(Tensor(4, mstype.int32), name='b')
|
||||
|
||||
def construct(self, x):
|
||||
out = self.param_a
|
||||
if self.param_a > self.param_b:
|
||||
for _ in range(0, 4):
|
||||
self.param_a += 1
|
||||
self.param_b -= 3
|
||||
self.param_b += 10
|
||||
for _ in range(0, 5):
|
||||
x += self.param_b
|
||||
out *= x
|
||||
return out
|
||||
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNet, self).__init__()
|
||||
self.net = net
|
||||
|
||||
def construct(self, *inputs):
|
||||
return grad_all(self.net)(*inputs)
|
||||
|
||||
x = Tensor(5, mstype.int32)
|
||||
|
||||
# graph mode
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
for_after_for_in_if_net = ForAfterForInIfNet()
|
||||
net = GradNet(for_after_for_in_if_net)
|
||||
graph_forward_res = for_after_for_in_if_net(x)
|
||||
graph_backward_res = net(x)
|
||||
|
||||
# pynative mode
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
for_after_for_in_if_net = ForAfterForInIfNet()
|
||||
net = GradNet(for_after_for_in_if_net)
|
||||
pynative_forward_res = for_after_for_in_if_net(x)
|
||||
pynative_backward_res = net(x)
|
||||
|
||||
assert graph_forward_res == pynative_forward_res
|
||||
assert graph_backward_res == pynative_backward_res
|
Loading…
Reference in New Issue