forked from mindspore-Ecosystem/mindspore
fusion convolution2d and pad op
This commit is contained in:
parent
3a50cb8432
commit
01a42bfcdd
|
@ -229,6 +229,7 @@ if(MSLITE_ENABLE_CONVERTER)
|
|||
${LITE_DIR}/tools/optimizer/fusion/batchmatmul_fusion.cc
|
||||
${LITE_DIR}/tools/optimizer/fusion/sigmoid_mul_fusion.cc
|
||||
${LITE_DIR}/tools/optimizer/fusion/conv_conv_fusion.cc
|
||||
${LITE_DIR}/tools/optimizer/fusion/conv_pad_fusion.cc
|
||||
${LITE_DIR}/tools/optimizer/fusion/conv_tuplegetitem_fusion.cc
|
||||
${LITE_DIR}/tools/optimizer/fusion/tflite_lstm_cell_fusion.cc
|
||||
${LITE_DIR}/tools/optimizer/fusion/tf_lstm_cell_fusion.cc
|
||||
|
|
|
@ -67,6 +67,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
|||
../optimizer/fusion/batchmatmul_fusion.cc
|
||||
../optimizer/fusion/sigmoid_mul_fusion.cc
|
||||
../optimizer/fusion/conv_conv_fusion.cc
|
||||
../optimizer/fusion/conv_pad_fusion.cc
|
||||
../optimizer/fusion/tflite_lstm_cell_fusion.cc
|
||||
../optimizer/fusion/tf_lstm_cell_fusion.cc
|
||||
../optimizer/fusion/tf_bidirection_gru_fusion.cc
|
||||
|
|
|
@ -36,6 +36,7 @@
|
|||
#include "tools/optimizer/fusion/batchmatmul_fusion.h"
|
||||
#include "tools/optimizer/fusion/sigmoid_mul_fusion.h"
|
||||
#include "tools/optimizer/fusion/conv_conv_fusion.h"
|
||||
#include "tools/optimizer/fusion/conv_pad_fusion.h"
|
||||
#include "tools/optimizer/fusion/tflite_lstm_cell_fusion.h"
|
||||
#include "tools/optimizer/fusion/tf_lstm_cell_fusion.h"
|
||||
#include "tools/optimizer/fusion/tf_bidirection_gru_fusion.h"
|
||||
|
@ -125,6 +126,7 @@ int AnfTransform::RunFusionPass(const FuncGraphPtr &old_graph, const converter::
|
|||
fusion_pm->AddPass(remove_unused_cast_pass);
|
||||
}
|
||||
fusion_pm->AddPass(std::make_shared<opt::ConvConvFusion>());
|
||||
fusion_pm->AddPass(std::make_shared<opt::ConvPadFusion>());
|
||||
if (!config->trainModel) {
|
||||
fusion_pm->AddPass(std::make_shared<opt::MatMulAddFusion>());
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -22,6 +22,14 @@
|
|||
#include "schema/inner/model_generated.h"
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
||||
namespace {
|
||||
constexpr size_t kPadDims = 4;
|
||||
constexpr size_t kExplicitPaddingsDims = 8;
|
||||
constexpr size_t NHWCTopPadPos = 2;
|
||||
constexpr size_t NCHWTopPadPos = 4;
|
||||
} // namespace
|
||||
|
||||
STATUS TFConvBaseParser::ParseKernels(const tensorflow::NodeDef &node_def, const mindspore::Format &format,
|
||||
std::vector<int64_t> *kernel) {
|
||||
tensorflow::AttrValue attr_value;
|
||||
|
@ -60,6 +68,33 @@ STATUS TFConvBaseParser::ParseStrides(const tensorflow::NodeDef &node_def, const
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS TFConvBaseParser::ParseExplicitPaddings(const tensorflow::NodeDef &node_def, const mindspore::Format &format,
|
||||
std::vector<int64_t> *explicit_paddings) {
|
||||
MS_ASSERT(explicit_paddings != nullptr);
|
||||
tensorflow::AttrValue attr_value;
|
||||
if (!TensorFlowUtils::FindAttrValue(node_def, "explicit_paddings", &attr_value)) {
|
||||
MS_LOG(ERROR) << "The explicit paddings value should be specified";
|
||||
return RET_ERROR;
|
||||
} else {
|
||||
auto explicit_paddings_list = attr_value.list();
|
||||
if (explicit_paddings_list.i_size() != kExplicitPaddingsDims) {
|
||||
MS_LOG(ERROR) << "The explicit paddings attr should contain only 8 elements";
|
||||
return RET_ERROR;
|
||||
}
|
||||
explicit_paddings->clear();
|
||||
if (format == mindspore::NHWC) {
|
||||
for (size_t i = 0; i < kPadDims; ++i) {
|
||||
explicit_paddings->push_back(explicit_paddings_list.i(i + NHWCTopPadPos));
|
||||
}
|
||||
} else {
|
||||
for (size_t i = 0; i < kPadDims; ++i) {
|
||||
explicit_paddings->push_back(explicit_paddings_list.i(i + NCHWTopPadPos));
|
||||
}
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS TFConvBaseParser::ParseDilations(const tensorflow::NodeDef &node_def, const mindspore::Format &format,
|
||||
std::vector<int64_t> *dilations) {
|
||||
tensorflow::AttrValue attr_value;
|
||||
|
@ -87,6 +122,8 @@ mindspore::PadMode TFConvBaseParser::ParsePadMode(const tensorflow::NodeDef &nod
|
|||
}
|
||||
if (attr_value.s() == "SAME") {
|
||||
return mindspore::PadMode::SAME;
|
||||
} else if (attr_value.s() == "EXPLICIT") {
|
||||
return mindspore::PadMode::PAD;
|
||||
}
|
||||
return mindspore::PadMode::VALID;
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -34,6 +34,8 @@ class TFConvBaseParser : public TFNodeParser {
|
|||
std::vector<int64_t> *dilations);
|
||||
static STATUS ParseKernels(const tensorflow::NodeDef &node_def, const mindspore::Format &format,
|
||||
std::vector<int64_t> *kernel);
|
||||
static STATUS ParseExplicitPaddings(const tensorflow::NodeDef &node_def, const mindspore::Format &format,
|
||||
std::vector<int64_t> *explicit_paddings);
|
||||
static mindspore::PadMode ParsePadMode(const tensorflow::NodeDef &node_def);
|
||||
};
|
||||
} // namespace lite
|
||||
|
|
|
@ -68,6 +68,14 @@ ops::PrimitiveC *TFConvParser::Parse(const tensorflow::NodeDef &tf_op,
|
|||
|
||||
auto pad_mode = ParsePadMode(tf_op);
|
||||
prim->set_pad_mode(pad_mode);
|
||||
if (pad_mode == PadMode::PAD) {
|
||||
std::vector<int64_t> explicit_paddings;
|
||||
if (ParseExplicitPaddings(tf_op, format, &explicit_paddings) != RET_OK) {
|
||||
MS_LOG(ERROR) << "parse explicit paddings attr failed";
|
||||
return nullptr;
|
||||
}
|
||||
prim->set_pad_list(explicit_paddings);
|
||||
}
|
||||
|
||||
*output_size = 1;
|
||||
if (AddOpInput(tf_op, 0, inputs) != RET_OK || AddOpInput(tf_op, 1, inputs) != RET_OK) {
|
||||
|
|
|
@ -0,0 +1,48 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "tools/converter/parser/tf/tf_depth_to_space_parser.h"
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
|
||||
#include "ops/depth_to_space.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
ops::PrimitiveC *TFDepthToSpaceParser::Parse(const tensorflow::NodeDef &tf_op,
|
||||
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
|
||||
std::vector<std::string> *inputs, int *output_size) {
|
||||
auto prim = std::make_unique<ops::DepthToSpace>();
|
||||
|
||||
tensorflow::AttrValue attr_value;
|
||||
if (!TensorFlowUtils::FindAttrValue(tf_op, "block_size", &attr_value)) {
|
||||
MS_LOG(ERROR) << "The block_size attr should be specified";
|
||||
return nullptr;
|
||||
}
|
||||
prim->set_block_size(attr_value.i());
|
||||
|
||||
*output_size = 1;
|
||||
if (AddOpInput(tf_op, 0, inputs) != RET_OK) {
|
||||
MS_LOG(ERROR) << "add op input failed";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return prim.release();
|
||||
}
|
||||
TFNodeRegistrar g_tfDepthToSpaceParser("DepthToSpace", new TFDepthToSpaceParser());
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,37 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_DEPTH_TO_SPACE_PARSER_H_
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_DEPTH_TO_SPACE_PARSER_H_
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include "tools/converter/parser/tf/tf_conv_base_parser.h"
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class TFDepthToSpaceParser : public TFConvBaseParser {
|
||||
public:
|
||||
TFDepthToSpaceParser() = default;
|
||||
~TFDepthToSpaceParser() override = default;
|
||||
|
||||
ops::PrimitiveC *Parse(const tensorflow::NodeDef &tf_op,
|
||||
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
|
||||
std::vector<std::string> *inputs, int *output_size) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_DEPTH_TO_SPACE_PARSER_H_
|
|
@ -0,0 +1,208 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/optimizer/fusion/conv_pad_fusion.h"
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "tools/common/tensor_util.h"
|
||||
#include "ops/fusion/pad_fusion.h"
|
||||
#include "ops/fusion/conv2d_fusion.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
constexpr size_t kPadInputsLength = 3;
|
||||
constexpr size_t kConvInputIndex = 1;
|
||||
constexpr size_t kConvNoBiasLen = 3;
|
||||
constexpr size_t kConvWithBiasLen = 4;
|
||||
constexpr size_t kFilterDimsSize = 2;
|
||||
constexpr size_t NHWCTopPadPos = 2;
|
||||
constexpr size_t NCHWTopPadPos = 4;
|
||||
constexpr size_t kTop = 0;
|
||||
constexpr size_t kBottom = 1;
|
||||
constexpr size_t kLeft = 2;
|
||||
constexpr size_t kRight = 3;
|
||||
constexpr size_t kPadDims = 4;
|
||||
|
||||
void ReplaceParamsAndNodes(const FuncGraphPtr &func_graph, const CNodePtr &conv_cnode, const CNodePtr &pad_cnode,
|
||||
const std::string &pattern_name) {
|
||||
auto paddings = pad_cnode->input(kInputIndexTwo)->cast<ParameterPtr>();
|
||||
auto pad_list = std::dynamic_pointer_cast<tensor::Tensor>(paddings->default_param());
|
||||
auto pad_data = static_cast<int32_t *>(pad_list->data_c());
|
||||
|
||||
std::vector<int64_t> pad_list_data;
|
||||
if (pattern_name == "PadConvPatternName") {
|
||||
pad_list_data.push_back(pad_data[kTop + NHWCTopPadPos]);
|
||||
pad_list_data.push_back(pad_data[kBottom + NHWCTopPadPos]);
|
||||
pad_list_data.push_back(pad_data[kLeft + NHWCTopPadPos]);
|
||||
pad_list_data.push_back(pad_data[kRight + NHWCTopPadPos]);
|
||||
} else {
|
||||
pad_list_data.push_back(pad_data[kTop + NCHWTopPadPos]);
|
||||
pad_list_data.push_back(pad_data[kBottom + NCHWTopPadPos]);
|
||||
pad_list_data.push_back(pad_data[kLeft + NCHWTopPadPos]);
|
||||
pad_list_data.push_back(pad_data[kRight + NCHWTopPadPos]);
|
||||
}
|
||||
|
||||
auto conv_primitive = GetValueNode<std::shared_ptr<ops::Conv2DFusion>>(conv_cnode->input(0));
|
||||
MS_ASSERT(conv_primitive != nullptr);
|
||||
int64_t conv_pad_mode = conv_primitive->GetAttr(ops::kPadMode) == nullptr ? 0 : conv_primitive->get_pad_mode();
|
||||
if (conv_pad_mode == PadMode::PAD) {
|
||||
auto pad_list_node = conv_primitive->GetAttr(ops::kPadList);
|
||||
if (pad_list_node != nullptr) {
|
||||
std::vector<int64_t> conv_pad_list = GetValue<std::vector<int64_t>>(pad_list_node);
|
||||
if (conv_pad_list.size() == kPadDims) {
|
||||
pad_list_data[kTop] += conv_pad_list[kTop];
|
||||
pad_list_data[kBottom] += conv_pad_list[kBottom];
|
||||
pad_list_data[kLeft] += conv_pad_list[kLeft];
|
||||
pad_list_data[kRight] += conv_pad_list[kRight];
|
||||
}
|
||||
}
|
||||
} else if (conv_pad_mode == PadMode::SAME) {
|
||||
ValuePtr kernel_node = conv_primitive->GetAttr(ops::kKernelSize);
|
||||
MS_ASSERT(kernel_node != nullptr);
|
||||
std::vector<int64_t> kernel_list = GetValue<std::vector<int64_t>>(kernel_node);
|
||||
if (kernel_list.size() != kFilterDimsSize) {
|
||||
MS_LOG(ERROR) << "Filter Dims should be 2, Fusion failed! ,name:" << conv_cnode->fullname_with_scope();
|
||||
return;
|
||||
} else if (kernel_list[0] == kernel_list[1]) {
|
||||
int64_t pad_size = std::floor(kernel_list[0] / 2);
|
||||
for (size_t i = 0; i < pad_list_data.size(); ++i) {
|
||||
pad_list_data[i] += pad_size;
|
||||
}
|
||||
} else {
|
||||
int64_t top_pad_size = std::floor(kernel_list[0] / 2);
|
||||
int64_t left_pad_size = std::floor(kernel_list[1] / 2);
|
||||
pad_list_data[kTop] += top_pad_size;
|
||||
pad_list_data[kBottom] += top_pad_size;
|
||||
pad_list_data[kLeft] += left_pad_size;
|
||||
pad_list_data[kRight] += left_pad_size;
|
||||
}
|
||||
conv_primitive->set_pad_mode(PadMode::PAD);
|
||||
} else {
|
||||
conv_primitive->set_pad_mode(PadMode::PAD);
|
||||
}
|
||||
conv_primitive->set_pad_list(pad_list_data);
|
||||
|
||||
// delete padFusion
|
||||
auto manager = func_graph->manager();
|
||||
manager->Replace(pad_cnode, pad_cnode->input(1));
|
||||
}
|
||||
|
||||
bool IsPrimitiveProper(const CNodePtr &conv_cnode, const CNodePtr &pad_cnode) {
|
||||
MS_ASSERT(conv_cnode != nullptr);
|
||||
MS_ASSERT(pad_cnode != nullptr);
|
||||
if (!utils::isa<Parameter>(pad_cnode->input(kInputIndexTwo))) {
|
||||
return false;
|
||||
}
|
||||
auto pad_primitive = GetValueNode<std::shared_ptr<ops::PadFusion>>(pad_cnode->input(0));
|
||||
MS_ASSERT(pad_primitive != nullptr);
|
||||
int64_t pad_mode = pad_primitive->get_padding_mode();
|
||||
if (pad_mode != PaddingMode::CONSTANT) {
|
||||
return false;
|
||||
}
|
||||
ValuePtr pad_constant_node = pad_primitive->GetAttr(ops::kConstantValue);
|
||||
MS_ASSERT(pad_constant_node != nullptr);
|
||||
float pad_value = GetValue<float>(pad_constant_node);
|
||||
if (pad_value != 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
VectorRef ConvPadFusion::DefinePadConvPattern() const {
|
||||
auto pad_var = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimPadFusion>);
|
||||
auto conv_var = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimConv2DFusion>);
|
||||
auto weight_var = std::make_shared<CondVar>(IsParamNode);
|
||||
auto bias_var = std::make_shared<SeqVar>();
|
||||
return VectorRef({conv_var, pad_var, weight_var, bias_var});
|
||||
}
|
||||
|
||||
VectorRef ConvPadFusion::DefinePadTransposeConvPattern() const {
|
||||
auto pad_var = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimPadFusion>);
|
||||
auto transpose_var = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimTranspose>);
|
||||
auto transpose_param = std::make_shared<CondVar>(IsParamNode);
|
||||
VectorRef transpose_conv_ref = VectorRef({transpose_var, pad_var, transpose_param});
|
||||
|
||||
auto conv_var = std::make_shared<CondVar>(IsConvNode);
|
||||
auto weight_var = std::make_shared<CondVar>(IsParamNode);
|
||||
auto bias_var = std::make_shared<SeqVar>();
|
||||
VectorRef trans_conv_ref = VectorRef({conv_var, transpose_conv_ref, weight_var, bias_var});
|
||||
return trans_conv_ref;
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, VectorRef> ConvPadFusion::DefinePatterns() const {
|
||||
std::unordered_map<std::string, VectorRef> patterns;
|
||||
patterns["PadConvPatternName"] = DefinePadConvPattern();
|
||||
patterns["PadTransposeConvPatternName"] = DefinePadTransposeConvPattern();
|
||||
return patterns;
|
||||
}
|
||||
|
||||
AnfNodePtr ConvPadFusion::Process(const std::string &pattern_name, const FuncGraphPtr &func_graph,
|
||||
const AnfNodePtr &node, const EquivPtr &equiv) const {
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
MS_ASSERT(node != nullptr);
|
||||
if (CheckIfFuncGraphIsNull(func_graph) != lite::RET_OK || CheckIfAnfNodeIsNull(node) != lite::RET_OK) {
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto conv_cnode = node->cast<CNodePtr>();
|
||||
MS_ASSERT(conv_cnode != nullptr);
|
||||
if (conv_cnode->inputs().size() != kConvWithBiasLen && conv_cnode->inputs().size() != kConvNoBiasLen) {
|
||||
MS_LOG(WARNING) << "conv node inputs error ,name:" << conv_cnode->fullname_with_scope();
|
||||
return nullptr;
|
||||
}
|
||||
CNodePtr pad_cnode = nullptr;
|
||||
if (pattern_name == "PadTransposeConvPatternName") {
|
||||
CNodePtr transpose_cnode = conv_cnode->input(1)->cast<CNodePtr>();
|
||||
if (IsMultiOutputTensors(func_graph, transpose_cnode)) {
|
||||
MS_LOG(WARNING) << "transpose node is used as input by multiple cnodes, Fusion failed! ,name:"
|
||||
<< transpose_cnode->fullname_with_scope();
|
||||
return nullptr;
|
||||
}
|
||||
MS_ASSERT(transpose_cnode != nullptr);
|
||||
pad_cnode = transpose_cnode->input(1)->cast<CNodePtr>();
|
||||
} else {
|
||||
pad_cnode = conv_cnode->input(1)->cast<CNodePtr>();
|
||||
}
|
||||
|
||||
if (IsMultiOutputTensors(func_graph, pad_cnode)) {
|
||||
MS_LOG(WARNING) << "pad node is used as input by multiple cnodes, Fusion failed! ,name:"
|
||||
<< pad_cnode->fullname_with_scope();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
MS_ASSERT(pad_cnode != nullptr);
|
||||
if (CheckIfCNodeIsNull(pad_cnode) != lite::RET_OK || CheckInputSize(pad_cnode, kPadInputsLength) != lite::RET_OK) {
|
||||
MS_LOG(WARNING) << "pad node inputs error ,name:" << pad_cnode->fullname_with_scope();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (!IsPrimitiveProper(conv_cnode, pad_cnode)) {
|
||||
MS_LOG(WARNING) << conv_cnode->fullname_with_scope() << " is not match with previous "
|
||||
<< pad_cnode->fullname_with_scope() << " op. Fusion failed!";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ReplaceParamsAndNodes(func_graph, conv_cnode, pad_cnode, pattern_name);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,41 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_PASS_FUSION_CONV_PAD_FUSION_H_
|
||||
#define MINDSPORE_LITE_SRC_PASS_FUSION_CONV_PAD_FUSION_H_
|
||||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
#include "tools/optimizer/common/multiple_pattern_process_pass.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class ConvPadFusion : public MultiplePatternProcessPass {
|
||||
public:
|
||||
explicit ConvPadFusion(const std::string &name = "conv_pad_fusion", bool multigraph = true)
|
||||
: MultiplePatternProcessPass(name, multigraph) {}
|
||||
~ConvPadFusion() override = default;
|
||||
|
||||
std::unordered_map<std::string, VectorRef> DefinePatterns() const override;
|
||||
VectorRef DefinePadConvPattern() const;
|
||||
VectorRef DefinePadTransposeConvPattern() const;
|
||||
AnfNodePtr Process(const std::string &pattern_name, const FuncGraphPtr &func_graph, const AnfNodePtr &,
|
||||
const EquivPtr &) const override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_SRC_PASS_FUSION_CONV_PAD_FUSION_H_
|
Loading…
Reference in New Issue