!45180 adapt maxpool for ocr
Merge pull request !45180 from zhengyuanhua/br3
This commit is contained in:
commit
15671f253c
|
@ -16,6 +16,8 @@
|
|||
|
||||
#include "tools/converter/adapter/acl/mapper/maxpool_fusion_mapper.h"
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "tools/converter/adapter/acl/mapper/primitive_mapper_register.h"
|
||||
#include "tools/converter/adapter/acl/mapper/tbe_op_def.h"
|
||||
#include "include/registry/converter_context.h"
|
||||
|
@ -24,6 +26,11 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
namespace {
|
||||
constexpr auto kCommonAttrValueNum = 2;
|
||||
constexpr auto kMaxKernelSize = 20;
|
||||
constexpr auto kMaxKernelSize_H_Mul_W = 255;
|
||||
} // namespace
|
||||
STATUS MaxPoolFusionMapper::Mapper(const CNodePtr &cnode) {
|
||||
ValueNodePtr value_node = nullptr;
|
||||
PrimitivePtr src_prim = nullptr;
|
||||
|
@ -35,17 +42,21 @@ STATUS MaxPoolFusionMapper::Mapper(const CNodePtr &cnode) {
|
|||
auto attr_val = src_prim->GetAttr(ops::kFmkType);
|
||||
int fmk_type = attr_val != nullptr ? GetValue<int>(attr_val) : converter::kFmkTypeTf;
|
||||
PrimitivePtr dst_prim = nullptr;
|
||||
std::string max_pool_name;
|
||||
if (fmk_type == converter::kFmkTypeCaffe) {
|
||||
dst_prim = std::make_shared<acl::Pooling>();
|
||||
} else if (fmk_type == converter::kFmkTypeOnnx) {
|
||||
max_pool_name = acl::kNamePooling;
|
||||
} else if (fmk_type == converter::kFmkTypeOnnx && IsKernelSizeValid(src_prim)) {
|
||||
dst_prim = std::make_shared<acl::MaxPoolV3>();
|
||||
max_pool_name = acl::kNameMaxPoolV3;
|
||||
} else {
|
||||
ops::MaxPool max_pool_op;
|
||||
dst_prim = max_pool_op.GetPrim();
|
||||
max_pool_name = ops::kNameMaxPool;
|
||||
}
|
||||
MS_CHECK_TRUE_MSG(dst_prim != nullptr, lite::RET_ERROR, "Get primitive by fmk type failed.");
|
||||
dst_prim->SetAttrs(src_prim->attrs());
|
||||
if (AdjustPoolAttr(fmk_type, kNameMaxPoolFusion, dst_prim) != lite::RET_OK) {
|
||||
if (AdjustPoolAttr(fmk_type, max_pool_name, dst_prim) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Adjust pool attr failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
|
@ -53,6 +64,24 @@ STATUS MaxPoolFusionMapper::Mapper(const CNodePtr &cnode) {
|
|||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
bool MaxPoolFusionMapper::IsKernelSizeValid(const PrimitivePtr &prim) const {
|
||||
auto attr_val = prim->GetAttr(ops::kKernelSize);
|
||||
if (attr_val == nullptr) {
|
||||
return true;
|
||||
}
|
||||
auto kernel_value = opt::CastToInt(attr_val);
|
||||
if (kernel_value.size() != kCommonAttrValueNum) {
|
||||
MS_LOG(ERROR) << " Kernel size value num must be two, but size is " << kernel_value.size();
|
||||
return false;
|
||||
}
|
||||
int64_t kernel_h = kernel_value[0];
|
||||
int64_t kernel_w = kernel_value[1];
|
||||
if ((kernel_h <= kMaxKernelSize && kernel_w <= kMaxKernelSize) || (kernel_h * kernel_w <= kMaxKernelSize_H_Mul_W)) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
REGISTER_PRIMITIVE_MAPPER(kNameMaxPoolFusion, MaxPoolFusionMapper)
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_ADAPTER_ACL_MAPPER_MAXPOOL_FUSION_MAPPER_H_
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_ADAPTER_ACL_MAPPER_MAXPOOL_FUSION_MAPPER_H_
|
||||
|
||||
#include <vector>
|
||||
#include "tools/converter/adapter/acl/mapper/primitive_mapper.h"
|
||||
#include "ops/fusion/max_pool_fusion.h"
|
||||
|
||||
|
@ -30,6 +31,9 @@ class MaxPoolFusionMapper : public PrimitiveMapper {
|
|||
~MaxPoolFusionMapper() override = default;
|
||||
|
||||
STATUS Mapper(const CNodePtr &cnode) override;
|
||||
|
||||
private:
|
||||
bool IsKernelSizeValid(const PrimitivePtr &prim) const;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -19,11 +19,13 @@
|
|||
#include <vector>
|
||||
#include "tools/converter/adapter/acl/common/utils.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "tools/converter/adapter/acl/mapper/tbe_op_def.h"
|
||||
#include "ir/graph_utils.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "include/registry/converter_context.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "ops/fusion/avg_pool_fusion.h"
|
||||
#include "ops/fusion/max_pool_fusion.h"
|
||||
#include "plugin/device/cpu/kernel/nnacl/op_base.h"
|
||||
#include "src/common/log_util.h"
|
||||
|
||||
|
@ -110,25 +112,36 @@ void PrimitiveMapper::AdjustCaffePoolAttr(const std::string &src_prim_name, cons
|
|||
dst_prim->AddAttr(ops::kMode, MakeValue(mode));
|
||||
|
||||
auto run_mode_val = dst_prim->GetAttr(ops::kRoundMode);
|
||||
if (run_mode_val == nullptr) {
|
||||
MS_LOG(INFO) << "There is no attr run mode";
|
||||
return;
|
||||
}
|
||||
auto run_mode = GetValue<int64_t>(run_mode_val);
|
||||
int64_t run_mode_ge = run_mode == RoundMode::FLOOR ? 1 : 0;
|
||||
dst_prim->set_attr(ops::kRoundMode, MakeValue(run_mode_ge));
|
||||
}
|
||||
|
||||
void PrimitiveMapper::AdjustOnnxPoolAttr(const PrimitivePtr &dst_prim) const {
|
||||
void PrimitiveMapper::AdjustOnnxPoolAttr(const std::string &src_prim_name, const PrimitivePtr &dst_prim) const {
|
||||
auto pad_mode_val = dst_prim->GetAttr(ops::kPadMode);
|
||||
if (pad_mode_val == nullptr) {
|
||||
MS_LOG(INFO) << "There is no attr pad mode";
|
||||
return;
|
||||
}
|
||||
static std::map<int64_t, std::string> kPadModToStrMap = {
|
||||
{PadMode::PAD, "CALCULATED"},
|
||||
{PadMode::SAME, "SAME"},
|
||||
{PadMode::VALID, "VALID"},
|
||||
};
|
||||
auto pad_mode_val = dst_prim->GetAttr(ops::kPadMode);
|
||||
auto pad_mode = GetValue<int64_t>(pad_mode_val);
|
||||
std::string padding_mode = "CALCULATED";
|
||||
if (kPadModToStrMap.find(pad_mode) != kPadModToStrMap.end()) {
|
||||
padding_mode = kPadModToStrMap[pad_mode];
|
||||
}
|
||||
dst_prim->AddAttr(kNamePaddingMode, MakeValue(padding_mode));
|
||||
|
||||
if (src_prim_name == ops::kNameMaxPool && padding_mode == "CALCULATED") {
|
||||
padding_mode = "VALID";
|
||||
}
|
||||
std::string pad_mode_name = src_prim_name == acl::kNameMaxPoolV3 ? kNamePaddingMode : ops::kPadMode;
|
||||
dst_prim->AddAttr(pad_mode_name, MakeValue(padding_mode));
|
||||
auto run_mode_val = dst_prim->GetAttr(ops::kRoundMode);
|
||||
int64_t run_mode = GetValue<int64_t>(run_mode_val);
|
||||
bool ceil_mode = run_mode == RoundMode::CEIL;
|
||||
|
@ -141,7 +154,7 @@ STATUS PrimitiveMapper::AdjustPoolAttr(int fmk_type, const std::string &src_prim
|
|||
AdjustCaffePoolAttr(src_prim_name, dst_prim);
|
||||
return lite::RET_OK;
|
||||
} else if (fmk_type == converter::kFmkTypeOnnx) {
|
||||
AdjustOnnxPoolAttr(dst_prim);
|
||||
AdjustOnnxPoolAttr(src_prim_name, dst_prim);
|
||||
}
|
||||
// adjust common attr
|
||||
MS_CHECK_TRUE_MSG(dst_prim != nullptr, lite::RET_ERROR, "dst_prim is nullptr.");
|
||||
|
|
|
@ -52,7 +52,7 @@ class PrimitiveMapper {
|
|||
private:
|
||||
void AdjustCaffePoolAttr(const std::string &src_prim_name, const PrimitivePtr &dst_prim) const;
|
||||
|
||||
void AdjustOnnxPoolAttr(const PrimitivePtr &dst_prim) const;
|
||||
void AdjustOnnxPoolAttr(const std::string &src_prim_name, const PrimitivePtr &dst_prim) const;
|
||||
|
||||
std::string name_;
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue