forked from mindspore-Ecosystem/mindspore
!1385 support for multi nest switch
Merge pull request !1385 from amongo/SupportMultiSwitch
This commit is contained in:
commit
1a4abefa9a
|
@ -52,13 +52,17 @@ bool InConvertWhiteList(const AnfNodePtr &node, size_t index) {
|
|||
// Example : when convert CNode(kPrimReduceSum, x, axis), node of index 2 in CNode->inputs is axis which should not be
|
||||
// converted to switch guarded.
|
||||
std::vector<std::pair<PrimitivePtr, std::vector<size_t>>> white_list(
|
||||
{{prim::kPrimApplyMomentum, {1, 2}}, {prim::kPrimMomentum, {2, 3}}, {prim::kPrimStateSetItem, {1}},
|
||||
{prim::kPrimEnvGetItem, {1}}, {prim::kPrimEnvSetItem, {1}}, {prim::kPrimReduceSum, {2}},
|
||||
{prim::kPrimReduceMean, {2}}, {prim::kPrimReduceAll, {2}}, {prim::kPrimCast, {2}},
|
||||
{prim::kPrimTranspose, {2}}, {prim::kPrimOneHot, {2}}, {prim::kPrimGatherV2, {3}},
|
||||
{prim::kPrimReshape, {2}}, {prim::kPrimAssign, {1}}, {prim::kPrimAssignAdd, {1}},
|
||||
{prim::kPrimAssignSub, {1}}, {prim::kPrimTensorSummary, {1}}, {prim::kPrimImageSummary, {1}},
|
||||
{prim::kPrimScalarSummary, {1}}, {prim::kPrimHistogramSummary, {1}}});
|
||||
{{prim::kPrimApplyMomentum, {1, 2}}, {prim::kPrimMomentum, {2, 3}},
|
||||
{prim::kPrimStateSetItem, {1}}, {prim::kPrimTupleGetItem, {2}},
|
||||
{prim::kPrimEnvGetItem, {1}}, {prim::kPrimEnvSetItem, {1}},
|
||||
{prim::kPrimReduceSum, {2}}, {prim::kPrimReduceMean, {2}},
|
||||
{prim::kPrimReduceAll, {2}}, {prim::kPrimCast, {2}},
|
||||
{prim::kPrimTranspose, {2}}, {prim::kPrimOneHot, {2}},
|
||||
{prim::kPrimGatherV2, {3}}, {prim::kPrimReshape, {2}},
|
||||
{prim::kPrimAssign, {1}}, {prim::kPrimAssignAdd, {1}},
|
||||
{prim::kPrimAssignSub, {1}}, {prim::kPrimTensorSummary, {1}},
|
||||
{prim::kPrimImageSummary, {1}}, {prim::kPrimScalarSummary, {1}},
|
||||
{prim::kPrimHistogramSummary, {1}}});
|
||||
for (auto &item : white_list) {
|
||||
auto matched = std::any_of(item.second.begin(), item.second.end(), [&item, &node, &index](size_t idx) {
|
||||
return IsPrimitiveCNode(node, item.first) && idx == index;
|
||||
|
@ -80,7 +84,8 @@ bool InConvertWhiteList(const AnfNodePtr &node, size_t index) {
|
|||
using NodeInputReplMap = std::unordered_map<std::pair<AnfNodePtr, size_t>, AnfNodePtr, PairHasher>;
|
||||
// replace the nodes which should be changed
|
||||
void RunSwitchNodeReplace(const FuncGraphManagerPtr &manager, std::vector<std::pair<CNodePtr, CNodePtr>> nodes_changed,
|
||||
std::unordered_map<AnfNodePtr, AnfNodePtr> repl_node, NodeInputReplMap repl_node_inputs) {
|
||||
std::unordered_map<AnfNodePtr, AnfNodePtr> repl_node, NodeInputReplMap repl_node_inputs,
|
||||
const FuncGraphPtr &func_graph) {
|
||||
for (auto &node_pair : nodes_changed) {
|
||||
CNodePtr old_node = node_pair.first;
|
||||
CNodePtr new_node = node_pair.second;
|
||||
|
@ -99,9 +104,11 @@ void RunSwitchNodeReplace(const FuncGraphManagerPtr &manager, std::vector<std::p
|
|||
}
|
||||
|
||||
for (auto &item : repl_node) {
|
||||
if (!manager->Replace(item.first, item.second)) {
|
||||
MS_LOG(EXCEPTION) << "TransformGraphDependNode replace node failed original:" << item.first->DebugString()
|
||||
<< " to new: " << item.second->DebugString();
|
||||
if (IsPrimitiveCNode(item.second, prim::kPrimReturn)) {
|
||||
func_graph->set_output(item.second->cast<CNodePtr>()->input(1));
|
||||
} else if (!manager->Replace(item.first, item.second)) {
|
||||
MS_LOG(EXCEPTION) << "TransformGraphDependNode replace node failed original:" << item.first->DebugString(2)
|
||||
<< " to new: " << item.second->DebugString(2);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -154,7 +161,7 @@ FuncGraphPtr TransformGraphCondBranchNodes(
|
|||
nodes_changed.emplace_back(node->cast<CNodePtr>(), new_node);
|
||||
}
|
||||
}
|
||||
RunSwitchNodeReplace(manager, nodes_changed, repl_node, repl_node_inputs);
|
||||
RunSwitchNodeReplace(manager, nodes_changed, repl_node, repl_node_inputs, graph);
|
||||
return graph;
|
||||
}
|
||||
|
||||
|
@ -508,11 +515,12 @@ bool GraphOutputCompatible(const AbstractBasePtr &true_branch_abs, const Abstrac
|
|||
|
||||
AnfNodePtr GenerateMergeNodes(const AnfNodePtr &true_output_node, const AnfNodePtr &false_output_node,
|
||||
const AbstractBasePtr &true_graph_output_abs,
|
||||
const AbstractBasePtr &false_graph_output_abs, const AnfNodePtr &cond) {
|
||||
const AbstractBasePtr &false_graph_output_abs, const FuncGraphPtr &switch_graph,
|
||||
const AnfNodePtr &cond) {
|
||||
MS_EXCEPTION_IF_NULL(true_graph_output_abs);
|
||||
MS_EXCEPTION_IF_NULL(false_graph_output_abs);
|
||||
MS_EXCEPTION_IF_NULL(cond);
|
||||
MS_EXCEPTION_IF_NULL(cond->func_graph());
|
||||
MS_EXCEPTION_IF_NULL(switch_graph);
|
||||
auto PrimMerge = prim::GetPythonOps("merge", "mindspore.ops.functional")->cast<PrimitivePtr>();
|
||||
MS_EXCEPTION_IF_NULL(PrimMerge);
|
||||
|
||||
|
@ -520,10 +528,10 @@ AnfNodePtr GenerateMergeNodes(const AnfNodePtr &true_output_node, const AnfNodeP
|
|||
std::vector<AnfNodePtr> merge_nodes;
|
||||
merge_nodes.push_back(NewValueNode(PrimMerge));
|
||||
std::vector<AnfNodePtr> make_tuple_nodes{NewValueNode(prim::kPrimMakeTuple), true_output_node, false_output_node};
|
||||
merge_nodes.push_back(cond->func_graph()->NewCNode(make_tuple_nodes));
|
||||
merge_nodes.push_back(switch_graph->NewCNode(make_tuple_nodes));
|
||||
std::vector<AnfNodePtr> tuple_getitem_nodes{NewValueNode(prim::kPrimTupleGetItem),
|
||||
cond->func_graph()->NewCNode(merge_nodes), NewValueNode(MakeValue(0))};
|
||||
return cond->func_graph()->NewCNode(tuple_getitem_nodes);
|
||||
switch_graph->NewCNode(merge_nodes), NewValueNode(MakeValue(0))};
|
||||
return switch_graph->NewCNode(tuple_getitem_nodes);
|
||||
} else {
|
||||
abstract::AbstractTuplePtr true_branch_tuple = true_graph_output_abs->cast<abstract::AbstractTuplePtr>();
|
||||
abstract::AbstractTuplePtr false_branch_tuple = false_graph_output_abs->cast<abstract::AbstractTuplePtr>();
|
||||
|
@ -533,27 +541,29 @@ AnfNodePtr GenerateMergeNodes(const AnfNodePtr &true_output_node, const AnfNodeP
|
|||
for (size_t i = 0; i < true_branch_tuple->elements().size(); i++) {
|
||||
std::vector<AnfNodePtr> true_getitem_nodes{NewValueNode(prim::kPrimTupleGetItem), true_output_node,
|
||||
NewValueNode(MakeValue(SizeToInt(i)))};
|
||||
auto true_node = cond->func_graph()->NewCNode(true_getitem_nodes);
|
||||
auto true_node = switch_graph->NewCNode(true_getitem_nodes);
|
||||
std::vector<AnfNodePtr> false_getitem_nodes{NewValueNode(prim::kPrimTupleGetItem), false_output_node,
|
||||
NewValueNode(MakeValue(SizeToInt(i)))};
|
||||
auto false_node = cond->func_graph()->NewCNode(false_getitem_nodes);
|
||||
auto false_node = switch_graph->NewCNode(false_getitem_nodes);
|
||||
|
||||
auto merge_node = GenerateMergeNodes(true_node, false_node, true_branch_tuple->elements()[i],
|
||||
false_branch_tuple->elements()[i], cond);
|
||||
false_branch_tuple->elements()[i], switch_graph, cond);
|
||||
make_tuple_nodes.push_back(merge_node);
|
||||
}
|
||||
return cond->func_graph()->NewCNode(make_tuple_nodes);
|
||||
return switch_graph->NewCNode(make_tuple_nodes);
|
||||
}
|
||||
}
|
||||
|
||||
AnfNodePtr TransformMergeBranches(const AnfNodePtr &true_output_node, const AnfNodePtr &false_output_node,
|
||||
const AbstractBasePtr &true_graph_output_abs,
|
||||
const AbstractBasePtr &false_graph_output_abs, const AnfNodePtr &cond) {
|
||||
const AbstractBasePtr &false_graph_output_abs, const AnfNodePtr &cond,
|
||||
const FuncGraphPtr &switch_graph) {
|
||||
if (!GraphOutputCompatible(true_graph_output_abs, false_graph_output_abs)) {
|
||||
MS_LOG(EXCEPTION) << "Switch output branch not compatible, true:" << true_graph_output_abs->ToString()
|
||||
<< ", false:" << false_graph_output_abs->ToString();
|
||||
}
|
||||
return GenerateMergeNodes(true_output_node, false_output_node, true_graph_output_abs, false_graph_output_abs, cond);
|
||||
return GenerateMergeNodes(true_output_node, false_output_node, true_graph_output_abs, false_graph_output_abs,
|
||||
switch_graph, cond);
|
||||
}
|
||||
} // namespace internal
|
||||
} // namespace irpass
|
||||
|
|
|
@ -168,7 +168,8 @@ FuncGraphPtr TransformGraphCondTrueBranchNodes(const FuncGraphPtr &graph, const
|
|||
FuncGraphPtr TransformGraphCondFalseBranchNodes(const FuncGraphPtr &graph, const AnfNodePtr &cond);
|
||||
AnfNodePtr TransformMergeBranches(const AnfNodePtr &true_output_node, const AnfNodePtr &false_output_node,
|
||||
const AbstractBasePtr &true_graph_output_abs,
|
||||
const AbstractBasePtr &false_graph_output_abs, const AnfNodePtr &cond);
|
||||
const AbstractBasePtr &false_graph_output_abs, const AnfNodePtr &cond,
|
||||
const FuncGraphPtr &func_graph);
|
||||
} // namespace internal
|
||||
|
||||
// {{prim::kPrimSwitch, X, G1, G2}, Xs}
|
||||
|
@ -190,6 +191,20 @@ class ConvertSwitchReplacement : public AnfVisitor {
|
|||
if (g2_ == nullptr || g1_->output() == nullptr || g2_->output() == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
// for switch replace method, only graphs without graph inside can be replaced
|
||||
for (auto &item : g1_->value_nodes()) {
|
||||
auto value_node = item.first;
|
||||
if (IsValueNode<FuncGraph>(value_node)) {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
for (auto &item : g2_->value_nodes()) {
|
||||
auto value_node = item.first;
|
||||
if (IsValueNode<FuncGraph>(value_node)) {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
auto true_output = g1_->output()->abstract();
|
||||
auto false_output = g2_->output()->abstract();
|
||||
|
@ -200,8 +215,8 @@ class ConvertSwitchReplacement : public AnfVisitor {
|
|||
auto fg = node->func_graph();
|
||||
auto cloned_g1 = InlineClone(trans_g1, fg, params);
|
||||
auto cloned_g2 = InlineClone(trans_g2, fg, params);
|
||||
|
||||
return internal::TransformMergeBranches(cloned_g1, cloned_g2, true_output, false_output, x_);
|
||||
auto nnode = internal::TransformMergeBranches(cloned_g1, cloned_g2, true_output, false_output, x_, fg);
|
||||
return nnode;
|
||||
}
|
||||
|
||||
void Visit(const AnfNodePtr &node) override {
|
||||
|
|
|
@ -162,7 +162,7 @@ OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) {
|
|||
}
|
||||
|
||||
OptPassGroupMap GetControlPhases(const opt::irpass::OptimizeIRPassLib &irpass) {
|
||||
opt::OptPassConfig control_group = opt::OptPassConfig({irpass.convert_switch_replacement_});
|
||||
opt::OptPassConfig control_group = opt::OptPassConfig({irpass.convert_switch_replacement_}, true);
|
||||
OptPassGroupMap map({
|
||||
{"control_group", control_group},
|
||||
{"renormalize", opt::OptPassConfig::Renormalize()},
|
||||
|
|
|
@ -346,7 +346,7 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) {
|
|||
if ((*value == *kAnyValue)) {
|
||||
auto value_desc = abs_base->value_desc();
|
||||
MS_EXCEPTION(TypeError) << "Unsupported parameter " << (value_desc.empty() ? "type" : value_desc)
|
||||
<< " for python primitive.";
|
||||
<< " for python primitive." << abs_base->ToString();
|
||||
}
|
||||
MS_EXCEPTION(TypeError) << "Unsupported parameter type for python primitive, the parameter value is "
|
||||
<< value->ToString();
|
||||
|
|
|
@ -24,6 +24,8 @@ from mindspore.common.parameter import Parameter, ParameterTuple
|
|||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.parameter import Parameter, ParameterTuple
|
||||
from mindspore.common import ms_function
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
@ -371,7 +373,8 @@ def test_switch_layer():
|
|||
class Layer1(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Layer1, self).__init__()
|
||||
self.z1 = Parameter(Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z1')
|
||||
self.z1 = Parameter(
|
||||
Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z1')
|
||||
|
||||
def construct(self, x):
|
||||
return x * self.z1
|
||||
|
@ -379,7 +382,8 @@ def test_switch_layer():
|
|||
class Layer2(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Layer2, self).__init__()
|
||||
self.z2 = Parameter(Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z2')
|
||||
self.z2 = Parameter(
|
||||
Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z2')
|
||||
|
||||
def construct(self, x):
|
||||
return x * self.z2
|
||||
|
@ -388,7 +392,8 @@ def test_switch_layer():
|
|||
def __init__(self):
|
||||
super(SwitchLayerCell, self).__init__()
|
||||
self.layers = (Layer1(), Layer2())
|
||||
self.z3 = Parameter(Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z3')
|
||||
self.z3 = Parameter(
|
||||
Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z3')
|
||||
|
||||
def construct(self, index, x):
|
||||
ret = F.switch_layer(index, self.layers)(x) * self.z3
|
||||
|
@ -406,7 +411,8 @@ def test_index_to_switch_layer():
|
|||
class Layer1(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Layer1, self).__init__()
|
||||
self.z1 = Parameter(Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z1')
|
||||
self.z1 = Parameter(
|
||||
Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z1')
|
||||
|
||||
def construct(self, x):
|
||||
return x * self.z1
|
||||
|
@ -414,7 +420,8 @@ def test_index_to_switch_layer():
|
|||
class Layer2(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Layer2, self).__init__()
|
||||
self.z2 = Parameter(Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z2')
|
||||
self.z2 = Parameter(
|
||||
Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z2')
|
||||
|
||||
def construct(self, x):
|
||||
return x * self.z2
|
||||
|
@ -423,7 +430,8 @@ def test_index_to_switch_layer():
|
|||
def __init__(self):
|
||||
super(SwitchLayerCell, self).__init__()
|
||||
self.layers = (Layer1(), Layer2())
|
||||
self.z3 = Parameter(Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z3')
|
||||
self.z3 = Parameter(
|
||||
Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z3')
|
||||
|
||||
def construct(self, index, x):
|
||||
ret = self.layers[index](x) * self.z3
|
||||
|
@ -444,3 +452,69 @@ def test_control_depend_check():
|
|||
depend = P.ControlDepend(2)
|
||||
with pytest.raises(TypeError) as e:
|
||||
depend = P.ControlDepend((2,))
|
||||
|
||||
|
||||
def test_if_nested_compile():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, auto_prefix=True):
|
||||
super().__init__(auto_prefix=auto_prefix)
|
||||
self.squre = P.Square()
|
||||
self.value = Tensor(3, dtype=ms.float32)
|
||||
|
||||
def construct(self, x, y):
|
||||
res = self.value
|
||||
if x <= y:
|
||||
res = x + res
|
||||
res = y + res
|
||||
else:
|
||||
if x == y:
|
||||
res = self.squre(self.value * y)
|
||||
else:
|
||||
res = self.squre(self.value)
|
||||
return res
|
||||
x = Tensor(1.0, dtype=ms.float32)
|
||||
y = Tensor(2.0, dtype=ms.float32)
|
||||
net = Net()
|
||||
net(x, y)
|
||||
|
||||
|
||||
def test_if_inside_for():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, auto_prefix=True):
|
||||
super().__init__(auto_prefix=auto_prefix)
|
||||
self.squre = P.Square()
|
||||
self.value = Tensor(3, dtype=ms.float32)
|
||||
self.count = 4
|
||||
|
||||
def construct(self, x, y):
|
||||
res = 0
|
||||
for i in range(self.count):
|
||||
if i == x:
|
||||
res = res + x
|
||||
else:
|
||||
res = res - y
|
||||
return res
|
||||
c1 = Tensor(1, dtype=ms.int32)
|
||||
c2 = Tensor(1, dtype=ms.int32)
|
||||
net = Net()
|
||||
out = net(c1, c2)
|
||||
|
||||
|
||||
def test_while_in_while():
|
||||
c1 = Tensor(1, dtype=ms.int32)
|
||||
c2 = Tensor(2, dtype=ms.int32)
|
||||
c3 = Tensor(3, dtype=ms.int32)
|
||||
c4 = Tensor(4, dtype=ms.int32)
|
||||
@ms_function
|
||||
def while_in_while(x, y, z, u):
|
||||
out = c4
|
||||
while x < y:
|
||||
z = c4 + c4
|
||||
while z < y:
|
||||
z = z + 1
|
||||
out = out + 1
|
||||
x = x + 1
|
||||
|
||||
out = out + 3
|
||||
return out
|
||||
while_in_while(c1, c2, c3, c4)
|
||||
|
|
|
@ -0,0 +1,81 @@
|
|||
import numpy as np
|
||||
|
||||
import mindspore
|
||||
from mindspore import nn
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
class Layer1(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Layer1, self).__init__()
|
||||
self.net = nn.Conv2d(3, 1, 3, pad_mode='same')
|
||||
self.pad = nn.Pad(
|
||||
paddings=((0, 0), (0, 2), (0, 0), (0, 0)), mode="CONSTANT")
|
||||
|
||||
def construct(self, x):
|
||||
y = self.net(x)
|
||||
return self.pad(y)
|
||||
|
||||
|
||||
class Layer2(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Layer2, self).__init__()
|
||||
self.net = nn.Conv2d(3, 1, 7, pad_mode='same')
|
||||
self.pad = nn.Pad(
|
||||
paddings=((0, 0), (0, 2), (0, 0), (0, 0)), mode="CONSTANT")
|
||||
|
||||
def construct(self, x):
|
||||
y = self.net(x)
|
||||
return self.pad(y)
|
||||
|
||||
|
||||
class Layer3(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Layer3, self).__init__()
|
||||
self.net = nn.Conv2d(3, 3, 3, pad_mode='same')
|
||||
|
||||
def construct(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class SwitchNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(SwitchNet, self).__init__()
|
||||
self.layer1 = Layer1()
|
||||
self.layer2 = Layer2()
|
||||
self.layer3 = Layer3()
|
||||
self.layers = (self.layer1, self.layer2, self.layer3)
|
||||
self.fill = P.Fill()
|
||||
|
||||
def construct(self, x, index):
|
||||
y = self.layers[index](x)
|
||||
return y
|
||||
|
||||
|
||||
class MySwitchNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(MySwitchNet, self).__init__()
|
||||
self.layer1 = Layer1()
|
||||
self.layer2 = Layer2()
|
||||
self.layer3 = Layer3()
|
||||
self.layers = (self.layer1, self.layer2, self.layer3)
|
||||
self.fill = P.Fill()
|
||||
|
||||
def construct(self, x, index):
|
||||
y = self.layers[0](x)
|
||||
for i in range(len(self.layers)):
|
||||
if i == index:
|
||||
y = self.layers[i](x)
|
||||
return y
|
||||
|
||||
|
||||
def test_layer_switch():
|
||||
net = MySwitchNet()
|
||||
x = Tensor(np.ones((3, 3, 24, 24)), mindspore.float32)
|
||||
index = Tensor(0, dtype=mindspore.int32)
|
||||
y = net(x, index)
|
Loading…
Reference in New Issue