split sort visit switch partial first

This commit is contained in:
kswang 2020-12-02 10:52:03 +08:00
parent 8bb8ea1e14
commit 367a31fa04
3 changed files with 9 additions and 5 deletions

View File

@ -605,7 +605,7 @@ void SessionBasic::GetNewCNodeInputs(const CNodePtr &cnode, KernelGraph *graph,
MS_EXCEPTION_IF_NULL(other_graph_cnode);
MS_EXCEPTION_IF_NULL(cnode_inputs);
auto origin_inputs = cnode->inputs();
bool optimize_depend = IsPrimitiveCNode(cnode, prim::kPrimDepend) && origin_inputs.size() == 3;
bool optimize_depend = IsPrimitiveCNode(cnode, prim::kPrimDepend) && origin_inputs.size() >= 3;
bool optimize_control_depend = IsPrimitiveCNode(cnode, prim::kPrimControlDepend) && origin_inputs.size() == 3;
// if has multiple depends,only select first depend as parameter
for (size_t input_idx = 1; input_idx < origin_inputs.size(); input_idx++) {
@ -615,7 +615,7 @@ void SessionBasic::GetNewCNodeInputs(const CNodePtr &cnode, KernelGraph *graph,
if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) {
cnode_inputs->emplace_back(graph->GetBackendAnfByFrontAnf(anf));
continue;
} else if (optimize_depend && input_idx == kDependAttachNodeIndex) {
} else if (optimize_depend && input_idx > 1) {
cnode_inputs->push_back(NewValueNode(MakeValue(SizeToInt(input_idx))));
continue;
} else if (other_graph_cnode->find(anf) != other_graph_cnode->end()) {

View File

@ -214,7 +214,9 @@ std::vector<AnfNodePtr> SplitSort(const FuncGraphPtr &graph, const std::string &
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto node_inputs = cnode->inputs();
std::reverse(node_inputs.begin(), node_inputs.end());
if (!IsPrimitiveCNode(cnode, prim::kPrimSwitch)) {
std::reverse(node_inputs.begin(), node_inputs.end());
}
auto ctrl_inputs = control_edges.find(node);
if (ctrl_inputs != control_edges.end()) {
node_inputs.insert(node_inputs.end(), ctrl_inputs->second.begin(), ctrl_inputs->second.end());

View File

@ -139,9 +139,11 @@ std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> TransformSegmentToAnfGr
}
auto fn = inps[0];
std::vector<AnfNodePtr> args{fn};
if (IsPrimitive(fn, prim::kPrimDepend) && inps.size() == 3 && eqv.find(inps[kDependAttachNodeIndex]) == eqv.end()) {
if (IsPrimitive(fn, prim::kPrimDepend) && inps.size() >= 3 && eqv.find(inps[kDependAttachNodeIndex]) == eqv.end()) {
args.emplace_back(RefSubGraphNode(fg, inps[kRealInputIndexInDepend], &inputs, &eqv));
args.emplace_back(NewValueNode(MakeValue(0)));
for (size_t i = 2; i < inps.size(); ++i) {
args.emplace_back(NewValueNode(MakeValue(0)));
}
} else if (IsPrimitive(fn, prim::kPrimControlDepend) && inps.size() == 3) {
for (size_t i = 1; i < inps.size(); ++i) {
if (inps[i]->isa<CNode>() && std::find(lst.begin(), lst.end(), inps[i]) == lst.end()) {