!8105 Fix MatMul's third input not in all_nodes after matmul_biasadd fusion

Merge pull request !8105 from huanghui/fix-matmul-biasadd-fusion
This commit is contained in:
mindspore-ci-bot 2020-11-03 11:27:55 +08:00 committed by Gitee
commit 9ae5f96988
2 changed files with 36 additions and 23 deletions

View File

@ -14,38 +14,39 @@
* limitations under the License.
*/
#include "backend/optimizer/ascend/ir_fusion/matmul_biasadd_fusion.h"
#include <memory>
#include <vector>
#include "backend/optimizer/common/helper.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "utils/utils.h"
namespace mindspore {
namespace opt {
namespace {
constexpr size_t kMatMulInputIndex = 1;
constexpr size_t kBiasInputIndex = 2;
} // namespace
const BaseRef MatmulBiasaddFusion::DefinePattern() const {
VarPtr X0 = std::make_shared<Var>();
VarPtr X1 = std::make_shared<Var>();
VarPtr X2 = std::make_shared<Var>();
const auto prim_bias_add = std::make_shared<Primitive>(kBiasAddOpName);
return VectorRef({prim_bias_add, VectorRef({prim::kPrimMatMul, X0, X1}), X2});
VectorRef matmul({matmul_var_, x0_, x1_});
VectorRef pattern({prim::kPrimBiasAdd, matmul, x2_});
return pattern;
}
const AnfNodePtr MatmulBiasaddFusion::Process(const FuncGraphPtr &, const AnfNodePtr &node, const EquivPtr &) const {
const AnfNodePtr MatmulBiasaddFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
const EquivPtr &equiv) const {
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
CheckCNodeInputSize(cnode, kBiasAddInputNum);
AnfNodePtr matmul = cnode->input(kMatMulInputIndex);
MS_EXCEPTION_IF_NULL(matmul);
auto matmul_cnode = matmul->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(matmul_cnode);
matmul_cnode->add_input(cnode->input(kBiasInputIndex));
AnfAlgo::SetNodeAttr(kAttrHasBias, MakeValue(true), matmul);
return matmul;
MS_EXCEPTION_IF_NULL(graph);
std::vector<AnfNodePtr> inputs;
inputs.emplace_back(NewValueNode(std::make_shared<Primitive>(prim::kPrimMatMul->name())));
inputs.emplace_back(GetAnfNodeByVar(equiv, x0_));
inputs.emplace_back(GetAnfNodeByVar(equiv, x1_));
inputs.emplace_back(GetAnfNodeByVar(equiv, x2_));
auto new_node = graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(new_node);
new_node->set_scope(node->scope());
new_node->set_abstract(node->abstract());
auto matmul = GetAnfNodeByVar(equiv, matmul_var_);
if (matmul == nullptr || !matmul->isa<CNode>()) {
MS_LOG(EXCEPTION) << "Get CNode MatMul failed!";
}
AnfAlgo::CopyNodeAttrs(matmul, new_node);
return new_node;
}
} // namespace opt
} // namespace mindspore

View File

@ -16,17 +16,29 @@
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_MATMUL_BIASADD_FUSION_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_MATMUL_BIASADD_FUSION_H_
#include <memory>
#include "backend/optimizer/common/optimizer.h"
namespace mindspore {
namespace opt {
class MatmulBiasaddFusion : public PatternProcessPass {
public:
explicit MatmulBiasaddFusion(bool multigraph = true) : PatternProcessPass("matmul_biasadd_fusion", multigraph) {}
explicit MatmulBiasaddFusion(bool multigraph = true) : PatternProcessPass("matmul_biasadd_fusion", multigraph) {
x0_ = std::make_shared<Var>();
x1_ = std::make_shared<Var>();
x2_ = std::make_shared<Var>();
matmul_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimMatMul->name()));
}
~MatmulBiasaddFusion() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
VarPtr x0_;
VarPtr x1_;
VarPtr x2_;
VarPtr matmul_var_;
};
} // namespace opt
} // namespace mindspore