benchmark_train output add MatMulActivationFusion pattern
This commit is contained in:
parent
0b25747876
commit
1419c57ec0
|
@ -359,6 +359,7 @@ set(TRAIN_SRC
|
|||
${CMAKE_CURRENT_SOURCE_DIR}/train/classification_train_accuracy_monitor.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/train/train_export.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/train/opt_allocator.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/train/optimizer/fusion/matmul_activation_fusion_pass.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/common/storage.cc
|
||||
${TOOLS_DIR}/converter/optimizer.cc
|
||||
${TOOLS_DIR}/converter/legacy_optimizer/fusion/fusion_pass.cc
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include <memory>
|
||||
#include "tools/converter/optimizer.h"
|
||||
#include "tools/converter/legacy_optimizer/fusion/matmul_biasadd_fusion_pass.h"
|
||||
#include "mindspore/lite/src/train/optimizer/fusion/matmul_activation_fusion_pass.h"
|
||||
#include "tools/converter/legacy_optimizer/graph/isolated_node_remove_pass.h"
|
||||
#include "tools/converter/legacy_optimizer/graph/subgraph_node_pass.h"
|
||||
|
||||
|
@ -42,6 +43,7 @@ STATUS GraphFusion::Run(schema::MetaGraphT *graph) {
|
|||
auto old_nodes = GetGraphNodes(*graph);
|
||||
Optimizer fusion_optimizer;
|
||||
fusion_optimizer.AddPass(new (std::nothrow) MatMulBiasAddFusionPass());
|
||||
fusion_optimizer.AddPass(new (std::nothrow) MatMulActivationFusionPass());
|
||||
fusion_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
|
||||
fusion_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
|
||||
auto status = fusion_optimizer.Run(graph);
|
||||
|
|
|
@ -0,0 +1,99 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/train/optimizer/fusion/matmul_activation_fusion_pass.h"
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "schema/inner/model_generated.h"
|
||||
#include "tools/common/meta_graph_utils.h"
|
||||
namespace {
|
||||
constexpr int kNumMatchPathLen = 2;
|
||||
constexpr int kMatmulInputIndexSize = 3;
|
||||
constexpr std::string_view MatMulName = "MATMUL";
|
||||
constexpr std::string_view ActName = "ACTIVATION";
|
||||
} // namespace
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS MatMulActivationFusionPass::DefinePattern() {
|
||||
auto matmul_op = std::make_shared<PatternOp>();
|
||||
MS_CHECK_TRUE_RET(matmul_op != nullptr, RET_NULL_PTR);
|
||||
matmul_op->id = MatMulName;
|
||||
matmul_op->types = {schema::PrimitiveType_MatMulFusion};
|
||||
auto act_op = std::make_shared<PatternOp>();
|
||||
MS_CHECK_TRUE_RET(act_op != nullptr, RET_NULL_PTR);
|
||||
act_op->id = ActName;
|
||||
act_op->types = {schema::PrimitiveType_Activation};
|
||||
act_op->left = matmul_op;
|
||||
auto fusion_pattern = std::make_unique<FusionPattern>("MatMulActivationFusion");
|
||||
MS_CHECK_TRUE_MSG(fusion_pattern != nullptr, RET_NULL_PTR, "new fusion_pattern failed");
|
||||
fusion_pattern->AddPatternOp(matmul_op);
|
||||
fusion_pattern->AddPatternOp(act_op);
|
||||
fusion_pattern->Finish();
|
||||
this->patterns.emplace_back(fusion_pattern.release());
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS MatMulActivationFusionPass::DoFusion(
|
||||
MetaGraphT *graph, const std::string &pattern_name,
|
||||
const std::unordered_map<std::string, std::shared_ptr<Path>> &matched_path) {
|
||||
MS_CHECK_TRUE_RET(graph != nullptr, RET_NULL_PTR);
|
||||
if (matched_path.size() != kNumMatchPathLen) {
|
||||
MS_LOG(ERROR) << "MatMul-Activation-Fusion should have two NodeIndex in matchedPair";
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
auto matmul_path_iter = matched_path.find(std::string(MatMulName));
|
||||
MS_CHECK_TRUE_RET(matmul_path_iter != matched_path.end(), RET_ERROR);
|
||||
auto &matmul_path = matmul_path_iter->second;
|
||||
MS_CHECK_TRUE_RET(matmul_path != nullptr, RET_NULL_PTR);
|
||||
auto act_path_iter = matched_path.find(std::string(ActName));
|
||||
MS_CHECK_TRUE_RET(act_path_iter != matched_path.end(), RET_ERROR);
|
||||
auto &act_path = act_path_iter->second;
|
||||
MS_CHECK_TRUE_RET(act_path != nullptr, RET_NULL_PTR);
|
||||
size_t matmul_index = matmul_path->nodeIdx;
|
||||
MS_CHECK_TRUE_RET(matmul_index < graph->nodes.size(), RET_ERROR);
|
||||
size_t act_index = act_path->nodeIdx;
|
||||
MS_CHECK_TRUE_RET(act_index < graph->nodes.size(), RET_ERROR);
|
||||
auto &matmul_node = graph->nodes.at(matmul_index);
|
||||
MS_CHECK_TRUE_RET(matmul_node != nullptr, RET_NULL_PTR);
|
||||
auto &act_node = graph->nodes.at(act_index);
|
||||
MS_CHECK_TRUE_RET(act_node != nullptr, RET_NULL_PTR);
|
||||
if (matmul_node->inputIndex.size() != kMatmulInputIndexSize ||
|
||||
matmul_node->quantType == schema::QuantType_QUANT_ALL ||
|
||||
matmul_node->quantType == schema::QuantType_QUANT_DYNAMIC) {
|
||||
MS_LOG(DEBUG) << "cannot fusion.";
|
||||
return RET_NO_CHANGE;
|
||||
}
|
||||
MS_CHECK_TRUE_RET(matmul_node->primitive != nullptr, RET_NULL_PTR);
|
||||
auto matmul_type = matmul_node->primitive->value.AsMatMulFusion();
|
||||
MS_CHECK_TRUE_RET(matmul_type->activation_type == ActivationType::ActivationType_NO_ACTIVATION, RET_NO_CHANGE);
|
||||
MS_CHECK_TRUE_RET(act_node->primitive != nullptr, RET_NULL_PTR);
|
||||
auto act_type = act_node->primitive->value.AsActivation()->activation_type;
|
||||
MS_CHECK_TRUE_RET(act_type == ActivationType::ActivationType_RELU || act_type == ActivationType::ActivationType_RELU6,
|
||||
RET_NO_CHANGE);
|
||||
matmul_type->activation_type = act_type;
|
||||
matmul_node->outputIndex = {act_node->outputIndex};
|
||||
// cannot delete node here, otherwise will destroy order in other pattern's node index
|
||||
// make it an isolated node to be removed in IsolatedNodeRemovePass
|
||||
act_node->inputIndex.clear();
|
||||
act_node->outputIndex.clear();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
MatMulActivationFusionPass::~MatMulActivationFusionPass() = default;
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,43 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_TRAIN_OPTIMIZER_FUSION_MATMUL_ACTIVATION_FUSION_PASS_H_
|
||||
#define MINDSPORE_LITE_SRC_TRAIN_OPTIMIZER_FUSION_MATMUL_ACTIVATION_FUSION_PASS_H_
|
||||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
#include "tools/converter/legacy_optimizer/fusion/fusion_pass.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class MatMulActivationFusionPass : public FusionPass {
|
||||
public:
|
||||
MatMulActivationFusionPass() = default;
|
||||
|
||||
~MatMulActivationFusionPass() override;
|
||||
|
||||
STATUS DefinePattern() override;
|
||||
|
||||
STATUS DoFusion(MetaGraphT *graph, const std::string &pattern_name,
|
||||
const std::unordered_map<std::string, std::shared_ptr<Path>> &matched_path) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_TRAIN_OPTIMIZER_FUSION_MATMUL_ACTIVATION_FUSION_PASS_H_
|
Loading…
Reference in New Issue