forked from mindspore-Ecosystem/mindspore
remove redundant phi
This commit is contained in:
parent
bbcefa731d
commit
acbccea644
|
@ -52,9 +52,10 @@ std::string GetNodeRepr(AnfNodePtr node) {
|
||||||
|
|
||||||
void ResolveFuncGraph_(const FuncGraphPtr &fg) {
|
void ResolveFuncGraph_(const FuncGraphPtr &fg) {
|
||||||
auto manager = Manage(fg, false);
|
auto manager = Manage(fg, false);
|
||||||
|
auto use_sig = parse::python_adapter::UseSignatureInResolve();
|
||||||
parse::python_adapter::set_use_signature_in_resolve(false);
|
parse::python_adapter::set_use_signature_in_resolve(false);
|
||||||
parse::ResolveAll(manager);
|
parse::ResolveAll(manager);
|
||||||
parse::python_adapter::set_use_signature_in_resolve(true);
|
parse::python_adapter::set_use_signature_in_resolve(use_sig);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Match(const AnfNodePtr &pattern, const AnfNodePtr &node, const NodeEquivPtr &equiv_ptr) {
|
bool Match(const AnfNodePtr &pattern, const AnfNodePtr &node, const NodeEquivPtr &equiv_ptr) {
|
||||||
|
|
|
@ -145,6 +145,12 @@ AnfNodePtr FunctionBlock::MakeResolve(const NameSpacePtr &name_space, const Symb
|
||||||
void FunctionBlock::SetPhiArgument(const ParameterPtr &phi) {
|
void FunctionBlock::SetPhiArgument(const ParameterPtr &phi) {
|
||||||
std::string var = phi_nodes_[phi];
|
std::string var = phi_nodes_[phi];
|
||||||
MS_LOG(DEBUG) << "graph " << func_graph_->ToString() << " set phi " << phi->ToString() << " for var " << var;
|
MS_LOG(DEBUG) << "graph " << func_graph_->ToString() << " set phi " << phi->ToString() << " for var " << var;
|
||||||
|
auto removable = CollectRemovablePhi(phi);
|
||||||
|
// If the phi node is not necessary, not need to add to jumps_ of the prev blocks.
|
||||||
|
if (removable) {
|
||||||
|
MS_LOG(DEBUG) << "remove the phi when call graph " << func_graph_->ToString() << " var " << var;
|
||||||
|
return;
|
||||||
|
}
|
||||||
for (auto &pred : prev_blocks_) {
|
for (auto &pred : prev_blocks_) {
|
||||||
MS_EXCEPTION_IF_NULL(pred);
|
MS_EXCEPTION_IF_NULL(pred);
|
||||||
MS_LOG(DEBUG) << "graph " << func_graph_->ToString() << " pred_blocks_ " << pred->func_graph_->ToString();
|
MS_LOG(DEBUG) << "graph " << func_graph_->ToString() << " pred_blocks_ " << pred->func_graph_->ToString();
|
||||||
|
@ -152,16 +158,6 @@ void FunctionBlock::SetPhiArgument(const ParameterPtr &phi) {
|
||||||
CNodePtr jump = pred->jumps_[this];
|
CNodePtr jump = pred->jumps_[this];
|
||||||
jump->add_input(arg_node);
|
jump->add_input(arg_node);
|
||||||
}
|
}
|
||||||
// If the phi node in the body part of a for/while loop is being removed,
|
|
||||||
// then the closure convert phase will generate a cycle in graph if the
|
|
||||||
// loop is kept after specialization. This should be investigate further.
|
|
||||||
// Just now user has to set a flag on a function to indicate the for loop
|
|
||||||
// will definitely can be unroll as the sequence in for statement is fixed
|
|
||||||
// size in compile time.
|
|
||||||
if (parser_.func_graph()->has_flag(GRAPH_FLAG_LOOP_CAN_UNROLL) ||
|
|
||||||
parser_.func_graph()->has_flag(GRAPH_FLAG_HAS_EFFECT)) {
|
|
||||||
CollectRemovablePhi(phi);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
AnfNodePtr FunctionBlock::SearchReplaceNode(const std::string &var, const ParameterPtr &phi) {
|
AnfNodePtr FunctionBlock::SearchReplaceNode(const std::string &var, const ParameterPtr &phi) {
|
||||||
|
@ -207,13 +203,13 @@ AnfNodePtr FunctionBlock::SearchReplaceNode(const std::string &var, const Parame
|
||||||
// 2. it's costly to iterate the graph to replace the phi for each phi.
|
// 2. it's costly to iterate the graph to replace the phi for each phi.
|
||||||
// Args :
|
// Args :
|
||||||
// phi : This parameter node is functioning as a phi node.
|
// phi : This parameter node is functioning as a phi node.
|
||||||
void FunctionBlock::CollectRemovablePhi(const ParameterPtr &phi) {
|
bool FunctionBlock::CollectRemovablePhi(const ParameterPtr &phi) {
|
||||||
MS_EXCEPTION_IF_NULL(phi);
|
MS_EXCEPTION_IF_NULL(phi);
|
||||||
std::string var = phi_nodes_[phi];
|
std::string var = phi_nodes_[phi];
|
||||||
MS_LOG(DEBUG) << "check phi " << phi->ToString() << " for " << var << " in graph " << func_graph_->ToString();
|
MS_LOG(DEBUG) << "check phi " << phi->DebugString() << " for " << var;
|
||||||
if (prev_blocks_.size() == 0) {
|
if (prev_blocks_.size() == 0) {
|
||||||
MS_LOG(DEBUG) << "no phi " << phi->ToString() << " for var " << var << " in graph " << func_graph_->ToString();
|
MS_LOG(DEBUG) << "no phi " << phi->DebugString() << " for var " << var;
|
||||||
return;
|
return false;
|
||||||
}
|
}
|
||||||
AnfNodePtr arg_node = SearchReplaceNode(var, phi);
|
AnfNodePtr arg_node = SearchReplaceNode(var, phi);
|
||||||
if (arg_node != nullptr) {
|
if (arg_node != nullptr) {
|
||||||
|
@ -235,13 +231,16 @@ void FunctionBlock::CollectRemovablePhi(const ParameterPtr &phi) {
|
||||||
const auto ¶m = phi_iter.second->cast<ParameterPtr>();
|
const auto ¶m = phi_iter.second->cast<ParameterPtr>();
|
||||||
if (param == phi) {
|
if (param == phi) {
|
||||||
MS_LOG(DEBUG) << "graph " << prev->func_graph_->ToString() << " var " << phi_iter.first->DebugString()
|
MS_LOG(DEBUG) << "graph " << prev->func_graph_->ToString() << " var " << phi_iter.first->DebugString()
|
||||||
<< " can be replaced from " << param->DebugString() << " with " << arg_node->DebugString();
|
<< " can be replaced from " << param->DebugString() << " with " << arg_node->DebugString()
|
||||||
|
<< " in graph " << arg_node->func_graph()->ToString();
|
||||||
prev->removable_phis_[phi_iter.first] = arg_node;
|
prev->removable_phis_[phi_iter.first] = arg_node;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
// A block should be marked matured if its predecessor blocks have been processed
|
// A block should be marked matured if its predecessor blocks have been processed
|
||||||
|
|
|
@ -52,7 +52,7 @@ class FunctionBlock : public std::enable_shared_from_this<FunctionBlock> {
|
||||||
AnfNodePtr ReadVariable(const std::string &var_name);
|
AnfNodePtr ReadVariable(const std::string &var_name);
|
||||||
void AddPrevBlock(const FunctionBlockPtr &block);
|
void AddPrevBlock(const FunctionBlockPtr &block);
|
||||||
void SetPhiArgument(const ParameterPtr &phi);
|
void SetPhiArgument(const ParameterPtr &phi);
|
||||||
void CollectRemovablePhi(const ParameterPtr &phi);
|
bool CollectRemovablePhi(const ParameterPtr &phi);
|
||||||
// A block is matured if all its predecessors is generated
|
// A block is matured if all its predecessors is generated
|
||||||
void Mature();
|
void Mature();
|
||||||
CNodePtr ForceToBoolNode(const AnfNodePtr &cond);
|
CNodePtr ForceToBoolNode(const AnfNodePtr &cond);
|
||||||
|
|
|
@ -1436,6 +1436,15 @@ FunctionBlockPtr Parser::ParsePass(const FunctionBlockPtr &block, const py::obje
|
||||||
return block;
|
return block;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
AnfNodePtr FindPhis(const std::unordered_map<ParameterPtr, AnfNodePtr> &removable_phis, const AnfNodePtr &node) {
|
||||||
|
const auto &inp = node->cast<ParameterPtr>();
|
||||||
|
const auto &iter = removable_phis.find(inp);
|
||||||
|
if (iter == removable_phis.end()) {
|
||||||
|
return node;
|
||||||
|
}
|
||||||
|
return FindPhis(removable_phis, iter->second);
|
||||||
|
}
|
||||||
|
|
||||||
void Parser::RemoveUnnecessaryPhis() {
|
void Parser::RemoveUnnecessaryPhis() {
|
||||||
// merge all removable phis to one map;
|
// merge all removable phis to one map;
|
||||||
std::unordered_map<ParameterPtr, AnfNodePtr> removable_phis;
|
std::unordered_map<ParameterPtr, AnfNodePtr> removable_phis;
|
||||||
|
@ -1443,28 +1452,39 @@ void Parser::RemoveUnnecessaryPhis() {
|
||||||
MS_EXCEPTION_IF_NULL(block);
|
MS_EXCEPTION_IF_NULL(block);
|
||||||
removable_phis.insert(block->removable_phis().begin(), block->removable_phis().end());
|
removable_phis.insert(block->removable_phis().begin(), block->removable_phis().end());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (removable_phis.size() == 0) {
|
if (removable_phis.size() == 0) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
for (auto &node : DeepUsedGraphSearch(func_graph_->get_return())) {
|
|
||||||
if (node->isa<CNode>()) {
|
auto fg_name = func_graph_->ToString();
|
||||||
const auto &cnode = node->cast<CNodePtr>();
|
auto mng = Manage(func_graph_, false);
|
||||||
auto &inputs = cnode->inputs();
|
// replace the nodes
|
||||||
for (std::size_t i = 0; i < inputs.size(); i++) {
|
for (auto iter : removable_phis) {
|
||||||
if (inputs[i]->isa<Parameter>()) {
|
auto new_node = FindPhis(removable_phis, iter.first);
|
||||||
const auto &inp = inputs[i]->cast<ParameterPtr>();
|
MS_LOG(DEBUG) << "phi " << iter.first->DebugString() << " to " << new_node->DebugString();
|
||||||
const auto &iter = removable_phis.find(inp);
|
mng->Replace(iter.first, new_node);
|
||||||
if (iter == removable_phis.end()) {
|
}
|
||||||
continue;
|
// remove the parameter
|
||||||
}
|
for (FunctionBlockPtr &block : func_block_list_) {
|
||||||
auto &argNode = iter->second;
|
MS_EXCEPTION_IF_NULL(block);
|
||||||
MS_LOG(DEBUG) << "graph " << cnode->func_graph()->ToString() << " replace phi " << inp->ToString() << " in "
|
auto &local_removable_phis = block->removable_phis();
|
||||||
<< cnode->DebugString() << " with " << argNode->DebugString();
|
if (local_removable_phis.size() == 0) {
|
||||||
cnode->set_input(i, argNode);
|
continue;
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
auto func_graph = block->func_graph();
|
||||||
|
auto ¶meters = func_graph->parameters();
|
||||||
|
std::vector<AnfNodePtr> new_parameters(parameters.size());
|
||||||
|
auto it = std::copy_if(
|
||||||
|
parameters.begin(), parameters.end(), new_parameters.begin(), [&local_removable_phis](AnfNodePtr param) {
|
||||||
|
return local_removable_phis.find(param->cast<ParameterPtr>()) == local_removable_phis.end();
|
||||||
|
});
|
||||||
|
|
||||||
|
// shrink container to new size
|
||||||
|
new_parameters.resize(std::distance(new_parameters.begin(), it));
|
||||||
|
func_graph->set_parameters(new_parameters);
|
||||||
|
}
|
||||||
|
for (auto fg : mng->func_graphs()) {
|
||||||
|
fg->ClearAllManagerInfo();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -111,6 +111,27 @@ std::vector<CNodePtr> BroadFirstSearchGraphCNodes(CNodePtr ret) {
|
||||||
return sorted_nodes;
|
return sorted_nodes;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<FuncGraphPtr> BroadFirstSearchGraphUsed(FuncGraphPtr root) {
|
||||||
|
std::deque<FuncGraphPtr> todo;
|
||||||
|
todo.push_back(root);
|
||||||
|
std::vector<FuncGraphPtr> sorted;
|
||||||
|
auto seen = NewSeenGeneration();
|
||||||
|
while (!todo.empty()) {
|
||||||
|
FuncGraphPtr top = todo.front();
|
||||||
|
todo.pop_front();
|
||||||
|
sorted.push_back(top);
|
||||||
|
auto used = top->func_graphs_used();
|
||||||
|
for (auto &item : used) {
|
||||||
|
if (item.first->seen_ == seen) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
todo.push_back(item.first);
|
||||||
|
item.first->seen_ = seen;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return sorted;
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<AnfNodePtr> SuccDeeper(const AnfNodePtr &node) {
|
std::vector<AnfNodePtr> SuccDeeper(const AnfNodePtr &node) {
|
||||||
std::vector<AnfNodePtr> vecs;
|
std::vector<AnfNodePtr> vecs;
|
||||||
if (node == nullptr) {
|
if (node == nullptr) {
|
||||||
|
|
|
@ -70,6 +70,7 @@ std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &root, const SuccFunc &succ =
|
||||||
const IncludeFunc &include = AlwaysInclude);
|
const IncludeFunc &include = AlwaysInclude);
|
||||||
|
|
||||||
std::vector<CNodePtr> BroadFirstSearchGraphCNodes(CNodePtr ret);
|
std::vector<CNodePtr> BroadFirstSearchGraphCNodes(CNodePtr ret);
|
||||||
|
std::vector<FuncGraphPtr> BroadFirstSearchGraphUsed(FuncGraphPtr root);
|
||||||
class FuncGraphIndex {
|
class FuncGraphIndex {
|
||||||
public:
|
public:
|
||||||
explicit FuncGraphIndex(const FuncGraphPtr &fg, const SearchFunc &search = DeepScopedGraphSearch,
|
explicit FuncGraphIndex(const FuncGraphPtr &fg, const SearchFunc &search = DeepScopedGraphSearch,
|
||||||
|
|
|
@ -77,6 +77,17 @@ std::string CNode::DebugString(int recursive_level) const {
|
||||||
return buffer.str();
|
return buffer.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string Parameter::DebugString(int recursive_level) const {
|
||||||
|
std::ostringstream buffer;
|
||||||
|
if (recursive_level > 0) {
|
||||||
|
if (func_graph() != nullptr) {
|
||||||
|
buffer << func_graph()->ToString() << ":";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
buffer << ToString();
|
||||||
|
return buffer.str();
|
||||||
|
}
|
||||||
|
|
||||||
std::string ValueNode::ToString() const {
|
std::string ValueNode::ToString() const {
|
||||||
MS_EXCEPTION_IF_NULL(value_);
|
MS_EXCEPTION_IF_NULL(value_);
|
||||||
if (value_->isa<FuncGraph>()) {
|
if (value_->isa<FuncGraph>()) {
|
||||||
|
|
|
@ -249,7 +249,7 @@ class Parameter : public ANode {
|
||||||
MS_DECLARE_PARENT(Parameter, ANode);
|
MS_DECLARE_PARENT(Parameter, ANode);
|
||||||
|
|
||||||
void accept(AnfVisitor *v) override;
|
void accept(AnfVisitor *v) override;
|
||||||
|
std::string DebugString(int recursive_level = 1) const override;
|
||||||
std::string name() const { return name_; }
|
std::string name() const { return name_; }
|
||||||
void set_name(const std::string &name) { name_ = name; }
|
void set_name(const std::string &name) { name_ = name; }
|
||||||
std::string fullname_with_scope() override { return name(); };
|
std::string fullname_with_scope() override { return name(); };
|
||||||
|
|
|
@ -417,6 +417,15 @@ std::shared_ptr<std::list<FuncGraphPtr>> FuncGraph::recursive_graphs() {
|
||||||
return mng->recursive_graphs(shared_from_base<FuncGraph>());
|
return mng->recursive_graphs(shared_from_base<FuncGraph>());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void FuncGraph::ClearAllManagerInfo() {
|
||||||
|
ClearNodes();
|
||||||
|
ClearValueNodes();
|
||||||
|
ClearFuncGraphCNodesIndex();
|
||||||
|
ClearFreeVariables();
|
||||||
|
ClearFuncGraphsUsed();
|
||||||
|
ClearJFuncGraphs();
|
||||||
|
}
|
||||||
|
|
||||||
AnfNodePtr FuncGraph::GetDefaultValueByName(const std::string &name) {
|
AnfNodePtr FuncGraph::GetDefaultValueByName(const std::string &name) {
|
||||||
auto itr = this->parameter_default_value_.find(name);
|
auto itr = this->parameter_default_value_.find(name);
|
||||||
if (itr == parameter_default_value_.end()) {
|
if (itr == parameter_default_value_.end()) {
|
||||||
|
|
|
@ -229,7 +229,8 @@ class FuncGraph : public FuncGraphBase {
|
||||||
}
|
}
|
||||||
this->debug_info_ = info;
|
this->debug_info_ = info;
|
||||||
}
|
}
|
||||||
|
// clear all info from manager
|
||||||
|
void ClearAllManagerInfo();
|
||||||
// get all nodes belonging to this func graph
|
// get all nodes belonging to this func graph
|
||||||
const AnfNodeSet &nodes();
|
const AnfNodeSet &nodes();
|
||||||
void CopyNodes(const FuncGraphPtr &source);
|
void CopyNodes(const FuncGraphPtr &source);
|
||||||
|
|
|
@ -25,6 +25,7 @@
|
||||||
#include "utils/log_adapter.h"
|
#include "utils/log_adapter.h"
|
||||||
#include "utils/profile.h"
|
#include "utils/profile.h"
|
||||||
#include "utils/context/ms_context.h"
|
#include "utils/context/ms_context.h"
|
||||||
|
#include "utils/graph_utils.h"
|
||||||
|
|
||||||
// namespace to support intermediate representation definition
|
// namespace to support intermediate representation definition
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
@ -400,11 +401,16 @@ void Cloner::LiftParameters(const FuncGraphPtr &func_graph_user, const FuncGraph
|
||||||
}
|
}
|
||||||
|
|
||||||
void Cloner::Lift() {
|
void Cloner::Lift() {
|
||||||
for (auto &func_graph_params : repl_func_graph_params_) {
|
// lift inner graph first
|
||||||
auto &func_graph = func_graph_params.first;
|
auto sorted = BroadFirstSearchGraphUsed(*(manager_->roots().begin()));
|
||||||
auto ¶ms = func_graph_params.second;
|
for (auto r_iter = sorted.rbegin(); r_iter != sorted.rend(); ++r_iter) {
|
||||||
for (auto &cnode : func_graph->func_graph_cnodes_index()) {
|
auto func_graph = *r_iter;
|
||||||
LiftParameters(cnode.first->first->func_graph(), func_graph, params);
|
auto iter = repl_func_graph_params_.find(func_graph);
|
||||||
|
if (iter != repl_func_graph_params_.end()) {
|
||||||
|
auto ¶ms = iter->second;
|
||||||
|
for (auto &cnode : func_graph->func_graph_cnodes_index()) {
|
||||||
|
LiftParameters(cnode.first->first->func_graph(), func_graph, params);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -520,12 +520,7 @@ void FuncGraphManager::MoveAllNodes(FuncGraphPtr source, FuncGraphPtr target) {
|
||||||
target->CopyFuncGraphsUsed(source);
|
target->CopyFuncGraphsUsed(source);
|
||||||
target->CopyJFuncGraphs(source);
|
target->CopyJFuncGraphs(source);
|
||||||
signals_->InvalidateComputer();
|
signals_->InvalidateComputer();
|
||||||
source->ClearNodes();
|
source->ClearAllManagerInfo();
|
||||||
source->ClearValueNodes();
|
|
||||||
source->ClearFuncGraphCNodesIndex();
|
|
||||||
source->ClearFreeVariables();
|
|
||||||
source->ClearFuncGraphsUsed();
|
|
||||||
source->ClearJFuncGraphs();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
FuncGraphTransaction FuncGraphManager::Transact() {
|
FuncGraphTransaction FuncGraphManager::Transact() {
|
||||||
|
|
|
@ -72,6 +72,7 @@ class PyFuncGraphFetcher {
|
||||||
mindspore::FuncGraphPtr func_graph = mindspore::parse::ParsePythonCode(fn);
|
mindspore::FuncGraphPtr func_graph = mindspore::parse::ParsePythonCode(fn);
|
||||||
if (doResolve_) {
|
if (doResolve_) {
|
||||||
std::shared_ptr<mindspore::FuncGraphManager> manager = mindspore::Manage(func_graph, false);
|
std::shared_ptr<mindspore::FuncGraphManager> manager = mindspore::Manage(func_graph, false);
|
||||||
|
mindspore::parse::python_adapter::set_use_signature_in_resolve(false);
|
||||||
mindspore::parse::ResolveAll(manager);
|
mindspore::parse::ResolveAll(manager);
|
||||||
}
|
}
|
||||||
return func_graph;
|
return func_graph;
|
||||||
|
|
|
@ -131,3 +131,26 @@ def test_while_in_while():
|
||||||
output = while_in_while(c1, c2, c3)
|
output = while_in_while(c1, c2, c3)
|
||||||
expect = Tensor([1274], mstype.int32)
|
expect = Tensor([1274], mstype.int32)
|
||||||
assert output == expect
|
assert output == expect
|
||||||
|
|
||||||
|
|
||||||
|
@ms_function
|
||||||
|
def while_by_while_in_while(x, y, z):
|
||||||
|
out = c4
|
||||||
|
while x < c2:
|
||||||
|
y = c4 + c4
|
||||||
|
while y < c2:
|
||||||
|
y = y + 1
|
||||||
|
out = out + y
|
||||||
|
z = c4 + c4
|
||||||
|
while z < c2:
|
||||||
|
z = z + 1
|
||||||
|
out = out + z
|
||||||
|
x = x + 1
|
||||||
|
out = out + x
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def test_while_by_while_in_while():
|
||||||
|
output = while_by_while_in_while(c1, c2, c3)
|
||||||
|
expect = Tensor([350], mstype.int32)
|
||||||
|
assert output == expect
|
||||||
|
|
Loading…
Reference in New Issue