forked from mindspore-Ecosystem/mindspore
!13703 [lite]matmul and add fusion
From: @xu_anyue Reviewed-by: Signed-off-by:
This commit is contained in:
commit
69af6643d3
|
@ -45,7 +45,7 @@ int GatherFp16CPUKernel::Init() {
|
|||
reinterpret_cast<float16_t *>(context_->allocator->Malloc(input_tensor->ElementsNum() * sizeof(float16_t)));
|
||||
Float32ToFloat16(reinterpret_cast<float *>(input_tensor->data_c()), input_data_, input_tensor->ElementsNum());
|
||||
}
|
||||
|
||||
(reinterpret_cast<GatherParameter *>(op_parameter_))->axis_ = *(reinterpret_cast<int *>(in_tensors_.at(2)->data_c()));
|
||||
if (!InferShapeDone()) {
|
||||
return RET_OK;
|
||||
}
|
||||
|
|
|
@ -15,7 +15,9 @@
|
|||
*/
|
||||
|
||||
#include "src/runtime/kernel/npu/matmul_npu.h"
|
||||
#include <memory>
|
||||
#include "src/kernel_registry.h"
|
||||
#include "src/runtime/agent/npu/npu_converter_utils.h"
|
||||
|
||||
using mindspore::kernel::KERNEL_ARCH::kNPU;
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
|
@ -24,6 +26,11 @@ using mindspore::schema::PrimitiveType_MatMul;
|
|||
namespace mindspore::kernel {
|
||||
int MatMulNPUKernel::IsSupport(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs,
|
||||
OpParameter *opParameter) {
|
||||
if (inputs.size() == 3) {
|
||||
if (inputs[2]->shape().size() != 1) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
@ -33,7 +40,33 @@ int MatMulNPUKernel::SetNPUInputs(const std::vector<lite::Tensor *> &inputs, con
|
|||
op_->set_input_x1(*npu_inputs[0]);
|
||||
op_->set_input_x2(*npu_inputs[1]);
|
||||
if (npu_inputs.size() == 3) {
|
||||
op_->set_input_bias(*npu_inputs[2]);
|
||||
matmul_parameter_->has_bias_ = true;
|
||||
add_op_ = new (std::nothrow) hiai::op::Add(name_ + "_add");
|
||||
if (add_op_ == nullptr) {
|
||||
MS_LOG(ERROR) << "new add op failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
add_op_->set_input_x1(*op_);
|
||||
auto bias_shape = inputs[2]->shape();
|
||||
auto bias_tensor = std::make_shared<ge::Tensor>();
|
||||
if (bias_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "new bias_tensor failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
ge::TensorDesc bias_tensor_desc(lite::ConverterToNPUShape({1, bias_shape[0], 1, 1}), ge::FORMAT_NCHW,
|
||||
lite::ConverterToNPUDataType(inputs[2]->data_type()));
|
||||
if (outputs[0]->shape().size() == 2) {
|
||||
bias_tensor_desc.SetShape(lite::ConverterToNPUShape({1, bias_shape[0]}));
|
||||
}
|
||||
bias_tensor->SetTensorDesc(bias_tensor_desc);
|
||||
bias_tensor->SetData(reinterpret_cast<const uint8_t *>(inputs[2]->data_c()), inputs[2]->Size());
|
||||
bias_ = new (std::nothrow) hiai::op::Const(name_ + "_bias");
|
||||
if (bias_ == nullptr) {
|
||||
MS_LOG(ERROR) << "new bias const failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
bias_->set_attr_value(bias_tensor);
|
||||
add_op_->set_input_x2(*bias_);
|
||||
}
|
||||
|
||||
op_->set_attr_transpose_x1(matmul_parameter_->a_transpose_);
|
||||
|
@ -41,13 +74,26 @@ int MatMulNPUKernel::SetNPUInputs(const std::vector<lite::Tensor *> &inputs, con
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
ge::Operator *mindspore::kernel::MatMulNPUKernel::GetNPUOp() { return this->op_; }
|
||||
ge::Operator *mindspore::kernel::MatMulNPUKernel::GetNPUOp() {
|
||||
if (matmul_parameter_->has_bias_) {
|
||||
return add_op_;
|
||||
}
|
||||
return op_;
|
||||
}
|
||||
|
||||
MatMulNPUKernel::~MatMulNPUKernel() {
|
||||
if (op_ != nullptr) {
|
||||
delete op_;
|
||||
op_ = nullptr;
|
||||
}
|
||||
if (add_op_ != nullptr) {
|
||||
delete add_op_;
|
||||
add_op_ = nullptr;
|
||||
}
|
||||
if (bias_ != nullptr) {
|
||||
delete bias_;
|
||||
bias_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
REG_KERNEL(kNPU, kNumberTypeFloat32, PrimitiveType_MatMul, NPUKernelCreator<MatMulNPUKernel>)
|
||||
|
|
|
@ -39,6 +39,8 @@ class MatMulNPUKernel : public NPUKernel {
|
|||
|
||||
private:
|
||||
hiai::op::MatMul *op_ = nullptr;
|
||||
hiai::op::Add *add_op_ = nullptr;
|
||||
hiai::op::Const *bias_ = nullptr;
|
||||
MatMulParameter *matmul_parameter_;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
|
|
@ -11,12 +11,12 @@ STRING(REPLACE " -fvisibility=hidden " " -fvisibility=default " CMAKE_C_FLAGS "$
|
|||
STRING(REPLACE " -fvisibility=hidden " " -fvisibility=default " CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
|
||||
|
||||
if(ENABLE_CONVERTER)
|
||||
set(CCSRC_SRC
|
||||
## ccsrc
|
||||
${CCSRC_DIR}/backend/optimizer/common/pattern_engine.cc
|
||||
${CCSRC_DIR}/backend/optimizer/common/visit.cc
|
||||
${CCSRC_DIR}/backend/optimizer/common/optimizer.cc
|
||||
)
|
||||
set(CCSRC_SRC
|
||||
## ccsrc
|
||||
${CCSRC_DIR}/backend/optimizer/common/pattern_engine.cc
|
||||
${CCSRC_DIR}/backend/optimizer/common/visit.cc
|
||||
${CCSRC_DIR}/backend/optimizer/common/optimizer.cc
|
||||
)
|
||||
else()
|
||||
set(TEST_LITE_SRC ${LITE_DIR}/src/common/log_adapter.cc)
|
||||
add_compile_definitions(USE_ANDROID_LOG)
|
||||
|
@ -38,10 +38,10 @@ file(GLOB KERNEL_OP_SRC
|
|||
file(GLOB KERNEL_OP_TRAIN_SRC
|
||||
${LITE_DIR}/nnacl/fp32_grad/*.c
|
||||
${LITE_DIR}/src/runtime/kernel/arm/fp32_grad/*.cc
|
||||
)
|
||||
)
|
||||
|
||||
if(SUPPORT_TRAIN)
|
||||
list(APPEND KERNEL_OP_SRC ${KERNEL_OP_TRAIN_SRC})
|
||||
list(APPEND KERNEL_OP_SRC ${KERNEL_OP_TRAIN_SRC})
|
||||
endif()
|
||||
if(PLATFORM_ARM64)
|
||||
# assembly
|
||||
|
@ -114,9 +114,9 @@ if(SUPPORT_GPU STREQUAL vulkan)
|
|||
endif()
|
||||
|
||||
if(PLATFORM_ARM32 OR PLATFORM_ARM64)
|
||||
if(ENABLE_CONVERTER)
|
||||
set(BUILD_MINDDATA "off")
|
||||
endif()
|
||||
if(ENABLE_CONVERTER)
|
||||
set(BUILD_MINDDATA "off")
|
||||
endif()
|
||||
endif()
|
||||
### runtime framework
|
||||
add_definitions(-DENABLE_V0)
|
||||
|
@ -189,19 +189,19 @@ if(ENABLE_MINDRT)
|
|||
include_directories(${CORE_DIR}/mindrt/)
|
||||
include_directories(${CORE_DIR}/mindrt/src/)
|
||||
set(TEST_LITE_SRC ${TEST_LITE_SRC}
|
||||
${LITE_DIR}/src/lite_mindrt.cc
|
||||
${LITE_DIR}/src/mindrt_executor.cc
|
||||
${CORE_DIR}/mindrt/src/litebus.cc
|
||||
${CORE_DIR}/mindrt/src/actor/actor.cc
|
||||
${CORE_DIR}/mindrt/src/actor/actormgr.cc
|
||||
${CORE_DIR}/mindrt/src/actor/actorpolicy.cc
|
||||
${CORE_DIR}/mindrt/src/actor/actorthread.cc
|
||||
${CORE_DIR}/mindrt/src/actor/aid.cc
|
||||
${CORE_DIR}/mindrt/src/async/async.cc
|
||||
${CORE_DIR}/mindrt/src/async/future.cc
|
||||
${CORE_DIR}/mindrt/src/async/uuid_base.cc
|
||||
${CORE_DIR}/mindrt/src/async/uuid_generator.cc
|
||||
)
|
||||
${LITE_DIR}/src/lite_mindrt.cc
|
||||
${LITE_DIR}/src/mindrt_executor.cc
|
||||
${CORE_DIR}/mindrt/src/litebus.cc
|
||||
${CORE_DIR}/mindrt/src/actor/actor.cc
|
||||
${CORE_DIR}/mindrt/src/actor/actormgr.cc
|
||||
${CORE_DIR}/mindrt/src/actor/actorpolicy.cc
|
||||
${CORE_DIR}/mindrt/src/actor/actorthread.cc
|
||||
${CORE_DIR}/mindrt/src/actor/aid.cc
|
||||
${CORE_DIR}/mindrt/src/async/async.cc
|
||||
${CORE_DIR}/mindrt/src/async/future.cc
|
||||
${CORE_DIR}/mindrt/src/async/uuid_base.cc
|
||||
${CORE_DIR}/mindrt/src/async/uuid_generator.cc
|
||||
)
|
||||
endif()
|
||||
|
||||
|
||||
|
@ -242,6 +242,7 @@ if(ENABLE_CONVERTER)
|
|||
${LITE_DIR}/tools/optimizer/fusion/tf_lstm_cell_fusion.cc
|
||||
${LITE_DIR}/tools/optimizer/fusion/tf_bidirection_gru_fusion.cc
|
||||
${LITE_DIR}/tools/optimizer/fusion/tf_bidirection_gru_cf_fusion.cc
|
||||
${LITE_DIR}/tools/optimizer/fusion/matmul_add_fusion.cc
|
||||
${LITE_DIR}/tools/optimizer/graph/weight_format_transform_pass.cc
|
||||
${LITE_DIR}/tools/optimizer/graph/weight_format_hardcode_pass.cc
|
||||
${LITE_DIR}/tools/optimizer/graph/clip_convert_activation_pass.cc
|
||||
|
@ -286,16 +287,16 @@ else()
|
|||
endif()
|
||||
### test src
|
||||
file(GLOB_RECURSE TEST_CASE_KERNEL_SRC
|
||||
${TEST_DIR}/ut/src/runtime/kernel/arm/common/*.cc
|
||||
${TEST_DIR}/ut/src/runtime/kernel/arm/fp32/*.cc
|
||||
${TEST_DIR}/ut/src/runtime/kernel/arm/int8/*.cc
|
||||
${TEST_DIR}/ut/src/runtime/kernel/arm/string/*.cc
|
||||
${TEST_DIR}/ut/nnacl/infer/*.cc
|
||||
)
|
||||
${TEST_DIR}/ut/src/runtime/kernel/arm/common/*.cc
|
||||
${TEST_DIR}/ut/src/runtime/kernel/arm/fp32/*.cc
|
||||
${TEST_DIR}/ut/src/runtime/kernel/arm/int8/*.cc
|
||||
${TEST_DIR}/ut/src/runtime/kernel/arm/string/*.cc
|
||||
${TEST_DIR}/ut/nnacl/infer/*.cc
|
||||
)
|
||||
|
||||
file(GLOB_RECURSE TEST_CASE_KERNEL_TRAIN_SRC
|
||||
${TEST_DIR}/ut/src/runtime/kernel/arm/fp32_grad/*.cc
|
||||
)
|
||||
${TEST_DIR}/ut/src/runtime/kernel/arm/fp32_grad/*.cc
|
||||
)
|
||||
|
||||
set(TEST_SRC
|
||||
${TEST_LITE_SRC}
|
||||
|
@ -306,7 +307,7 @@ set(TEST_SRC
|
|||
${TEST_DIR}/ut/src/infer_test.cc
|
||||
${TEST_DIR}/ut/src/utils_test.cc
|
||||
${TEST_DIR}/ut/src/scheduler_test.cc
|
||||
)
|
||||
)
|
||||
|
||||
if(ENABLE_CONVERTER)
|
||||
set(TEST_SRC
|
||||
|
@ -358,7 +359,7 @@ endif()
|
|||
|
||||
if(ENABLE_FP16 AND SUPPORT_TRAIN)
|
||||
file(GLOB_RECURSE TEST_CASE_KERNEL_FP16_SRC_GRAD
|
||||
${TEST_DIR}/ut/src/runtime/kernel/arm/fp16_grad/*.cc)
|
||||
${TEST_DIR}/ut/src/runtime/kernel/arm/fp16_grad/*.cc)
|
||||
list(APPEND TEST_SRC ${TEST_CASE_KERNEL_FP16_SRC_GRAD})
|
||||
endif()
|
||||
|
||||
|
|
|
@ -52,6 +52,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
|||
../optimizer/fusion/tf_lstm_cell_fusion.cc
|
||||
../optimizer/fusion/tf_bidirection_gru_fusion.cc
|
||||
../optimizer/fusion/tf_bidirection_gru_cf_fusion.cc
|
||||
../optimizer/fusion/matmul_add_fusion.cc
|
||||
../optimizer/graph/weight_format_transform_pass.cc
|
||||
../optimizer/graph/weight_format_hardcode_pass.cc
|
||||
../optimizer/graph/clip_convert_activation_pass.cc
|
||||
|
|
|
@ -35,6 +35,7 @@
|
|||
#include "tools/optimizer/fusion/tf_lstm_cell_fusion.h"
|
||||
#include "tools/optimizer/fusion/tf_bidirection_gru_fusion.h"
|
||||
#include "tools/optimizer/fusion/tf_bidirection_gru_cf_fusion.h"
|
||||
#include "tools/optimizer/fusion/matmul_add_fusion.h"
|
||||
#include "tools/optimizer/graph/primitive_adjust_pass.h"
|
||||
#include "tools/optimizer/graph/mindir_adjust_pass.h"
|
||||
#include "tools/optimizer/graph/redundant_op_remove_pass.h"
|
||||
|
@ -107,6 +108,9 @@ int AnfTransform::AddFusionPass(const std::shared_ptr<opt::GraphOptimizer> &opti
|
|||
fusion_pm->AddPass(remove_unused_transpose_pass);
|
||||
}
|
||||
fusion_pm->AddPass(std::make_shared<opt::ConvConvFusion>());
|
||||
if (!config->trainModel) {
|
||||
fusion_pm->AddPass(std::make_shared<opt::MatMulAddFusion>());
|
||||
}
|
||||
optimizer->AddPassManager(fusion_pm);
|
||||
return RET_OK;
|
||||
}
|
||||
|
|
|
@ -1,48 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/parser/onnx/onnx_gemm_parser.h"
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "ops/make_tuple.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
ops::PrimitiveC *OnnxGemmParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) {
|
||||
auto prim = std::make_unique<ops::MakeTuple>();
|
||||
|
||||
auto node_parser = OnnxNodeParserRegistry::GetInstance()->GetNodeParser("MatMul");
|
||||
if (node_parser == nullptr) {
|
||||
MS_LOG(ERROR) << "parse node " << onnx_node.op_type() << " failed.";
|
||||
return nullptr;
|
||||
}
|
||||
auto *matmul_primitive = node_parser->Parse(onnx_graph, onnx_node);
|
||||
prim->AddAttr("MatMul", std::shared_ptr<ops::PrimitiveC>(matmul_primitive));
|
||||
|
||||
node_parser = OnnxNodeParserRegistry::GetInstance()->GetNodeParser("BiasAdd");
|
||||
if (node_parser == nullptr) {
|
||||
MS_LOG(ERROR) << "parse node " << onnx_node.op_type() << " failed.";
|
||||
return nullptr;
|
||||
}
|
||||
auto *bias_add_primitive = node_parser->Parse(onnx_graph, onnx_node);
|
||||
prim->AddAttr("BiasAdd", std::shared_ptr<ops::PrimitiveC>(bias_add_primitive));
|
||||
|
||||
return prim.release();
|
||||
}
|
||||
|
||||
OnnxNodeRegistrar g_onnxGemmParser("Gemm", new OnnxGemmParser());
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -1,34 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_GEMM_PARSER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_GEMM_PARSER_H
|
||||
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser.h"
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class OnnxGemmParser : public OnnxNodeParser {
|
||||
public:
|
||||
OnnxGemmParser() : OnnxNodeParser("Gemm") {}
|
||||
~OnnxGemmParser() override = default;
|
||||
|
||||
ops::PrimitiveC *Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_GEMM_PARSER_H
|
|
@ -46,5 +46,6 @@ ops::PrimitiveC *OnnxMatmulParser::Parse(const onnx::GraphProto &onnx_graph, con
|
|||
}
|
||||
|
||||
OnnxNodeRegistrar g_onnxMatmulParser("MatMul", new OnnxMatmulParser());
|
||||
OnnxNodeRegistrar g_onnxGemmParser("Gemm", new OnnxMatmulParser());
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -44,7 +44,6 @@ static const std::unordered_map<int, mindspore::TypeId> TYPE_MAP = {
|
|||
{onnx::TensorProto_DataType_FLOAT, mindspore::kNumberTypeFloat32},
|
||||
{onnx::TensorProto_DataType_BOOL, mindspore::kNumberTypeBool}};
|
||||
|
||||
std::set<std::string> SPECIAL_NODE = {"Gemm"};
|
||||
FuncGraphPtr OnnxModelParser::Parse(const std::string &model_file, const std::string &weight_file,
|
||||
const QuantType &quant_type) {
|
||||
NoSupportOp::GetInstance()->SetFmkType("ONNX");
|
||||
|
@ -215,11 +214,6 @@ STATUS OnnxModelParser::ConvertNodes(const onnx::GraphProto &onnx_graph, const F
|
|||
MS_LOG(ERROR) << "convert " << onnx_node.op_type() << " quant param failed.";
|
||||
continue;
|
||||
}
|
||||
if (IsSpecialOnnxNode(onnx_node)) {
|
||||
auto status_node = ConvertSpecialOnnxNode(onnx_node, anf_graph, anf_nodes_map, primitive_c);
|
||||
status = status == RET_OK ? status_node : status;
|
||||
continue;
|
||||
}
|
||||
// build CNode
|
||||
status = BuildCNode(onnx_node, anf_graph, anf_nodes_map, graph_inputs, primitive_c, root_node_name);
|
||||
if (status != RET_OK) {
|
||||
|
@ -1023,117 +1017,6 @@ STATUS OnnxModelParser::BuildCondGraph(const FuncGraphPtr &cond_graph, const Anf
|
|||
return status;
|
||||
}
|
||||
|
||||
STATUS OnnxModelParser::ConvertSpecialOnnxNode(const onnx::NodeProto &onnx_node, const FuncGraphPtr &anf_graph,
|
||||
std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map,
|
||||
ops::PrimitiveC *primitive_c) {
|
||||
if (primitive_c == nullptr || anf_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "imitive_c is nullptr.";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
STATUS status = RET_OK;
|
||||
if (onnx_node.op_type() == "Gemm") {
|
||||
status = ConvertOnnxGemmNode(onnx_node, anf_graph, anf_nodes_map, primitive_c);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "the node is not special node.";
|
||||
status = RET_ERROR;
|
||||
}
|
||||
delete primitive_c;
|
||||
return status;
|
||||
}
|
||||
|
||||
STATUS OnnxModelParser::ConvertOnnxGemmNode(const onnx::NodeProto &onnx_node, const FuncGraphPtr &anf_graph,
|
||||
std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map,
|
||||
ops::PrimitiveC *primitive_c) {
|
||||
if (primitive_c == nullptr || anf_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "parameter has nullptr.";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
if (onnx_node.op_type() != "Gemm") {
|
||||
MS_LOG(ERROR) << "this op is not gemm, it is " << onnx_node.op_type();
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (primitive_c == nullptr) {
|
||||
MS_LOG(ERROR) << "primitive_c is nullptr.";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
auto status = BuildCNodeForGemm(onnx_node, anf_graph, anf_nodes_map, primitive_c, "MatMul");
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "convert gemm node failed.";
|
||||
return status;
|
||||
}
|
||||
status = BuildCNodeForGemm(onnx_node, anf_graph, anf_nodes_map, primitive_c, "BiasAdd");
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "convert gemm node failed.";
|
||||
return status;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS OnnxModelParser::BuildCNodeForGemm(const onnx::NodeProto &onnx_node, const FuncGraphPtr &anf_graph,
|
||||
std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map,
|
||||
ops::PrimitiveC *primitive_c, const std::string &name) {
|
||||
if (primitive_c == nullptr || anf_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "parameter has nullptr.";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
auto value = primitive_c->GetAttr(name);
|
||||
primitive_c->EraseAttr(name);
|
||||
if (value == nullptr) {
|
||||
MS_LOG(ERROR) << "op parse failed.";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
auto prim_ptr = value->cast<std::shared_ptr<ops::PrimitiveC>>();
|
||||
if (prim_ptr == nullptr) {
|
||||
MS_LOG(ERROR) << "primitive parse failed.";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
auto type_ptr = TypeIdToType(kTypeUnknown);
|
||||
std::vector<int64_t> shape_vector;
|
||||
std::vector<AnfNodePtr> op_inputs;
|
||||
auto quant_params_holder = std::make_shared<QuantParamHolder>();
|
||||
auto quant_params_holder_origin = primitive_c->GetAttr("quant_params")->cast<QuantParamHolderPtr>();
|
||||
if (name == "MatMul") {
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
if (anf_nodes_map->find(onnx_node.input(i)) == anf_nodes_map->end()) {
|
||||
MS_LOG(ERROR) << "op " << onnx_node.op_type() << " inputs get failed.";
|
||||
return RET_ERROR;
|
||||
} else {
|
||||
op_inputs.push_back(anf_nodes_map->at(onnx_node.input(i)));
|
||||
quant_params_holder->AddInputQuantParam(quant_params_holder_origin->input_quant_params().at(i));
|
||||
}
|
||||
}
|
||||
quant_params_holder->AddOutputQuantParam(std::vector<schema::QuantParamT>(1));
|
||||
auto new_cnode = anf_graph->NewCNode(prim_ptr, op_inputs);
|
||||
if (new_cnode == nullptr) {
|
||||
MS_LOG(ERROR) << "new cnode error";
|
||||
return RET_ERROR;
|
||||
}
|
||||
new_cnode->set_fullname_with_scope("Gemm_MatMul_" + onnx_node.output(0));
|
||||
new_cnode->set_abstract(std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector));
|
||||
anf_nodes_map->emplace("Gemm_MatMul_" + onnx_node.output(0), new_cnode);
|
||||
} else {
|
||||
if (anf_nodes_map->find("Gemm_MatMul_" + onnx_node.output(0)) == anf_nodes_map->end() ||
|
||||
anf_nodes_map->find(onnx_node.input(2)) == anf_nodes_map->end()) {
|
||||
MS_LOG(ERROR) << "op " << onnx_node.op_type() << " inputs get failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
op_inputs.push_back(anf_nodes_map->at("Gemm_MatMul_" + onnx_node.output(0)));
|
||||
op_inputs.push_back(anf_nodes_map->at(onnx_node.input(2)));
|
||||
quant_params_holder->AddInputQuantParam(std::vector<schema::QuantParamT>(1));
|
||||
quant_params_holder->AddInputQuantParam(quant_params_holder_origin->input_quant_params().at(2));
|
||||
quant_params_holder->AddOutputQuantParam(quant_params_holder_origin->output_quant_params().front());
|
||||
auto new_cnode = anf_graph->NewCNode(prim_ptr, op_inputs);
|
||||
if (new_cnode == nullptr) {
|
||||
MS_LOG(ERROR) << "new cnode error";
|
||||
return RET_ERROR;
|
||||
}
|
||||
new_cnode->set_fullname_with_scope("Gemm_BiasAdd_" + onnx_node.output(0));
|
||||
new_cnode->set_abstract(std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector));
|
||||
anf_nodes_map->emplace(onnx_node.output(0), new_cnode);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS OnnxModelParser::BuildParameterNodeForQuantParam(const void *data, const std::string &name, TypeId type) {
|
||||
if (data == nullptr) {
|
||||
MS_LOG(ERROR) << "value is nullptr.";
|
||||
|
@ -1281,10 +1164,6 @@ STATUS OnnxModelParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_t
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
bool OnnxModelParser::IsSpecialOnnxNode(const onnx::NodeProto &onnx_node) {
|
||||
return SPECIAL_NODE.find(onnx_node.op_type()) != SPECIAL_NODE.end();
|
||||
}
|
||||
|
||||
TypeId OnnxModelParser::GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type) {
|
||||
auto iter = TYPE_MAP.find(onnx_type);
|
||||
if (iter == TYPE_MAP.end()) {
|
||||
|
|
|
@ -69,21 +69,11 @@ class OnnxModelParser : public ModelParser {
|
|||
ops::PrimitiveC *primitive_c, std::string loop_name);
|
||||
static STATUS BuildOpOutputs(const onnx::NodeProto &onnx_node, const FuncGraphPtr &func_graph_ptr,
|
||||
std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map, const CNodePtr &cnode);
|
||||
static STATUS ConvertSpecialOnnxNode(const onnx::NodeProto &onnx_node, const FuncGraphPtr &func_graph_ptr,
|
||||
std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map,
|
||||
ops::PrimitiveC *primitive_c);
|
||||
static STATUS ConvertOnnxGemmNode(const onnx::NodeProto &onnx_node, const FuncGraphPtr &func_graph_ptr,
|
||||
std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map,
|
||||
ops::PrimitiveC *primitive_c);
|
||||
static STATUS BuildCNodeForGemm(const onnx::NodeProto &onnx_node, const FuncGraphPtr &func_graph_ptr,
|
||||
std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map,
|
||||
ops::PrimitiveC *primitive_c, const std::string &name);
|
||||
STATUS ConvertOpQuantParams(const onnx::NodeProto &onnx_node, ops::PrimitiveC *primitive_c);
|
||||
STATUS ParseQuantParam(const onnx::NodeProto &onnx_node);
|
||||
STATUS SetTensorQuantParam(const std::string &tensor_name, std::vector<QuantParamT> *quant_params);
|
||||
STATUS SetTensorQuantParamFromNode(const std::string &tensor_name, std::vector<QuantParamT> *quant_params);
|
||||
STATUS CopyTensorQuantParam(const std::string &tensor_name, QuantParamT *quant_param, bool scale_or_not);
|
||||
static bool IsSpecialOnnxNode(const onnx::NodeProto &onnx_node);
|
||||
STATUS ConvertLoopOnnxNode(const onnx::NodeProto &onnx_node,
|
||||
std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map,
|
||||
const std::string &root_node_name);
|
||||
|
|
|
@ -0,0 +1,79 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/optimizer/fusion/matmul_add_fusion.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
constexpr size_t AddInputSize = 3;
|
||||
constexpr size_t MatMulInputSize = 3;
|
||||
bool CheckAndGetMatMulIndex(const CNodePtr &cnode, size_t *index) {
|
||||
MS_ASSERT(cnode != nullptr);
|
||||
MS_ASSERT(index != nullptr);
|
||||
if (cnode->size() != AddInputSize) {
|
||||
return false;
|
||||
}
|
||||
size_t matmul_index = 0;
|
||||
for (size_t i = 1; i < cnode->size(); ++i) {
|
||||
if (CheckPrimitiveType(cnode->input(i), prim::kPrimMatMul)) {
|
||||
auto matmul_cnode = cnode->input(i)->cast<CNodePtr>();
|
||||
if (matmul_cnode->size() > MatMulInputSize) {
|
||||
continue;
|
||||
}
|
||||
matmul_index = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (matmul_index == 0) {
|
||||
return false;
|
||||
}
|
||||
*index = matmul_index;
|
||||
return true;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool MatMulAddFusion::Run(const FuncGraphPtr &func_graph) {
|
||||
MS_ASSERT(func_graph != nulltr);
|
||||
auto node_list = TopoSort(func_graph->get_return());
|
||||
for (auto &node : node_list) {
|
||||
if (!utils::isa<CNode>(node)) {
|
||||
continue;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (!CheckPrimitiveType(node, prim::kPrimAddFusion) && !CheckPrimitiveType(node, prim::kPrimBiasAdd)) {
|
||||
continue;
|
||||
}
|
||||
size_t index = 0;
|
||||
if (!CheckAndGetMatMulIndex(cnode, &index)) {
|
||||
continue;
|
||||
}
|
||||
auto matmul_cnode = cnode->input(index)->cast<CNodePtr>();
|
||||
auto bias_node = cnode->input(AddInputSize - index);
|
||||
if (!utils::isa<Parameter>(bias_node) || !bias_node->cast<ParameterPtr>()->default_param()) {
|
||||
continue;
|
||||
}
|
||||
matmul_cnode->add_input(bias_node);
|
||||
auto manager = func_graph->manager();
|
||||
MS_ASSERT(manager != nullptr);
|
||||
matmul_cnode->set_fullname_with_scope(node->fullname_with_scope());
|
||||
manager->Replace(node, matmul_cnode);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,34 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_PASS_FUSION_MATMUL_ADD_FUSION_H_
|
||||
#define MINDSPORE_LITE_SRC_PASS_FUSION_MATMUL_ADD_FUSION_H_
|
||||
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
#include "tools/converter/converter_context.h"
|
||||
#include "backend/optimizer/common/pass.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class MatMulAddFusion : public Pass {
|
||||
public:
|
||||
MatMulAddFusion() : Pass("matmul_add_fusion") {}
|
||||
~MatMulAddFusion() override = default;
|
||||
bool Run(const FuncGraphPtr &func_graph) override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_SRC_PASS_FUSION_MATMUL_ADD_FUSION_H_
|
Loading…
Reference in New Issue