forked from mindspore-Ecosystem/mindspore
!8105 Fix MatMul's third input not in all_nodes after matmul_biasadd fusion
Merge pull request !8105 from huanghui/fix-matmul-biasadd-fusion
This commit is contained in:
commit
9ae5f96988
|
@ -14,38 +14,39 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/optimizer/ascend/ir_fusion/matmul_biasadd_fusion.h"
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "utils/utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
constexpr size_t kMatMulInputIndex = 1;
|
||||
constexpr size_t kBiasInputIndex = 2;
|
||||
} // namespace
|
||||
|
||||
const BaseRef MatmulBiasaddFusion::DefinePattern() const {
|
||||
VarPtr X0 = std::make_shared<Var>();
|
||||
VarPtr X1 = std::make_shared<Var>();
|
||||
VarPtr X2 = std::make_shared<Var>();
|
||||
const auto prim_bias_add = std::make_shared<Primitive>(kBiasAddOpName);
|
||||
return VectorRef({prim_bias_add, VectorRef({prim::kPrimMatMul, X0, X1}), X2});
|
||||
VectorRef matmul({matmul_var_, x0_, x1_});
|
||||
VectorRef pattern({prim::kPrimBiasAdd, matmul, x2_});
|
||||
return pattern;
|
||||
}
|
||||
|
||||
const AnfNodePtr MatmulBiasaddFusion::Process(const FuncGraphPtr &, const AnfNodePtr &node, const EquivPtr &) const {
|
||||
const AnfNodePtr MatmulBiasaddFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
||||
const EquivPtr &equiv) const {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
CheckCNodeInputSize(cnode, kBiasAddInputNum);
|
||||
AnfNodePtr matmul = cnode->input(kMatMulInputIndex);
|
||||
MS_EXCEPTION_IF_NULL(matmul);
|
||||
auto matmul_cnode = matmul->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(matmul_cnode);
|
||||
matmul_cnode->add_input(cnode->input(kBiasInputIndex));
|
||||
AnfAlgo::SetNodeAttr(kAttrHasBias, MakeValue(true), matmul);
|
||||
return matmul;
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
std::vector<AnfNodePtr> inputs;
|
||||
inputs.emplace_back(NewValueNode(std::make_shared<Primitive>(prim::kPrimMatMul->name())));
|
||||
inputs.emplace_back(GetAnfNodeByVar(equiv, x0_));
|
||||
inputs.emplace_back(GetAnfNodeByVar(equiv, x1_));
|
||||
inputs.emplace_back(GetAnfNodeByVar(equiv, x2_));
|
||||
auto new_node = graph->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(new_node);
|
||||
new_node->set_scope(node->scope());
|
||||
new_node->set_abstract(node->abstract());
|
||||
|
||||
auto matmul = GetAnfNodeByVar(equiv, matmul_var_);
|
||||
if (matmul == nullptr || !matmul->isa<CNode>()) {
|
||||
MS_LOG(EXCEPTION) << "Get CNode MatMul failed!";
|
||||
}
|
||||
AnfAlgo::CopyNodeAttrs(matmul, new_node);
|
||||
return new_node;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -16,17 +16,29 @@
|
|||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_MATMUL_BIASADD_FUSION_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_MATMUL_BIASADD_FUSION_H_
|
||||
|
||||
#include <memory>
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class MatmulBiasaddFusion : public PatternProcessPass {
|
||||
public:
|
||||
explicit MatmulBiasaddFusion(bool multigraph = true) : PatternProcessPass("matmul_biasadd_fusion", multigraph) {}
|
||||
explicit MatmulBiasaddFusion(bool multigraph = true) : PatternProcessPass("matmul_biasadd_fusion", multigraph) {
|
||||
x0_ = std::make_shared<Var>();
|
||||
x1_ = std::make_shared<Var>();
|
||||
x2_ = std::make_shared<Var>();
|
||||
matmul_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimMatMul->name()));
|
||||
}
|
||||
|
||||
~MatmulBiasaddFusion() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
VarPtr x0_;
|
||||
VarPtr x1_;
|
||||
VarPtr x2_;
|
||||
VarPtr matmul_var_;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue