forked from mindspore-Ecosystem/mindspore
1.fix bug of control flow pass bug
2.add invaild transpose op remove
This commit is contained in:
parent
26e34dec80
commit
6c63a0e917
|
@ -55,45 +55,143 @@ void ControlFlowPass::ReplaceNode(const FuncGraphPtr &fg,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void ControlFlowPass::FunGraphInputsOnlyUsedByAfterParts(const FuncGraphPtr &fg, const CNodePtr &aim_cnode,
|
void ControlFlowPass::VisitedNodesUsedByAfterParts(const std::set<AnfNodePtr> &visited_nodes,
|
||||||
std::vector<AnfNodePtr> *fg_inputs_only_used_by_after_fg) {
|
const std::vector<AnfNodePtr> &remain_nodes,
|
||||||
auto fg_inputs = fg->get_inputs();
|
std::vector<AnfNodePtr> *visited_nodes_used_by_after_fg) {
|
||||||
fg_inputs_only_used_by_after_fg->assign(fg_inputs.begin(), fg_inputs.end());
|
std::deque<AnfNodePtr> nodes{};
|
||||||
auto nodes = TopoSort(aim_cnode);
|
std::set<AnfNodePtr> visited_nodes_used_by_after_fg_set{};
|
||||||
for (auto it = fg_inputs_only_used_by_after_fg->begin(); it != fg_inputs_only_used_by_after_fg->end();) {
|
std::set<FuncGraphPtr> visited_fg_set{};
|
||||||
if (lite::IsContain(nodes, *it)) {
|
std::set<AnfNodePtr> remain_nodes_set{};
|
||||||
it = fg_inputs_only_used_by_after_fg->erase(it);
|
nodes.assign(remain_nodes.begin(), remain_nodes.end());
|
||||||
} else {
|
while (!nodes.empty()) {
|
||||||
++it;
|
auto node = nodes.front();
|
||||||
|
nodes.pop_front();
|
||||||
|
remain_nodes_set.insert(node);
|
||||||
|
if (!utils::isa<CNodePtr>(node)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto cnode = node->cast<CNodePtr>();
|
||||||
|
for (auto &input : cnode->inputs()) {
|
||||||
|
if (visited_nodes.find(input) != visited_nodes.end() &&
|
||||||
|
visited_nodes_used_by_after_fg_set.find(input) == visited_nodes_used_by_after_fg_set.end()) {
|
||||||
|
visited_nodes_used_by_after_fg->push_back(input);
|
||||||
|
visited_nodes_used_by_after_fg_set.insert(input);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int ControlFlowPass::SplitGraph(const FuncGraphPtr &fg, const PrimitivePtr &aim_prim, AnfNodePtr *aim_prim_type_node,
|
size_t ControlFlowPass::GetItemVisitedNums(const std::set<AnfNodePtr> &visited_nodes, const AnfNodePtr &tuple_node) {
|
||||||
std::vector<AnfNodePtr> *remain_nodes) {
|
size_t count = 0;
|
||||||
|
for (auto &node : visited_nodes) {
|
||||||
|
if (!utils::isa<CNodePtr>(node)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (!CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto get_item_cnode = node->cast<CNodePtr>();
|
||||||
|
if (get_item_cnode->inputs()[kCNodeFirstInputIndex] == tuple_node) {
|
||||||
|
count++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return count;
|
||||||
|
}
|
||||||
|
|
||||||
|
void ControlFlowPass::MoveGetItemToVisited(const size_t &need_size, const AnfNodePtr &tuple_node,
|
||||||
|
std::set<AnfNodePtr> *visited_nodes, std::vector<AnfNodePtr> *remain_nodes) {
|
||||||
|
size_t i = 0;
|
||||||
|
for (auto it = remain_nodes->begin(); it != remain_nodes->end();) {
|
||||||
|
if (!utils::isa<CNodePtr>(*it)) {
|
||||||
|
++it;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (!CheckPrimitiveType(*it, prim::kPrimTupleGetItem)) {
|
||||||
|
++it;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto get_item_cnode = (*it)->cast<CNodePtr>();
|
||||||
|
if (get_item_cnode->inputs()[kCNodeFirstInputIndex] != tuple_node) {
|
||||||
|
++it;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
i++;
|
||||||
|
visited_nodes->insert(*it);
|
||||||
|
it = remain_nodes->erase(it);
|
||||||
|
if (need_size == i) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
MS_LOG(INFO) << tuple_node->fullname_with_scope() << " not found enough get item, size: " << need_size - i;
|
||||||
|
}
|
||||||
|
|
||||||
|
void ControlFlowPass::BindGetItemNodes(std::set<AnfNodePtr> *visited_nodes, std::vector<AnfNodePtr> *remain_nodes) {
|
||||||
|
std::deque<AnfNodePtr> multi_output_nodes{};
|
||||||
|
for (auto &node : *visited_nodes) {
|
||||||
|
if (!utils::isa<CNodePtr>(node)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (utils::isa<abstract::AbstractTuple>(node->abstract())) {
|
||||||
|
multi_output_nodes.push_back(node);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
while (!multi_output_nodes.empty()) {
|
||||||
|
auto cur_node = multi_output_nodes.front();
|
||||||
|
multi_output_nodes.pop_front();
|
||||||
|
size_t total_getitem_size = cur_node->abstract()->cast<abstract::AbstractTuplePtr>()->size();
|
||||||
|
size_t visited_getitem_size = GetItemVisitedNums(*visited_nodes, cur_node);
|
||||||
|
if (total_getitem_size == visited_getitem_size) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t need_getitem_size = total_getitem_size - visited_getitem_size;
|
||||||
|
MoveGetItemToVisited(need_getitem_size, cur_node, visited_nodes, remain_nodes);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int ControlFlowPass::SplitGraph(const FuncGraphPtr &fg, AnfNodePtr *control_flow_node,
|
||||||
|
std::set<AnfNodePtr> *visited_nodes, std::vector<AnfNodePtr> *remain_nodes) {
|
||||||
auto inputs = fg->get_inputs();
|
auto inputs = fg->get_inputs();
|
||||||
std::vector<AnfNodePtr> visited_nodes{};
|
|
||||||
visited_nodes.assign(inputs.begin(), inputs.end());
|
|
||||||
// notice: fg->nodes() is not work in this pass, cause too many useless parameter have been created.
|
// notice: fg->nodes() is not work in this pass, cause too many useless parameter have been created.
|
||||||
auto node_list = TopoSort(fg->get_return());
|
auto node_list = TopoSort(fg->get_return());
|
||||||
for (auto &node : node_list) {
|
for (auto &node : node_list) {
|
||||||
if (utils::isa<CNodePtr>(node) && CheckPrimitiveType(node, aim_prim)) {
|
if (utils::isa<CNodePtr>(node) &&
|
||||||
*aim_prim_type_node = node;
|
(CheckPrimitiveType(node, prim::kPrimWhile) || CheckPrimitiveType(node, prim::kPrimIf))) {
|
||||||
|
*control_flow_node = node;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
if (!utils::isa<CNodePtr>(node) && !utils::isa<ParameterPtr>(node)) {
|
}
|
||||||
|
|
||||||
|
std::deque<AnfNodePtr> q;
|
||||||
|
visited_nodes->insert(inputs.begin(), inputs.end());
|
||||||
|
q.push_back(*control_flow_node);
|
||||||
|
while (!q.empty()) {
|
||||||
|
auto node = q.front();
|
||||||
|
q.pop_front();
|
||||||
|
if (!utils::isa<CNodePtr>(node)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (!lite::IsContain(visited_nodes, node)) {
|
visited_nodes->insert(node);
|
||||||
visited_nodes.push_back(node);
|
auto cnode = utils::cast<CNodePtr>(node);
|
||||||
|
for (size_t i = 0; i < cnode->inputs().size(); i++) {
|
||||||
|
auto input = cnode->input(i);
|
||||||
|
if (visited_nodes->find(input) == visited_nodes->end()) {
|
||||||
|
q.push_back(input);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto &node : node_list) {
|
for (auto &node : node_list) {
|
||||||
if (!lite::IsContain(visited_nodes, node) && node != *aim_prim_type_node) {
|
if (visited_nodes->find(node) == visited_nodes->end()) {
|
||||||
remain_nodes->push_back(node);
|
remain_nodes->push_back(node);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
visited_nodes->erase(*control_flow_node);
|
||||||
|
|
||||||
|
BindGetItemNodes(visited_nodes, remain_nodes);
|
||||||
|
|
||||||
return RET_SUCCESS;
|
return RET_SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -124,9 +222,9 @@ int ControlFlowPass::CreateAfterGraph(const FuncGraphPtr &main_fg, const std::ve
|
||||||
}
|
}
|
||||||
|
|
||||||
int ControlFlowPass::CreateWhileCondCallNode(
|
int ControlFlowPass::CreateWhileCondCallNode(
|
||||||
const FuncGraphPtr &fg, const CNodePtr &while_cnode, std::vector<AnfNodePtr> *fg_inputs_only_used_by_after_fg,
|
const FuncGraphPtr &fg, const CNodePtr &while_cnode, std::vector<AnfNodePtr> *visited_nodes_used_by_after_fg,
|
||||||
CNodePtr *cond_call_cnode,
|
CNodePtr *cond_call_cnode,
|
||||||
std::unordered_map<AnfNodePtr, AnfNodePtr> *fg_inputs_and_after_partial_inputs_replace_pairs) {
|
std::unordered_map<AnfNodePtr, AnfNodePtr> *visited_nodes_and_cond_fg_inputs_replace_pairs) {
|
||||||
auto cond_vnode = while_cnode->input(kWhileCondIndex);
|
auto cond_vnode = while_cnode->input(kWhileCondIndex);
|
||||||
MS_ASSERT(cond_vnode != nullptr);
|
MS_ASSERT(cond_vnode != nullptr);
|
||||||
auto cond_fg = GetValueNode<std::shared_ptr<FuncGraph>>(cond_vnode);
|
auto cond_fg = GetValueNode<std::shared_ptr<FuncGraph>>(cond_vnode);
|
||||||
|
@ -134,30 +232,54 @@ int ControlFlowPass::CreateWhileCondCallNode(
|
||||||
MS_LOG(ERROR) << "Get value as func graph failed.";
|
MS_LOG(ERROR) << "Get value as func graph failed.";
|
||||||
return RET_FAILED;
|
return RET_FAILED;
|
||||||
}
|
}
|
||||||
// get fg input which is not used by cond fg
|
|
||||||
FunGraphInputsOnlyUsedByAfterParts(fg, while_cnode, fg_inputs_only_used_by_after_fg);
|
|
||||||
|
|
||||||
std::vector<AnfNodePtr> cond_call_cnode_inputs{cond_vnode};
|
// create after partial node
|
||||||
cond_call_cnode_inputs.insert(cond_call_cnode_inputs.end(), while_cnode->inputs().begin() + kWhileMinInputSize,
|
ValueNodePtr cond_partial_anf_primitive = GetPartialAnfPrim();
|
||||||
while_cnode->inputs().end());
|
if (cond_partial_anf_primitive == nullptr) {
|
||||||
// set after fg inputs to cond_call_cnode inputs
|
MS_LOG(ERROR) << "GetPartialAnfPrim failed.";
|
||||||
cond_call_cnode_inputs.insert(cond_call_cnode_inputs.end(), fg_inputs_only_used_by_after_fg->begin(),
|
return RET_FAILED;
|
||||||
fg_inputs_only_used_by_after_fg->end());
|
|
||||||
|
|
||||||
*cond_call_cnode = fg->NewCNode(cond_call_cnode_inputs);
|
|
||||||
(*cond_call_cnode)->set_fullname_with_scope("CNode_" + cond_fg->get_attr("graph_name")->ToString());
|
|
||||||
|
|
||||||
for (auto &node : *fg_inputs_only_used_by_after_fg) {
|
|
||||||
if (!utils::isa<ParameterPtr>(node)) {
|
|
||||||
MS_LOG(ERROR) << "fg is not right.";
|
|
||||||
return RET_FAILED;
|
|
||||||
}
|
|
||||||
auto new_parameter = cond_fg->add_parameter();
|
|
||||||
new_parameter->set_name(node->fullname_with_scope() + "_cond_fg_parameter");
|
|
||||||
new_parameter->set_abstract(node->abstract());
|
|
||||||
(*fg_inputs_and_after_partial_inputs_replace_pairs)[node] = new_parameter;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<AnfNodePtr> cond_partial_cnode_inputs{cond_partial_anf_primitive, cond_vnode};
|
||||||
|
cond_partial_cnode_inputs.insert(cond_partial_cnode_inputs.end(), while_cnode->inputs().begin() + kWhileMinInputSize,
|
||||||
|
while_cnode->inputs().end());
|
||||||
|
|
||||||
|
auto origin_cond_fg_inputs = cond_fg->get_inputs();
|
||||||
|
for (auto it = visited_nodes_used_by_after_fg->begin(); it != visited_nodes_used_by_after_fg->end();) {
|
||||||
|
bool found = false;
|
||||||
|
size_t index = -1;
|
||||||
|
for (size_t i = kPartialFirstInputSize; i < cond_partial_cnode_inputs.size(); ++i) {
|
||||||
|
if (cond_partial_cnode_inputs[i] == *it) {
|
||||||
|
found = true;
|
||||||
|
index = i - kPartialFirstInputSize;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (found) {
|
||||||
|
it = visited_nodes_used_by_after_fg->erase(it);
|
||||||
|
(*visited_nodes_and_cond_fg_inputs_replace_pairs)[*it] = origin_cond_fg_inputs.at(index);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// set after fg inputs to cond_partial_cnode inputs
|
||||||
|
cond_partial_cnode_inputs.push_back(*it);
|
||||||
|
auto new_parameter = cond_fg->add_parameter();
|
||||||
|
new_parameter->set_name((*it)->fullname_with_scope() + "_cond_fg_parameter");
|
||||||
|
new_parameter->set_abstract((*it)->abstract());
|
||||||
|
(*visited_nodes_and_cond_fg_inputs_replace_pairs)[*it] = new_parameter;
|
||||||
|
|
||||||
|
++it;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto cond_partial_cnode = fg->NewCNode(cond_partial_cnode_inputs);
|
||||||
|
cond_partial_cnode->set_fullname_with_scope("partial_" + cond_fg->get_attr("graph_name")->ToString());
|
||||||
|
|
||||||
|
// insert call node
|
||||||
|
std::vector<AnfNodePtr> call_node_inputs{cond_partial_cnode};
|
||||||
|
*cond_call_cnode = fg->NewCNode(call_node_inputs);
|
||||||
|
(*cond_call_cnode)->set_fullname_with_scope("call_" + cond_partial_cnode->fullname_with_scope());
|
||||||
|
|
||||||
return RET_SUCCESS;
|
return RET_SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -169,10 +291,7 @@ int ControlFlowPass::CreateWhileBodyPartialNode(const FuncGraphPtr &cond_fg, con
|
||||||
MS_LOG(ERROR) << "Get value as func_graph failed.";
|
MS_LOG(ERROR) << "Get value as func_graph failed.";
|
||||||
return RET_FAILED;
|
return RET_FAILED;
|
||||||
}
|
}
|
||||||
if (ProcessWhileOp(body_fg) != RET_SUCCESS) {
|
|
||||||
MS_LOG(ERROR) << "ProcessWhileOp failed.";
|
|
||||||
return RET_FAILED;
|
|
||||||
}
|
|
||||||
ValueNodePtr partial_anf_primitive = GetPartialAnfPrim();
|
ValueNodePtr partial_anf_primitive = GetPartialAnfPrim();
|
||||||
if (partial_anf_primitive == nullptr) {
|
if (partial_anf_primitive == nullptr) {
|
||||||
MS_LOG(ERROR) << "GetPartialAnfPrim failed.";
|
MS_LOG(ERROR) << "GetPartialAnfPrim failed.";
|
||||||
|
@ -223,13 +342,15 @@ int ControlFlowPass::CreateWhileBodyPartialNode(const FuncGraphPtr &cond_fg, con
|
||||||
auto cond_call_cnode = body_fg->NewCNode(cond_call_cnode_inputs);
|
auto cond_call_cnode = body_fg->NewCNode(cond_call_cnode_inputs);
|
||||||
cond_call_cnode->set_fullname_with_scope(body_fg->get_attr("graph_name")->ToString() + "_call_cond_fg");
|
cond_call_cnode->set_fullname_with_scope(body_fg->get_attr("graph_name")->ToString() + "_call_cond_fg");
|
||||||
body_fg->set_output(cond_call_cnode);
|
body_fg->set_output(cond_call_cnode);
|
||||||
|
|
||||||
|
to_process_q.push_back(body_fg);
|
||||||
return RET_SUCCESS;
|
return RET_SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
int ControlFlowPass::CreateWhileAfterPartialNode(
|
int ControlFlowPass::CreateWhileAfterPartialNode(
|
||||||
const FuncGraphPtr &main_fg, const FuncGraphPtr &cond_fg, const std::vector<AnfNodePtr> &remain_nodes,
|
const FuncGraphPtr &main_fg, const FuncGraphPtr &cond_fg, const std::vector<AnfNodePtr> &remain_nodes,
|
||||||
const std::vector<AnfNodePtr> &fg_inputs_only_used_by_after_fg,
|
const std::vector<AnfNodePtr> &visited_nodes_used_by_after_fg,
|
||||||
const std::unordered_map<AnfNodePtr, AnfNodePtr> &fg_inputs_and_after_partial_inputs_replace_pairs,
|
const std::unordered_map<AnfNodePtr, AnfNodePtr> &visited_nodes_and_cond_fg_inputs_replace_pairs,
|
||||||
CNodePtr *while_cnode, CNodePtr *after_partial_cnode) {
|
CNodePtr *while_cnode, CNodePtr *after_partial_cnode) {
|
||||||
// create after_fg
|
// create after_fg
|
||||||
FuncGraphPtr after_fg = nullptr;
|
FuncGraphPtr after_fg = nullptr;
|
||||||
|
@ -245,7 +366,7 @@ int ControlFlowPass::CreateWhileAfterPartialNode(
|
||||||
return RET_FAILED;
|
return RET_FAILED;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unordered_map<AnfNodePtr, AnfNodePtr> while_output_replace_pairs{};
|
std::unordered_map<AnfNodePtr, AnfNodePtr> after_partial_inputs_and_after_fg_inputs_replace_pairs{};
|
||||||
std::vector<AnfNodePtr> after_partial_cnode_inputs{partial_anf_primitive, after_value_node};
|
std::vector<AnfNodePtr> after_partial_cnode_inputs{partial_anf_primitive, after_value_node};
|
||||||
auto cond_fg_inputs = cond_fg->get_inputs();
|
auto cond_fg_inputs = cond_fg->get_inputs();
|
||||||
for (const auto &node : after_fg->nodes()) {
|
for (const auto &node : after_fg->nodes()) {
|
||||||
|
@ -273,46 +394,33 @@ int ControlFlowPass::CreateWhileAfterPartialNode(
|
||||||
auto new_parameter = after_fg->add_parameter();
|
auto new_parameter = after_fg->add_parameter();
|
||||||
new_parameter->set_name(node->fullname_with_scope() + "_after_partial_parameter");
|
new_parameter->set_name(node->fullname_with_scope() + "_after_partial_parameter");
|
||||||
new_parameter->set_abstract(node->abstract());
|
new_parameter->set_abstract(node->abstract());
|
||||||
while_output_replace_pairs[node] = new_parameter;
|
after_partial_inputs_and_after_fg_inputs_replace_pairs[node] = new_parameter;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto &pair : while_output_replace_pairs) {
|
for (auto &pair : after_partial_inputs_and_after_fg_inputs_replace_pairs) {
|
||||||
// get all nodes in after_fg
|
|
||||||
after_fg->manager()->Replace(pair.first, pair.second);
|
after_fg->manager()->Replace(pair.first, pair.second);
|
||||||
after_fg->DropNode(pair.first);
|
after_fg->DropNode(pair.first);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unordered_map<AnfNodePtr, AnfNodePtr> after_partial_replace_pairs{};
|
std::unordered_map<AnfNodePtr, AnfNodePtr> visited_nodes_after_fg_replace_pair{};
|
||||||
for (auto &input : fg_inputs_only_used_by_after_fg) {
|
for (auto &input : visited_nodes_used_by_after_fg) {
|
||||||
after_partial_cnode_inputs.push_back(fg_inputs_and_after_partial_inputs_replace_pairs.at(input));
|
after_partial_cnode_inputs.push_back(visited_nodes_and_cond_fg_inputs_replace_pairs.at(input));
|
||||||
auto new_parameter = after_fg->add_parameter();
|
auto new_parameter = after_fg->add_parameter();
|
||||||
new_parameter->set_name(input->fullname_with_scope() + "_after_fg_parameter");
|
new_parameter->set_name(input->fullname_with_scope() + "_after_fg_parameter");
|
||||||
new_parameter->set_abstract(input->abstract());
|
new_parameter->set_abstract(input->abstract());
|
||||||
after_partial_replace_pairs[input] = new_parameter;
|
visited_nodes_after_fg_replace_pair[input] = new_parameter;
|
||||||
}
|
}
|
||||||
|
|
||||||
ReplaceNode(after_fg, after_partial_replace_pairs);
|
ReplaceNode(after_fg, visited_nodes_after_fg_replace_pair);
|
||||||
*after_partial_cnode = cond_fg->NewCNode(after_partial_cnode_inputs);
|
*after_partial_cnode = cond_fg->NewCNode(after_partial_cnode_inputs);
|
||||||
(*after_partial_cnode)->set_fullname_with_scope("CNode_" + after_fg->get_attr("graph_name")->ToString());
|
(*after_partial_cnode)->set_fullname_with_scope("CNode_" + after_fg->get_attr("graph_name")->ToString());
|
||||||
return RET_SUCCESS;
|
return RET_SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
int ControlFlowPass::ProcessWhileOp(const FuncGraphPtr &fg) {
|
int ControlFlowPass::ProcessWhileOp(const FuncGraphPtr &fg, const std::set<AnfNodePtr> &visited_nodes,
|
||||||
if (fg == nullptr) {
|
const std::vector<AnfNodePtr> &remain_nodes, const AnfNodePtr &while_node) {
|
||||||
MS_LOG(ERROR) << "fg is nullptr.";
|
|
||||||
return RET_FAILED;
|
|
||||||
}
|
|
||||||
|
|
||||||
AnfNodePtr while_node = nullptr;
|
|
||||||
std::vector<AnfNodePtr> remain_nodes{};
|
|
||||||
int ret = SplitGraph(fg, prim::kPrimWhile, &while_node, &remain_nodes);
|
|
||||||
if (ret != RET_SUCCESS) {
|
|
||||||
MS_LOG(ERROR) << "SplitGraph failed, ret: " << ret;
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (while_node == nullptr) {
|
if (while_node == nullptr) {
|
||||||
MS_LOG(INFO) << "not found while, not need to process.";
|
MS_LOG(INFO) << "not found while, no need to process.";
|
||||||
return RET_SUCCESS;
|
return RET_SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -323,17 +431,19 @@ int ControlFlowPass::ProcessWhileOp(const FuncGraphPtr &fg) {
|
||||||
return RET_FAILED;
|
return RET_FAILED;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<AnfNodePtr> visited_nodes_used_by_after_fg{};
|
||||||
|
VisitedNodesUsedByAfterParts(visited_nodes, remain_nodes, &visited_nodes_used_by_after_fg);
|
||||||
|
|
||||||
CNodePtr cond_call_cnode = nullptr;
|
CNodePtr cond_call_cnode = nullptr;
|
||||||
std::vector<AnfNodePtr> fg_inputs_only_used_by_after_fg{};
|
std::unordered_map<AnfNodePtr, AnfNodePtr> visited_nodes_and_cond_fg_inputs_replace_pairs{};
|
||||||
std::unordered_map<AnfNodePtr, AnfNodePtr> fg_inputs_and_after_partial_inputs_replace_pairs{};
|
int ret = CreateWhileCondCallNode(fg, while_cnode, &visited_nodes_used_by_after_fg, &cond_call_cnode,
|
||||||
ret = CreateWhileCondCallNode(fg, while_cnode, &fg_inputs_only_used_by_after_fg, &cond_call_cnode,
|
&visited_nodes_and_cond_fg_inputs_replace_pairs);
|
||||||
&fg_inputs_and_after_partial_inputs_replace_pairs);
|
|
||||||
if (ret != RET_SUCCESS) {
|
if (ret != RET_SUCCESS) {
|
||||||
MS_LOG(ERROR) << "while create cond call cnode failed, ret: " << ret;
|
MS_LOG(ERROR) << "while create cond call cnode failed, ret: " << ret;
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
AnfNodePtr cond_fg_vnode = cond_call_cnode->input(kCNodePrimIndex);
|
AnfNodePtr cond_fg_vnode = cond_call_cnode->input(kCNodePrimIndex)->cast<CNodePtr>()->input(kCNodeFirstInputIndex);
|
||||||
MS_ASSERT(cond_fg_vnode != nullptr);
|
MS_ASSERT(cond_fg_vnode != nullptr);
|
||||||
auto cond_fg = GetValueNode<std::shared_ptr<FuncGraph>>(cond_fg_vnode);
|
auto cond_fg = GetValueNode<std::shared_ptr<FuncGraph>>(cond_fg_vnode);
|
||||||
if (cond_fg == nullptr) {
|
if (cond_fg == nullptr) {
|
||||||
|
@ -349,9 +459,8 @@ int ControlFlowPass::ProcessWhileOp(const FuncGraphPtr &fg) {
|
||||||
}
|
}
|
||||||
|
|
||||||
CNodePtr after_partial_cnode = nullptr;
|
CNodePtr after_partial_cnode = nullptr;
|
||||||
ret =
|
ret = CreateWhileAfterPartialNode(fg, cond_fg, remain_nodes, visited_nodes_used_by_after_fg,
|
||||||
CreateWhileAfterPartialNode(fg, cond_fg, remain_nodes, fg_inputs_only_used_by_after_fg,
|
visited_nodes_and_cond_fg_inputs_replace_pairs, &while_cnode, &after_partial_cnode);
|
||||||
fg_inputs_and_after_partial_inputs_replace_pairs, &while_cnode, &after_partial_cnode);
|
|
||||||
if (ret != RET_SUCCESS) {
|
if (ret != RET_SUCCESS) {
|
||||||
MS_LOG(ERROR) << "while create after partial cnode failed, ret: " << ret;
|
MS_LOG(ERROR) << "while create after partial cnode failed, ret: " << ret;
|
||||||
return ret;
|
return ret;
|
||||||
|
@ -368,34 +477,33 @@ int ControlFlowPass::ProcessWhileOp(const FuncGraphPtr &fg) {
|
||||||
std::vector<AnfNodePtr> switch_node_inputs = {switch_anf_primitive, cond_fg->output(), body_partial_node,
|
std::vector<AnfNodePtr> switch_node_inputs = {switch_anf_primitive, cond_fg->output(), body_partial_node,
|
||||||
after_partial_cnode};
|
after_partial_cnode};
|
||||||
auto switch_cnode = cond_fg->NewCNode(switch_node_inputs);
|
auto switch_cnode = cond_fg->NewCNode(switch_node_inputs);
|
||||||
switch_cnode->set_fullname_with_scope("Switch-" + cond_fg->get_attr("graph_name")->ToString());
|
switch_cnode->set_fullname_with_scope("while-Switch-" + cond_fg->get_attr("graph_name")->ToString());
|
||||||
|
|
||||||
// insert call node
|
// insert call node
|
||||||
std::vector<AnfNodePtr> call_node_inputs{switch_cnode};
|
std::vector<AnfNodePtr> call_node_inputs{switch_cnode};
|
||||||
auto call_node = cond_fg->NewCNode(call_node_inputs);
|
auto call_node = cond_fg->NewCNode(call_node_inputs);
|
||||||
call_node->set_fullname_with_scope("call_" + switch_cnode->fullname_with_scope());
|
call_node->set_fullname_with_scope("call_" + switch_cnode->fullname_with_scope());
|
||||||
cond_fg->set_output(call_node);
|
cond_fg->set_output(call_node);
|
||||||
|
|
||||||
fg->DropNode(while_cnode);
|
fg->DropNode(while_cnode);
|
||||||
fg->set_output(cond_call_cnode);
|
fg->set_output(cond_call_cnode);
|
||||||
|
|
||||||
FuncGraphPtr after_fg =
|
auto after_fg =
|
||||||
after_partial_cnode->input(kCNodeFirstInputIndex)->cast<ValueNodePtr>()->value()->cast<FuncGraphPtr>();
|
after_partial_cnode->input(kCNodeFirstInputIndex)->cast<ValueNodePtr>()->value()->cast<FuncGraphPtr>();
|
||||||
if (after_fg == nullptr) {
|
if (after_fg == nullptr) {
|
||||||
MS_LOG(ERROR) << "after_fg is nullptr.";
|
MS_LOG(ERROR) << "after_fg is nullptr.";
|
||||||
return RET_FAILED;
|
return RET_FAILED;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!Run(after_fg)) {
|
to_process_q.push_back(cond_fg);
|
||||||
MS_LOG(ERROR) << "process control flow for after fg failed.";
|
to_process_q.push_back(after_fg);
|
||||||
return RET_FAILED;
|
|
||||||
}
|
|
||||||
return RET_SUCCESS;
|
return RET_SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
int ControlFlowPass::CreateIfPartialNode(const FuncGraphPtr &fg,
|
int ControlFlowPass::CreateIfPartialNode(const FuncGraphPtr &fg, const size_t &index,
|
||||||
const std::vector<AnfNodePtr> &fg_inputs_only_used_by_after_partial,
|
std::vector<AnfNodePtr> *visited_nodes_used_by_after_fg, CNodePtr *if_cnode,
|
||||||
const size_t &index, CNodePtr *if_cnode, FuncGraphPtr *after_fg,
|
FuncGraphPtr *after_fg, CNodePtr *then_partial_cnode) {
|
||||||
CNodePtr *then_partial_cnode) {
|
|
||||||
auto then_vnode = (*if_cnode)->input(index);
|
auto then_vnode = (*if_cnode)->input(index);
|
||||||
MS_ASSERT(then_vnode != nullptr);
|
MS_ASSERT(then_vnode != nullptr);
|
||||||
auto then_fg = GetValueNode<std::shared_ptr<FuncGraph>>(then_vnode);
|
auto then_fg = GetValueNode<std::shared_ptr<FuncGraph>>(then_vnode);
|
||||||
|
@ -404,7 +512,7 @@ int ControlFlowPass::CreateIfPartialNode(const FuncGraphPtr &fg,
|
||||||
return RET_FAILED;
|
return RET_FAILED;
|
||||||
}
|
}
|
||||||
|
|
||||||
// create after partial node
|
// create then partial node
|
||||||
ValueNodePtr then_partial_anf_primitive = GetPartialAnfPrim();
|
ValueNodePtr then_partial_anf_primitive = GetPartialAnfPrim();
|
||||||
if (then_partial_anf_primitive == nullptr) {
|
if (then_partial_anf_primitive == nullptr) {
|
||||||
MS_LOG(ERROR) << "GetPartialAnfPrim failed.";
|
MS_LOG(ERROR) << "GetPartialAnfPrim failed.";
|
||||||
|
@ -414,28 +522,46 @@ int ControlFlowPass::CreateIfPartialNode(const FuncGraphPtr &fg,
|
||||||
then_partial_cnode_inputs.insert(then_partial_cnode_inputs.end(), (*if_cnode)->inputs().begin() + kIfMinInputSize,
|
then_partial_cnode_inputs.insert(then_partial_cnode_inputs.end(), (*if_cnode)->inputs().begin() + kIfMinInputSize,
|
||||||
(*if_cnode)->inputs().end());
|
(*if_cnode)->inputs().end());
|
||||||
|
|
||||||
|
auto if_cond_input = (*if_cnode)->inputs()[kIfCondIndex];
|
||||||
|
|
||||||
|
std::unordered_map<AnfNodePtr, AnfNodePtr> visited_nodes_and_after_partial_inputs_replace_pairs{};
|
||||||
|
std::vector<AnfNodePtr> then_nodes_used_by_after_partial{};
|
||||||
// set fg inputs to then_partial_cnode inputs
|
// set fg inputs to then_partial_cnode inputs
|
||||||
then_partial_cnode_inputs.insert(then_partial_cnode_inputs.end(), fg_inputs_only_used_by_after_partial.begin(),
|
auto origin_then_fg_inputs = then_fg->get_inputs();
|
||||||
fg_inputs_only_used_by_after_partial.end());
|
for (auto &item : *visited_nodes_used_by_after_fg) {
|
||||||
|
bool found = false;
|
||||||
|
size_t input_index = -1;
|
||||||
|
for (size_t i = kPartialFirstInputSize; i < then_partial_cnode_inputs.size(); ++i) {
|
||||||
|
if (then_partial_cnode_inputs[i] == item) {
|
||||||
|
found = true;
|
||||||
|
input_index = i - kPartialFirstInputSize;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (found) {
|
||||||
|
visited_nodes_and_after_partial_inputs_replace_pairs[item] = origin_then_fg_inputs.at(input_index);
|
||||||
|
then_nodes_used_by_after_partial.push_back(origin_then_fg_inputs.at(input_index));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// set after fg inputs to cond_partial_cnode inputs
|
||||||
|
then_partial_cnode_inputs.push_back(item);
|
||||||
|
auto new_parameter = then_fg->add_parameter();
|
||||||
|
if (index == kIfThenIndex) {
|
||||||
|
new_parameter->set_name(item->fullname_with_scope() + "_then_fg_parameter");
|
||||||
|
} else {
|
||||||
|
new_parameter->set_name(item->fullname_with_scope() + "_else_fg_parameter");
|
||||||
|
}
|
||||||
|
new_parameter->set_abstract(item->abstract());
|
||||||
|
visited_nodes_and_after_partial_inputs_replace_pairs[item] = new_parameter;
|
||||||
|
then_nodes_used_by_after_partial.push_back(new_parameter);
|
||||||
|
}
|
||||||
|
|
||||||
*then_partial_cnode = fg->NewCNode(then_partial_cnode_inputs);
|
*then_partial_cnode = fg->NewCNode(then_partial_cnode_inputs);
|
||||||
auto then_fg_name = then_fg->get_attr("graph_name")->ToString();
|
auto then_fg_name = then_fg->get_attr("graph_name")->ToString();
|
||||||
(*then_partial_cnode)->set_fullname_with_scope("partial_" + then_fg_name);
|
(*then_partial_cnode)->set_fullname_with_scope("partial_" + then_fg_name);
|
||||||
|
|
||||||
std::unordered_map<AnfNodePtr, AnfNodePtr> then_fg_inputs_and_fg_inputs_replace_pairs{};
|
|
||||||
std::vector<AnfNodePtr> new_parameters{};
|
|
||||||
for (auto &node : fg_inputs_only_used_by_after_partial) {
|
|
||||||
if (!utils::isa<ParameterPtr>(node)) {
|
|
||||||
MS_LOG(ERROR) << "fg is not right.";
|
|
||||||
return RET_FAILED;
|
|
||||||
}
|
|
||||||
auto new_parameter = then_fg->add_parameter();
|
|
||||||
new_parameter->set_name(node->fullname_with_scope() + "_" + then_fg_name + "_parameter");
|
|
||||||
new_parameter->set_abstract(node->abstract());
|
|
||||||
then_fg_inputs_and_fg_inputs_replace_pairs[node] = new_parameter;
|
|
||||||
new_parameters.push_back(new_parameter);
|
|
||||||
}
|
|
||||||
|
|
||||||
// create after partial node
|
// create after partial node
|
||||||
ValueNodePtr after_partial_anf_primitive = GetPartialAnfPrim();
|
ValueNodePtr after_partial_anf_primitive = GetPartialAnfPrim();
|
||||||
if (after_partial_anf_primitive == nullptr) {
|
if (after_partial_anf_primitive == nullptr) {
|
||||||
|
@ -445,17 +571,21 @@ int ControlFlowPass::CreateIfPartialNode(const FuncGraphPtr &fg,
|
||||||
auto after_value_node = NewValueNode(*after_fg);
|
auto after_value_node = NewValueNode(*after_fg);
|
||||||
// make the right after partial input
|
// make the right after partial input
|
||||||
std::vector<AnfNodePtr> after_partial_cnode_inputs{after_partial_anf_primitive, after_value_node};
|
std::vector<AnfNodePtr> after_partial_cnode_inputs{after_partial_anf_primitive, after_value_node};
|
||||||
auto then_fg_output = then_fg->output()->cast<CNodePtr>();
|
if (!CheckPrimitiveType(then_fg->output(), prim::kPrimMakeTuple)) {
|
||||||
if (!CheckPrimitiveType(then_fg_output, prim::kPrimMakeTuple)) {
|
after_partial_cnode_inputs.push_back(then_fg->output());
|
||||||
after_partial_cnode_inputs.push_back(then_fg_output);
|
|
||||||
} else {
|
} else {
|
||||||
|
auto then_fg_output = then_fg->output()->cast<CNodePtr>();
|
||||||
for (size_t i = kCNodeFirstInputIndex; i < then_fg_output->inputs().size(); ++i) {
|
for (size_t i = kCNodeFirstInputIndex; i < then_fg_output->inputs().size(); ++i) {
|
||||||
after_partial_cnode_inputs.push_back(then_fg_output->input(i));
|
after_partial_cnode_inputs.push_back(then_fg_output->input(i));
|
||||||
}
|
}
|
||||||
|
then_fg->DropNode(then_fg_output);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
size_t if_output_size = after_partial_cnode_inputs.size() - 2;
|
||||||
|
|
||||||
// add after fg inputs to partial node
|
// add after fg inputs to partial node
|
||||||
std::copy(new_parameters.begin(), new_parameters.end(), std::back_inserter(after_partial_cnode_inputs));
|
std::copy(then_nodes_used_by_after_partial.begin(), then_nodes_used_by_after_partial.end(),
|
||||||
|
std::back_inserter(after_partial_cnode_inputs));
|
||||||
|
|
||||||
// insert partial node
|
// insert partial node
|
||||||
auto after_partial_cnode = then_fg->NewCNode(after_partial_cnode_inputs);
|
auto after_partial_cnode = then_fg->NewCNode(after_partial_cnode_inputs);
|
||||||
|
@ -467,61 +597,53 @@ int ControlFlowPass::CreateIfPartialNode(const FuncGraphPtr &fg,
|
||||||
auto call_node = then_fg->NewCNode(call_node_inputs);
|
auto call_node = then_fg->NewCNode(call_node_inputs);
|
||||||
call_node->set_fullname_with_scope("call_" + after_partial_cnode->fullname_with_scope());
|
call_node->set_fullname_with_scope("call_" + after_partial_cnode->fullname_with_scope());
|
||||||
then_fg->set_output(call_node);
|
then_fg->set_output(call_node);
|
||||||
then_fg->DropNode(then_fg_output);
|
to_process_q.push_back(then_fg);
|
||||||
|
|
||||||
|
ReplaceNode(*after_fg, visited_nodes_and_after_partial_inputs_replace_pairs);
|
||||||
|
|
||||||
// check the inputs of after fg
|
// check the inputs of after fg
|
||||||
auto after_fg_inputs_size = (*after_fg)->get_inputs().size();
|
auto after_fg_inputs_size = (*after_fg)->get_inputs().size();
|
||||||
if (after_fg_inputs_size == after_partial_cnode_inputs.size() - 2) {
|
if (after_fg_inputs_size == after_partial_cnode_inputs.size() - kPartialFirstInputSize) {
|
||||||
MS_LOG(INFO) << "not need add after fg input parameters.";
|
MS_LOG(INFO) << "not need add after fg input parameters.";
|
||||||
return RET_SUCCESS;
|
return RET_SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
// make the inputs of the after fg
|
// make the inputs of the after fg
|
||||||
std::unordered_map<AnfNodePtr, AnfNodePtr> after_partial_after_fg_replace_pairs{};
|
std::unordered_map<AnfNodePtr, AnfNodePtr> after_partial_after_fg_replace_pairs{};
|
||||||
std::unordered_map<AnfNodePtr, AnfNodePtr> if_cnode_after_fg_replace_pairs{};
|
|
||||||
for (size_t i = kPartialFirstInputSize; i < after_partial_cnode_inputs.size(); ++i) {
|
for (size_t i = kPartialFirstInputSize; i < after_partial_cnode_inputs.size(); ++i) {
|
||||||
auto &input = after_partial_cnode_inputs[i];
|
auto &input = after_partial_cnode_inputs[i];
|
||||||
auto new_parameter = (*after_fg)->add_parameter();
|
auto new_parameter = (*after_fg)->add_parameter();
|
||||||
new_parameter->set_name(input->fullname_with_scope() + "_after_fg_parameter");
|
new_parameter->set_name(std::to_string(i - kPartialFirstInputSize) + "_" + input->fullname_with_scope());
|
||||||
new_parameter->set_abstract(input->abstract());
|
new_parameter->set_abstract(input->abstract());
|
||||||
after_partial_after_fg_replace_pairs[input] = new_parameter;
|
if (i < kPartialFirstInputSize + if_output_size) {
|
||||||
if (i < kPartialFirstInputSize + (*if_cnode)->size() - kIfMinInputSize) {
|
|
||||||
after_partial_after_fg_replace_pairs[*if_cnode] = new_parameter;
|
after_partial_after_fg_replace_pairs[*if_cnode] = new_parameter;
|
||||||
|
} else {
|
||||||
|
after_partial_after_fg_replace_pairs[input] = new_parameter;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ReplaceNode(*after_fg, then_fg_inputs_and_fg_inputs_replace_pairs);
|
|
||||||
ReplaceNode(*after_fg, after_partial_after_fg_replace_pairs);
|
ReplaceNode(*after_fg, after_partial_after_fg_replace_pairs);
|
||||||
|
|
||||||
return RET_SUCCESS;
|
return RET_SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
int ControlFlowPass::CreateIfElsePartialNode(const FuncGraphPtr &main_fg,
|
int ControlFlowPass::CreateIfElsePartialNode(const FuncGraphPtr &main_fg,
|
||||||
const std::vector<AnfNodePtr> &fg_inputs_only_used_by_after_partial,
|
std::vector<AnfNodePtr> *visited_nodes_used_by_after_fg,
|
||||||
CNodePtr *if_cnode, FuncGraphPtr *after_fg, CNodePtr *else_partial_cnode) {
|
CNodePtr *if_cnode, FuncGraphPtr *after_fg, CNodePtr *else_partial_cnode) {
|
||||||
return CreateIfPartialNode(main_fg, fg_inputs_only_used_by_after_partial, kIfElseIndex, if_cnode, after_fg,
|
return CreateIfPartialNode(main_fg, kIfElseIndex, visited_nodes_used_by_after_fg, if_cnode, after_fg,
|
||||||
else_partial_cnode);
|
else_partial_cnode);
|
||||||
}
|
}
|
||||||
|
|
||||||
int ControlFlowPass::CreateIfThenPartialNode(const FuncGraphPtr &main_fg,
|
int ControlFlowPass::CreateIfThenPartialNode(const FuncGraphPtr &main_fg,
|
||||||
const std::vector<AnfNodePtr> &fg_inputs_only_used_by_after_partial,
|
std::vector<AnfNodePtr> *visited_nodes_used_by_after_fg,
|
||||||
CNodePtr *if_cnode, FuncGraphPtr *after_fg, CNodePtr *then_partial_cnode) {
|
CNodePtr *if_cnode, FuncGraphPtr *after_fg, CNodePtr *then_partial_cnode) {
|
||||||
return CreateIfPartialNode(main_fg, fg_inputs_only_used_by_after_partial, kIfThenIndex, if_cnode, after_fg,
|
return CreateIfPartialNode(main_fg, kIfThenIndex, visited_nodes_used_by_after_fg, if_cnode, after_fg,
|
||||||
then_partial_cnode);
|
then_partial_cnode);
|
||||||
}
|
}
|
||||||
|
|
||||||
int ControlFlowPass::ProcessIfOp(const FuncGraphPtr &fg) {
|
int ControlFlowPass::ProcessIfOp(const FuncGraphPtr &fg, const std::set<AnfNodePtr> &visited_nodes,
|
||||||
if (fg == nullptr) {
|
const std::vector<AnfNodePtr> &remain_nodes, const AnfNodePtr &if_node) {
|
||||||
MS_LOG(ERROR) << "fg is nullptr.";
|
|
||||||
return RET_FAILED;
|
|
||||||
}
|
|
||||||
AnfNodePtr if_node = nullptr;
|
|
||||||
std::vector<AnfNodePtr> remain_nodes{};
|
|
||||||
int ret = SplitGraph(fg, prim::kPrimIf, &if_node, &remain_nodes);
|
|
||||||
if (ret != RET_SUCCESS) {
|
|
||||||
MS_LOG(ERROR) << "SplitGraph failed, ret: " << ret;
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (if_node == nullptr) {
|
if (if_node == nullptr) {
|
||||||
MS_LOG(INFO) << "not found if, not need to process.";
|
MS_LOG(INFO) << "not found if, no need to process.";
|
||||||
return RET_SUCCESS;
|
return RET_SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -540,18 +662,18 @@ int ControlFlowPass::ProcessIfOp(const FuncGraphPtr &fg) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// get fg input which is not used by after_parts
|
// get fg input which is not used by after_parts
|
||||||
std::vector<AnfNodePtr> fg_inputs_only_used_by_after_partial{};
|
std::vector<AnfNodePtr> visited_nodes_used_by_after_fg{};
|
||||||
FunGraphInputsOnlyUsedByAfterParts(fg, if_cnode, &fg_inputs_only_used_by_after_partial);
|
VisitedNodesUsedByAfterParts(visited_nodes, remain_nodes, &visited_nodes_used_by_after_fg);
|
||||||
|
|
||||||
CNodePtr then_partial_cnode = nullptr;
|
CNodePtr then_partial_cnode = nullptr;
|
||||||
ret = CreateIfThenPartialNode(fg, fg_inputs_only_used_by_after_partial, &if_cnode, &after_fg, &then_partial_cnode);
|
int ret = CreateIfThenPartialNode(fg, &visited_nodes_used_by_after_fg, &if_cnode, &after_fg, &then_partial_cnode);
|
||||||
if (ret != RET_SUCCESS) {
|
if (ret != RET_SUCCESS) {
|
||||||
MS_LOG(ERROR) << "if create then partial cnode failed, ret: " << ret;
|
MS_LOG(ERROR) << "if create then partial cnode failed, ret: " << ret;
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
CNodePtr else_partial_cnode = nullptr;
|
CNodePtr else_partial_cnode = nullptr;
|
||||||
ret = CreateIfElsePartialNode(fg, fg_inputs_only_used_by_after_partial, &if_cnode, &after_fg, &else_partial_cnode);
|
ret = CreateIfElsePartialNode(fg, &visited_nodes_used_by_after_fg, &if_cnode, &after_fg, &else_partial_cnode);
|
||||||
if (ret != RET_SUCCESS) {
|
if (ret != RET_SUCCESS) {
|
||||||
MS_LOG(ERROR) << "if create else partial cnode failed, ret: " << ret;
|
MS_LOG(ERROR) << "if create else partial cnode failed, ret: " << ret;
|
||||||
return ret;
|
return ret;
|
||||||
|
@ -568,35 +690,70 @@ int ControlFlowPass::ProcessIfOp(const FuncGraphPtr &fg) {
|
||||||
std::vector<AnfNodePtr> switch_node_inputs = {switch_anf_primitive, if_cnode->input(kIfCondIndex), then_partial_cnode,
|
std::vector<AnfNodePtr> switch_node_inputs = {switch_anf_primitive, if_cnode->input(kIfCondIndex), then_partial_cnode,
|
||||||
else_partial_cnode};
|
else_partial_cnode};
|
||||||
auto switch_cnode = fg->NewCNode(switch_node_inputs);
|
auto switch_cnode = fg->NewCNode(switch_node_inputs);
|
||||||
switch_cnode->set_fullname_with_scope("Switch-" + fg->get_attr("graph_name")->ToString());
|
switch_cnode->set_fullname_with_scope("if-Switch-" + fg->get_attr("graph_name")->ToString());
|
||||||
|
|
||||||
// insert call node
|
// insert call node
|
||||||
std::vector<AnfNodePtr> call_node_inputs{switch_cnode};
|
std::vector<AnfNodePtr> call_node_inputs{switch_cnode};
|
||||||
auto call_node = fg->NewCNode(call_node_inputs);
|
auto call_node = fg->NewCNode(call_node_inputs);
|
||||||
call_node->set_fullname_with_scope("call_" + switch_cnode->fullname_with_scope());
|
call_node->set_fullname_with_scope("call_" + switch_cnode->fullname_with_scope());
|
||||||
fg->DropNode(if_cnode);
|
fg->DropNode(if_cnode);
|
||||||
fg->set_output(call_node);
|
fg->set_output(call_node, true);
|
||||||
|
|
||||||
if (!Run(after_fg)) {
|
to_process_q.push_back(after_fg);
|
||||||
MS_LOG(ERROR) << "process control flow for after fg failed.";
|
|
||||||
return RET_FAILED;
|
|
||||||
}
|
|
||||||
|
|
||||||
return RET_SUCCESS;
|
return RET_SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ControlFlowPass::Run(const FuncGraphPtr &fg) {
|
int ControlFlowPass::ProcessControlOp(const FuncGraphPtr &fg) {
|
||||||
int ret = ProcessWhileOp(fg);
|
if (fg == nullptr) {
|
||||||
if (ret != RET_SUCCESS) {
|
MS_LOG(ERROR) << "fg is nullptr.";
|
||||||
MS_LOG(ERROR) << "ProcessWhileOp failed.";
|
return RET_FAILED;
|
||||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(ret);
|
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
ret = ProcessIfOp(fg);
|
|
||||||
|
AnfNodePtr control_flow_node = nullptr;
|
||||||
|
std::vector<AnfNodePtr> remain_nodes{};
|
||||||
|
std::set<AnfNodePtr> visited_nodes{};
|
||||||
|
int ret = SplitGraph(fg, &control_flow_node, &visited_nodes, &remain_nodes);
|
||||||
if (ret != RET_SUCCESS) {
|
if (ret != RET_SUCCESS) {
|
||||||
MS_LOG(ERROR) << "ProcessIfOp failed.";
|
MS_LOG(ERROR) << "SplitGraph failed, ret: " << ret;
|
||||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(ret);
|
return ret;
|
||||||
return false;
|
}
|
||||||
|
|
||||||
|
if (control_flow_node == nullptr) {
|
||||||
|
MS_LOG(INFO) << "not found control flow op, no need to process.";
|
||||||
|
return RET_SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (CheckPrimitiveType(control_flow_node, prim::kPrimWhile)) {
|
||||||
|
ret = ProcessWhileOp(fg, visited_nodes, remain_nodes, control_flow_node);
|
||||||
|
if (ret != RET_SUCCESS) {
|
||||||
|
MS_LOG(ERROR) << "ProcessWhileOp failed.";
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (CheckPrimitiveType(control_flow_node, prim::kPrimIf)) {
|
||||||
|
ret = ProcessIfOp(fg, visited_nodes, remain_nodes, control_flow_node);
|
||||||
|
if (ret != RET_SUCCESS) {
|
||||||
|
MS_LOG(ERROR) << "ProcessIfOp failed.";
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return RET_SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ControlFlowPass::Run(const FuncGraphPtr &fg) {
|
||||||
|
to_process_q.push_back(fg);
|
||||||
|
while (!to_process_q.empty()) {
|
||||||
|
auto cur_fg = to_process_q.front();
|
||||||
|
auto cur_fg_name = cur_fg->get_attr("graph_name")->ToString();
|
||||||
|
int ret = ProcessControlOp(cur_fg);
|
||||||
|
if (ret != RET_SUCCESS) {
|
||||||
|
MS_LOG(ERROR) << "ProcessControlOp for graph: " << cur_fg_name << " failed.";
|
||||||
|
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(ret);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
to_process_q.pop_front();
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,6 +19,8 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
#include <deque>
|
||||||
|
#include <set>
|
||||||
#include "schema/inner/model_generated.h"
|
#include "schema/inner/model_generated.h"
|
||||||
#include "tools/converter/converter_flags.h"
|
#include "tools/converter/converter_flags.h"
|
||||||
#include "backend/optimizer/common/pass.h"
|
#include "backend/optimizer/common/pass.h"
|
||||||
|
@ -35,37 +37,46 @@ class ControlFlowPass : public Pass {
|
||||||
static ValueNodePtr GetSwitchAnfPrim();
|
static ValueNodePtr GetSwitchAnfPrim();
|
||||||
static ValueNodePtr GetPartialAnfPrim();
|
static ValueNodePtr GetPartialAnfPrim();
|
||||||
void ReplaceNode(const FuncGraphPtr &fg, const std::unordered_map<AnfNodePtr, AnfNodePtr> &replace_pairs);
|
void ReplaceNode(const FuncGraphPtr &fg, const std::unordered_map<AnfNodePtr, AnfNodePtr> &replace_pairs);
|
||||||
void FunGraphInputsOnlyUsedByAfterParts(const FuncGraphPtr &fg, const CNodePtr &aim_cnode,
|
void VisitedNodesUsedByAfterParts(const std::set<AnfNodePtr> &visited_nodes,
|
||||||
std::vector<AnfNodePtr> *fg_inputs_only_used_by_after_fg);
|
const std::vector<AnfNodePtr> &remain_nodes,
|
||||||
int SplitGraph(const FuncGraphPtr &fg, const PrimitivePtr &aim_prim, AnfNodePtr *aim_prim_type_node,
|
std::vector<AnfNodePtr> *visited_nodes_used_by_after_fg);
|
||||||
|
int SplitGraph(const FuncGraphPtr &fg, AnfNodePtr *control_flow_node, std::set<AnfNodePtr> *visited_nodes,
|
||||||
std::vector<AnfNodePtr> *remain_nodes);
|
std::vector<AnfNodePtr> *remain_nodes);
|
||||||
|
size_t GetItemVisitedNums(const std::set<AnfNodePtr> &visited_nodes, const AnfNodePtr &tuple_node);
|
||||||
|
void MoveGetItemToVisited(const size_t &need_size, const AnfNodePtr &tuple_node, std::set<AnfNodePtr> *visited_nodes,
|
||||||
|
std::vector<AnfNodePtr> *remain_nodes);
|
||||||
|
void BindGetItemNodes(std::set<AnfNodePtr> *visited_nodes, std::vector<AnfNodePtr> *remain_nodes);
|
||||||
int CreateAfterGraph(const FuncGraphPtr &main_fg, const std::vector<AnfNodePtr> &remain_nodes,
|
int CreateAfterGraph(const FuncGraphPtr &main_fg, const std::vector<AnfNodePtr> &remain_nodes,
|
||||||
const CNodePtr &aim_cnode, FuncGraphPtr *after_fg);
|
const CNodePtr &aim_cnode, FuncGraphPtr *after_fg);
|
||||||
|
|
||||||
// process while
|
// process while
|
||||||
int CreateWhileCondCallNode(
|
int CreateWhileCondCallNode(
|
||||||
const FuncGraphPtr &fg, const CNodePtr &while_cnode, std::vector<AnfNodePtr> *fg_inputs_only_used_by_after_fg,
|
const FuncGraphPtr &fg, const CNodePtr &while_cnode, std::vector<AnfNodePtr> *visited_nodes_used_by_after_fg,
|
||||||
CNodePtr *cond_partial_cnode,
|
CNodePtr *cond_partial_cnode,
|
||||||
std::unordered_map<AnfNodePtr, AnfNodePtr> *fg_inputs_and_after_partial_inputs_replace_pairs);
|
std::unordered_map<AnfNodePtr, AnfNodePtr> *visited_nodes_and_cond_fg_inputs_replace_pairs);
|
||||||
int CreateWhileBodyPartialNode(const FuncGraphPtr &cond_fg, const CNodePtr &while_cnode, CNodePtr *body_partial_node);
|
int CreateWhileBodyPartialNode(const FuncGraphPtr &cond_fg, const CNodePtr &while_cnode, CNodePtr *body_partial_node);
|
||||||
int CreateWhileAfterPartialNode(
|
int CreateWhileAfterPartialNode(
|
||||||
const FuncGraphPtr &main_fg, const FuncGraphPtr &cond_fg, const std::vector<AnfNodePtr> &remain_nodes,
|
const FuncGraphPtr &main_fg, const FuncGraphPtr &cond_fg, const std::vector<AnfNodePtr> &remain_nodes,
|
||||||
const std::vector<AnfNodePtr> &fg_inputs_only_used_by_after_fg,
|
const std::vector<AnfNodePtr> &visited_nodes_used_by_after_fg,
|
||||||
const std::unordered_map<AnfNodePtr, AnfNodePtr> &fg_inputs_and_after_partial_inputs_replace_pairs,
|
const std::unordered_map<AnfNodePtr, AnfNodePtr> &visited_nodes_and_cond_fg_inputs_replace_pairs,
|
||||||
CNodePtr *while_cnode, CNodePtr *after_partial_cnode);
|
CNodePtr *while_cnode, CNodePtr *after_partial_cnode);
|
||||||
int ProcessWhileOp(const FuncGraphPtr &fg);
|
int ProcessWhileOp(const FuncGraphPtr &fg, const std::set<AnfNodePtr> &visited_nodes,
|
||||||
|
const std::vector<AnfNodePtr> &remain_nodes, const AnfNodePtr &while_node);
|
||||||
|
|
||||||
// process if
|
// process if
|
||||||
int CreateIfPartialNode(const FuncGraphPtr &fg, const std::vector<AnfNodePtr> &fg_inputs_only_used_by_after_partial,
|
int CreateIfPartialNode(const FuncGraphPtr &fg, const size_t &index,
|
||||||
const size_t &index, CNodePtr *if_cnode, FuncGraphPtr *after_fg,
|
std::vector<AnfNodePtr> *fg_inputs_only_used_by_after_partial, CNodePtr *if_cnode,
|
||||||
CNodePtr *then_partial_cnode);
|
FuncGraphPtr *after_fg, CNodePtr *then_partial_cnode);
|
||||||
int CreateIfThenPartialNode(const FuncGraphPtr &main_fg,
|
int CreateIfThenPartialNode(const FuncGraphPtr &main_fg,
|
||||||
const std::vector<AnfNodePtr> &fg_inputs_only_used_by_after_partial, CNodePtr *if_cnode,
|
std::vector<AnfNodePtr> *fg_inputs_only_used_by_after_partial, CNodePtr *if_cnode,
|
||||||
FuncGraphPtr *after_fg, CNodePtr *then_partial_cnode);
|
FuncGraphPtr *after_fg, CNodePtr *then_partial_cnode);
|
||||||
int CreateIfElsePartialNode(const FuncGraphPtr &main_fg,
|
int CreateIfElsePartialNode(const FuncGraphPtr &main_fg,
|
||||||
const std::vector<AnfNodePtr> &fg_inputs_only_used_by_after_partial, CNodePtr *if_cnode,
|
std::vector<AnfNodePtr> *fg_inputs_only_used_by_after_partial, CNodePtr *if_cnode,
|
||||||
FuncGraphPtr *after_fg, CNodePtr *else_partial_cnode);
|
FuncGraphPtr *after_fg, CNodePtr *else_partial_cnode);
|
||||||
int ProcessIfOp(const FuncGraphPtr &fg);
|
int ProcessIfOp(const FuncGraphPtr &fg, const std::set<AnfNodePtr> &visited_nodes,
|
||||||
|
const std::vector<AnfNodePtr> &remain_nodes, const AnfNodePtr &if_node);
|
||||||
|
|
||||||
|
int ProcessControlOp(const FuncGraphPtr &fg);
|
||||||
|
|
||||||
const size_t kCNodePrimIndex = 0;
|
const size_t kCNodePrimIndex = 0;
|
||||||
const size_t kCNodeFirstInputIndex = 1;
|
const size_t kCNodeFirstInputIndex = 1;
|
||||||
|
@ -82,6 +93,8 @@ class ControlFlowPass : public Pass {
|
||||||
const size_t kIfThenIndex = 1;
|
const size_t kIfThenIndex = 1;
|
||||||
const size_t kIfElseIndex = 2;
|
const size_t kIfElseIndex = 2;
|
||||||
const size_t kIfCondIndex = 3;
|
const size_t kIfCondIndex = 3;
|
||||||
|
|
||||||
|
std::deque<FuncGraphPtr> to_process_q{};
|
||||||
};
|
};
|
||||||
} // namespace mindspore::opt
|
} // namespace mindspore::opt
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -127,6 +127,13 @@ int RemoveRedundantOpPass::ReplaceOp(const AnfNodePtr &anf_node, const FuncGraph
|
||||||
return lite::RET_NO_CHANGE;
|
return lite::RET_NO_CHANGE;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (CheckPrimitiveType(anf_node, prim::kPrimTranspose)) {
|
||||||
|
if (cnode->size() != kInputTripleNum) {
|
||||||
|
MS_LOG(DEBUG) << "The node inputs size is bigger than 2";
|
||||||
|
remove_cnode_.insert(anf_node);
|
||||||
|
return lite::RET_NO_CHANGE;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
bool replace_succ = manager->Replace(anf_node, cnode->input(1));
|
bool replace_succ = manager->Replace(anf_node, cnode->input(1));
|
||||||
if (!replace_succ) {
|
if (!replace_succ) {
|
||||||
|
@ -287,6 +294,23 @@ int RemoveRedundantOpPass::RemoveInvalidPadOp(const AnfNodePtr &anf_node, const
|
||||||
return lite::RET_OK;
|
return lite::RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int RemoveRedundantOpPass::RemoveInvalidTransposeOp(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) {
|
||||||
|
auto cnode = anf_node->cast<CNodePtr>();
|
||||||
|
if (cnode->size() != kInputTripleNum) {
|
||||||
|
MS_LOG(DEBUG) << "The node inputs size is bigger than 2";
|
||||||
|
return lite::RET_NO_CHANGE;
|
||||||
|
}
|
||||||
|
auto index_node = cnode->inputs()[2]->cast<ParameterPtr>();
|
||||||
|
if (index_node == nullptr) {
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
auto tensor_info = std::dynamic_pointer_cast<tensor::Tensor>(index_node->default_param());
|
||||||
|
if (tensor_info->Size() != 0) {
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
return ReplaceOp(anf_node, manager);
|
||||||
|
}
|
||||||
|
|
||||||
bool RemoveRedundantOpPass::Run(const FuncGraphPtr &func_graph) {
|
bool RemoveRedundantOpPass::Run(const FuncGraphPtr &func_graph) {
|
||||||
MS_ASSERT(func_graph != nullptr);
|
MS_ASSERT(func_graph != nullptr);
|
||||||
auto manager = func_graph->manager();
|
auto manager = func_graph->manager();
|
||||||
|
@ -315,6 +339,9 @@ bool RemoveRedundantOpPass::Run(const FuncGraphPtr &func_graph) {
|
||||||
if (CheckPrimitiveType(node, prim::kPrimPadFusion)) {
|
if (CheckPrimitiveType(node, prim::kPrimPadFusion)) {
|
||||||
status = RemoveInvalidPadOp(node, manager);
|
status = RemoveInvalidPadOp(node, manager);
|
||||||
}
|
}
|
||||||
|
if (CheckPrimitiveType(node, prim::kPrimTranspose)) {
|
||||||
|
status = RemoveInvalidTransposeOp(node, manager);
|
||||||
|
}
|
||||||
if (CheckPrimitiveType(node, prim::kPrimIf) || CheckPrimitiveType(node, prim::kPrimWhile)) {
|
if (CheckPrimitiveType(node, prim::kPrimIf) || CheckPrimitiveType(node, prim::kPrimWhile)) {
|
||||||
auto sub_func_graph = GetValueNode<FuncGraphPtr>(node->cast<CNodePtr>()->input(1));
|
auto sub_func_graph = GetValueNode<FuncGraphPtr>(node->cast<CNodePtr>()->input(1));
|
||||||
if (sub_func_graph == nullptr) {
|
if (sub_func_graph == nullptr) {
|
||||||
|
|
|
@ -33,6 +33,7 @@ class RemoveRedundantOpPass : public Pass {
|
||||||
int ReplaceTupleGetItem(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager);
|
int ReplaceTupleGetItem(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager);
|
||||||
int RemoveDropoutOp(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager);
|
int RemoveDropoutOp(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager);
|
||||||
int RemoveInvalidPadOp(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager);
|
int RemoveInvalidPadOp(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager);
|
||||||
|
int RemoveInvalidTransposeOp(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager);
|
||||||
bool Run(const FuncGraphPtr &graph) override;
|
bool Run(const FuncGraphPtr &graph) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
Loading…
Reference in New Issue