!23669 Fix inline pass problem in switch.

Merge pull request !23669 from LiangZhibo/inline
This commit is contained in:
i-robot 2021-09-23 11:52:14 +00:00 committed by Gitee
commit d9e6edfc9f
6 changed files with 46 additions and 15 deletions

View File

@ -30,6 +30,8 @@
#include "ir/func_graph_cloner.h"
#include "ir/tensor.h"
#include "frontend/operator/ops.h"
#include "abstract/abstract_value.h"
#include "utils/utils.h"
namespace mindspore {
namespace opt {
@ -170,6 +172,7 @@ class InlinerBase : public AnfVisitor {
}
}
// Or, just make a clone for not single used fg.
MS_LOG(INFO) << "Run InlineClone in inline pass, subgraph number may increase.";
return InlineClone(fg, node->func_graph(), args, inputs[0]->scope());
}
@ -279,6 +282,30 @@ class InlinerBase : public AnfVisitor {
return node->func_graph()->NewCNode(node_inputs);
}
bool CheckSwitchBranchAbstract(const AbstractBasePtr &branch_abstract) {
if (branch_abstract != nullptr && branch_abstract->isa<abstract::AbstractError>()) {
auto branch_abstract_value = branch_abstract->GetValueTrack();
MS_EXCEPTION_IF_NULL(branch_abstract_value);
auto branch_abstract_value_string_imm = branch_abstract_value->cast<StringImmPtr>();
if (branch_abstract_value_string_imm != nullptr) {
auto branch_abstract_value_string_imm_value = branch_abstract_value_string_imm->value();
return branch_abstract_value_string_imm_value == kDeadNodeName ||
branch_abstract_value_string_imm_value == kPolyNodeName;
}
}
return false;
}
bool CheckSwitchInputs(const std::vector<AnfNodePtr> &sw_inputs) {
auto true_branch_abstract = sw_inputs[kSwitchTrueKernelGraphIndex]->abstract();
auto false_branch_abstract = sw_inputs[kSwitchFalseKernelGraphIndex]->abstract();
// When branch has dead node or poly node, do not perform inline.
if (CheckSwitchBranchAbstract(true_branch_abstract) || CheckSwitchBranchAbstract(false_branch_abstract)) {
return true;
}
return !sw_inputs[1]->isa<ValueNode>() || IsValueNode<tensor::Tensor>(sw_inputs[1]);
}
// This is a try-best algorithm to find a graph which may generate branch call.
// It does not handle high-order function call. For high-orderer call branch, it still may be inlined.
bool GraphHasBranch(FuncGraphPtr fg) {
@ -293,7 +320,7 @@ class InlinerBase : public AnfVisitor {
if (sw_inputs.size() != 4) {
MS_LOG(EXCEPTION) << "switch inputs should be 4";
}
if (!sw_inputs[1]->isa<ValueNode>() || IsValueNode<tensor::Tensor>(sw_inputs[1])) {
if (CheckSwitchInputs(sw_inputs)) {
has_branch = true;
break;
}

View File

@ -41,18 +41,18 @@ class ParameterEliminator {
MS_EXCEPTION_IF_NULL(manager);
bool changes = false;
while (true) {
auto tr = manager->Transact();
const auto &[fg, callers] = SearchFuncGraphCallers(func_graph);
if (fg == nullptr) {
break;
}
const auto &erase_indexes = EraseUnusedParameters(fg, &tr);
auto manager = fg->manager();
MS_EXCEPTION_IF_NULL(manager);
const auto &erase_indexes = EraseUnusedParameters(fg, manager);
for (auto caller : callers) {
// Erase the corresponding args.
EraseArgs(caller, erase_indexes, &tr);
EraseArgs(caller, erase_indexes, manager);
}
changes = true;
tr.Commit();
}
return changes;
}
@ -99,7 +99,7 @@ class ParameterEliminator {
return {nullptr, {}};
}
static std::unordered_set<size_t> EraseUnusedParameters(const FuncGraphPtr &fg, FuncGraphTransaction *tr) {
static std::unordered_set<size_t> EraseUnusedParameters(const FuncGraphPtr &fg, const FuncGraphManagerPtr &manager) {
MS_EXCEPTION_IF_NULL(fg->manager());
const auto &manager_node_users = fg->manager()->node_users();
const auto &parameters = fg->parameters();
@ -122,12 +122,12 @@ class ParameterEliminator {
MS_LOG(DEBUG) << "Erase parameter:" << parameters[i]->DebugString() << ",index:" << i;
}
}
tr->SetParameters(fg, new_parameters);
manager->SetParameters(fg, new_parameters);
return unused_parameter_indexes;
}
static void EraseArgs(const CNodePtr &caller, const std::unordered_set<size_t> &unused_parameter_indexes,
FuncGraphTransaction *tr) {
const FuncGraphManagerPtr &manager) {
std::vector<AnfNodePtr> new_args = {caller->inputs()[0]};
for (size_t i = 0; i < caller->inputs().size() - 1; i++) {
if (unused_parameter_indexes.find(i) == unused_parameter_indexes.end()) {
@ -139,7 +139,7 @@ class ParameterEliminator {
TraceGuard trace_guard(std::make_shared<TraceCopy>(caller->debug_info()));
auto new_caller = caller->func_graph()->NewCNode(new_args);
new_caller->set_abstract(caller->abstract());
tr->Replace(caller, new_caller);
manager->Replace(caller, new_caller);
}
};
} // namespace irpass

View File

@ -334,7 +334,8 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
opt::OptPassConfig updatestate_loads_eliminate = opt::OptPassConfig(opt::irpass::UpdatestateLoadsEliminater());
// Before adjusting map_a, check GetA1A2() and GetOptPynativeGradEpiloguePhases().
OptPassGroupMap map_a({{"a_1", a_1},
OptPassGroupMap map_a({{"switch_simplify", opt::OptPassConfig({irpass.switch_simplify_})},
{"a_1", a_1},
{"updatestate_depend_eliminate", updatestate_depend_eliminate},
{"updatestate_assign_eliminate", updatestate_assign_eliminate},
{"updatestate_loads_eliminate", updatestate_loads_eliminate},
@ -360,7 +361,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
OptPassGroupMap GetA1A2(const opt::irpass::OptimizeIRPassLib &irpass) {
auto opt_a = GetOptPassesA(irpass);
constexpr auto a1_a2_len = 6;
constexpr auto a1_a2_len = 7;
OptPassGroupMap a1_a2(opt_a.begin(), opt_a.begin() + a1_a2_len);
return a1_a2;
}

View File

@ -24,6 +24,7 @@
#include "frontend/operator/composite/do_signature.h"
#include "abstract/abstract_function.h"
#include "abstract/utils.h"
#include "utils/utils.h"
#include "ir/graph_utils.h"
#include "utils/log_adapter.h"
#include "debug/trace.h"
@ -396,8 +397,8 @@ AnfNodePtr FuncGraphSpecializer::BuildReplacedNode(const AnfNodeConfigPtr &conf)
}
namespace {
const StringImmPtr kDeadNode = std::make_shared<StringImm>("Dead Node");
const StringImmPtr kPolyNode = std::make_shared<StringImm>("Poly Node");
const StringImmPtr kDeadNode = std::make_shared<StringImm>(kDeadNodeName);
const StringImmPtr kPolyNode = std::make_shared<StringImm>(kPolyNodeName);
inline bool CanSpecializeNode(const AnfNodePtr &node) {
if (IsValueNode<FuncGraph>(node) || IsValueNode<MetaFuncGraph>(node) || IsValueNode<Primitive>(node)) {

View File

@ -213,6 +213,8 @@ constexpr auto kFusedCastAdamWeightDecayName = "FusedCastAdamWeightDecay";
constexpr auto kFusedAdamName = "FusedAdam";
constexpr auto kFusedSparseAdamName = "FusedSparseAdam";
constexpr auto kFusedMatMulBiasAddName = "FusedMatMulBiasAdd";
constexpr auto kDeadNodeName = "DeadNode";
constexpr auto kPolyNodeName = "PolyNode";
constexpr auto kApplyAdagradV2OpName = "ApplyAdagradV2";
constexpr auto kSparseApplyAdagradV2OpName = "SparseApplyAdagradV2";
constexpr auto kSparseApplyFtrlOpName = "SparseApplyFtrl";

View File

@ -133,8 +133,8 @@ def control_flow_if_after_if_in_for(input_net, x, expect1, expect2):
assert graph_backward_res == expect2
@pytest.mark.skip(reason="ME EvalCNode error")
@pytest.mark.level1
# @pytest.mark.skip(reason="ME EvalCNode error")
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training