forked from mindspore-Ecosystem/mindspore
Never inline middle after block for control flow.
This commit is contained in:
parent
0c7ba7a7fa
commit
25a2b8cd5b
|
@ -134,12 +134,24 @@ class InlinerBase : public AnfVisitor {
|
|||
|
||||
std::vector<AnfNodePtr> args;
|
||||
(void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(args));
|
||||
// compare size to avoid the case that the function has default value after grad.
|
||||
// Compare size to avoid the case that the function has default value after grad.
|
||||
// for which after renormalize, the function default value will be an input
|
||||
if (fg->parameters().size() != args.size()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (IsUniqueUse(nullptr, fg, nullptr)) {
|
||||
// The other branch calling the last after block.
|
||||
if (fg->has_flag(FUNC_GRAPH_FLAG_AFTER_BLOCK)) {
|
||||
// Check if parameters' changed.
|
||||
auto param_simplified_caller = SimplifyAfterParameter(fg, node, args);
|
||||
if (param_simplified_caller != nullptr) {
|
||||
return param_simplified_caller;
|
||||
}
|
||||
}
|
||||
|
||||
// For the single used fg, including non-after and after not matched above,
|
||||
// we move the whole fg nodes.
|
||||
if (use_move_) {
|
||||
auto mng = fg->manager();
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
|
@ -148,10 +160,20 @@ class InlinerBase : public AnfVisitor {
|
|||
mng->MoveAllCNodeDropGraph(fg, node->func_graph(), inputs[0]->scope());
|
||||
return out_node;
|
||||
}
|
||||
} else if (fg->has_flag(FUNC_GRAPH_FLAG_AFTER_BLOCK) && GraphHasBranch(fg)) {
|
||||
// Not to inline after block if it has switch call inside, to avoid switch expansion.
|
||||
return TransformBranchCall(fg, node, args);
|
||||
} else {
|
||||
// We don't expand the middle multiple used after block, except the last one.
|
||||
if (GraphHasBranch(fg)) {
|
||||
return nullptr;
|
||||
}
|
||||
// Check if parameters' changed for the first met branch calling.
|
||||
if (fg->has_flag(FUNC_GRAPH_FLAG_AFTER_BLOCK)) {
|
||||
auto param_simplified_caller = SimplifyAfterParameter(fg, node, args);
|
||||
if (param_simplified_caller != nullptr) {
|
||||
return param_simplified_caller;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Or, just make a clone for not single used fg.
|
||||
return InlineClone(fg, node->func_graph(), args, inputs[0]->scope());
|
||||
}
|
||||
|
||||
|
@ -183,37 +205,34 @@ class InlinerBase : public AnfVisitor {
|
|||
|
||||
// For after block which contains branch call, delete the parameters which is not used.
|
||||
// In most cases, it may be a `Module` or other constant input.
|
||||
AnfNodePtr TransformBranchCall(const FuncGraphPtr &fg, const AnfNodePtr &node, const std::vector<AnfNodePtr> &args) {
|
||||
AnfNodePtr SimplifyAfterParameter(const FuncGraphPtr &fg, const AnfNodePtr &node,
|
||||
const std::vector<AnfNodePtr> &args) {
|
||||
auto &fg_params = fg->parameters();
|
||||
std::vector<int64_t> used_param_index;
|
||||
auto mng = fg->manager();
|
||||
bool should_simplify = false;
|
||||
for (size_t i = 0; i < fg_params.size(); i++) {
|
||||
if (mng->node_users()[fg_params[i]].size() != 0) {
|
||||
used_param_index.emplace_back(i);
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "Not used parameter " << fg_params[i]->DebugString() << " for calling " << fg->ToString();
|
||||
should_simplify = true;
|
||||
}
|
||||
}
|
||||
// If all parameters are used by cnodes
|
||||
if (used_param_index.size() == fg_params.size()) {
|
||||
if (!should_simplify) {
|
||||
return nullptr;
|
||||
}
|
||||
if (transformed_branch_chache_.find(fg) == transformed_branch_chache_.end()) {
|
||||
MS_LOG(DEBUG) << "Parameter not used found for graph :" << fg->ToString();
|
||||
// clone a new graph and ignore the not used parameters
|
||||
FuncGraphPtr new_fg = TransformableClone(fg);
|
||||
auto &new_fg_params = new_fg->parameters();
|
||||
std::vector<AnfNodePtr> new_params;
|
||||
std::transform(used_param_index.begin(), used_param_index.end(), std::back_inserter(new_params),
|
||||
[&new_fg_params](size_t i) { return new_fg_params[i]; });
|
||||
new_fg->set_parameters(new_params);
|
||||
// New func graph must set FUNC_GRAPH_FLAG_AFTER_BLOCK flag otherwise the new graph will be inlined.
|
||||
new_fg->set_flag(FUNC_GRAPH_FLAG_AFTER_BLOCK, true);
|
||||
// Add new graph to the cache to improve perfomance when call HasBranchCall.
|
||||
graph_branch_cache_[new_fg] = true;
|
||||
// If a graph be called at two or more locations, it should not be cloned once again, so add it to the cache.
|
||||
transformed_branch_chache_[fg] = new_fg;
|
||||
}
|
||||
MS_LOG(DEBUG) << "Parameter not used found for graph :" << fg->ToString();
|
||||
// Clone a new graph and ignore the not used parameters
|
||||
auto new_fg = TransformableClone(fg);
|
||||
auto &new_fg_params = new_fg->parameters();
|
||||
std::vector<AnfNodePtr> new_params;
|
||||
std::transform(used_param_index.begin(), used_param_index.end(), std::back_inserter(new_params),
|
||||
[&new_fg_params](size_t i) { return new_fg_params[i]; });
|
||||
new_fg->set_parameters(new_params);
|
||||
|
||||
std::vector<AnfNodePtr> node_inputs;
|
||||
node_inputs.push_back(NewValueNode(transformed_branch_chache_[fg]));
|
||||
node_inputs.push_back(NewValueNode(new_fg));
|
||||
std::transform(used_param_index.begin(), used_param_index.end(), std::back_inserter(node_inputs),
|
||||
[&args](size_t i) { return args[i]; });
|
||||
return node->func_graph()->NewCNode(node_inputs);
|
||||
|
@ -273,8 +292,6 @@ class InlinerBase : public AnfVisitor {
|
|||
bool use_move_;
|
||||
std::vector<std::vector<CriterionFuncType>> criterions_;
|
||||
std::unordered_map<FuncGraphPtr, bool> graph_branch_cache_;
|
||||
// Key is the old func graph, and the value is the new func_graph
|
||||
std::unordered_map<FuncGraphPtr, FuncGraphPtr> transformed_branch_chache_;
|
||||
};
|
||||
|
||||
bool IsUniqueUse(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &) {
|
||||
|
|
Loading…
Reference in New Issue