forked from mindspore-Ecosystem/mindspore
remove redundant phi
This commit is contained in:
parent
bbcefa731d
commit
acbccea644
|
@ -52,9 +52,10 @@ std::string GetNodeRepr(AnfNodePtr node) {
|
|||
|
||||
void ResolveFuncGraph_(const FuncGraphPtr &fg) {
|
||||
auto manager = Manage(fg, false);
|
||||
auto use_sig = parse::python_adapter::UseSignatureInResolve();
|
||||
parse::python_adapter::set_use_signature_in_resolve(false);
|
||||
parse::ResolveAll(manager);
|
||||
parse::python_adapter::set_use_signature_in_resolve(true);
|
||||
parse::python_adapter::set_use_signature_in_resolve(use_sig);
|
||||
}
|
||||
|
||||
bool Match(const AnfNodePtr &pattern, const AnfNodePtr &node, const NodeEquivPtr &equiv_ptr) {
|
||||
|
|
|
@ -145,6 +145,12 @@ AnfNodePtr FunctionBlock::MakeResolve(const NameSpacePtr &name_space, const Symb
|
|||
void FunctionBlock::SetPhiArgument(const ParameterPtr &phi) {
|
||||
std::string var = phi_nodes_[phi];
|
||||
MS_LOG(DEBUG) << "graph " << func_graph_->ToString() << " set phi " << phi->ToString() << " for var " << var;
|
||||
auto removable = CollectRemovablePhi(phi);
|
||||
// If the phi node is not necessary, not need to add to jumps_ of the prev blocks.
|
||||
if (removable) {
|
||||
MS_LOG(DEBUG) << "remove the phi when call graph " << func_graph_->ToString() << " var " << var;
|
||||
return;
|
||||
}
|
||||
for (auto &pred : prev_blocks_) {
|
||||
MS_EXCEPTION_IF_NULL(pred);
|
||||
MS_LOG(DEBUG) << "graph " << func_graph_->ToString() << " pred_blocks_ " << pred->func_graph_->ToString();
|
||||
|
@ -152,16 +158,6 @@ void FunctionBlock::SetPhiArgument(const ParameterPtr &phi) {
|
|||
CNodePtr jump = pred->jumps_[this];
|
||||
jump->add_input(arg_node);
|
||||
}
|
||||
// If the phi node in the body part of a for/while loop is being removed,
|
||||
// then the closure convert phase will generate a cycle in graph if the
|
||||
// loop is kept after specialization. This should be investigate further.
|
||||
// Just now user has to set a flag on a function to indicate the for loop
|
||||
// will definitely can be unroll as the sequence in for statement is fixed
|
||||
// size in compile time.
|
||||
if (parser_.func_graph()->has_flag(GRAPH_FLAG_LOOP_CAN_UNROLL) ||
|
||||
parser_.func_graph()->has_flag(GRAPH_FLAG_HAS_EFFECT)) {
|
||||
CollectRemovablePhi(phi);
|
||||
}
|
||||
}
|
||||
|
||||
AnfNodePtr FunctionBlock::SearchReplaceNode(const std::string &var, const ParameterPtr &phi) {
|
||||
|
@ -207,13 +203,13 @@ AnfNodePtr FunctionBlock::SearchReplaceNode(const std::string &var, const Parame
|
|||
// 2. it's costly to iterate the graph to replace the phi for each phi.
|
||||
// Args :
|
||||
// phi : This parameter node is functioning as a phi node.
|
||||
void FunctionBlock::CollectRemovablePhi(const ParameterPtr &phi) {
|
||||
bool FunctionBlock::CollectRemovablePhi(const ParameterPtr &phi) {
|
||||
MS_EXCEPTION_IF_NULL(phi);
|
||||
std::string var = phi_nodes_[phi];
|
||||
MS_LOG(DEBUG) << "check phi " << phi->ToString() << " for " << var << " in graph " << func_graph_->ToString();
|
||||
MS_LOG(DEBUG) << "check phi " << phi->DebugString() << " for " << var;
|
||||
if (prev_blocks_.size() == 0) {
|
||||
MS_LOG(DEBUG) << "no phi " << phi->ToString() << " for var " << var << " in graph " << func_graph_->ToString();
|
||||
return;
|
||||
MS_LOG(DEBUG) << "no phi " << phi->DebugString() << " for var " << var;
|
||||
return false;
|
||||
}
|
||||
AnfNodePtr arg_node = SearchReplaceNode(var, phi);
|
||||
if (arg_node != nullptr) {
|
||||
|
@ -235,13 +231,16 @@ void FunctionBlock::CollectRemovablePhi(const ParameterPtr &phi) {
|
|||
const auto ¶m = phi_iter.second->cast<ParameterPtr>();
|
||||
if (param == phi) {
|
||||
MS_LOG(DEBUG) << "graph " << prev->func_graph_->ToString() << " var " << phi_iter.first->DebugString()
|
||||
<< " can be replaced from " << param->DebugString() << " with " << arg_node->DebugString();
|
||||
<< " can be replaced from " << param->DebugString() << " with " << arg_node->DebugString()
|
||||
<< " in graph " << arg_node->func_graph()->ToString();
|
||||
prev->removable_phis_[phi_iter.first] = arg_node;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// A block should be marked matured if its predecessor blocks have been processed
|
||||
|
|
|
@ -52,7 +52,7 @@ class FunctionBlock : public std::enable_shared_from_this<FunctionBlock> {
|
|||
AnfNodePtr ReadVariable(const std::string &var_name);
|
||||
void AddPrevBlock(const FunctionBlockPtr &block);
|
||||
void SetPhiArgument(const ParameterPtr &phi);
|
||||
void CollectRemovablePhi(const ParameterPtr &phi);
|
||||
bool CollectRemovablePhi(const ParameterPtr &phi);
|
||||
// A block is matured if all its predecessors is generated
|
||||
void Mature();
|
||||
CNodePtr ForceToBoolNode(const AnfNodePtr &cond);
|
||||
|
|
|
@ -1436,6 +1436,15 @@ FunctionBlockPtr Parser::ParsePass(const FunctionBlockPtr &block, const py::obje
|
|||
return block;
|
||||
}
|
||||
|
||||
AnfNodePtr FindPhis(const std::unordered_map<ParameterPtr, AnfNodePtr> &removable_phis, const AnfNodePtr &node) {
|
||||
const auto &inp = node->cast<ParameterPtr>();
|
||||
const auto &iter = removable_phis.find(inp);
|
||||
if (iter == removable_phis.end()) {
|
||||
return node;
|
||||
}
|
||||
return FindPhis(removable_phis, iter->second);
|
||||
}
|
||||
|
||||
void Parser::RemoveUnnecessaryPhis() {
|
||||
// merge all removable phis to one map;
|
||||
std::unordered_map<ParameterPtr, AnfNodePtr> removable_phis;
|
||||
|
@ -1443,28 +1452,39 @@ void Parser::RemoveUnnecessaryPhis() {
|
|||
MS_EXCEPTION_IF_NULL(block);
|
||||
removable_phis.insert(block->removable_phis().begin(), block->removable_phis().end());
|
||||
}
|
||||
|
||||
if (removable_phis.size() == 0) {
|
||||
return;
|
||||
}
|
||||
for (auto &node : DeepUsedGraphSearch(func_graph_->get_return())) {
|
||||
if (node->isa<CNode>()) {
|
||||
const auto &cnode = node->cast<CNodePtr>();
|
||||
auto &inputs = cnode->inputs();
|
||||
for (std::size_t i = 0; i < inputs.size(); i++) {
|
||||
if (inputs[i]->isa<Parameter>()) {
|
||||
const auto &inp = inputs[i]->cast<ParameterPtr>();
|
||||
const auto &iter = removable_phis.find(inp);
|
||||
if (iter == removable_phis.end()) {
|
||||
continue;
|
||||
}
|
||||
auto &argNode = iter->second;
|
||||
MS_LOG(DEBUG) << "graph " << cnode->func_graph()->ToString() << " replace phi " << inp->ToString() << " in "
|
||||
<< cnode->DebugString() << " with " << argNode->DebugString();
|
||||
cnode->set_input(i, argNode);
|
||||
}
|
||||
}
|
||||
|
||||
auto fg_name = func_graph_->ToString();
|
||||
auto mng = Manage(func_graph_, false);
|
||||
// replace the nodes
|
||||
for (auto iter : removable_phis) {
|
||||
auto new_node = FindPhis(removable_phis, iter.first);
|
||||
MS_LOG(DEBUG) << "phi " << iter.first->DebugString() << " to " << new_node->DebugString();
|
||||
mng->Replace(iter.first, new_node);
|
||||
}
|
||||
// remove the parameter
|
||||
for (FunctionBlockPtr &block : func_block_list_) {
|
||||
MS_EXCEPTION_IF_NULL(block);
|
||||
auto &local_removable_phis = block->removable_phis();
|
||||
if (local_removable_phis.size() == 0) {
|
||||
continue;
|
||||
}
|
||||
auto func_graph = block->func_graph();
|
||||
auto ¶meters = func_graph->parameters();
|
||||
std::vector<AnfNodePtr> new_parameters(parameters.size());
|
||||
auto it = std::copy_if(
|
||||
parameters.begin(), parameters.end(), new_parameters.begin(), [&local_removable_phis](AnfNodePtr param) {
|
||||
return local_removable_phis.find(param->cast<ParameterPtr>()) == local_removable_phis.end();
|
||||
});
|
||||
|
||||
// shrink container to new size
|
||||
new_parameters.resize(std::distance(new_parameters.begin(), it));
|
||||
func_graph->set_parameters(new_parameters);
|
||||
}
|
||||
for (auto fg : mng->func_graphs()) {
|
||||
fg->ClearAllManagerInfo();
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -111,6 +111,27 @@ std::vector<CNodePtr> BroadFirstSearchGraphCNodes(CNodePtr ret) {
|
|||
return sorted_nodes;
|
||||
}
|
||||
|
||||
std::vector<FuncGraphPtr> BroadFirstSearchGraphUsed(FuncGraphPtr root) {
|
||||
std::deque<FuncGraphPtr> todo;
|
||||
todo.push_back(root);
|
||||
std::vector<FuncGraphPtr> sorted;
|
||||
auto seen = NewSeenGeneration();
|
||||
while (!todo.empty()) {
|
||||
FuncGraphPtr top = todo.front();
|
||||
todo.pop_front();
|
||||
sorted.push_back(top);
|
||||
auto used = top->func_graphs_used();
|
||||
for (auto &item : used) {
|
||||
if (item.first->seen_ == seen) {
|
||||
continue;
|
||||
}
|
||||
todo.push_back(item.first);
|
||||
item.first->seen_ = seen;
|
||||
}
|
||||
}
|
||||
return sorted;
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> SuccDeeper(const AnfNodePtr &node) {
|
||||
std::vector<AnfNodePtr> vecs;
|
||||
if (node == nullptr) {
|
||||
|
|
|
@ -70,6 +70,7 @@ std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &root, const SuccFunc &succ =
|
|||
const IncludeFunc &include = AlwaysInclude);
|
||||
|
||||
std::vector<CNodePtr> BroadFirstSearchGraphCNodes(CNodePtr ret);
|
||||
std::vector<FuncGraphPtr> BroadFirstSearchGraphUsed(FuncGraphPtr root);
|
||||
class FuncGraphIndex {
|
||||
public:
|
||||
explicit FuncGraphIndex(const FuncGraphPtr &fg, const SearchFunc &search = DeepScopedGraphSearch,
|
||||
|
|
|
@ -77,6 +77,17 @@ std::string CNode::DebugString(int recursive_level) const {
|
|||
return buffer.str();
|
||||
}
|
||||
|
||||
std::string Parameter::DebugString(int recursive_level) const {
|
||||
std::ostringstream buffer;
|
||||
if (recursive_level > 0) {
|
||||
if (func_graph() != nullptr) {
|
||||
buffer << func_graph()->ToString() << ":";
|
||||
}
|
||||
}
|
||||
buffer << ToString();
|
||||
return buffer.str();
|
||||
}
|
||||
|
||||
std::string ValueNode::ToString() const {
|
||||
MS_EXCEPTION_IF_NULL(value_);
|
||||
if (value_->isa<FuncGraph>()) {
|
||||
|
|
|
@ -249,7 +249,7 @@ class Parameter : public ANode {
|
|||
MS_DECLARE_PARENT(Parameter, ANode);
|
||||
|
||||
void accept(AnfVisitor *v) override;
|
||||
|
||||
std::string DebugString(int recursive_level = 1) const override;
|
||||
std::string name() const { return name_; }
|
||||
void set_name(const std::string &name) { name_ = name; }
|
||||
std::string fullname_with_scope() override { return name(); };
|
||||
|
|
|
@ -417,6 +417,15 @@ std::shared_ptr<std::list<FuncGraphPtr>> FuncGraph::recursive_graphs() {
|
|||
return mng->recursive_graphs(shared_from_base<FuncGraph>());
|
||||
}
|
||||
|
||||
void FuncGraph::ClearAllManagerInfo() {
|
||||
ClearNodes();
|
||||
ClearValueNodes();
|
||||
ClearFuncGraphCNodesIndex();
|
||||
ClearFreeVariables();
|
||||
ClearFuncGraphsUsed();
|
||||
ClearJFuncGraphs();
|
||||
}
|
||||
|
||||
AnfNodePtr FuncGraph::GetDefaultValueByName(const std::string &name) {
|
||||
auto itr = this->parameter_default_value_.find(name);
|
||||
if (itr == parameter_default_value_.end()) {
|
||||
|
|
|
@ -229,7 +229,8 @@ class FuncGraph : public FuncGraphBase {
|
|||
}
|
||||
this->debug_info_ = info;
|
||||
}
|
||||
|
||||
// clear all info from manager
|
||||
void ClearAllManagerInfo();
|
||||
// get all nodes belonging to this func graph
|
||||
const AnfNodeSet &nodes();
|
||||
void CopyNodes(const FuncGraphPtr &source);
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
#include "utils/log_adapter.h"
|
||||
#include "utils/profile.h"
|
||||
#include "utils/context/ms_context.h"
|
||||
#include "utils/graph_utils.h"
|
||||
|
||||
// namespace to support intermediate representation definition
|
||||
namespace mindspore {
|
||||
|
@ -400,11 +401,16 @@ void Cloner::LiftParameters(const FuncGraphPtr &func_graph_user, const FuncGraph
|
|||
}
|
||||
|
||||
void Cloner::Lift() {
|
||||
for (auto &func_graph_params : repl_func_graph_params_) {
|
||||
auto &func_graph = func_graph_params.first;
|
||||
auto ¶ms = func_graph_params.second;
|
||||
for (auto &cnode : func_graph->func_graph_cnodes_index()) {
|
||||
LiftParameters(cnode.first->first->func_graph(), func_graph, params);
|
||||
// lift inner graph first
|
||||
auto sorted = BroadFirstSearchGraphUsed(*(manager_->roots().begin()));
|
||||
for (auto r_iter = sorted.rbegin(); r_iter != sorted.rend(); ++r_iter) {
|
||||
auto func_graph = *r_iter;
|
||||
auto iter = repl_func_graph_params_.find(func_graph);
|
||||
if (iter != repl_func_graph_params_.end()) {
|
||||
auto ¶ms = iter->second;
|
||||
for (auto &cnode : func_graph->func_graph_cnodes_index()) {
|
||||
LiftParameters(cnode.first->first->func_graph(), func_graph, params);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -520,12 +520,7 @@ void FuncGraphManager::MoveAllNodes(FuncGraphPtr source, FuncGraphPtr target) {
|
|||
target->CopyFuncGraphsUsed(source);
|
||||
target->CopyJFuncGraphs(source);
|
||||
signals_->InvalidateComputer();
|
||||
source->ClearNodes();
|
||||
source->ClearValueNodes();
|
||||
source->ClearFuncGraphCNodesIndex();
|
||||
source->ClearFreeVariables();
|
||||
source->ClearFuncGraphsUsed();
|
||||
source->ClearJFuncGraphs();
|
||||
source->ClearAllManagerInfo();
|
||||
}
|
||||
|
||||
FuncGraphTransaction FuncGraphManager::Transact() {
|
||||
|
|
|
@ -72,6 +72,7 @@ class PyFuncGraphFetcher {
|
|||
mindspore::FuncGraphPtr func_graph = mindspore::parse::ParsePythonCode(fn);
|
||||
if (doResolve_) {
|
||||
std::shared_ptr<mindspore::FuncGraphManager> manager = mindspore::Manage(func_graph, false);
|
||||
mindspore::parse::python_adapter::set_use_signature_in_resolve(false);
|
||||
mindspore::parse::ResolveAll(manager);
|
||||
}
|
||||
return func_graph;
|
||||
|
|
|
@ -131,3 +131,26 @@ def test_while_in_while():
|
|||
output = while_in_while(c1, c2, c3)
|
||||
expect = Tensor([1274], mstype.int32)
|
||||
assert output == expect
|
||||
|
||||
|
||||
@ms_function
|
||||
def while_by_while_in_while(x, y, z):
|
||||
out = c4
|
||||
while x < c2:
|
||||
y = c4 + c4
|
||||
while y < c2:
|
||||
y = y + 1
|
||||
out = out + y
|
||||
z = c4 + c4
|
||||
while z < c2:
|
||||
z = z + 1
|
||||
out = out + z
|
||||
x = x + 1
|
||||
out = out + x
|
||||
return out
|
||||
|
||||
|
||||
def test_while_by_while_in_while():
|
||||
output = while_by_while_in_while(c1, c2, c3)
|
||||
expect = Tensor([350], mstype.int32)
|
||||
assert output == expect
|
||||
|
|
Loading…
Reference in New Issue