forked from mindspore-Ecosystem/mindspore
!23669 Fix inline pass problem in switch.
Merge pull request !23669 from LiangZhibo/inline
This commit is contained in:
commit
d9e6edfc9f
|
@ -30,6 +30,8 @@
|
|||
#include "ir/func_graph_cloner.h"
|
||||
#include "ir/tensor.h"
|
||||
#include "frontend/operator/ops.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -170,6 +172,7 @@ class InlinerBase : public AnfVisitor {
|
|||
}
|
||||
}
|
||||
// Or, just make a clone for not single used fg.
|
||||
MS_LOG(INFO) << "Run InlineClone in inline pass, subgraph number may increase.";
|
||||
return InlineClone(fg, node->func_graph(), args, inputs[0]->scope());
|
||||
}
|
||||
|
||||
|
@ -279,6 +282,30 @@ class InlinerBase : public AnfVisitor {
|
|||
return node->func_graph()->NewCNode(node_inputs);
|
||||
}
|
||||
|
||||
bool CheckSwitchBranchAbstract(const AbstractBasePtr &branch_abstract) {
|
||||
if (branch_abstract != nullptr && branch_abstract->isa<abstract::AbstractError>()) {
|
||||
auto branch_abstract_value = branch_abstract->GetValueTrack();
|
||||
MS_EXCEPTION_IF_NULL(branch_abstract_value);
|
||||
auto branch_abstract_value_string_imm = branch_abstract_value->cast<StringImmPtr>();
|
||||
if (branch_abstract_value_string_imm != nullptr) {
|
||||
auto branch_abstract_value_string_imm_value = branch_abstract_value_string_imm->value();
|
||||
return branch_abstract_value_string_imm_value == kDeadNodeName ||
|
||||
branch_abstract_value_string_imm_value == kPolyNodeName;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool CheckSwitchInputs(const std::vector<AnfNodePtr> &sw_inputs) {
|
||||
auto true_branch_abstract = sw_inputs[kSwitchTrueKernelGraphIndex]->abstract();
|
||||
auto false_branch_abstract = sw_inputs[kSwitchFalseKernelGraphIndex]->abstract();
|
||||
// When branch has dead node or poly node, do not perform inline.
|
||||
if (CheckSwitchBranchAbstract(true_branch_abstract) || CheckSwitchBranchAbstract(false_branch_abstract)) {
|
||||
return true;
|
||||
}
|
||||
return !sw_inputs[1]->isa<ValueNode>() || IsValueNode<tensor::Tensor>(sw_inputs[1]);
|
||||
}
|
||||
|
||||
// This is a try-best algorithm to find a graph which may generate branch call.
|
||||
// It does not handle high-order function call. For high-orderer call branch, it still may be inlined.
|
||||
bool GraphHasBranch(FuncGraphPtr fg) {
|
||||
|
@ -293,7 +320,7 @@ class InlinerBase : public AnfVisitor {
|
|||
if (sw_inputs.size() != 4) {
|
||||
MS_LOG(EXCEPTION) << "switch inputs should be 4";
|
||||
}
|
||||
if (!sw_inputs[1]->isa<ValueNode>() || IsValueNode<tensor::Tensor>(sw_inputs[1])) {
|
||||
if (CheckSwitchInputs(sw_inputs)) {
|
||||
has_branch = true;
|
||||
break;
|
||||
}
|
||||
|
|
|
@ -41,18 +41,18 @@ class ParameterEliminator {
|
|||
MS_EXCEPTION_IF_NULL(manager);
|
||||
bool changes = false;
|
||||
while (true) {
|
||||
auto tr = manager->Transact();
|
||||
const auto &[fg, callers] = SearchFuncGraphCallers(func_graph);
|
||||
if (fg == nullptr) {
|
||||
break;
|
||||
}
|
||||
const auto &erase_indexes = EraseUnusedParameters(fg, &tr);
|
||||
auto manager = fg->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
const auto &erase_indexes = EraseUnusedParameters(fg, manager);
|
||||
for (auto caller : callers) {
|
||||
// Erase the corresponding args.
|
||||
EraseArgs(caller, erase_indexes, &tr);
|
||||
EraseArgs(caller, erase_indexes, manager);
|
||||
}
|
||||
changes = true;
|
||||
tr.Commit();
|
||||
}
|
||||
return changes;
|
||||
}
|
||||
|
@ -99,7 +99,7 @@ class ParameterEliminator {
|
|||
return {nullptr, {}};
|
||||
}
|
||||
|
||||
static std::unordered_set<size_t> EraseUnusedParameters(const FuncGraphPtr &fg, FuncGraphTransaction *tr) {
|
||||
static std::unordered_set<size_t> EraseUnusedParameters(const FuncGraphPtr &fg, const FuncGraphManagerPtr &manager) {
|
||||
MS_EXCEPTION_IF_NULL(fg->manager());
|
||||
const auto &manager_node_users = fg->manager()->node_users();
|
||||
const auto ¶meters = fg->parameters();
|
||||
|
@ -122,12 +122,12 @@ class ParameterEliminator {
|
|||
MS_LOG(DEBUG) << "Erase parameter:" << parameters[i]->DebugString() << ",index:" << i;
|
||||
}
|
||||
}
|
||||
tr->SetParameters(fg, new_parameters);
|
||||
manager->SetParameters(fg, new_parameters);
|
||||
return unused_parameter_indexes;
|
||||
}
|
||||
|
||||
static void EraseArgs(const CNodePtr &caller, const std::unordered_set<size_t> &unused_parameter_indexes,
|
||||
FuncGraphTransaction *tr) {
|
||||
const FuncGraphManagerPtr &manager) {
|
||||
std::vector<AnfNodePtr> new_args = {caller->inputs()[0]};
|
||||
for (size_t i = 0; i < caller->inputs().size() - 1; i++) {
|
||||
if (unused_parameter_indexes.find(i) == unused_parameter_indexes.end()) {
|
||||
|
@ -139,7 +139,7 @@ class ParameterEliminator {
|
|||
TraceGuard trace_guard(std::make_shared<TraceCopy>(caller->debug_info()));
|
||||
auto new_caller = caller->func_graph()->NewCNode(new_args);
|
||||
new_caller->set_abstract(caller->abstract());
|
||||
tr->Replace(caller, new_caller);
|
||||
manager->Replace(caller, new_caller);
|
||||
}
|
||||
};
|
||||
} // namespace irpass
|
||||
|
|
|
@ -334,7 +334,8 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
|
|||
opt::OptPassConfig updatestate_loads_eliminate = opt::OptPassConfig(opt::irpass::UpdatestateLoadsEliminater());
|
||||
|
||||
// Before adjusting map_a, check GetA1A2() and GetOptPynativeGradEpiloguePhases().
|
||||
OptPassGroupMap map_a({{"a_1", a_1},
|
||||
OptPassGroupMap map_a({{"switch_simplify", opt::OptPassConfig({irpass.switch_simplify_})},
|
||||
{"a_1", a_1},
|
||||
{"updatestate_depend_eliminate", updatestate_depend_eliminate},
|
||||
{"updatestate_assign_eliminate", updatestate_assign_eliminate},
|
||||
{"updatestate_loads_eliminate", updatestate_loads_eliminate},
|
||||
|
@ -360,7 +361,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
|
|||
|
||||
OptPassGroupMap GetA1A2(const opt::irpass::OptimizeIRPassLib &irpass) {
|
||||
auto opt_a = GetOptPassesA(irpass);
|
||||
constexpr auto a1_a2_len = 6;
|
||||
constexpr auto a1_a2_len = 7;
|
||||
OptPassGroupMap a1_a2(opt_a.begin(), opt_a.begin() + a1_a2_len);
|
||||
return a1_a2;
|
||||
}
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
#include "frontend/operator/composite/do_signature.h"
|
||||
#include "abstract/abstract_function.h"
|
||||
#include "abstract/utils.h"
|
||||
#include "utils/utils.h"
|
||||
#include "ir/graph_utils.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "debug/trace.h"
|
||||
|
@ -396,8 +397,8 @@ AnfNodePtr FuncGraphSpecializer::BuildReplacedNode(const AnfNodeConfigPtr &conf)
|
|||
}
|
||||
|
||||
namespace {
|
||||
const StringImmPtr kDeadNode = std::make_shared<StringImm>("Dead Node");
|
||||
const StringImmPtr kPolyNode = std::make_shared<StringImm>("Poly Node");
|
||||
const StringImmPtr kDeadNode = std::make_shared<StringImm>(kDeadNodeName);
|
||||
const StringImmPtr kPolyNode = std::make_shared<StringImm>(kPolyNodeName);
|
||||
|
||||
inline bool CanSpecializeNode(const AnfNodePtr &node) {
|
||||
if (IsValueNode<FuncGraph>(node) || IsValueNode<MetaFuncGraph>(node) || IsValueNode<Primitive>(node)) {
|
||||
|
|
|
@ -213,6 +213,8 @@ constexpr auto kFusedCastAdamWeightDecayName = "FusedCastAdamWeightDecay";
|
|||
constexpr auto kFusedAdamName = "FusedAdam";
|
||||
constexpr auto kFusedSparseAdamName = "FusedSparseAdam";
|
||||
constexpr auto kFusedMatMulBiasAddName = "FusedMatMulBiasAdd";
|
||||
constexpr auto kDeadNodeName = "DeadNode";
|
||||
constexpr auto kPolyNodeName = "PolyNode";
|
||||
constexpr auto kApplyAdagradV2OpName = "ApplyAdagradV2";
|
||||
constexpr auto kSparseApplyAdagradV2OpName = "SparseApplyAdagradV2";
|
||||
constexpr auto kSparseApplyFtrlOpName = "SparseApplyFtrl";
|
||||
|
|
|
@ -133,8 +133,8 @@ def control_flow_if_after_if_in_for(input_net, x, expect1, expect2):
|
|||
assert graph_backward_res == expect2
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="ME EvalCNode error")
|
||||
@pytest.mark.level1
|
||||
# @pytest.mark.skip(reason="ME EvalCNode error")
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
|
|
Loading…
Reference in New Issue