forked from mindspore-Ecosystem/mindspore
set abstract of switch
This commit is contained in:
parent
bfdd040728
commit
fa7cb34c8a
|
@ -1696,7 +1696,7 @@ std::vector<KernelGraphPtr> AnfRuntimeAlgorithm::GetCallSwitchKernelGraph(const
|
|||
}
|
||||
auto partial_cnode = partial->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(partial_cnode);
|
||||
auto graph_node = partial_cnode->input(kCallKernelGraphIndex);
|
||||
auto graph_node = partial_cnode->input(kPartialGraphIndex);
|
||||
MS_EXCEPTION_IF_NULL(graph_node);
|
||||
auto graph_value_node = graph_node->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(graph_value_node);
|
||||
|
@ -1706,7 +1706,7 @@ std::vector<KernelGraphPtr> AnfRuntimeAlgorithm::GetCallSwitchKernelGraph(const
|
|||
return child_graph;
|
||||
};
|
||||
if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimCall)) {
|
||||
auto input1 = cnode->input(kCallKernelGraphIndex);
|
||||
auto input1 = cnode->input(kPartialGraphIndex);
|
||||
MS_EXCEPTION_IF_NULL(input1);
|
||||
auto value_node = input1->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
|
@ -1714,11 +1714,10 @@ std::vector<KernelGraphPtr> AnfRuntimeAlgorithm::GetCallSwitchKernelGraph(const
|
|||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
return {kernel_graph->cast<KernelGraphPtr>()};
|
||||
} else if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) {
|
||||
return {get_switch_kernel_graph(kSwitchTrueKernelGraphIndex),
|
||||
get_switch_kernel_graph(kSwitchFalseKernelGraphIndex)};
|
||||
return {get_switch_kernel_graph(kSwitchTrueBranchIndex), get_switch_kernel_graph(kSwitchFalseBranchIndex)};
|
||||
} else if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitchLayer)) {
|
||||
std::vector<KernelGraphPtr> child_graphs;
|
||||
for (size_t idx = kMakeTupleInSwitchLayerIndex; idx < cnode->inputs().size(); idx++) {
|
||||
for (size_t idx = kSwitchLayerBranchesIndex; idx < cnode->inputs().size(); idx++) {
|
||||
auto child_graph = get_switch_kernel_graph(idx);
|
||||
child_graphs.emplace_back(child_graph);
|
||||
}
|
||||
|
|
|
@ -642,7 +642,7 @@ class CallInfoFinder {
|
|||
}
|
||||
|
||||
CallBranch GetCallBranch(const CNodePtr &cnode) {
|
||||
auto input_graph = cnode->input(kCallKernelGraphIndex);
|
||||
auto input_graph = cnode->input(kPartialGraphIndex);
|
||||
MS_EXCEPTION_IF_NULL(input_graph);
|
||||
auto kg = GetValueNode<KernelGraphPtr>(input_graph);
|
||||
MS_EXCEPTION_IF_NULL(kg);
|
||||
|
|
|
@ -826,7 +826,7 @@ std::vector<AnfNodePtr> SessionBasic::CreateCallSwitchInputs(const CNodePtr &cno
|
|||
}
|
||||
std::vector<AnfNodePtr> switch_inputs = {switch_cnode->input(kAnfPrimitiveIndex),
|
||||
switch_cnode->input(kFirstDataInputIndex)};
|
||||
for (size_t index = kFirstBranchInSwitch; index < switch_cnode->inputs().size(); index++) {
|
||||
for (size_t index = kSwitchTrueBranchIndex; index < switch_cnode->inputs().size(); index++) {
|
||||
auto node = switch_cnode->input(index);
|
||||
// there is real input in call, should put it to true and false branch in switch
|
||||
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial)) {
|
||||
|
@ -919,7 +919,7 @@ std::vector<AnfNodePtr> SessionBasic::CreateCallSwitchLayerInputs(const CNodePtr
|
|||
MS_EXCEPTION_IF_NULL(switch_layer_cnode);
|
||||
std::vector<AnfNodePtr> switch_layer_inputs = {switch_layer_cnode->input(kAnfPrimitiveIndex),
|
||||
switch_layer_cnode->input(kFirstDataInputIndex)};
|
||||
auto make_tuple_node = switch_layer_cnode->input(kMakeTupleInSwitchLayerIndex);
|
||||
auto make_tuple_node = switch_layer_cnode->input(kSwitchLayerBranchesIndex);
|
||||
MS_EXCEPTION_IF_NULL(make_tuple_node);
|
||||
auto node = make_tuple_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
@ -1040,7 +1040,7 @@ void SessionBasic::CreateCNodeInputs(const CNodePtr &cnode, KernelGraph *graph,
|
|||
MS_EXCEPTION_IF_NULL(graph);
|
||||
if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) {
|
||||
(void)cnode_inputs->emplace_back(graph->GetBackendAnfByFrontAnf(cnode->input(kFirstDataInputIndex)));
|
||||
for (size_t index = kFirstBranchInSwitch; index < cnode->inputs().size(); index++) {
|
||||
for (size_t index = kSwitchTrueBranchIndex; index < cnode->inputs().size(); index++) {
|
||||
auto node_input = cnode->input(index);
|
||||
auto switch_input = CreateSwitchInput(cnode, node_input, graph);
|
||||
(void)cnode_inputs->emplace_back(switch_input);
|
||||
|
|
|
@ -296,8 +296,8 @@ class InlinerBase : public AnfVisitor {
|
|||
}
|
||||
|
||||
bool CheckSwitchInputs(const std::vector<AnfNodePtr> &sw_inputs) {
|
||||
auto true_branch_abstract = sw_inputs[kSwitchTrueKernelGraphIndex]->abstract();
|
||||
auto false_branch_abstract = sw_inputs[kSwitchFalseKernelGraphIndex]->abstract();
|
||||
auto true_branch_abstract = sw_inputs[kSwitchTrueBranchIndex]->abstract();
|
||||
auto false_branch_abstract = sw_inputs[kSwitchFalseBranchIndex]->abstract();
|
||||
// When branch has dead node or poly node, do not perform inline.
|
||||
if (CheckSwitchBranchAbstract(true_branch_abstract) || CheckSwitchBranchAbstract(false_branch_abstract)) {
|
||||
return true;
|
||||
|
|
|
@ -31,6 +31,7 @@
|
|||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace irpass {
|
||||
const auto kMinInputSizeOfCallWithArgs = 2;
|
||||
// {{prim::kPrimPartial, X, Xs}, Ys} -> {X, Xs, Ys} or {X, Ys, Xs}
|
||||
class PartialEliminater : public AnfVisitor {
|
||||
public:
|
||||
|
@ -98,7 +99,7 @@ class PartialEliminater : public AnfVisitor {
|
|||
|
||||
auto &inputs = node->cast<CNodePtr>()->inputs();
|
||||
// {prim::kPrimPartial, X, Xs}
|
||||
if (inputs.size() < 2) {
|
||||
if (inputs.size() <= 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -130,8 +131,8 @@ class ChoicePartialEliminater : public AnfVisitor {
|
|||
}
|
||||
|
||||
auto &inputs = node->cast<CNodePtr>()->inputs();
|
||||
// {prim::kPrimPartial, G, Xs}
|
||||
if (inputs.size() < 3) {
|
||||
// {prim::kPrimPartial, G}
|
||||
if (inputs.size() < kPartialMinInputSize) {
|
||||
MS_LOG(EXCEPTION) << "Node should be Partial CNode, but: " << node->DebugString();
|
||||
return;
|
||||
}
|
||||
|
@ -218,6 +219,9 @@ class ChoicePartialEliminater : public AnfVisitor {
|
|||
if (new_params[j] == nullptr) {
|
||||
TraceGuard guard(std::make_shared<TraceCopy>(anchor_fg_params[j]->debug_info()));
|
||||
ParameterPtr param = std::make_shared<Parameter>(another_fg);
|
||||
auto new_abs =
|
||||
anchor_fg_params[j]->abstract() == nullptr ? nullptr : anchor_fg_params[j]->abstract()->Clone();
|
||||
param->set_abstract(new_abs);
|
||||
new_params[j] = param;
|
||||
}
|
||||
}
|
||||
|
@ -226,6 +230,8 @@ class ChoicePartialEliminater : public AnfVisitor {
|
|||
if (new_params[anchor_args_size + j] == nullptr) {
|
||||
TraceGuard guard(std::make_shared<TraceCopy>(extra_inputs[j]->debug_info()));
|
||||
ParameterPtr param = std::make_shared<Parameter>(another_fg);
|
||||
auto new_abs = extra_inputs[j]->abstract() == nullptr ? nullptr : extra_inputs[j]->abstract()->Clone();
|
||||
param->set_abstract(new_abs);
|
||||
new_params[anchor_args_size + j] = param;
|
||||
}
|
||||
}
|
||||
|
@ -242,6 +248,8 @@ class ChoicePartialEliminater : public AnfVisitor {
|
|||
for (size_t i = 0; i < extra_inputs.size(); ++i) {
|
||||
TraceGuard guard(std::make_shared<TraceCopy>(extra_inputs[i]->debug_info()));
|
||||
ParameterPtr param = std::make_shared<Parameter>(anchor_fg);
|
||||
auto new_abs = extra_inputs[i]->abstract() == nullptr ? nullptr : extra_inputs[i]->abstract()->Clone();
|
||||
param->set_abstract(new_abs);
|
||||
new_params.push_back(param);
|
||||
}
|
||||
// Reorder Zs_ to last;
|
||||
|
@ -312,19 +320,19 @@ class SwitchPartialEliminater : public ChoicePartialEliminater {
|
|||
return nullptr;
|
||||
}
|
||||
auto input0_cnode = cnode->input(0)->cast<CNodePtr>();
|
||||
if (input0_cnode->size() != 4) {
|
||||
if (input0_cnode->size() != kSwitchInputSize) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
fg_list_.clear();
|
||||
args_list_.clear();
|
||||
auto &maybe_partial_1 = input0_cnode->input(2);
|
||||
auto &maybe_partial_1 = input0_cnode->input(kSwitchTrueBranchIndex);
|
||||
Visit(maybe_partial_1);
|
||||
auto &maybe_partial_2 = input0_cnode->input(3);
|
||||
auto &maybe_partial_2 = input0_cnode->input(kSwitchFalseBranchIndex);
|
||||
Visit(maybe_partial_2);
|
||||
|
||||
// Either one should be {Partial, G, X}
|
||||
if (fg_list_.size() != 2 && args_list_.size() != 2) {
|
||||
if (fg_list_.size() != kSwitchBranchesNum && args_list_.size() != kSwitchBranchesNum) {
|
||||
return nullptr;
|
||||
}
|
||||
// Should not continue;
|
||||
|
@ -359,11 +367,12 @@ class SwitchPartialEliminater : public ChoicePartialEliminater {
|
|||
TraceGuard guard1(std::make_shared<TraceCopy>(input0_cnode->debug_info()));
|
||||
// {Switch, cond, G1, G2}
|
||||
auto switch_cnode = old_cnode->func_graph()->NewCNode({input0_cnode->input(0), input0_cnode->input(1), G1, G2});
|
||||
switch_cnode->set_abstract(input0_cnode->abstract());
|
||||
AnfNodePtrList args{switch_cnode};
|
||||
(void)std::copy(partial_args.begin(), partial_args.end(), std::back_inserter(args));
|
||||
(void)std::copy(extra_args.begin(), extra_args.end(), std::back_inserter(args));
|
||||
// Zs
|
||||
if (old_cnode->size() >= 2) {
|
||||
if (old_cnode->size() >= kMinInputSizeOfCallWithArgs) {
|
||||
(void)std::copy(old_cnode->inputs().begin() + 1, old_cnode->inputs().end(), std::back_inserter(args));
|
||||
}
|
||||
TraceGuard guard2(std::make_shared<TraceCopy>(old_cnode->debug_info()));
|
||||
|
@ -392,14 +401,14 @@ class SwitchLayerPartialEliminater : public ChoicePartialEliminater {
|
|||
}
|
||||
auto switch_layer_cnode = cnode->input(0)->cast<CNodePtr>();
|
||||
// {SwitchLayer, cond, MakeTuple{}}
|
||||
if (switch_layer_cnode->size() != 3) {
|
||||
if (switch_layer_cnode->size() != kSwitchLayerInputSize) {
|
||||
return nullptr;
|
||||
}
|
||||
if (!IsPrimitiveCNode(switch_layer_cnode->input(2), prim::kPrimMakeTuple)) {
|
||||
if (!IsPrimitiveCNode(switch_layer_cnode->input(kSwitchLayerBranchesIndex), prim::kPrimMakeTuple)) {
|
||||
return nullptr;
|
||||
}
|
||||
auto make_tuple_cnode = switch_layer_cnode->input(2)->cast<CNodePtr>();
|
||||
if (make_tuple_cnode->size() < 2) {
|
||||
auto make_tuple_cnode = switch_layer_cnode->input(kSwitchLayerBranchesIndex)->cast<CNodePtr>();
|
||||
if (make_tuple_cnode->size() <= 1) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
@ -437,7 +446,7 @@ class SwitchLayerPartialEliminater : public ChoicePartialEliminater {
|
|||
private:
|
||||
AnfNodePtr BuildNewSwitchLayerNode(const CNodePtr &old_cnode, const CNodePtr switch_layer_cnode,
|
||||
const AnfNodePtrList &anchor_partial_args, const AnfNodePtrList &extra_args) {
|
||||
auto make_tuple_cnode = switch_layer_cnode->input(2)->cast<CNodePtr>();
|
||||
auto make_tuple_cnode = switch_layer_cnode->input(kSwitchLayerBranchesIndex)->cast<CNodePtr>();
|
||||
AnfNodePtrList make_tuple_args{make_tuple_cnode->input(0)};
|
||||
make_tuple_args.insert(make_tuple_args.end(), fg_list_.begin(), fg_list_.end());
|
||||
TraceGuard guard1(std::make_shared<TraceCopy>(make_tuple_cnode->debug_info()));
|
||||
|
@ -452,7 +461,7 @@ class SwitchLayerPartialEliminater : public ChoicePartialEliminater {
|
|||
(void)std::copy(anchor_partial_args.begin(), anchor_partial_args.end(), std::back_inserter(args));
|
||||
(void)std::copy(extra_args.begin(), extra_args.end(), std::back_inserter(args));
|
||||
// Zs
|
||||
if (old_cnode->size() >= 2) {
|
||||
if (old_cnode->size() >= kMinInputSizeOfCallWithArgs) {
|
||||
(void)std::copy(old_cnode->inputs().begin() + 1, old_cnode->inputs().end(), std::back_inserter(args));
|
||||
}
|
||||
TraceGuard guard3(std::make_shared<TraceCopy>(old_cnode->debug_info()));
|
||||
|
|
|
@ -563,18 +563,21 @@ constexpr auto kFirstDataInputIndex = 1;
|
|||
constexpr auto kRealInputNodeIndexInTupleGetItem = 1;
|
||||
constexpr auto kInputNodeOutputIndexInTupleGetItem = 2;
|
||||
constexpr auto kTupleGetItemInputSize = 3;
|
||||
// index define of partial
|
||||
constexpr auto kPartialMinInputSize = 2;
|
||||
constexpr auto kPartialGraphIndex = 1;
|
||||
|
||||
// index define of switch
|
||||
constexpr auto kSwitchInputSize = 4;
|
||||
constexpr auto kFirstBranchInSwitch = 2;
|
||||
constexpr auto kCallKernelGraphIndex = 1;
|
||||
constexpr auto kSwitchTrueKernelGraphIndex = 2;
|
||||
constexpr auto kSwitchFalseKernelGraphIndex = 3;
|
||||
constexpr auto kMakeTupleInSwitchLayerIndex = 2;
|
||||
constexpr auto kSwitchTrueBranchIndex = 2;
|
||||
constexpr auto kSwitchFalseBranchIndex = 3;
|
||||
constexpr auto kSwitchBranchesNum = 2;
|
||||
|
||||
// index define of switch_layer
|
||||
constexpr auto kSwitchLayerInputSize = 3;
|
||||
// index define of control depend
|
||||
constexpr auto kControlDependPriorIndex = 1;
|
||||
constexpr auto kControlDependBehindIndex = 2;
|
||||
constexpr auto kControlDependInputSize = 3;
|
||||
constexpr auto kControlDependMode = "depend_mode";
|
||||
constexpr auto kSwitchLayerSelectIndex = 1;
|
||||
constexpr auto kSwitchLayerBranchesIndex = 2;
|
||||
|
||||
// index define of depend
|
||||
constexpr auto kRealInputIndexInDepend = 1;
|
||||
constexpr auto kDependAttachNodeIndex = 2;
|
||||
|
|
|
@ -330,9 +330,9 @@ void CompileGraph::AddSwitch(const CNodePtr &node) {
|
|||
MS_LOG(EXCEPTION) << "Length of inputs of primitive " << prim::kPrimSwitch->name() << " is less than 4";
|
||||
}
|
||||
VectorRef args;
|
||||
args.emplace_back(Ref(inputs[kCallKernelGraphIndex]));
|
||||
args.emplace_back(Ref(inputs[kSwitchTrueKernelGraphIndex]));
|
||||
args.emplace_back(Ref(inputs[kSwitchFalseKernelGraphIndex]));
|
||||
args.emplace_back(Ref(inputs[kPartialGraphIndex]));
|
||||
args.emplace_back(Ref(inputs[kSwitchTrueBranchIndex]));
|
||||
args.emplace_back(Ref(inputs[kSwitchFalseBranchIndex]));
|
||||
AddInst(Instruction::kSwitch, args);
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue