forked from mindspore-Ecosystem/mindspore
!15784 if Switch/SwitchLayer, do not replace Load or remove UpdateState
From: @huangbingjian Reviewed-by: @zh_qh Signed-off-by: @zh_qh
This commit is contained in:
commit
31df54f424
|
@ -880,7 +880,7 @@ static std::vector<std::pair<CNodePtr, CNodePtr>> FindPrimalJPair(const FuncGrap
|
|||
} else if (IsPrimitive(cnode->inputs().at(0), prim::kPrimJ)) {
|
||||
// To find J user.
|
||||
auto j_user = GetJUser(node_user_map, cnode, index);
|
||||
primal_j_pair.emplace_back(std::pair<CNodePtr, CNodePtr>(nullptr, j_user));
|
||||
(void)primal_j_pair.emplace_back(std::pair<CNodePtr, CNodePtr>(nullptr, j_user));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -85,6 +85,11 @@ std::vector<std::vector<size_t>> SplitGroup(const std::vector<AnfNodePtr> &topos
|
|||
if (IsPrimitiveCNode(node, prim::kPrimLoad)) {
|
||||
return false;
|
||||
}
|
||||
// if Call/Switch/SwitchLayer, do not replace load.
|
||||
if (IsPrimitiveCNode(node, prim::kPrimCall) || IsPrimitiveCNode(node, prim::kPrimSwitch) ||
|
||||
IsPrimitiveCNode(node, prim::kPrimSwitchLayer)) {
|
||||
return true;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
auto &inputs = cnode->inputs();
|
||||
return std::any_of(inputs.begin(), inputs.end(),
|
||||
|
|
|
@ -28,6 +28,7 @@ namespace {
|
|||
// data = Load(input, attach)
|
||||
// data = Depend(input, attach)
|
||||
// monad = UpdateState(input, attach)
|
||||
constexpr size_t kFirstInputIndex = 0;
|
||||
constexpr size_t kInputIndex = 1;
|
||||
constexpr size_t kAttachIndex = 2;
|
||||
constexpr size_t kMakeTupleSize = 3;
|
||||
|
@ -120,6 +121,13 @@ AnfNodePtr EliminateUpdateStateForPureNode(const CNodePtr &update_state, const A
|
|||
return nullptr;
|
||||
}
|
||||
}
|
||||
// Skip Call/Switch/SwitchLayer.
|
||||
auto first_input_node = cnode->input(kFirstInputIndex);
|
||||
if (IsPrimitiveCNode(first_input_node, prim::kPrimCall) || IsPrimitiveCNode(first_input_node, prim::kPrimSwitch) ||
|
||||
IsPrimitiveCNode(first_input_node, prim::kPrimSwitchLayer)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Remove UpdateState by replace it with its input monad.
|
||||
return update_state->input(kInputIndex);
|
||||
}
|
||||
|
|
|
@ -246,7 +246,7 @@ void SubstitutionList::DisplayStatusOfSubstitution(const std::unordered_map<std:
|
|||
<< std::endl;
|
||||
for (size_t i = 0; i < list_.size(); i++) {
|
||||
auto name = list_[i]->name_;
|
||||
ss << std::left << std::setw(space + 4) << name << "\t";
|
||||
ss << std::left << std::setw(SizeToInt(space) + 4) << name << "\t";
|
||||
for (auto change : status.at(name + std::to_string(i))) {
|
||||
ss << change << " ";
|
||||
}
|
||||
|
|
|
@ -393,7 +393,7 @@ class TensorDataImpl : public TensorData {
|
|||
pos++;
|
||||
}
|
||||
size_t len = pos - index;
|
||||
std::string space(max_width - len, ' ');
|
||||
std::string space(max_width - SizeToInt(len), ' ');
|
||||
str = str.replace(index, len, space);
|
||||
index = str.find('#', index);
|
||||
}
|
||||
|
|
|
@ -170,3 +170,52 @@ def test_for_in_if_03():
|
|||
|
||||
assert graph_forward_res == pynative_forward_res
|
||||
assert graph_backward_res == pynative_backward_res
|
||||
|
||||
|
||||
def test_for_in_if_04():
|
||||
class ForInIfNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.param_a = Parameter(Tensor(5, mstype.int32), name='a')
|
||||
self.param_b = Parameter(Tensor(4, mstype.int32), name='b')
|
||||
|
||||
def construct(self, x):
|
||||
out = self.param_a
|
||||
x = self.func(x)
|
||||
out *= x
|
||||
return out
|
||||
|
||||
def func(self, x):
|
||||
if self.param_a > self.param_b:
|
||||
for _ in range(0, 4):
|
||||
self.param_a += 1
|
||||
self.param_b -= 3
|
||||
self.param_b += 10
|
||||
return x
|
||||
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNet, self).__init__()
|
||||
self.net = net
|
||||
|
||||
def construct(self, *inputs):
|
||||
return grad_all(self.net)(*inputs)
|
||||
|
||||
x = Tensor(5, mstype.int32)
|
||||
|
||||
# graph mode
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
for_in_if_net = ForInIfNet()
|
||||
net = GradNet(for_in_if_net)
|
||||
graph_forward_res = for_in_if_net(x)
|
||||
graph_backward_res = net(x)
|
||||
|
||||
# pynative mode
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
for_in_if_net = ForInIfNet()
|
||||
net = GradNet(for_in_if_net)
|
||||
pynative_forward_res = for_in_if_net(x)
|
||||
pynative_backward_res = net(x)
|
||||
|
||||
assert graph_forward_res == pynative_forward_res
|
||||
assert graph_backward_res == pynative_backward_res
|
||||
|
|
|
@ -74,6 +74,25 @@ class IfAfterIfNet2(nn.Cell):
|
|||
return y
|
||||
|
||||
|
||||
class IfAfterIfNet3(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.param_a = Parameter(Tensor(5, mstype.int32), name='a')
|
||||
self.param_b = Parameter(Tensor(4, mstype.int32), name='b')
|
||||
|
||||
def construct(self, x, y):
|
||||
out = x * y + self.func(self.param_b)
|
||||
if self.param_a > self.param_b:
|
||||
out += 5
|
||||
return out
|
||||
|
||||
def func(self, x):
|
||||
if self.param_a > self.param_b:
|
||||
x += 5
|
||||
self.param_b += 4
|
||||
return x
|
||||
|
||||
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNet, self).__init__()
|
||||
|
@ -118,3 +137,9 @@ def test_if_after_if_02():
|
|||
x = Tensor(2, mstype.int32)
|
||||
y = Tensor(5, mstype.int32)
|
||||
control_flow_if_after_if(IfAfterIfNet2, x, y)
|
||||
|
||||
|
||||
def test_if_after_if_03():
|
||||
x = Tensor(2, mstype.int32)
|
||||
y = Tensor(5, mstype.int32)
|
||||
control_flow_if_after_if(IfAfterIfNet3, x, y)
|
||||
|
|
Loading…
Reference in New Issue