set abstract of switch

This commit is contained in:
chenfei 2022-01-19 10:58:40 +08:00
parent bfdd040728
commit fa7cb34c8a
7 changed files with 49 additions and 38 deletions

View File

@ -1696,7 +1696,7 @@ std::vector<KernelGraphPtr> AnfRuntimeAlgorithm::GetCallSwitchKernelGraph(const
} }
auto partial_cnode = partial->cast<CNodePtr>(); auto partial_cnode = partial->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(partial_cnode); MS_EXCEPTION_IF_NULL(partial_cnode);
auto graph_node = partial_cnode->input(kCallKernelGraphIndex); auto graph_node = partial_cnode->input(kPartialGraphIndex);
MS_EXCEPTION_IF_NULL(graph_node); MS_EXCEPTION_IF_NULL(graph_node);
auto graph_value_node = graph_node->cast<ValueNodePtr>(); auto graph_value_node = graph_node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(graph_value_node); MS_EXCEPTION_IF_NULL(graph_value_node);
@ -1706,7 +1706,7 @@ std::vector<KernelGraphPtr> AnfRuntimeAlgorithm::GetCallSwitchKernelGraph(const
return child_graph; return child_graph;
}; };
if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimCall)) { if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimCall)) {
auto input1 = cnode->input(kCallKernelGraphIndex); auto input1 = cnode->input(kPartialGraphIndex);
MS_EXCEPTION_IF_NULL(input1); MS_EXCEPTION_IF_NULL(input1);
auto value_node = input1->cast<ValueNodePtr>(); auto value_node = input1->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node); MS_EXCEPTION_IF_NULL(value_node);
@ -1714,11 +1714,10 @@ std::vector<KernelGraphPtr> AnfRuntimeAlgorithm::GetCallSwitchKernelGraph(const
MS_EXCEPTION_IF_NULL(kernel_graph); MS_EXCEPTION_IF_NULL(kernel_graph);
return {kernel_graph->cast<KernelGraphPtr>()}; return {kernel_graph->cast<KernelGraphPtr>()};
} else if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) { } else if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) {
return {get_switch_kernel_graph(kSwitchTrueKernelGraphIndex), return {get_switch_kernel_graph(kSwitchTrueBranchIndex), get_switch_kernel_graph(kSwitchFalseBranchIndex)};
get_switch_kernel_graph(kSwitchFalseKernelGraphIndex)};
} else if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitchLayer)) { } else if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitchLayer)) {
std::vector<KernelGraphPtr> child_graphs; std::vector<KernelGraphPtr> child_graphs;
for (size_t idx = kMakeTupleInSwitchLayerIndex; idx < cnode->inputs().size(); idx++) { for (size_t idx = kSwitchLayerBranchesIndex; idx < cnode->inputs().size(); idx++) {
auto child_graph = get_switch_kernel_graph(idx); auto child_graph = get_switch_kernel_graph(idx);
child_graphs.emplace_back(child_graph); child_graphs.emplace_back(child_graph);
} }

View File

@ -642,7 +642,7 @@ class CallInfoFinder {
} }
CallBranch GetCallBranch(const CNodePtr &cnode) { CallBranch GetCallBranch(const CNodePtr &cnode) {
auto input_graph = cnode->input(kCallKernelGraphIndex); auto input_graph = cnode->input(kPartialGraphIndex);
MS_EXCEPTION_IF_NULL(input_graph); MS_EXCEPTION_IF_NULL(input_graph);
auto kg = GetValueNode<KernelGraphPtr>(input_graph); auto kg = GetValueNode<KernelGraphPtr>(input_graph);
MS_EXCEPTION_IF_NULL(kg); MS_EXCEPTION_IF_NULL(kg);

View File

@ -826,7 +826,7 @@ std::vector<AnfNodePtr> SessionBasic::CreateCallSwitchInputs(const CNodePtr &cno
} }
std::vector<AnfNodePtr> switch_inputs = {switch_cnode->input(kAnfPrimitiveIndex), std::vector<AnfNodePtr> switch_inputs = {switch_cnode->input(kAnfPrimitiveIndex),
switch_cnode->input(kFirstDataInputIndex)}; switch_cnode->input(kFirstDataInputIndex)};
for (size_t index = kFirstBranchInSwitch; index < switch_cnode->inputs().size(); index++) { for (size_t index = kSwitchTrueBranchIndex; index < switch_cnode->inputs().size(); index++) {
auto node = switch_cnode->input(index); auto node = switch_cnode->input(index);
// there is real input in call, should put it to true and false branch in switch // there is real input in call, should put it to true and false branch in switch
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial)) { if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial)) {
@ -919,7 +919,7 @@ std::vector<AnfNodePtr> SessionBasic::CreateCallSwitchLayerInputs(const CNodePtr
MS_EXCEPTION_IF_NULL(switch_layer_cnode); MS_EXCEPTION_IF_NULL(switch_layer_cnode);
std::vector<AnfNodePtr> switch_layer_inputs = {switch_layer_cnode->input(kAnfPrimitiveIndex), std::vector<AnfNodePtr> switch_layer_inputs = {switch_layer_cnode->input(kAnfPrimitiveIndex),
switch_layer_cnode->input(kFirstDataInputIndex)}; switch_layer_cnode->input(kFirstDataInputIndex)};
auto make_tuple_node = switch_layer_cnode->input(kMakeTupleInSwitchLayerIndex); auto make_tuple_node = switch_layer_cnode->input(kSwitchLayerBranchesIndex);
MS_EXCEPTION_IF_NULL(make_tuple_node); MS_EXCEPTION_IF_NULL(make_tuple_node);
auto node = make_tuple_node->cast<CNodePtr>(); auto node = make_tuple_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
@ -1040,7 +1040,7 @@ void SessionBasic::CreateCNodeInputs(const CNodePtr &cnode, KernelGraph *graph,
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) { if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) {
(void)cnode_inputs->emplace_back(graph->GetBackendAnfByFrontAnf(cnode->input(kFirstDataInputIndex))); (void)cnode_inputs->emplace_back(graph->GetBackendAnfByFrontAnf(cnode->input(kFirstDataInputIndex)));
for (size_t index = kFirstBranchInSwitch; index < cnode->inputs().size(); index++) { for (size_t index = kSwitchTrueBranchIndex; index < cnode->inputs().size(); index++) {
auto node_input = cnode->input(index); auto node_input = cnode->input(index);
auto switch_input = CreateSwitchInput(cnode, node_input, graph); auto switch_input = CreateSwitchInput(cnode, node_input, graph);
(void)cnode_inputs->emplace_back(switch_input); (void)cnode_inputs->emplace_back(switch_input);

View File

@ -296,8 +296,8 @@ class InlinerBase : public AnfVisitor {
} }
bool CheckSwitchInputs(const std::vector<AnfNodePtr> &sw_inputs) { bool CheckSwitchInputs(const std::vector<AnfNodePtr> &sw_inputs) {
auto true_branch_abstract = sw_inputs[kSwitchTrueKernelGraphIndex]->abstract(); auto true_branch_abstract = sw_inputs[kSwitchTrueBranchIndex]->abstract();
auto false_branch_abstract = sw_inputs[kSwitchFalseKernelGraphIndex]->abstract(); auto false_branch_abstract = sw_inputs[kSwitchFalseBranchIndex]->abstract();
// When branch has dead node or poly node, do not perform inline. // When branch has dead node or poly node, do not perform inline.
if (CheckSwitchBranchAbstract(true_branch_abstract) || CheckSwitchBranchAbstract(false_branch_abstract)) { if (CheckSwitchBranchAbstract(true_branch_abstract) || CheckSwitchBranchAbstract(false_branch_abstract)) {
return true; return true;

View File

@ -31,6 +31,7 @@
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
namespace irpass { namespace irpass {
const auto kMinInputSizeOfCallWithArgs = 2;
// {{prim::kPrimPartial, X, Xs}, Ys} -> {X, Xs, Ys} or {X, Ys, Xs} // {{prim::kPrimPartial, X, Xs}, Ys} -> {X, Xs, Ys} or {X, Ys, Xs}
class PartialEliminater : public AnfVisitor { class PartialEliminater : public AnfVisitor {
public: public:
@ -98,7 +99,7 @@ class PartialEliminater : public AnfVisitor {
auto &inputs = node->cast<CNodePtr>()->inputs(); auto &inputs = node->cast<CNodePtr>()->inputs();
// {prim::kPrimPartial, X, Xs} // {prim::kPrimPartial, X, Xs}
if (inputs.size() < 2) { if (inputs.size() <= 1) {
return; return;
} }
@ -130,8 +131,8 @@ class ChoicePartialEliminater : public AnfVisitor {
} }
auto &inputs = node->cast<CNodePtr>()->inputs(); auto &inputs = node->cast<CNodePtr>()->inputs();
// {prim::kPrimPartial, G, Xs} // {prim::kPrimPartial, G}
if (inputs.size() < 3) { if (inputs.size() < kPartialMinInputSize) {
MS_LOG(EXCEPTION) << "Node should be Partial CNode, but: " << node->DebugString(); MS_LOG(EXCEPTION) << "Node should be Partial CNode, but: " << node->DebugString();
return; return;
} }
@ -218,6 +219,9 @@ class ChoicePartialEliminater : public AnfVisitor {
if (new_params[j] == nullptr) { if (new_params[j] == nullptr) {
TraceGuard guard(std::make_shared<TraceCopy>(anchor_fg_params[j]->debug_info())); TraceGuard guard(std::make_shared<TraceCopy>(anchor_fg_params[j]->debug_info()));
ParameterPtr param = std::make_shared<Parameter>(another_fg); ParameterPtr param = std::make_shared<Parameter>(another_fg);
auto new_abs =
anchor_fg_params[j]->abstract() == nullptr ? nullptr : anchor_fg_params[j]->abstract()->Clone();
param->set_abstract(new_abs);
new_params[j] = param; new_params[j] = param;
} }
} }
@ -226,6 +230,8 @@ class ChoicePartialEliminater : public AnfVisitor {
if (new_params[anchor_args_size + j] == nullptr) { if (new_params[anchor_args_size + j] == nullptr) {
TraceGuard guard(std::make_shared<TraceCopy>(extra_inputs[j]->debug_info())); TraceGuard guard(std::make_shared<TraceCopy>(extra_inputs[j]->debug_info()));
ParameterPtr param = std::make_shared<Parameter>(another_fg); ParameterPtr param = std::make_shared<Parameter>(another_fg);
auto new_abs = extra_inputs[j]->abstract() == nullptr ? nullptr : extra_inputs[j]->abstract()->Clone();
param->set_abstract(new_abs);
new_params[anchor_args_size + j] = param; new_params[anchor_args_size + j] = param;
} }
} }
@ -242,6 +248,8 @@ class ChoicePartialEliminater : public AnfVisitor {
for (size_t i = 0; i < extra_inputs.size(); ++i) { for (size_t i = 0; i < extra_inputs.size(); ++i) {
TraceGuard guard(std::make_shared<TraceCopy>(extra_inputs[i]->debug_info())); TraceGuard guard(std::make_shared<TraceCopy>(extra_inputs[i]->debug_info()));
ParameterPtr param = std::make_shared<Parameter>(anchor_fg); ParameterPtr param = std::make_shared<Parameter>(anchor_fg);
auto new_abs = extra_inputs[i]->abstract() == nullptr ? nullptr : extra_inputs[i]->abstract()->Clone();
param->set_abstract(new_abs);
new_params.push_back(param); new_params.push_back(param);
} }
// Reorder Zs_ to last; // Reorder Zs_ to last;
@ -312,19 +320,19 @@ class SwitchPartialEliminater : public ChoicePartialEliminater {
return nullptr; return nullptr;
} }
auto input0_cnode = cnode->input(0)->cast<CNodePtr>(); auto input0_cnode = cnode->input(0)->cast<CNodePtr>();
if (input0_cnode->size() != 4) { if (input0_cnode->size() != kSwitchInputSize) {
return nullptr; return nullptr;
} }
fg_list_.clear(); fg_list_.clear();
args_list_.clear(); args_list_.clear();
auto &maybe_partial_1 = input0_cnode->input(2); auto &maybe_partial_1 = input0_cnode->input(kSwitchTrueBranchIndex);
Visit(maybe_partial_1); Visit(maybe_partial_1);
auto &maybe_partial_2 = input0_cnode->input(3); auto &maybe_partial_2 = input0_cnode->input(kSwitchFalseBranchIndex);
Visit(maybe_partial_2); Visit(maybe_partial_2);
// Either one should be {Partial, G, X} // Either one should be {Partial, G, X}
if (fg_list_.size() != 2 && args_list_.size() != 2) { if (fg_list_.size() != kSwitchBranchesNum && args_list_.size() != kSwitchBranchesNum) {
return nullptr; return nullptr;
} }
// Should not continue; // Should not continue;
@ -359,11 +367,12 @@ class SwitchPartialEliminater : public ChoicePartialEliminater {
TraceGuard guard1(std::make_shared<TraceCopy>(input0_cnode->debug_info())); TraceGuard guard1(std::make_shared<TraceCopy>(input0_cnode->debug_info()));
// {Switch, cond, G1, G2} // {Switch, cond, G1, G2}
auto switch_cnode = old_cnode->func_graph()->NewCNode({input0_cnode->input(0), input0_cnode->input(1), G1, G2}); auto switch_cnode = old_cnode->func_graph()->NewCNode({input0_cnode->input(0), input0_cnode->input(1), G1, G2});
switch_cnode->set_abstract(input0_cnode->abstract());
AnfNodePtrList args{switch_cnode}; AnfNodePtrList args{switch_cnode};
(void)std::copy(partial_args.begin(), partial_args.end(), std::back_inserter(args)); (void)std::copy(partial_args.begin(), partial_args.end(), std::back_inserter(args));
(void)std::copy(extra_args.begin(), extra_args.end(), std::back_inserter(args)); (void)std::copy(extra_args.begin(), extra_args.end(), std::back_inserter(args));
// Zs // Zs
if (old_cnode->size() >= 2) { if (old_cnode->size() >= kMinInputSizeOfCallWithArgs) {
(void)std::copy(old_cnode->inputs().begin() + 1, old_cnode->inputs().end(), std::back_inserter(args)); (void)std::copy(old_cnode->inputs().begin() + 1, old_cnode->inputs().end(), std::back_inserter(args));
} }
TraceGuard guard2(std::make_shared<TraceCopy>(old_cnode->debug_info())); TraceGuard guard2(std::make_shared<TraceCopy>(old_cnode->debug_info()));
@ -392,14 +401,14 @@ class SwitchLayerPartialEliminater : public ChoicePartialEliminater {
} }
auto switch_layer_cnode = cnode->input(0)->cast<CNodePtr>(); auto switch_layer_cnode = cnode->input(0)->cast<CNodePtr>();
// {SwitchLayer, cond, MakeTuple{}} // {SwitchLayer, cond, MakeTuple{}}
if (switch_layer_cnode->size() != 3) { if (switch_layer_cnode->size() != kSwitchLayerInputSize) {
return nullptr; return nullptr;
} }
if (!IsPrimitiveCNode(switch_layer_cnode->input(2), prim::kPrimMakeTuple)) { if (!IsPrimitiveCNode(switch_layer_cnode->input(kSwitchLayerBranchesIndex), prim::kPrimMakeTuple)) {
return nullptr; return nullptr;
} }
auto make_tuple_cnode = switch_layer_cnode->input(2)->cast<CNodePtr>(); auto make_tuple_cnode = switch_layer_cnode->input(kSwitchLayerBranchesIndex)->cast<CNodePtr>();
if (make_tuple_cnode->size() < 2) { if (make_tuple_cnode->size() <= 1) {
return nullptr; return nullptr;
} }
@ -437,7 +446,7 @@ class SwitchLayerPartialEliminater : public ChoicePartialEliminater {
private: private:
AnfNodePtr BuildNewSwitchLayerNode(const CNodePtr &old_cnode, const CNodePtr switch_layer_cnode, AnfNodePtr BuildNewSwitchLayerNode(const CNodePtr &old_cnode, const CNodePtr switch_layer_cnode,
const AnfNodePtrList &anchor_partial_args, const AnfNodePtrList &extra_args) { const AnfNodePtrList &anchor_partial_args, const AnfNodePtrList &extra_args) {
auto make_tuple_cnode = switch_layer_cnode->input(2)->cast<CNodePtr>(); auto make_tuple_cnode = switch_layer_cnode->input(kSwitchLayerBranchesIndex)->cast<CNodePtr>();
AnfNodePtrList make_tuple_args{make_tuple_cnode->input(0)}; AnfNodePtrList make_tuple_args{make_tuple_cnode->input(0)};
make_tuple_args.insert(make_tuple_args.end(), fg_list_.begin(), fg_list_.end()); make_tuple_args.insert(make_tuple_args.end(), fg_list_.begin(), fg_list_.end());
TraceGuard guard1(std::make_shared<TraceCopy>(make_tuple_cnode->debug_info())); TraceGuard guard1(std::make_shared<TraceCopy>(make_tuple_cnode->debug_info()));
@ -452,7 +461,7 @@ class SwitchLayerPartialEliminater : public ChoicePartialEliminater {
(void)std::copy(anchor_partial_args.begin(), anchor_partial_args.end(), std::back_inserter(args)); (void)std::copy(anchor_partial_args.begin(), anchor_partial_args.end(), std::back_inserter(args));
(void)std::copy(extra_args.begin(), extra_args.end(), std::back_inserter(args)); (void)std::copy(extra_args.begin(), extra_args.end(), std::back_inserter(args));
// Zs // Zs
if (old_cnode->size() >= 2) { if (old_cnode->size() >= kMinInputSizeOfCallWithArgs) {
(void)std::copy(old_cnode->inputs().begin() + 1, old_cnode->inputs().end(), std::back_inserter(args)); (void)std::copy(old_cnode->inputs().begin() + 1, old_cnode->inputs().end(), std::back_inserter(args));
} }
TraceGuard guard3(std::make_shared<TraceCopy>(old_cnode->debug_info())); TraceGuard guard3(std::make_shared<TraceCopy>(old_cnode->debug_info()));

View File

@ -563,18 +563,21 @@ constexpr auto kFirstDataInputIndex = 1;
constexpr auto kRealInputNodeIndexInTupleGetItem = 1; constexpr auto kRealInputNodeIndexInTupleGetItem = 1;
constexpr auto kInputNodeOutputIndexInTupleGetItem = 2; constexpr auto kInputNodeOutputIndexInTupleGetItem = 2;
constexpr auto kTupleGetItemInputSize = 3; constexpr auto kTupleGetItemInputSize = 3;
// index define of partial
constexpr auto kPartialMinInputSize = 2;
constexpr auto kPartialGraphIndex = 1;
// index define of switch
constexpr auto kSwitchInputSize = 4; constexpr auto kSwitchInputSize = 4;
constexpr auto kFirstBranchInSwitch = 2; constexpr auto kSwitchTrueBranchIndex = 2;
constexpr auto kCallKernelGraphIndex = 1; constexpr auto kSwitchFalseBranchIndex = 3;
constexpr auto kSwitchTrueKernelGraphIndex = 2; constexpr auto kSwitchBranchesNum = 2;
constexpr auto kSwitchFalseKernelGraphIndex = 3;
constexpr auto kMakeTupleInSwitchLayerIndex = 2; // index define of switch_layer
constexpr auto kSwitchLayerInputSize = 3; constexpr auto kSwitchLayerInputSize = 3;
// index define of control depend constexpr auto kSwitchLayerSelectIndex = 1;
constexpr auto kControlDependPriorIndex = 1; constexpr auto kSwitchLayerBranchesIndex = 2;
constexpr auto kControlDependBehindIndex = 2;
constexpr auto kControlDependInputSize = 3;
constexpr auto kControlDependMode = "depend_mode";
// index define of depend // index define of depend
constexpr auto kRealInputIndexInDepend = 1; constexpr auto kRealInputIndexInDepend = 1;
constexpr auto kDependAttachNodeIndex = 2; constexpr auto kDependAttachNodeIndex = 2;

View File

@ -330,9 +330,9 @@ void CompileGraph::AddSwitch(const CNodePtr &node) {
MS_LOG(EXCEPTION) << "Length of inputs of primitive " << prim::kPrimSwitch->name() << " is less than 4"; MS_LOG(EXCEPTION) << "Length of inputs of primitive " << prim::kPrimSwitch->name() << " is less than 4";
} }
VectorRef args; VectorRef args;
args.emplace_back(Ref(inputs[kCallKernelGraphIndex])); args.emplace_back(Ref(inputs[kPartialGraphIndex]));
args.emplace_back(Ref(inputs[kSwitchTrueKernelGraphIndex])); args.emplace_back(Ref(inputs[kSwitchTrueBranchIndex]));
args.emplace_back(Ref(inputs[kSwitchFalseKernelGraphIndex])); args.emplace_back(Ref(inputs[kSwitchFalseBranchIndex]));
AddInst(Instruction::kSwitch, args); AddInst(Instruction::kSwitch, args);
} }