Never inline middle after block for control flow.

This commit is contained in:
Zhang Qinghua 2020-12-04 10:19:17 +08:00
parent 0c7ba7a7fa
commit 25a2b8cd5b
1 changed files with 43 additions and 26 deletions

View File

@ -134,12 +134,24 @@ class InlinerBase : public AnfVisitor {
std::vector<AnfNodePtr> args;
(void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(args));
// compare size to avoid the case that the function has default value after grad.
// Compare size to avoid the case that the function has default value after grad.
// for which after renormalize, the function default value will be an input
if (fg->parameters().size() != args.size()) {
return nullptr;
}
if (IsUniqueUse(nullptr, fg, nullptr)) {
// The other branch calling the last after block.
if (fg->has_flag(FUNC_GRAPH_FLAG_AFTER_BLOCK)) {
// Check if parameters' changed.
auto param_simplified_caller = SimplifyAfterParameter(fg, node, args);
if (param_simplified_caller != nullptr) {
return param_simplified_caller;
}
}
// For the single used fg, including non-after and after not matched above,
// we move the whole fg nodes.
if (use_move_) {
auto mng = fg->manager();
MS_EXCEPTION_IF_NULL(mng);
@ -148,10 +160,20 @@ class InlinerBase : public AnfVisitor {
mng->MoveAllCNodeDropGraph(fg, node->func_graph(), inputs[0]->scope());
return out_node;
}
} else if (fg->has_flag(FUNC_GRAPH_FLAG_AFTER_BLOCK) && GraphHasBranch(fg)) {
// Not to inline after block if it has switch call inside, to avoid switch expansion.
return TransformBranchCall(fg, node, args);
} else {
// We don't expand the middle multiple used after block, except the last one.
if (GraphHasBranch(fg)) {
return nullptr;
}
// Check if parameters' changed for the first met branch calling.
if (fg->has_flag(FUNC_GRAPH_FLAG_AFTER_BLOCK)) {
auto param_simplified_caller = SimplifyAfterParameter(fg, node, args);
if (param_simplified_caller != nullptr) {
return param_simplified_caller;
}
}
}
// Or, just make a clone for not single used fg.
return InlineClone(fg, node->func_graph(), args, inputs[0]->scope());
}
@ -183,37 +205,34 @@ class InlinerBase : public AnfVisitor {
// For after block which contains branch call, delete the parameters which is not used.
// In most cases, it may be a `Module` or other constant input.
AnfNodePtr TransformBranchCall(const FuncGraphPtr &fg, const AnfNodePtr &node, const std::vector<AnfNodePtr> &args) {
AnfNodePtr SimplifyAfterParameter(const FuncGraphPtr &fg, const AnfNodePtr &node,
const std::vector<AnfNodePtr> &args) {
auto &fg_params = fg->parameters();
std::vector<int64_t> used_param_index;
auto mng = fg->manager();
bool should_simplify = false;
for (size_t i = 0; i < fg_params.size(); i++) {
if (mng->node_users()[fg_params[i]].size() != 0) {
used_param_index.emplace_back(i);
} else {
MS_LOG(DEBUG) << "Not used parameter " << fg_params[i]->DebugString() << " for calling " << fg->ToString();
should_simplify = true;
}
}
// If all parameters are used by cnodes
if (used_param_index.size() == fg_params.size()) {
if (!should_simplify) {
return nullptr;
}
if (transformed_branch_chache_.find(fg) == transformed_branch_chache_.end()) {
MS_LOG(DEBUG) << "Parameter not used found for graph :" << fg->ToString();
// clone a new graph and ignore the not used parameters
FuncGraphPtr new_fg = TransformableClone(fg);
auto &new_fg_params = new_fg->parameters();
std::vector<AnfNodePtr> new_params;
std::transform(used_param_index.begin(), used_param_index.end(), std::back_inserter(new_params),
[&new_fg_params](size_t i) { return new_fg_params[i]; });
new_fg->set_parameters(new_params);
// New func graph must set FUNC_GRAPH_FLAG_AFTER_BLOCK flag otherwise the new graph will be inlined.
new_fg->set_flag(FUNC_GRAPH_FLAG_AFTER_BLOCK, true);
// Add new graph to the cache to improve perfomance when call HasBranchCall.
graph_branch_cache_[new_fg] = true;
// If a graph be called at two or more locations, it should not be cloned once again, so add it to the cache.
transformed_branch_chache_[fg] = new_fg;
}
MS_LOG(DEBUG) << "Parameter not used found for graph :" << fg->ToString();
// Clone a new graph and ignore the not used parameters
auto new_fg = TransformableClone(fg);
auto &new_fg_params = new_fg->parameters();
std::vector<AnfNodePtr> new_params;
std::transform(used_param_index.begin(), used_param_index.end(), std::back_inserter(new_params),
[&new_fg_params](size_t i) { return new_fg_params[i]; });
new_fg->set_parameters(new_params);
std::vector<AnfNodePtr> node_inputs;
node_inputs.push_back(NewValueNode(transformed_branch_chache_[fg]));
node_inputs.push_back(NewValueNode(new_fg));
std::transform(used_param_index.begin(), used_param_index.end(), std::back_inserter(node_inputs),
[&args](size_t i) { return args[i]; });
return node->func_graph()->NewCNode(node_inputs);
@ -273,8 +292,6 @@ class InlinerBase : public AnfVisitor {
bool use_move_;
std::vector<std::vector<CriterionFuncType>> criterions_;
std::unordered_map<FuncGraphPtr, bool> graph_branch_cache_;
// Key is the old func graph, and the value is the new func_graph
std::unordered_map<FuncGraphPtr, FuncGraphPtr> transformed_branch_chache_;
};
bool IsUniqueUse(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &) {