diff --git a/mindspore/lite/tools/converter/adapter/acl/mapper/maxpool_fusion_mapper.cc b/mindspore/lite/tools/converter/adapter/acl/mapper/maxpool_fusion_mapper.cc index 4318d3813ce..5333a090b5c 100644 --- a/mindspore/lite/tools/converter/adapter/acl/mapper/maxpool_fusion_mapper.cc +++ b/mindspore/lite/tools/converter/adapter/acl/mapper/maxpool_fusion_mapper.cc @@ -16,6 +16,8 @@ #include "tools/converter/adapter/acl/mapper/maxpool_fusion_mapper.h" #include +#include +#include "tools/optimizer/common/gllo_utils.h" #include "tools/converter/adapter/acl/mapper/primitive_mapper_register.h" #include "tools/converter/adapter/acl/mapper/tbe_op_def.h" #include "include/registry/converter_context.h" @@ -24,6 +26,11 @@ namespace mindspore { namespace lite { +namespace { +constexpr auto kCommonAttrValueNum = 2; +constexpr auto kMaxKernelSize = 20; +constexpr auto kMaxKernelSize_H_Mul_W = 255; +} // namespace STATUS MaxPoolFusionMapper::Mapper(const CNodePtr &cnode) { ValueNodePtr value_node = nullptr; PrimitivePtr src_prim = nullptr; @@ -35,17 +42,21 @@ STATUS MaxPoolFusionMapper::Mapper(const CNodePtr &cnode) { auto attr_val = src_prim->GetAttr(ops::kFmkType); int fmk_type = attr_val != nullptr ? GetValue(attr_val) : converter::kFmkTypeTf; PrimitivePtr dst_prim = nullptr; + std::string max_pool_name; if (fmk_type == converter::kFmkTypeCaffe) { dst_prim = std::make_shared(); - } else if (fmk_type == converter::kFmkTypeOnnx) { + max_pool_name = acl::kNamePooling; + } else if (fmk_type == converter::kFmkTypeOnnx && IsKernelSizeValid(src_prim)) { dst_prim = std::make_shared(); + max_pool_name = acl::kNameMaxPoolV3; } else { ops::MaxPool max_pool_op; dst_prim = max_pool_op.GetPrim(); + max_pool_name = ops::kNameMaxPool; } MS_CHECK_TRUE_MSG(dst_prim != nullptr, lite::RET_ERROR, "Get primitive by fmk type failed."); dst_prim->SetAttrs(src_prim->attrs()); - if (AdjustPoolAttr(fmk_type, kNameMaxPoolFusion, dst_prim) != lite::RET_OK) { + if (AdjustPoolAttr(fmk_type, max_pool_name, dst_prim) != lite::RET_OK) { MS_LOG(ERROR) << "Adjust pool attr failed."; return lite::RET_ERROR; } @@ -53,6 +64,24 @@ STATUS MaxPoolFusionMapper::Mapper(const CNodePtr &cnode) { return lite::RET_OK; } +bool MaxPoolFusionMapper::IsKernelSizeValid(const PrimitivePtr &prim) const { + auto attr_val = prim->GetAttr(ops::kKernelSize); + if (attr_val == nullptr) { + return true; + } + auto kernel_value = opt::CastToInt(attr_val); + if (kernel_value.size() != kCommonAttrValueNum) { + MS_LOG(ERROR) << " Kernel size value num must be two, but size is " << kernel_value.size(); + return false; + } + int64_t kernel_h = kernel_value[0]; + int64_t kernel_w = kernel_value[1]; + if ((kernel_h <= kMaxKernelSize && kernel_w <= kMaxKernelSize) || (kernel_h * kernel_w <= kMaxKernelSize_H_Mul_W)) { + return true; + } + return false; +} + REGISTER_PRIMITIVE_MAPPER(kNameMaxPoolFusion, MaxPoolFusionMapper) } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/adapter/acl/mapper/maxpool_fusion_mapper.h b/mindspore/lite/tools/converter/adapter/acl/mapper/maxpool_fusion_mapper.h index 18daebb7c4e..329968a6068 100644 --- a/mindspore/lite/tools/converter/adapter/acl/mapper/maxpool_fusion_mapper.h +++ b/mindspore/lite/tools/converter/adapter/acl/mapper/maxpool_fusion_mapper.h @@ -17,6 +17,7 @@ #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_ADAPTER_ACL_MAPPER_MAXPOOL_FUSION_MAPPER_H_ #define MINDSPORE_LITE_TOOLS_CONVERTER_ADAPTER_ACL_MAPPER_MAXPOOL_FUSION_MAPPER_H_ +#include #include "tools/converter/adapter/acl/mapper/primitive_mapper.h" #include "ops/fusion/max_pool_fusion.h" @@ -30,6 +31,9 @@ class MaxPoolFusionMapper : public PrimitiveMapper { ~MaxPoolFusionMapper() override = default; STATUS Mapper(const CNodePtr &cnode) override; + + private: + bool IsKernelSizeValid(const PrimitivePtr &prim) const; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/adapter/acl/mapper/primitive_mapper.cc b/mindspore/lite/tools/converter/adapter/acl/mapper/primitive_mapper.cc index 4d3890d52d8..1ff54a82c3e 100644 --- a/mindspore/lite/tools/converter/adapter/acl/mapper/primitive_mapper.cc +++ b/mindspore/lite/tools/converter/adapter/acl/mapper/primitive_mapper.cc @@ -19,11 +19,13 @@ #include #include "tools/converter/adapter/acl/common/utils.h" #include "tools/optimizer/common/gllo_utils.h" +#include "tools/converter/adapter/acl/mapper/tbe_op_def.h" #include "ir/graph_utils.h" #include "include/errorcode.h" #include "include/registry/converter_context.h" #include "ops/op_utils.h" #include "ops/fusion/avg_pool_fusion.h" +#include "ops/fusion/max_pool_fusion.h" #include "plugin/device/cpu/kernel/nnacl/op_base.h" #include "src/common/log_util.h" @@ -110,25 +112,36 @@ void PrimitiveMapper::AdjustCaffePoolAttr(const std::string &src_prim_name, cons dst_prim->AddAttr(ops::kMode, MakeValue(mode)); auto run_mode_val = dst_prim->GetAttr(ops::kRoundMode); + if (run_mode_val == nullptr) { + MS_LOG(INFO) << "There is no attr run mode"; + return; + } auto run_mode = GetValue(run_mode_val); int64_t run_mode_ge = run_mode == RoundMode::FLOOR ? 1 : 0; dst_prim->set_attr(ops::kRoundMode, MakeValue(run_mode_ge)); } -void PrimitiveMapper::AdjustOnnxPoolAttr(const PrimitivePtr &dst_prim) const { +void PrimitiveMapper::AdjustOnnxPoolAttr(const std::string &src_prim_name, const PrimitivePtr &dst_prim) const { + auto pad_mode_val = dst_prim->GetAttr(ops::kPadMode); + if (pad_mode_val == nullptr) { + MS_LOG(INFO) << "There is no attr pad mode"; + return; + } static std::map kPadModToStrMap = { {PadMode::PAD, "CALCULATED"}, {PadMode::SAME, "SAME"}, {PadMode::VALID, "VALID"}, }; - auto pad_mode_val = dst_prim->GetAttr(ops::kPadMode); auto pad_mode = GetValue(pad_mode_val); std::string padding_mode = "CALCULATED"; if (kPadModToStrMap.find(pad_mode) != kPadModToStrMap.end()) { padding_mode = kPadModToStrMap[pad_mode]; } - dst_prim->AddAttr(kNamePaddingMode, MakeValue(padding_mode)); - + if (src_prim_name == ops::kNameMaxPool && padding_mode == "CALCULATED") { + padding_mode = "VALID"; + } + std::string pad_mode_name = src_prim_name == acl::kNameMaxPoolV3 ? kNamePaddingMode : ops::kPadMode; + dst_prim->AddAttr(pad_mode_name, MakeValue(padding_mode)); auto run_mode_val = dst_prim->GetAttr(ops::kRoundMode); int64_t run_mode = GetValue(run_mode_val); bool ceil_mode = run_mode == RoundMode::CEIL; @@ -141,7 +154,7 @@ STATUS PrimitiveMapper::AdjustPoolAttr(int fmk_type, const std::string &src_prim AdjustCaffePoolAttr(src_prim_name, dst_prim); return lite::RET_OK; } else if (fmk_type == converter::kFmkTypeOnnx) { - AdjustOnnxPoolAttr(dst_prim); + AdjustOnnxPoolAttr(src_prim_name, dst_prim); } // adjust common attr MS_CHECK_TRUE_MSG(dst_prim != nullptr, lite::RET_ERROR, "dst_prim is nullptr."); diff --git a/mindspore/lite/tools/converter/adapter/acl/mapper/primitive_mapper.h b/mindspore/lite/tools/converter/adapter/acl/mapper/primitive_mapper.h index 44e0f0cbd06..57a069e3ccc 100644 --- a/mindspore/lite/tools/converter/adapter/acl/mapper/primitive_mapper.h +++ b/mindspore/lite/tools/converter/adapter/acl/mapper/primitive_mapper.h @@ -52,7 +52,7 @@ class PrimitiveMapper { private: void AdjustCaffePoolAttr(const std::string &src_prim_name, const PrimitivePtr &dst_prim) const; - void AdjustOnnxPoolAttr(const PrimitivePtr &dst_prim) const; + void AdjustOnnxPoolAttr(const std::string &src_prim_name, const PrimitivePtr &dst_prim) const; std::string name_; };