1. if Switch/SwitchLayer, do not replace Load or remove UpdateState; 2. add control flow testcases; 3. fix codedex problem

This commit is contained in:
huangbingjian 2021-04-27 17:24:55 +08:00
parent 3bd51c88f4
commit 2a85af5d83
7 changed files with 90 additions and 3 deletions

View File

@ -880,7 +880,7 @@ static std::vector<std::pair<CNodePtr, CNodePtr>> FindPrimalJPair(const FuncGrap
} else if (IsPrimitive(cnode->inputs().at(0), prim::kPrimJ)) {
// To find J user.
auto j_user = GetJUser(node_user_map, cnode, index);
primal_j_pair.emplace_back(std::pair<CNodePtr, CNodePtr>(nullptr, j_user));
(void)primal_j_pair.emplace_back(std::pair<CNodePtr, CNodePtr>(nullptr, j_user));
}
}

View File

@ -85,6 +85,11 @@ std::vector<std::vector<size_t>> SplitGroup(const std::vector<AnfNodePtr> &topos
if (IsPrimitiveCNode(node, prim::kPrimLoad)) {
return false;
}
// if Call/Switch/SwitchLayer, do not replace load.
if (IsPrimitiveCNode(node, prim::kPrimCall) || IsPrimitiveCNode(node, prim::kPrimSwitch) ||
IsPrimitiveCNode(node, prim::kPrimSwitchLayer)) {
return true;
}
auto cnode = node->cast<CNodePtr>();
auto &inputs = cnode->inputs();
return std::any_of(inputs.begin(), inputs.end(),

View File

@ -28,6 +28,7 @@ namespace {
// data = Load(input, attach)
// data = Depend(input, attach)
// monad = UpdateState(input, attach)
constexpr size_t kFirstInputIndex = 0;
constexpr size_t kInputIndex = 1;
constexpr size_t kAttachIndex = 2;
constexpr size_t kMakeTupleSize = 3;
@ -120,6 +121,13 @@ AnfNodePtr EliminateUpdateStateForPureNode(const CNodePtr &update_state, const A
return nullptr;
}
}
// Skip Call/Switch/SwitchLayer.
auto first_input_node = cnode->input(kFirstInputIndex);
if (IsPrimitiveCNode(first_input_node, prim::kPrimCall) || IsPrimitiveCNode(first_input_node, prim::kPrimSwitch) ||
IsPrimitiveCNode(first_input_node, prim::kPrimSwitchLayer)) {
return nullptr;
}
// Remove UpdateState by replace it with its input monad.
return update_state->input(kInputIndex);
}

View File

@ -246,7 +246,7 @@ void SubstitutionList::DisplayStatusOfSubstitution(const std::unordered_map<std:
<< std::endl;
for (size_t i = 0; i < list_.size(); i++) {
auto name = list_[i]->name_;
ss << std::left << std::setw(space + 4) << name << "\t";
ss << std::left << std::setw(SizeToInt(space) + 4) << name << "\t";
for (auto change : status.at(name + std::to_string(i))) {
ss << change << " ";
}

View File

@ -393,7 +393,7 @@ class TensorDataImpl : public TensorData {
pos++;
}
size_t len = pos - index;
std::string space(max_width - len, ' ');
std::string space(max_width - SizeToInt(len), ' ');
str = str.replace(index, len, space);
index = str.find('#', index);
}

View File

@ -170,3 +170,52 @@ def test_for_in_if_03():
assert graph_forward_res == pynative_forward_res
assert graph_backward_res == pynative_backward_res
def test_for_in_if_04():
class ForInIfNet(nn.Cell):
def __init__(self):
super().__init__()
self.param_a = Parameter(Tensor(5, mstype.int32), name='a')
self.param_b = Parameter(Tensor(4, mstype.int32), name='b')
def construct(self, x):
out = self.param_a
x = self.func(x)
out *= x
return out
def func(self, x):
if self.param_a > self.param_b:
for _ in range(0, 4):
self.param_a += 1
self.param_b -= 3
self.param_b += 10
return x
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
def construct(self, *inputs):
return grad_all(self.net)(*inputs)
x = Tensor(5, mstype.int32)
# graph mode
context.set_context(mode=context.GRAPH_MODE)
for_in_if_net = ForInIfNet()
net = GradNet(for_in_if_net)
graph_forward_res = for_in_if_net(x)
graph_backward_res = net(x)
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE)
for_in_if_net = ForInIfNet()
net = GradNet(for_in_if_net)
pynative_forward_res = for_in_if_net(x)
pynative_backward_res = net(x)
assert graph_forward_res == pynative_forward_res
assert graph_backward_res == pynative_backward_res

View File

@ -74,6 +74,25 @@ class IfAfterIfNet2(nn.Cell):
return y
class IfAfterIfNet3(nn.Cell):
def __init__(self):
super().__init__()
self.param_a = Parameter(Tensor(5, mstype.int32), name='a')
self.param_b = Parameter(Tensor(4, mstype.int32), name='b')
def construct(self, x, y):
out = x * y + self.func(self.param_b)
if self.param_a > self.param_b:
out += 5
return out
def func(self, x):
if self.param_a > self.param_b:
x += 5
self.param_b += 4
return x
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
@ -118,3 +137,9 @@ def test_if_after_if_02():
x = Tensor(2, mstype.int32)
y = Tensor(5, mstype.int32)
control_flow_if_after_if(IfAfterIfNet2, x, y)
def test_if_after_if_03():
x = Tensor(2, mstype.int32)
y = Tensor(5, mstype.int32)
control_flow_if_after_if(IfAfterIfNet3, x, y)