!12371 [auto-monad] Optimize merge_addn and matmul_biasadd_fusion
From: @hwhewei Reviewed-by: @zh_qh,@ginfung Signed-off-by: @zh_qh
This commit is contained in:
commit
32c8733d3b
|
@ -39,9 +39,7 @@ const AnfNodePtr MatmulBiasaddFusion::Process(const FuncGraphPtr &graph, const A
|
||||||
}
|
}
|
||||||
|
|
||||||
// If there is a side-effect operator in the fusion, do not merge
|
// If there is a side-effect operator in the fusion, do not merge
|
||||||
MonadState state_matmul = GetMonadState(matmul);
|
if (!IsStateEquivalent(node, matmul)) {
|
||||||
MonadState state_node = GetMonadState(node, matmul);
|
|
||||||
if (!IsStateEquivalent(state_matmul, state_node)) {
|
|
||||||
return node;
|
return node;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -71,16 +71,15 @@ class MergeAddN : public AnfVisitor {
|
||||||
is_inner_ = true;
|
is_inner_ = true;
|
||||||
|
|
||||||
// {prim::kPrimMakeTuple, {PrimAddN, {prim::kPrimMakeTuple, Xs}}, Ys}
|
// {prim::kPrimMakeTuple, {PrimAddN, {prim::kPrimMakeTuple, Xs}}, Ys}
|
||||||
AnfVisitor::Match(prim::kPrimAddN, {IsCNode})(inputs[1]);
|
const auto &first_input = inputs.at(1);
|
||||||
|
AnfVisitor::Match(prim::kPrimAddN, {IsCNode})(first_input);
|
||||||
if (is_match_) {
|
if (is_match_) {
|
||||||
if (!is_unique(inputs[1])) {
|
if (!is_unique(first_input)) {
|
||||||
is_match_ = false;
|
is_match_ = false;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
MonadState state_input = GetMonadState(inputs[1]);
|
if (!IsStateEquivalent(cnode, first_input)) {
|
||||||
MonadState state_cnode = GetMonadState(cnode, inputs[1]);
|
|
||||||
if (!IsStateEquivalent(state_cnode, state_input)) {
|
|
||||||
is_match_ = false;
|
is_match_ = false;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -92,16 +91,15 @@ class MergeAddN : public AnfVisitor {
|
||||||
}
|
}
|
||||||
|
|
||||||
// {prim::kPrimMakeTuple, Ys, {PrimAddN, {prim::kPrimMakeTuple, Xs}}}
|
// {prim::kPrimMakeTuple, Ys, {PrimAddN, {prim::kPrimMakeTuple, Xs}}}
|
||||||
AnfVisitor::Match(prim::kPrimAddN, {IsCNode})(inputs.back());
|
const auto &last_input = inputs.back();
|
||||||
|
AnfVisitor::Match(prim::kPrimAddN, {IsCNode})(last_input);
|
||||||
if (is_match_) {
|
if (is_match_) {
|
||||||
if (!is_unique(inputs.back())) {
|
if (!is_unique(last_input)) {
|
||||||
is_match_ = false;
|
is_match_ = false;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
MonadState state_input = GetMonadState(inputs.back());
|
if (!IsStateEquivalent(cnode, last_input)) {
|
||||||
MonadState state_cnode = GetMonadState(cnode, inputs.back());
|
|
||||||
if (!IsStateEquivalent(state_cnode, state_input)) {
|
|
||||||
is_match_ = false;
|
is_match_ = false;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
|
@ -300,6 +300,46 @@ bool IsStateEquivalent(const MonadState &state1, const MonadState &state2) {
|
||||||
(state1.io == nullptr || state2.io == nullptr || state1.io == state2.io);
|
(state1.io == nullptr || state2.io == nullptr || state1.io == state2.io);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool IsStateStrictEquivalent(const AnfNodePtr &outer, const AnfNodePtr &inner) {
|
||||||
|
MonadState state_matmul = GetMonadState(inner);
|
||||||
|
MonadState state_node = GetMonadState(outer, inner);
|
||||||
|
return IsStateEquivalent(state_matmul, state_node);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::set<CNodePtr> GetLoadInputs(const AnfNodePtr &node) {
|
||||||
|
std::set<CNodePtr> loads;
|
||||||
|
auto cnode = dyn_cast<CNode>(node);
|
||||||
|
if (cnode == nullptr) {
|
||||||
|
return loads;
|
||||||
|
}
|
||||||
|
auto &inputs = cnode->inputs();
|
||||||
|
for (size_t i = 1; i < inputs.size(); ++i) {
|
||||||
|
auto &input = inputs.at(i);
|
||||||
|
if (IsPrimitiveCNode(input, prim::kPrimLoad)) {
|
||||||
|
loads.insert(input->cast<CNodePtr>());
|
||||||
|
} else if (IsPrimitiveCNode(input, prim::kPrimMakeTuple)) {
|
||||||
|
loads.merge(GetLoadInputs(input));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return loads;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool IsStateEquivalent(const AnfNodePtr &outer, const AnfNodePtr &inner) {
|
||||||
|
constexpr size_t kMonadInput = 2;
|
||||||
|
auto outer_loads = GetLoadInputs(outer);
|
||||||
|
if (outer_loads.empty()) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
auto inner_loads = GetLoadInputs(inner);
|
||||||
|
if (inner_loads.empty()) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
outer_loads.merge(inner_loads);
|
||||||
|
auto &monad = (*outer_loads.begin())->inputs().at(kMonadInput);
|
||||||
|
return std::all_of(++outer_loads.begin(), outer_loads.end(),
|
||||||
|
[&monad](const CNodePtr &load) { return load->inputs().at(kMonadInput) == monad; });
|
||||||
|
}
|
||||||
|
|
||||||
size_t NewSeenGeneration() {
|
size_t NewSeenGeneration() {
|
||||||
static size_t seen_generation = 0;
|
static size_t seen_generation = 0;
|
||||||
return ++seen_generation;
|
return ++seen_generation;
|
||||||
|
@ -353,6 +393,26 @@ std::string GetMaketupleNodeTarget(const CNodePtr &cnode) {
|
||||||
std::string default_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
|
std::string default_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
|
||||||
return default_target;
|
return default_target;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string GetAttrTarget(const PrimitivePtr &primitive, const ValuePtr &att_target, const AnfNodePtr &attr_input,
|
||||||
|
const std::string &primitive_target, const std::string &default_target) {
|
||||||
|
if (IsPrimitive(attr_input, prim::kPrimImageSummary) || IsPrimitive(attr_input, prim::kPrimScalarSummary) ||
|
||||||
|
IsPrimitive(attr_input, prim::kPrimTensorSummary) || IsPrimitive(attr_input, prim::kPrimHistogramSummary) ||
|
||||||
|
IsPrimitive(attr_input, prim::kPrimStateSetItem) || IsPrimitive(attr_input, prim::kPrimDepend) ||
|
||||||
|
IsPrimitive(attr_input, prim::kPrimControlDepend) || IsPrimitive(attr_input, prim::kPrimReturn) ||
|
||||||
|
IsPrimitive(attr_input, prim::kPrimPartial)) {
|
||||||
|
primitive->EraseAttr(primitive_target);
|
||||||
|
return default_target;
|
||||||
|
}
|
||||||
|
if (!att_target->isa<StringImm>()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Only support string CPU|GPU|Ascend for primitive_target";
|
||||||
|
}
|
||||||
|
auto target = GetValue<std::string>(att_target);
|
||||||
|
if (kTargetSet.find(target) == kTargetSet.end()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Only support string CPU|GPU|Ascend for primitive_target, but get " << target;
|
||||||
|
}
|
||||||
|
return target;
|
||||||
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::string GetCNodeTarget(const AnfNodePtr &node) {
|
std::string GetCNodeTarget(const AnfNodePtr &node) {
|
||||||
|
@ -387,22 +447,7 @@ std::string GetCNodeTarget(const AnfNodePtr &node) {
|
||||||
auto primitive = value->cast<PrimitivePtr>();
|
auto primitive = value->cast<PrimitivePtr>();
|
||||||
auto att_target = primitive->GetAttr(primitive_target);
|
auto att_target = primitive->GetAttr(primitive_target);
|
||||||
if (att_target != nullptr) {
|
if (att_target != nullptr) {
|
||||||
if (IsPrimitive(attr_input, prim::kPrimImageSummary) || IsPrimitive(attr_input, prim::kPrimScalarSummary) ||
|
return GetAttrTarget(primitive, att_target, attr_input, primitive_target, default_target);
|
||||||
IsPrimitive(attr_input, prim::kPrimTensorSummary) || IsPrimitive(attr_input, prim::kPrimHistogramSummary) ||
|
|
||||||
IsPrimitive(attr_input, prim::kPrimStateSetItem) || IsPrimitive(attr_input, prim::kPrimDepend) ||
|
|
||||||
IsPrimitive(attr_input, prim::kPrimControlDepend) || IsPrimitive(attr_input, prim::kPrimReturn) ||
|
|
||||||
IsPrimitive(attr_input, prim::kPrimPartial)) {
|
|
||||||
primitive->EraseAttr(primitive_target);
|
|
||||||
return default_target;
|
|
||||||
}
|
|
||||||
if (!att_target->isa<StringImm>()) {
|
|
||||||
MS_LOG(EXCEPTION) << "Only support string CPU|GPU|Ascend for primitive_target";
|
|
||||||
}
|
|
||||||
auto target = GetValue<std::string>(att_target);
|
|
||||||
if (kTargetSet.find(target) == kTargetSet.end()) {
|
|
||||||
MS_LOG(EXCEPTION) << "Only support string CPU|GPU|Ascend for primitive_target, but get " << target;
|
|
||||||
}
|
|
||||||
return target;
|
|
||||||
}
|
}
|
||||||
if (IsPrimitiveCNode(node, prim::kPrimDepend)) {
|
if (IsPrimitiveCNode(node, prim::kPrimDepend)) {
|
||||||
auto &inputs = cnode->inputs();
|
auto &inputs = cnode->inputs();
|
||||||
|
|
|
@ -530,6 +530,12 @@ MonadState GetMonadState(const AnfNodePtr &node, const AnfNodePtr &skip_input =
|
||||||
// Check if two state is equivalent.
|
// Check if two state is equivalent.
|
||||||
bool IsStateEquivalent(const MonadState &state1, const MonadState &state2);
|
bool IsStateEquivalent(const MonadState &state1, const MonadState &state2);
|
||||||
|
|
||||||
|
// Check if monad state is strict equivalent for the connected two nodes.
|
||||||
|
bool IsStateStrictEquivalent(const AnfNodePtr &outer, const AnfNodePtr &inner);
|
||||||
|
|
||||||
|
// Check if monad state is equivalent for the connected two nodes, not strict but more faster.
|
||||||
|
bool IsStateEquivalent(const AnfNodePtr &outer, const AnfNodePtr &inner);
|
||||||
|
|
||||||
// used to check whether a ValueNode has some kind of value
|
// used to check whether a ValueNode has some kind of value
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static bool IsValueNode(const AnfNodePtr &node) {
|
static bool IsValueNode(const AnfNodePtr &node) {
|
||||||
|
|
Loading…
Reference in New Issue