!1385 support for multi nest switch

Merge pull request !1385 from amongo/SupportMultiSwitch
This commit is contained in:
mindspore-ci-bot 2020-05-28 16:21:10 +08:00 committed by Gitee
commit 1a4abefa9a
6 changed files with 214 additions and 34 deletions

View File

@ -52,13 +52,17 @@ bool InConvertWhiteList(const AnfNodePtr &node, size_t index) {
// Example : when convert CNode(kPrimReduceSum, x, axis), node of index 2 in CNode->inputs is axis which should not be
// converted to switch guarded.
std::vector<std::pair<PrimitivePtr, std::vector<size_t>>> white_list(
{{prim::kPrimApplyMomentum, {1, 2}}, {prim::kPrimMomentum, {2, 3}}, {prim::kPrimStateSetItem, {1}},
{prim::kPrimEnvGetItem, {1}}, {prim::kPrimEnvSetItem, {1}}, {prim::kPrimReduceSum, {2}},
{prim::kPrimReduceMean, {2}}, {prim::kPrimReduceAll, {2}}, {prim::kPrimCast, {2}},
{prim::kPrimTranspose, {2}}, {prim::kPrimOneHot, {2}}, {prim::kPrimGatherV2, {3}},
{prim::kPrimReshape, {2}}, {prim::kPrimAssign, {1}}, {prim::kPrimAssignAdd, {1}},
{prim::kPrimAssignSub, {1}}, {prim::kPrimTensorSummary, {1}}, {prim::kPrimImageSummary, {1}},
{prim::kPrimScalarSummary, {1}}, {prim::kPrimHistogramSummary, {1}}});
{{prim::kPrimApplyMomentum, {1, 2}}, {prim::kPrimMomentum, {2, 3}},
{prim::kPrimStateSetItem, {1}}, {prim::kPrimTupleGetItem, {2}},
{prim::kPrimEnvGetItem, {1}}, {prim::kPrimEnvSetItem, {1}},
{prim::kPrimReduceSum, {2}}, {prim::kPrimReduceMean, {2}},
{prim::kPrimReduceAll, {2}}, {prim::kPrimCast, {2}},
{prim::kPrimTranspose, {2}}, {prim::kPrimOneHot, {2}},
{prim::kPrimGatherV2, {3}}, {prim::kPrimReshape, {2}},
{prim::kPrimAssign, {1}}, {prim::kPrimAssignAdd, {1}},
{prim::kPrimAssignSub, {1}}, {prim::kPrimTensorSummary, {1}},
{prim::kPrimImageSummary, {1}}, {prim::kPrimScalarSummary, {1}},
{prim::kPrimHistogramSummary, {1}}});
for (auto &item : white_list) {
auto matched = std::any_of(item.second.begin(), item.second.end(), [&item, &node, &index](size_t idx) {
return IsPrimitiveCNode(node, item.first) && idx == index;
@ -80,7 +84,8 @@ bool InConvertWhiteList(const AnfNodePtr &node, size_t index) {
using NodeInputReplMap = std::unordered_map<std::pair<AnfNodePtr, size_t>, AnfNodePtr, PairHasher>;
// replace the nodes which should be changed
void RunSwitchNodeReplace(const FuncGraphManagerPtr &manager, std::vector<std::pair<CNodePtr, CNodePtr>> nodes_changed,
std::unordered_map<AnfNodePtr, AnfNodePtr> repl_node, NodeInputReplMap repl_node_inputs) {
std::unordered_map<AnfNodePtr, AnfNodePtr> repl_node, NodeInputReplMap repl_node_inputs,
const FuncGraphPtr &func_graph) {
for (auto &node_pair : nodes_changed) {
CNodePtr old_node = node_pair.first;
CNodePtr new_node = node_pair.second;
@ -99,9 +104,11 @@ void RunSwitchNodeReplace(const FuncGraphManagerPtr &manager, std::vector<std::p
}
for (auto &item : repl_node) {
if (!manager->Replace(item.first, item.second)) {
MS_LOG(EXCEPTION) << "TransformGraphDependNode replace node failed original:" << item.first->DebugString()
<< " to new: " << item.second->DebugString();
if (IsPrimitiveCNode(item.second, prim::kPrimReturn)) {
func_graph->set_output(item.second->cast<CNodePtr>()->input(1));
} else if (!manager->Replace(item.first, item.second)) {
MS_LOG(EXCEPTION) << "TransformGraphDependNode replace node failed original:" << item.first->DebugString(2)
<< " to new: " << item.second->DebugString(2);
}
}
}
@ -154,7 +161,7 @@ FuncGraphPtr TransformGraphCondBranchNodes(
nodes_changed.emplace_back(node->cast<CNodePtr>(), new_node);
}
}
RunSwitchNodeReplace(manager, nodes_changed, repl_node, repl_node_inputs);
RunSwitchNodeReplace(manager, nodes_changed, repl_node, repl_node_inputs, graph);
return graph;
}
@ -508,11 +515,12 @@ bool GraphOutputCompatible(const AbstractBasePtr &true_branch_abs, const Abstrac
AnfNodePtr GenerateMergeNodes(const AnfNodePtr &true_output_node, const AnfNodePtr &false_output_node,
const AbstractBasePtr &true_graph_output_abs,
const AbstractBasePtr &false_graph_output_abs, const AnfNodePtr &cond) {
const AbstractBasePtr &false_graph_output_abs, const FuncGraphPtr &switch_graph,
const AnfNodePtr &cond) {
MS_EXCEPTION_IF_NULL(true_graph_output_abs);
MS_EXCEPTION_IF_NULL(false_graph_output_abs);
MS_EXCEPTION_IF_NULL(cond);
MS_EXCEPTION_IF_NULL(cond->func_graph());
MS_EXCEPTION_IF_NULL(switch_graph);
auto PrimMerge = prim::GetPythonOps("merge", "mindspore.ops.functional")->cast<PrimitivePtr>();
MS_EXCEPTION_IF_NULL(PrimMerge);
@ -520,10 +528,10 @@ AnfNodePtr GenerateMergeNodes(const AnfNodePtr &true_output_node, const AnfNodeP
std::vector<AnfNodePtr> merge_nodes;
merge_nodes.push_back(NewValueNode(PrimMerge));
std::vector<AnfNodePtr> make_tuple_nodes{NewValueNode(prim::kPrimMakeTuple), true_output_node, false_output_node};
merge_nodes.push_back(cond->func_graph()->NewCNode(make_tuple_nodes));
merge_nodes.push_back(switch_graph->NewCNode(make_tuple_nodes));
std::vector<AnfNodePtr> tuple_getitem_nodes{NewValueNode(prim::kPrimTupleGetItem),
cond->func_graph()->NewCNode(merge_nodes), NewValueNode(MakeValue(0))};
return cond->func_graph()->NewCNode(tuple_getitem_nodes);
switch_graph->NewCNode(merge_nodes), NewValueNode(MakeValue(0))};
return switch_graph->NewCNode(tuple_getitem_nodes);
} else {
abstract::AbstractTuplePtr true_branch_tuple = true_graph_output_abs->cast<abstract::AbstractTuplePtr>();
abstract::AbstractTuplePtr false_branch_tuple = false_graph_output_abs->cast<abstract::AbstractTuplePtr>();
@ -533,27 +541,29 @@ AnfNodePtr GenerateMergeNodes(const AnfNodePtr &true_output_node, const AnfNodeP
for (size_t i = 0; i < true_branch_tuple->elements().size(); i++) {
std::vector<AnfNodePtr> true_getitem_nodes{NewValueNode(prim::kPrimTupleGetItem), true_output_node,
NewValueNode(MakeValue(SizeToInt(i)))};
auto true_node = cond->func_graph()->NewCNode(true_getitem_nodes);
auto true_node = switch_graph->NewCNode(true_getitem_nodes);
std::vector<AnfNodePtr> false_getitem_nodes{NewValueNode(prim::kPrimTupleGetItem), false_output_node,
NewValueNode(MakeValue(SizeToInt(i)))};
auto false_node = cond->func_graph()->NewCNode(false_getitem_nodes);
auto false_node = switch_graph->NewCNode(false_getitem_nodes);
auto merge_node = GenerateMergeNodes(true_node, false_node, true_branch_tuple->elements()[i],
false_branch_tuple->elements()[i], cond);
false_branch_tuple->elements()[i], switch_graph, cond);
make_tuple_nodes.push_back(merge_node);
}
return cond->func_graph()->NewCNode(make_tuple_nodes);
return switch_graph->NewCNode(make_tuple_nodes);
}
}
AnfNodePtr TransformMergeBranches(const AnfNodePtr &true_output_node, const AnfNodePtr &false_output_node,
const AbstractBasePtr &true_graph_output_abs,
const AbstractBasePtr &false_graph_output_abs, const AnfNodePtr &cond) {
const AbstractBasePtr &false_graph_output_abs, const AnfNodePtr &cond,
const FuncGraphPtr &switch_graph) {
if (!GraphOutputCompatible(true_graph_output_abs, false_graph_output_abs)) {
MS_LOG(EXCEPTION) << "Switch output branch not compatible, true:" << true_graph_output_abs->ToString()
<< " false:" << false_graph_output_abs->ToString();
}
return GenerateMergeNodes(true_output_node, false_output_node, true_graph_output_abs, false_graph_output_abs, cond);
return GenerateMergeNodes(true_output_node, false_output_node, true_graph_output_abs, false_graph_output_abs,
switch_graph, cond);
}
} // namespace internal
} // namespace irpass

View File

@ -168,7 +168,8 @@ FuncGraphPtr TransformGraphCondTrueBranchNodes(const FuncGraphPtr &graph, const
FuncGraphPtr TransformGraphCondFalseBranchNodes(const FuncGraphPtr &graph, const AnfNodePtr &cond);
AnfNodePtr TransformMergeBranches(const AnfNodePtr &true_output_node, const AnfNodePtr &false_output_node,
const AbstractBasePtr &true_graph_output_abs,
const AbstractBasePtr &false_graph_output_abs, const AnfNodePtr &cond);
const AbstractBasePtr &false_graph_output_abs, const AnfNodePtr &cond,
const FuncGraphPtr &func_graph);
} // namespace internal
// {{prim::kPrimSwitch, X, G1, G2}, Xs}
@ -190,6 +191,20 @@ class ConvertSwitchReplacement : public AnfVisitor {
if (g2_ == nullptr || g1_->output() == nullptr || g2_->output() == nullptr) {
return nullptr;
}
// for switch replace method, only graphs without graph inside can be replaced
for (auto &item : g1_->value_nodes()) {
auto value_node = item.first;
if (IsValueNode<FuncGraph>(value_node)) {
return nullptr;
}
}
for (auto &item : g2_->value_nodes()) {
auto value_node = item.first;
if (IsValueNode<FuncGraph>(value_node)) {
return nullptr;
}
}
auto true_output = g1_->output()->abstract();
auto false_output = g2_->output()->abstract();
@ -200,8 +215,8 @@ class ConvertSwitchReplacement : public AnfVisitor {
auto fg = node->func_graph();
auto cloned_g1 = InlineClone(trans_g1, fg, params);
auto cloned_g2 = InlineClone(trans_g2, fg, params);
return internal::TransformMergeBranches(cloned_g1, cloned_g2, true_output, false_output, x_);
auto nnode = internal::TransformMergeBranches(cloned_g1, cloned_g2, true_output, false_output, x_, fg);
return nnode;
}
void Visit(const AnfNodePtr &node) override {

View File

@ -162,7 +162,7 @@ OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) {
}
OptPassGroupMap GetControlPhases(const opt::irpass::OptimizeIRPassLib &irpass) {
opt::OptPassConfig control_group = opt::OptPassConfig({irpass.convert_switch_replacement_});
opt::OptPassConfig control_group = opt::OptPassConfig({irpass.convert_switch_replacement_}, true);
OptPassGroupMap map({
{"control_group", control_group},
{"renormalize", opt::OptPassConfig::Renormalize()},

View File

@ -346,7 +346,7 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) {
if ((*value == *kAnyValue)) {
auto value_desc = abs_base->value_desc();
MS_EXCEPTION(TypeError) << "Unsupported parameter " << (value_desc.empty() ? "type" : value_desc)
<< " for python primitive.";
<< " for python primitive." << abs_base->ToString();
}
MS_EXCEPTION(TypeError) << "Unsupported parameter type for python primitive, the parameter value is "
<< value->ToString();

View File

@ -24,6 +24,8 @@ from mindspore.common.parameter import Parameter, ParameterTuple
from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore.ops import operations as P
from mindspore.common.parameter import Parameter, ParameterTuple
from mindspore.common import ms_function
context.set_context(mode=context.GRAPH_MODE)
@ -371,7 +373,8 @@ def test_switch_layer():
class Layer1(nn.Cell):
def __init__(self):
super(Layer1, self).__init__()
self.z1 = Parameter(Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z1')
self.z1 = Parameter(
Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z1')
def construct(self, x):
return x * self.z1
@ -379,7 +382,8 @@ def test_switch_layer():
class Layer2(nn.Cell):
def __init__(self):
super(Layer2, self).__init__()
self.z2 = Parameter(Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z2')
self.z2 = Parameter(
Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z2')
def construct(self, x):
return x * self.z2
@ -388,7 +392,8 @@ def test_switch_layer():
def __init__(self):
super(SwitchLayerCell, self).__init__()
self.layers = (Layer1(), Layer2())
self.z3 = Parameter(Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z3')
self.z3 = Parameter(
Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z3')
def construct(self, index, x):
ret = F.switch_layer(index, self.layers)(x) * self.z3
@ -406,7 +411,8 @@ def test_index_to_switch_layer():
class Layer1(nn.Cell):
def __init__(self):
super(Layer1, self).__init__()
self.z1 = Parameter(Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z1')
self.z1 = Parameter(
Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z1')
def construct(self, x):
return x * self.z1
@ -414,7 +420,8 @@ def test_index_to_switch_layer():
class Layer2(nn.Cell):
def __init__(self):
super(Layer2, self).__init__()
self.z2 = Parameter(Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z2')
self.z2 = Parameter(
Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z2')
def construct(self, x):
return x * self.z2
@ -423,7 +430,8 @@ def test_index_to_switch_layer():
def __init__(self):
super(SwitchLayerCell, self).__init__()
self.layers = (Layer1(), Layer2())
self.z3 = Parameter(Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z3')
self.z3 = Parameter(
Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z3')
def construct(self, index, x):
ret = self.layers[index](x) * self.z3
@ -444,3 +452,69 @@ def test_control_depend_check():
depend = P.ControlDepend(2)
with pytest.raises(TypeError) as e:
depend = P.ControlDepend((2,))
def test_if_nested_compile():
class Net(nn.Cell):
def __init__(self, auto_prefix=True):
super().__init__(auto_prefix=auto_prefix)
self.squre = P.Square()
self.value = Tensor(3, dtype=ms.float32)
def construct(self, x, y):
res = self.value
if x <= y:
res = x + res
res = y + res
else:
if x == y:
res = self.squre(self.value * y)
else:
res = self.squre(self.value)
return res
x = Tensor(1.0, dtype=ms.float32)
y = Tensor(2.0, dtype=ms.float32)
net = Net()
net(x, y)
def test_if_inside_for():
class Net(nn.Cell):
def __init__(self, auto_prefix=True):
super().__init__(auto_prefix=auto_prefix)
self.squre = P.Square()
self.value = Tensor(3, dtype=ms.float32)
self.count = 4
def construct(self, x, y):
res = 0
for i in range(self.count):
if i == x:
res = res + x
else:
res = res - y
return res
c1 = Tensor(1, dtype=ms.int32)
c2 = Tensor(1, dtype=ms.int32)
net = Net()
out = net(c1, c2)
def test_while_in_while():
c1 = Tensor(1, dtype=ms.int32)
c2 = Tensor(2, dtype=ms.int32)
c3 = Tensor(3, dtype=ms.int32)
c4 = Tensor(4, dtype=ms.int32)
@ms_function
def while_in_while(x, y, z, u):
out = c4
while x < y:
z = c4 + c4
while z < y:
z = z + 1
out = out + 1
x = x + 1
out = out + 3
return out
while_in_while(c1, c2, c3, c4)

View File

@ -0,0 +1,81 @@
import numpy as np
import mindspore
from mindspore import nn
from mindspore import Tensor
from mindspore import context
from mindspore.ops import operations as P
context.set_context(mode=context.GRAPH_MODE)
class Layer1(nn.Cell):
def __init__(self):
super(Layer1, self).__init__()
self.net = nn.Conv2d(3, 1, 3, pad_mode='same')
self.pad = nn.Pad(
paddings=((0, 0), (0, 2), (0, 0), (0, 0)), mode="CONSTANT")
def construct(self, x):
y = self.net(x)
return self.pad(y)
class Layer2(nn.Cell):
def __init__(self):
super(Layer2, self).__init__()
self.net = nn.Conv2d(3, 1, 7, pad_mode='same')
self.pad = nn.Pad(
paddings=((0, 0), (0, 2), (0, 0), (0, 0)), mode="CONSTANT")
def construct(self, x):
y = self.net(x)
return self.pad(y)
class Layer3(nn.Cell):
def __init__(self):
super(Layer3, self).__init__()
self.net = nn.Conv2d(3, 3, 3, pad_mode='same')
def construct(self, x):
return self.net(x)
class SwitchNet(nn.Cell):
def __init__(self):
super(SwitchNet, self).__init__()
self.layer1 = Layer1()
self.layer2 = Layer2()
self.layer3 = Layer3()
self.layers = (self.layer1, self.layer2, self.layer3)
self.fill = P.Fill()
def construct(self, x, index):
y = self.layers[index](x)
return y
class MySwitchNet(nn.Cell):
def __init__(self):
super(MySwitchNet, self).__init__()
self.layer1 = Layer1()
self.layer2 = Layer2()
self.layer3 = Layer3()
self.layers = (self.layer1, self.layer2, self.layer3)
self.fill = P.Fill()
def construct(self, x, index):
y = self.layers[0](x)
for i in range(len(self.layers)):
if i == index:
y = self.layers[i](x)
return y
def test_layer_switch():
net = MySwitchNet()
x = Tensor(np.ones((3, 3, 24, 24)), mindspore.float32)
index = Tensor(0, dtype=mindspore.int32)
y = net(x, index)