!45180 adapt maxpool for ocr

Merge pull request !45180 from zhengyuanhua/br3
This commit is contained in:
i-robot 2022-11-07 03:17:46 +00:00 committed by Gitee
commit 15671f253c
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 54 additions and 8 deletions

View File

@ -16,6 +16,8 @@
#include "tools/converter/adapter/acl/mapper/maxpool_fusion_mapper.h"
#include <memory>
#include <string>
#include "tools/optimizer/common/gllo_utils.h"
#include "tools/converter/adapter/acl/mapper/primitive_mapper_register.h"
#include "tools/converter/adapter/acl/mapper/tbe_op_def.h"
#include "include/registry/converter_context.h"
@ -24,6 +26,11 @@
namespace mindspore {
namespace lite {
namespace {
constexpr auto kCommonAttrValueNum = 2;
constexpr auto kMaxKernelSize = 20;
constexpr auto kMaxKernelSize_H_Mul_W = 255;
} // namespace
STATUS MaxPoolFusionMapper::Mapper(const CNodePtr &cnode) {
ValueNodePtr value_node = nullptr;
PrimitivePtr src_prim = nullptr;
@ -35,17 +42,21 @@ STATUS MaxPoolFusionMapper::Mapper(const CNodePtr &cnode) {
auto attr_val = src_prim->GetAttr(ops::kFmkType);
int fmk_type = attr_val != nullptr ? GetValue<int>(attr_val) : converter::kFmkTypeTf;
PrimitivePtr dst_prim = nullptr;
std::string max_pool_name;
if (fmk_type == converter::kFmkTypeCaffe) {
dst_prim = std::make_shared<acl::Pooling>();
} else if (fmk_type == converter::kFmkTypeOnnx) {
max_pool_name = acl::kNamePooling;
} else if (fmk_type == converter::kFmkTypeOnnx && IsKernelSizeValid(src_prim)) {
dst_prim = std::make_shared<acl::MaxPoolV3>();
max_pool_name = acl::kNameMaxPoolV3;
} else {
ops::MaxPool max_pool_op;
dst_prim = max_pool_op.GetPrim();
max_pool_name = ops::kNameMaxPool;
}
MS_CHECK_TRUE_MSG(dst_prim != nullptr, lite::RET_ERROR, "Get primitive by fmk type failed.");
dst_prim->SetAttrs(src_prim->attrs());
if (AdjustPoolAttr(fmk_type, kNameMaxPoolFusion, dst_prim) != lite::RET_OK) {
if (AdjustPoolAttr(fmk_type, max_pool_name, dst_prim) != lite::RET_OK) {
MS_LOG(ERROR) << "Adjust pool attr failed.";
return lite::RET_ERROR;
}
@ -53,6 +64,24 @@ STATUS MaxPoolFusionMapper::Mapper(const CNodePtr &cnode) {
return lite::RET_OK;
}
bool MaxPoolFusionMapper::IsKernelSizeValid(const PrimitivePtr &prim) const {
auto attr_val = prim->GetAttr(ops::kKernelSize);
if (attr_val == nullptr) {
return true;
}
auto kernel_value = opt::CastToInt(attr_val);
if (kernel_value.size() != kCommonAttrValueNum) {
MS_LOG(ERROR) << " Kernel size value num must be two, but size is " << kernel_value.size();
return false;
}
int64_t kernel_h = kernel_value[0];
int64_t kernel_w = kernel_value[1];
if ((kernel_h <= kMaxKernelSize && kernel_w <= kMaxKernelSize) || (kernel_h * kernel_w <= kMaxKernelSize_H_Mul_W)) {
return true;
}
return false;
}
REGISTER_PRIMITIVE_MAPPER(kNameMaxPoolFusion, MaxPoolFusionMapper)
} // namespace lite
} // namespace mindspore

View File

@ -17,6 +17,7 @@
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_ADAPTER_ACL_MAPPER_MAXPOOL_FUSION_MAPPER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_ADAPTER_ACL_MAPPER_MAXPOOL_FUSION_MAPPER_H_
#include <vector>
#include "tools/converter/adapter/acl/mapper/primitive_mapper.h"
#include "ops/fusion/max_pool_fusion.h"
@ -30,6 +31,9 @@ class MaxPoolFusionMapper : public PrimitiveMapper {
~MaxPoolFusionMapper() override = default;
STATUS Mapper(const CNodePtr &cnode) override;
private:
bool IsKernelSizeValid(const PrimitivePtr &prim) const;
};
} // namespace lite
} // namespace mindspore

View File

@ -19,11 +19,13 @@
#include <vector>
#include "tools/converter/adapter/acl/common/utils.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "tools/converter/adapter/acl/mapper/tbe_op_def.h"
#include "ir/graph_utils.h"
#include "include/errorcode.h"
#include "include/registry/converter_context.h"
#include "ops/op_utils.h"
#include "ops/fusion/avg_pool_fusion.h"
#include "ops/fusion/max_pool_fusion.h"
#include "plugin/device/cpu/kernel/nnacl/op_base.h"
#include "src/common/log_util.h"
@ -110,25 +112,36 @@ void PrimitiveMapper::AdjustCaffePoolAttr(const std::string &src_prim_name, cons
dst_prim->AddAttr(ops::kMode, MakeValue(mode));
auto run_mode_val = dst_prim->GetAttr(ops::kRoundMode);
if (run_mode_val == nullptr) {
MS_LOG(INFO) << "There is no attr run mode";
return;
}
auto run_mode = GetValue<int64_t>(run_mode_val);
int64_t run_mode_ge = run_mode == RoundMode::FLOOR ? 1 : 0;
dst_prim->set_attr(ops::kRoundMode, MakeValue(run_mode_ge));
}
void PrimitiveMapper::AdjustOnnxPoolAttr(const PrimitivePtr &dst_prim) const {
void PrimitiveMapper::AdjustOnnxPoolAttr(const std::string &src_prim_name, const PrimitivePtr &dst_prim) const {
auto pad_mode_val = dst_prim->GetAttr(ops::kPadMode);
if (pad_mode_val == nullptr) {
MS_LOG(INFO) << "There is no attr pad mode";
return;
}
static std::map<int64_t, std::string> kPadModToStrMap = {
{PadMode::PAD, "CALCULATED"},
{PadMode::SAME, "SAME"},
{PadMode::VALID, "VALID"},
};
auto pad_mode_val = dst_prim->GetAttr(ops::kPadMode);
auto pad_mode = GetValue<int64_t>(pad_mode_val);
std::string padding_mode = "CALCULATED";
if (kPadModToStrMap.find(pad_mode) != kPadModToStrMap.end()) {
padding_mode = kPadModToStrMap[pad_mode];
}
dst_prim->AddAttr(kNamePaddingMode, MakeValue(padding_mode));
if (src_prim_name == ops::kNameMaxPool && padding_mode == "CALCULATED") {
padding_mode = "VALID";
}
std::string pad_mode_name = src_prim_name == acl::kNameMaxPoolV3 ? kNamePaddingMode : ops::kPadMode;
dst_prim->AddAttr(pad_mode_name, MakeValue(padding_mode));
auto run_mode_val = dst_prim->GetAttr(ops::kRoundMode);
int64_t run_mode = GetValue<int64_t>(run_mode_val);
bool ceil_mode = run_mode == RoundMode::CEIL;
@ -141,7 +154,7 @@ STATUS PrimitiveMapper::AdjustPoolAttr(int fmk_type, const std::string &src_prim
AdjustCaffePoolAttr(src_prim_name, dst_prim);
return lite::RET_OK;
} else if (fmk_type == converter::kFmkTypeOnnx) {
AdjustOnnxPoolAttr(dst_prim);
AdjustOnnxPoolAttr(src_prim_name, dst_prim);
}
// adjust common attr
MS_CHECK_TRUE_MSG(dst_prim != nullptr, lite::RET_ERROR, "dst_prim is nullptr.");

View File

@ -52,7 +52,7 @@ class PrimitiveMapper {
private:
void AdjustCaffePoolAttr(const std::string &src_prim_name, const PrimitivePtr &dst_prim) const;
void AdjustOnnxPoolAttr(const PrimitivePtr &dst_prim) const;
void AdjustOnnxPoolAttr(const std::string &src_prim_name, const PrimitivePtr &dst_prim) const;
std::string name_;
};