remove redundant phi

This commit is contained in:
Wei Luning 2020-07-15 13:04:34 +08:00
parent bbcefa731d
commit acbccea644
14 changed files with 136 additions and 48 deletions

View File

@ -52,9 +52,10 @@ std::string GetNodeRepr(AnfNodePtr node) {
void ResolveFuncGraph_(const FuncGraphPtr &fg) { void ResolveFuncGraph_(const FuncGraphPtr &fg) {
auto manager = Manage(fg, false); auto manager = Manage(fg, false);
auto use_sig = parse::python_adapter::UseSignatureInResolve();
parse::python_adapter::set_use_signature_in_resolve(false); parse::python_adapter::set_use_signature_in_resolve(false);
parse::ResolveAll(manager); parse::ResolveAll(manager);
parse::python_adapter::set_use_signature_in_resolve(true); parse::python_adapter::set_use_signature_in_resolve(use_sig);
} }
bool Match(const AnfNodePtr &pattern, const AnfNodePtr &node, const NodeEquivPtr &equiv_ptr) { bool Match(const AnfNodePtr &pattern, const AnfNodePtr &node, const NodeEquivPtr &equiv_ptr) {

View File

@ -145,6 +145,12 @@ AnfNodePtr FunctionBlock::MakeResolve(const NameSpacePtr &name_space, const Symb
void FunctionBlock::SetPhiArgument(const ParameterPtr &phi) { void FunctionBlock::SetPhiArgument(const ParameterPtr &phi) {
std::string var = phi_nodes_[phi]; std::string var = phi_nodes_[phi];
MS_LOG(DEBUG) << "graph " << func_graph_->ToString() << " set phi " << phi->ToString() << " for var " << var; MS_LOG(DEBUG) << "graph " << func_graph_->ToString() << " set phi " << phi->ToString() << " for var " << var;
auto removable = CollectRemovablePhi(phi);
// If the phi node is not necessary, not need to add to jumps_ of the prev blocks.
if (removable) {
MS_LOG(DEBUG) << "remove the phi when call graph " << func_graph_->ToString() << " var " << var;
return;
}
for (auto &pred : prev_blocks_) { for (auto &pred : prev_blocks_) {
MS_EXCEPTION_IF_NULL(pred); MS_EXCEPTION_IF_NULL(pred);
MS_LOG(DEBUG) << "graph " << func_graph_->ToString() << " pred_blocks_ " << pred->func_graph_->ToString(); MS_LOG(DEBUG) << "graph " << func_graph_->ToString() << " pred_blocks_ " << pred->func_graph_->ToString();
@ -152,16 +158,6 @@ void FunctionBlock::SetPhiArgument(const ParameterPtr &phi) {
CNodePtr jump = pred->jumps_[this]; CNodePtr jump = pred->jumps_[this];
jump->add_input(arg_node); jump->add_input(arg_node);
} }
// If the phi node in the body part of a for/while loop is being removed,
// then the closure convert phase will generate a cycle in graph if the
// loop is kept after specialization. This should be investigate further.
// Just now user has to set a flag on a function to indicate the for loop
// will definitely can be unroll as the sequence in for statement is fixed
// size in compile time.
if (parser_.func_graph()->has_flag(GRAPH_FLAG_LOOP_CAN_UNROLL) ||
parser_.func_graph()->has_flag(GRAPH_FLAG_HAS_EFFECT)) {
CollectRemovablePhi(phi);
}
} }
AnfNodePtr FunctionBlock::SearchReplaceNode(const std::string &var, const ParameterPtr &phi) { AnfNodePtr FunctionBlock::SearchReplaceNode(const std::string &var, const ParameterPtr &phi) {
@ -207,13 +203,13 @@ AnfNodePtr FunctionBlock::SearchReplaceNode(const std::string &var, const Parame
// 2. it's costly to iterate the graph to replace the phi for each phi. // 2. it's costly to iterate the graph to replace the phi for each phi.
// Args : // Args :
// phi : This parameter node is functioning as a phi node. // phi : This parameter node is functioning as a phi node.
void FunctionBlock::CollectRemovablePhi(const ParameterPtr &phi) { bool FunctionBlock::CollectRemovablePhi(const ParameterPtr &phi) {
MS_EXCEPTION_IF_NULL(phi); MS_EXCEPTION_IF_NULL(phi);
std::string var = phi_nodes_[phi]; std::string var = phi_nodes_[phi];
MS_LOG(DEBUG) << "check phi " << phi->ToString() << " for " << var << " in graph " << func_graph_->ToString(); MS_LOG(DEBUG) << "check phi " << phi->DebugString() << " for " << var;
if (prev_blocks_.size() == 0) { if (prev_blocks_.size() == 0) {
MS_LOG(DEBUG) << "no phi " << phi->ToString() << " for var " << var << " in graph " << func_graph_->ToString(); MS_LOG(DEBUG) << "no phi " << phi->DebugString() << " for var " << var;
return; return false;
} }
AnfNodePtr arg_node = SearchReplaceNode(var, phi); AnfNodePtr arg_node = SearchReplaceNode(var, phi);
if (arg_node != nullptr) { if (arg_node != nullptr) {
@ -235,13 +231,16 @@ void FunctionBlock::CollectRemovablePhi(const ParameterPtr &phi) {
const auto &param = phi_iter.second->cast<ParameterPtr>(); const auto &param = phi_iter.second->cast<ParameterPtr>();
if (param == phi) { if (param == phi) {
MS_LOG(DEBUG) << "graph " << prev->func_graph_->ToString() << " var " << phi_iter.first->DebugString() MS_LOG(DEBUG) << "graph " << prev->func_graph_->ToString() << " var " << phi_iter.first->DebugString()
<< " can be replaced from " << param->DebugString() << " with " << arg_node->DebugString(); << " can be replaced from " << param->DebugString() << " with " << arg_node->DebugString()
<< " in graph " << arg_node->func_graph()->ToString();
prev->removable_phis_[phi_iter.first] = arg_node; prev->removable_phis_[phi_iter.first] = arg_node;
} }
} }
} }
} }
return true;
} }
return false;
} }
// A block should be marked matured if its predecessor blocks have been processed // A block should be marked matured if its predecessor blocks have been processed

View File

@ -52,7 +52,7 @@ class FunctionBlock : public std::enable_shared_from_this<FunctionBlock> {
AnfNodePtr ReadVariable(const std::string &var_name); AnfNodePtr ReadVariable(const std::string &var_name);
void AddPrevBlock(const FunctionBlockPtr &block); void AddPrevBlock(const FunctionBlockPtr &block);
void SetPhiArgument(const ParameterPtr &phi); void SetPhiArgument(const ParameterPtr &phi);
void CollectRemovablePhi(const ParameterPtr &phi); bool CollectRemovablePhi(const ParameterPtr &phi);
// A block is matured if all its predecessors is generated // A block is matured if all its predecessors is generated
void Mature(); void Mature();
CNodePtr ForceToBoolNode(const AnfNodePtr &cond); CNodePtr ForceToBoolNode(const AnfNodePtr &cond);

View File

@ -1436,6 +1436,15 @@ FunctionBlockPtr Parser::ParsePass(const FunctionBlockPtr &block, const py::obje
return block; return block;
} }
AnfNodePtr FindPhis(const std::unordered_map<ParameterPtr, AnfNodePtr> &removable_phis, const AnfNodePtr &node) {
const auto &inp = node->cast<ParameterPtr>();
const auto &iter = removable_phis.find(inp);
if (iter == removable_phis.end()) {
return node;
}
return FindPhis(removable_phis, iter->second);
}
void Parser::RemoveUnnecessaryPhis() { void Parser::RemoveUnnecessaryPhis() {
// merge all removable phis to one map; // merge all removable phis to one map;
std::unordered_map<ParameterPtr, AnfNodePtr> removable_phis; std::unordered_map<ParameterPtr, AnfNodePtr> removable_phis;
@ -1443,28 +1452,39 @@ void Parser::RemoveUnnecessaryPhis() {
MS_EXCEPTION_IF_NULL(block); MS_EXCEPTION_IF_NULL(block);
removable_phis.insert(block->removable_phis().begin(), block->removable_phis().end()); removable_phis.insert(block->removable_phis().begin(), block->removable_phis().end());
} }
if (removable_phis.size() == 0) { if (removable_phis.size() == 0) {
return; return;
} }
for (auto &node : DeepUsedGraphSearch(func_graph_->get_return())) {
if (node->isa<CNode>()) { auto fg_name = func_graph_->ToString();
const auto &cnode = node->cast<CNodePtr>(); auto mng = Manage(func_graph_, false);
auto &inputs = cnode->inputs(); // replace the nodes
for (std::size_t i = 0; i < inputs.size(); i++) { for (auto iter : removable_phis) {
if (inputs[i]->isa<Parameter>()) { auto new_node = FindPhis(removable_phis, iter.first);
const auto &inp = inputs[i]->cast<ParameterPtr>(); MS_LOG(DEBUG) << "phi " << iter.first->DebugString() << " to " << new_node->DebugString();
const auto &iter = removable_phis.find(inp); mng->Replace(iter.first, new_node);
if (iter == removable_phis.end()) { }
continue; // remove the parameter
} for (FunctionBlockPtr &block : func_block_list_) {
auto &argNode = iter->second; MS_EXCEPTION_IF_NULL(block);
MS_LOG(DEBUG) << "graph " << cnode->func_graph()->ToString() << " replace phi " << inp->ToString() << " in " auto &local_removable_phis = block->removable_phis();
<< cnode->DebugString() << " with " << argNode->DebugString(); if (local_removable_phis.size() == 0) {
cnode->set_input(i, argNode); continue;
}
}
} }
auto func_graph = block->func_graph();
auto &parameters = func_graph->parameters();
std::vector<AnfNodePtr> new_parameters(parameters.size());
auto it = std::copy_if(
parameters.begin(), parameters.end(), new_parameters.begin(), [&local_removable_phis](AnfNodePtr param) {
return local_removable_phis.find(param->cast<ParameterPtr>()) == local_removable_phis.end();
});
// shrink container to new size
new_parameters.resize(std::distance(new_parameters.begin(), it));
func_graph->set_parameters(new_parameters);
}
for (auto fg : mng->func_graphs()) {
fg->ClearAllManagerInfo();
} }
} }

View File

@ -111,6 +111,27 @@ std::vector<CNodePtr> BroadFirstSearchGraphCNodes(CNodePtr ret) {
return sorted_nodes; return sorted_nodes;
} }
std::vector<FuncGraphPtr> BroadFirstSearchGraphUsed(FuncGraphPtr root) {
std::deque<FuncGraphPtr> todo;
todo.push_back(root);
std::vector<FuncGraphPtr> sorted;
auto seen = NewSeenGeneration();
while (!todo.empty()) {
FuncGraphPtr top = todo.front();
todo.pop_front();
sorted.push_back(top);
auto used = top->func_graphs_used();
for (auto &item : used) {
if (item.first->seen_ == seen) {
continue;
}
todo.push_back(item.first);
item.first->seen_ = seen;
}
}
return sorted;
}
std::vector<AnfNodePtr> SuccDeeper(const AnfNodePtr &node) { std::vector<AnfNodePtr> SuccDeeper(const AnfNodePtr &node) {
std::vector<AnfNodePtr> vecs; std::vector<AnfNodePtr> vecs;
if (node == nullptr) { if (node == nullptr) {

View File

@ -70,6 +70,7 @@ std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &root, const SuccFunc &succ =
const IncludeFunc &include = AlwaysInclude); const IncludeFunc &include = AlwaysInclude);
std::vector<CNodePtr> BroadFirstSearchGraphCNodes(CNodePtr ret); std::vector<CNodePtr> BroadFirstSearchGraphCNodes(CNodePtr ret);
std::vector<FuncGraphPtr> BroadFirstSearchGraphUsed(FuncGraphPtr root);
class FuncGraphIndex { class FuncGraphIndex {
public: public:
explicit FuncGraphIndex(const FuncGraphPtr &fg, const SearchFunc &search = DeepScopedGraphSearch, explicit FuncGraphIndex(const FuncGraphPtr &fg, const SearchFunc &search = DeepScopedGraphSearch,

View File

@ -77,6 +77,17 @@ std::string CNode::DebugString(int recursive_level) const {
return buffer.str(); return buffer.str();
} }
std::string Parameter::DebugString(int recursive_level) const {
std::ostringstream buffer;
if (recursive_level > 0) {
if (func_graph() != nullptr) {
buffer << func_graph()->ToString() << ":";
}
}
buffer << ToString();
return buffer.str();
}
std::string ValueNode::ToString() const { std::string ValueNode::ToString() const {
MS_EXCEPTION_IF_NULL(value_); MS_EXCEPTION_IF_NULL(value_);
if (value_->isa<FuncGraph>()) { if (value_->isa<FuncGraph>()) {

View File

@ -249,7 +249,7 @@ class Parameter : public ANode {
MS_DECLARE_PARENT(Parameter, ANode); MS_DECLARE_PARENT(Parameter, ANode);
void accept(AnfVisitor *v) override; void accept(AnfVisitor *v) override;
std::string DebugString(int recursive_level = 1) const override;
std::string name() const { return name_; } std::string name() const { return name_; }
void set_name(const std::string &name) { name_ = name; } void set_name(const std::string &name) { name_ = name; }
std::string fullname_with_scope() override { return name(); }; std::string fullname_with_scope() override { return name(); };

View File

@ -417,6 +417,15 @@ std::shared_ptr<std::list<FuncGraphPtr>> FuncGraph::recursive_graphs() {
return mng->recursive_graphs(shared_from_base<FuncGraph>()); return mng->recursive_graphs(shared_from_base<FuncGraph>());
} }
void FuncGraph::ClearAllManagerInfo() {
ClearNodes();
ClearValueNodes();
ClearFuncGraphCNodesIndex();
ClearFreeVariables();
ClearFuncGraphsUsed();
ClearJFuncGraphs();
}
AnfNodePtr FuncGraph::GetDefaultValueByName(const std::string &name) { AnfNodePtr FuncGraph::GetDefaultValueByName(const std::string &name) {
auto itr = this->parameter_default_value_.find(name); auto itr = this->parameter_default_value_.find(name);
if (itr == parameter_default_value_.end()) { if (itr == parameter_default_value_.end()) {

View File

@ -229,7 +229,8 @@ class FuncGraph : public FuncGraphBase {
} }
this->debug_info_ = info; this->debug_info_ = info;
} }
// clear all info from manager
void ClearAllManagerInfo();
// get all nodes belonging to this func graph // get all nodes belonging to this func graph
const AnfNodeSet &nodes(); const AnfNodeSet &nodes();
void CopyNodes(const FuncGraphPtr &source); void CopyNodes(const FuncGraphPtr &source);

View File

@ -25,6 +25,7 @@
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
#include "utils/profile.h" #include "utils/profile.h"
#include "utils/context/ms_context.h" #include "utils/context/ms_context.h"
#include "utils/graph_utils.h"
// namespace to support intermediate representation definition // namespace to support intermediate representation definition
namespace mindspore { namespace mindspore {
@ -400,11 +401,16 @@ void Cloner::LiftParameters(const FuncGraphPtr &func_graph_user, const FuncGraph
} }
void Cloner::Lift() { void Cloner::Lift() {
for (auto &func_graph_params : repl_func_graph_params_) { // lift inner graph first
auto &func_graph = func_graph_params.first; auto sorted = BroadFirstSearchGraphUsed(*(manager_->roots().begin()));
auto &params = func_graph_params.second; for (auto r_iter = sorted.rbegin(); r_iter != sorted.rend(); ++r_iter) {
for (auto &cnode : func_graph->func_graph_cnodes_index()) { auto func_graph = *r_iter;
LiftParameters(cnode.first->first->func_graph(), func_graph, params); auto iter = repl_func_graph_params_.find(func_graph);
if (iter != repl_func_graph_params_.end()) {
auto &params = iter->second;
for (auto &cnode : func_graph->func_graph_cnodes_index()) {
LiftParameters(cnode.first->first->func_graph(), func_graph, params);
}
} }
} }
} }

View File

@ -520,12 +520,7 @@ void FuncGraphManager::MoveAllNodes(FuncGraphPtr source, FuncGraphPtr target) {
target->CopyFuncGraphsUsed(source); target->CopyFuncGraphsUsed(source);
target->CopyJFuncGraphs(source); target->CopyJFuncGraphs(source);
signals_->InvalidateComputer(); signals_->InvalidateComputer();
source->ClearNodes(); source->ClearAllManagerInfo();
source->ClearValueNodes();
source->ClearFuncGraphCNodesIndex();
source->ClearFreeVariables();
source->ClearFuncGraphsUsed();
source->ClearJFuncGraphs();
} }
FuncGraphTransaction FuncGraphManager::Transact() { FuncGraphTransaction FuncGraphManager::Transact() {

View File

@ -72,6 +72,7 @@ class PyFuncGraphFetcher {
mindspore::FuncGraphPtr func_graph = mindspore::parse::ParsePythonCode(fn); mindspore::FuncGraphPtr func_graph = mindspore::parse::ParsePythonCode(fn);
if (doResolve_) { if (doResolve_) {
std::shared_ptr<mindspore::FuncGraphManager> manager = mindspore::Manage(func_graph, false); std::shared_ptr<mindspore::FuncGraphManager> manager = mindspore::Manage(func_graph, false);
mindspore::parse::python_adapter::set_use_signature_in_resolve(false);
mindspore::parse::ResolveAll(manager); mindspore::parse::ResolveAll(manager);
} }
return func_graph; return func_graph;

View File

@ -131,3 +131,26 @@ def test_while_in_while():
output = while_in_while(c1, c2, c3) output = while_in_while(c1, c2, c3)
expect = Tensor([1274], mstype.int32) expect = Tensor([1274], mstype.int32)
assert output == expect assert output == expect
@ms_function
def while_by_while_in_while(x, y, z):
out = c4
while x < c2:
y = c4 + c4
while y < c2:
y = y + 1
out = out + y
z = c4 + c4
while z < c2:
z = z + 1
out = out + z
x = x + 1
out = out + x
return out
def test_while_by_while_in_while():
output = while_by_while_in_while(c1, c2, c3)
expect = Tensor([350], mstype.int32)
assert output == expect