forked from mindspore-Ecosystem/mindspore
split sort visit switch partial first
This commit is contained in:
parent
8bb8ea1e14
commit
367a31fa04
|
@ -605,7 +605,7 @@ void SessionBasic::GetNewCNodeInputs(const CNodePtr &cnode, KernelGraph *graph,
|
|||
MS_EXCEPTION_IF_NULL(other_graph_cnode);
|
||||
MS_EXCEPTION_IF_NULL(cnode_inputs);
|
||||
auto origin_inputs = cnode->inputs();
|
||||
bool optimize_depend = IsPrimitiveCNode(cnode, prim::kPrimDepend) && origin_inputs.size() == 3;
|
||||
bool optimize_depend = IsPrimitiveCNode(cnode, prim::kPrimDepend) && origin_inputs.size() >= 3;
|
||||
bool optimize_control_depend = IsPrimitiveCNode(cnode, prim::kPrimControlDepend) && origin_inputs.size() == 3;
|
||||
// if has multiple depends,only select first depend as parameter
|
||||
for (size_t input_idx = 1; input_idx < origin_inputs.size(); input_idx++) {
|
||||
|
@ -615,7 +615,7 @@ void SessionBasic::GetNewCNodeInputs(const CNodePtr &cnode, KernelGraph *graph,
|
|||
if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) {
|
||||
cnode_inputs->emplace_back(graph->GetBackendAnfByFrontAnf(anf));
|
||||
continue;
|
||||
} else if (optimize_depend && input_idx == kDependAttachNodeIndex) {
|
||||
} else if (optimize_depend && input_idx > 1) {
|
||||
cnode_inputs->push_back(NewValueNode(MakeValue(SizeToInt(input_idx))));
|
||||
continue;
|
||||
} else if (other_graph_cnode->find(anf) != other_graph_cnode->end()) {
|
||||
|
|
|
@ -214,7 +214,9 @@ std::vector<AnfNodePtr> SplitSort(const FuncGraphPtr &graph, const std::string &
|
|||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto node_inputs = cnode->inputs();
|
||||
std::reverse(node_inputs.begin(), node_inputs.end());
|
||||
if (!IsPrimitiveCNode(cnode, prim::kPrimSwitch)) {
|
||||
std::reverse(node_inputs.begin(), node_inputs.end());
|
||||
}
|
||||
auto ctrl_inputs = control_edges.find(node);
|
||||
if (ctrl_inputs != control_edges.end()) {
|
||||
node_inputs.insert(node_inputs.end(), ctrl_inputs->second.begin(), ctrl_inputs->second.end());
|
||||
|
|
|
@ -139,9 +139,11 @@ std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> TransformSegmentToAnfGr
|
|||
}
|
||||
auto fn = inps[0];
|
||||
std::vector<AnfNodePtr> args{fn};
|
||||
if (IsPrimitive(fn, prim::kPrimDepend) && inps.size() == 3 && eqv.find(inps[kDependAttachNodeIndex]) == eqv.end()) {
|
||||
if (IsPrimitive(fn, prim::kPrimDepend) && inps.size() >= 3 && eqv.find(inps[kDependAttachNodeIndex]) == eqv.end()) {
|
||||
args.emplace_back(RefSubGraphNode(fg, inps[kRealInputIndexInDepend], &inputs, &eqv));
|
||||
args.emplace_back(NewValueNode(MakeValue(0)));
|
||||
for (size_t i = 2; i < inps.size(); ++i) {
|
||||
args.emplace_back(NewValueNode(MakeValue(0)));
|
||||
}
|
||||
} else if (IsPrimitive(fn, prim::kPrimControlDepend) && inps.size() == 3) {
|
||||
for (size_t i = 1; i < inps.size(); ++i) {
|
||||
if (inps[i]->isa<CNode>() && std::find(lst.begin(), lst.end(), inps[i]) == lst.end()) {
|
||||
|
|
Loading…
Reference in New Issue