!25692 310 support ocr model

Merge pull request !25692 from zhengyuanhua/model_br1
This commit is contained in:
i-robot 2021-11-01 02:36:41 +00:00 committed by Gitee
commit 027adf9fff
20 changed files with 346 additions and 379 deletions

View File

@ -910,6 +910,7 @@ DfGraphConvertor &DfGraphConvertor::BuildGraph() {
} else if (vars_[name] != nullptr) { } else if (vars_[name] != nullptr) {
MS_LOG(INFO) << "add var input " << it->ToString(); MS_LOG(INFO) << "add var input " << it->ToString();
auto op = Convert(it); auto op = Convert(it);
UpdateConstOpDesc(it, vars_[name]);
MS_EXCEPTION_IF_NULL(op); MS_EXCEPTION_IF_NULL(op);
inputs.push_back(*op); inputs.push_back(*op);
} }
@ -942,6 +943,40 @@ DfGraphConvertor &DfGraphConvertor::BuildGraph() {
return *this; return *this;
} }
void DfGraphConvertor::UpdateConstOpDesc(const AnfNodePtr &it, const OperatorPtr &op) const {
if (!it->isa<Parameter>()) {
MS_LOG(DEBUG) << "It is not parameter, name: " << it->DebugString();
return;
}
auto para = it->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(para);
std::string format = kOpFormat_NCHW;
std::string param_debug_info = para->DebugString();
auto param_format = param_format_.find(param_debug_info);
if (param_format != param_format_.end()) {
format = param_format->second;
MS_LOG(DEBUG) << "Parameter debug info: " << param_debug_info << ", format is " << format;
}
if (format == kOpFormat_NCHW) {
MS_LOG(DEBUG) << "Format is not changed, no need to update op desc, name: " << param_debug_info;
return;
}
if (!para->has_default()) {
MS_LOG(DEBUG) << "Parameter has no default, no need to update op desc, name: " << param_debug_info;
return;
}
auto value = para->default_param();
MS_EXCEPTION_IF_NULL(value);
auto tensor = value->cast<std::shared_ptr<tensor::Tensor>>();
MS_EXCEPTION_IF_NULL(tensor);
auto const_op_desc = TransformUtil::GetGeTensorDesc(tensor->shape_c(), tensor->data_type(), format);
if (const_op_desc == nullptr) {
MS_LOG(WARNING) << "Create parameter " << para->name() << " output descriptor failed!";
return;
}
(void)std::static_pointer_cast<Constant>(op)->update_output_desc_y(*const_op_desc);
}
void DfGraphConvertor::UpdateDataOpDesc(const AnfNodePtr &it, const OperatorPtr &op) const { void DfGraphConvertor::UpdateDataOpDesc(const AnfNodePtr &it, const OperatorPtr &op) const {
auto node = std::static_pointer_cast<AnfNode>(it); auto node = std::static_pointer_cast<AnfNode>(it);
if (node == nullptr) { if (node == nullptr) {
@ -1543,6 +1578,27 @@ void DfGraphConvertor::ConvertReshape(const CNodePtr node) {
op_cache_[node.get()] = op; op_cache_[node.get()] = op;
} }
void DfGraphConvertor::ConvertConv2D(const CNodePtr node) {
MS_EXCEPTION_IF_NULL(node);
OpAdapterPtr adpt = FindAdapter(node, training_);
if (adpt == nullptr) {
return;
}
auto op = adpt->generate(node);
MS_EXCEPTION_IF_NULL(op);
auto value_node = node->input(0)->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
MS_EXCEPTION_IF_NULL(value_node->value());
auto primitive = value_node->value()->cast<PrimitivePtr>();
MS_EXCEPTION_IF_NULL(primitive);
auto value = primitive->GetAttr("padding");
if (value != nullptr) {
std::string pad_mode = GetValue<std::string>(value);
(void)op->SetAttr("padding", pad_mode);
}
op_cache_[node.get()] = op;
}
AnfNodePtr DfGraphConvertor::TraceTupleGetItem(const CNodePtr &node, uint64_t *index) { AnfNodePtr DfGraphConvertor::TraceTupleGetItem(const CNodePtr &node, uint64_t *index) {
const int TUPLE_GET_ITEM_INDEX = 2; const int TUPLE_GET_ITEM_INDEX = 2;
if (node->inputs().size() < 3) { // "tuple_getitem" primitive must have 3 inputs if (node->inputs().size() < 3) { // "tuple_getitem" primitive must have 3 inputs
@ -1735,6 +1791,12 @@ bool DfGraphConvertor::CheckCNode(const std::string &name, const CNodePtr node)
return true; return true;
} }
// Add attr pad mode to Conv2D
if (name == prim::kPrimConv2D->name() || name == prim::kPrimDepthwiseConv2dNative->name()) {
ConvertConv2D(node);
return true;
}
// make_tuple is used for a dynamic_input, convert it to a vector of OutHandlers // make_tuple is used for a dynamic_input, convert it to a vector of OutHandlers
if (name == prim::kPrimMakeTuple->name()) { if (name == prim::kPrimMakeTuple->name()) {
ConvertMakeTuple(node); ConvertMakeTuple(node);
@ -1813,7 +1875,7 @@ void DfGraphConvertor::SaveParamFormat(const CNodePtr node) {
CheckAndConvertUtils::ConvertAttrValueToString(prim->name(), "format", &attr.second); CheckAndConvertUtils::ConvertAttrValueToString(prim->name(), "format", &attr.second);
} }
std::string format = attr.second->ToString(); std::string format = attr.second->ToString();
if (format != "NCDHW") { if (format != "NCDHW" && format != "NHWC") {
break; break;
} }
for (size_t i = 1; i < node->size(); i++) { for (size_t i = 1; i < node->size(); i++) {

View File

@ -164,6 +164,7 @@ class DfGraphConvertor {
void ConvertMakeTuple(const CNodePtr node); void ConvertMakeTuple(const CNodePtr node);
void ConvertTopK(const CNodePtr node); void ConvertTopK(const CNodePtr node);
void ConvertReshape(const CNodePtr node); void ConvertReshape(const CNodePtr node);
void ConvertConv2D(const CNodePtr node);
std::vector<int64_t> CastToInt(const ValuePtr &value); std::vector<int64_t> CastToInt(const ValuePtr &value);
bool CheckCNode(const std::string &name, const CNodePtr node); bool CheckCNode(const std::string &name, const CNodePtr node);
void TraceOutput(AnfNodePtr node); void TraceOutput(AnfNodePtr node);
@ -177,6 +178,7 @@ class DfGraphConvertor {
void BuildSaveCheckpointGraph(); void BuildSaveCheckpointGraph();
void DrawCNode(const CNodePtr node, const OpAdapterPtr adpt); void DrawCNode(const CNodePtr node, const OpAdapterPtr adpt);
void UpdateDataOpDesc(const AnfNodePtr &it, const OperatorPtr &op) const; void UpdateDataOpDesc(const AnfNodePtr &it, const OperatorPtr &op) const;
void UpdateConstOpDesc(const AnfNodePtr &it, const OperatorPtr &op) const;
void AddGraphConstInput(const OperatorPtr &op); void AddGraphConstInput(const OperatorPtr &op);
OperatorPtr ToOperatorPtr(const AnfNodePtr &node); OperatorPtr ToOperatorPtr(const AnfNodePtr &node);
bool IsSourceEdgeNode(const AnfNodePtr &node); bool IsSourceEdgeNode(const AnfNodePtr &node);

View File

@ -36,6 +36,7 @@
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
static const std::set<std::string> kDevice = {"Ascend310", "Ascend710"}; static const std::set<std::string> kDevice = {"Ascend310", "Ascend710"};
static const std::map<int64_t, std::string> kEnumFormatToStrMap = {{Format::NCHW, "NCHW"}, {Format::NHWC, "NHWC"}};
namespace { namespace {
constexpr auto kMakeTuple = "MakeTuple"; constexpr auto kMakeTuple = "MakeTuple";
constexpr auto kOutputNames = "outputs_names"; constexpr auto kOutputNames = "outputs_names";
@ -51,6 +52,50 @@ constexpr auto kDelRedundantTranspose = "DeleteRedundantTranspose";
constexpr size_t kDependInputNum = 3; constexpr size_t kDependInputNum = 3;
constexpr size_t kDependFirstInputIdx = 1; constexpr size_t kDependFirstInputIdx = 1;
constexpr size_t kTupleGetItemFirstInputIdx = 1; constexpr size_t kTupleGetItemFirstInputIdx = 1;
STATUS PreProcForMindIr(const FuncGraphPtr &func_graph) { return lite::RET_OK; }
STATUS PreProcForTF(const FuncGraphPtr &func_graph) {
if (!lite::RunOptimizerPass(func_graph, {kInferShapePass})) {
MS_LOG(ERROR) << "Infer shape pass failed.";
return lite::RET_ERROR;
}
auto node_list = TopoSort(func_graph->get_return());
for (auto &node : node_list) {
CHECK_NULL_RETURN(node);
if (!utils::isa<CNodePtr>(node)) {
continue;
}
auto cnode = node->cast<CNodePtr>();
CHECK_NULL_RETURN(cnode);
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
CHECK_NULL_RETURN(prim);
if (prim->GetAttr(ops::kFormat) != nullptr) {
auto node_format = GetValue<int64_t>(prim->GetAttr(ops::kFormat));
if (kEnumFormatToStrMap.find(node_format) != kEnumFormatToStrMap.end()) {
std::string format = kEnumFormatToStrMap.at(node_format);
prim->AddAttr("io_format", MakeValue(format));
}
}
}
return lite::RET_OK;
}
STATUS PreProcForCaffe(const FuncGraphPtr &func_graph) {
if (!lite::RunOptimizerPass(func_graph, {kInferShapePass, kToNCHWFormatPass, kDelRedundantTranspose})) {
MS_LOG(ERROR) << "To nchw format failed.";
return lite::RET_ERROR;
}
return lite::RET_OK;
}
STATUS PreProcForOnnx(const FuncGraphPtr &func_graph) {
if (!lite::RunOptimizerPass(func_graph, {kInferShapePass, kToNCHWFormatPass, kDelRedundantTranspose})) {
MS_LOG(ERROR) << "To nchw format failed.";
return lite::RET_ERROR;
}
return lite::RET_OK;
}
} // namespace } // namespace
AclPassImpl::AclPassImpl(const converter::Flags &config) AclPassImpl::AclPassImpl(const converter::Flags &config)
@ -174,15 +219,21 @@ STATUS AclPassImpl::PreProcGraph(const FuncGraphPtr &func_graph) {
MS_LOG(ERROR) << "Common pass failed."; MS_LOG(ERROR) << "Common pass failed.";
return lite::RET_ERROR; return lite::RET_ERROR;
} }
if (fmk_type_ == converter::kFmkTypeMs) { std::map<converter::FmkType, std::function<STATUS(const FuncGraphPtr &)>> fmk_proc_func = {
MS_LOG(DEBUG) << "MindIr no need to change format."; {converter::kFmkTypeMs, PreProcForMindIr},
return lite::RET_OK; {converter::kFmkTypeTf, PreProcForTF},
} {converter::kFmkTypeCaffe, PreProcForCaffe},
// The format of nodes (cnode, parameter, val) must be nchw due to interface of convert om {converter::kFmkTypeOnnx, PreProcForOnnx},
if (!lite::RunOptimizerPass(func_graph, {kInferShapePass, kToNCHWFormatPass, kDelRedundantTranspose})) { };
MS_LOG(ERROR) << "To nchw format failed."; if (fmk_proc_func.find(fmk_type_) != fmk_proc_func.end()) {
auto func = fmk_proc_func.at(fmk_type_);
if (func(func_graph) != lite::RET_OK) {
MS_LOG(ERROR) << "Pre proc failed, fmk " << fmk_type_;
return lite::RET_ERROR; return lite::RET_ERROR;
} }
} else {
MS_LOG(WARNING) << "Not support fmk type " << fmk_type_;
}
MS_LOG(DEBUG) << "Pre proc graph success."; MS_LOG(DEBUG) << "Pre proc graph success.";
return lite::RET_OK; return lite::RET_OK;
} }

View File

@ -1,34 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/adapter/acl/mapper/add_fusion_mapper.h"
#include <memory>
#include "tools/converter/adapter/acl/mapper/primitive_mapper_register.h"
namespace mindspore {
namespace lite {
STATUS AddFusionMapper::Mapper(const CNodePtr &cnode) {
auto dst_prim = std::make_shared<ops::Add>();
if (MoveAttrMap(cnode, dst_prim) != RET_OK) {
MS_LOG(ERROR) << "AddFusion mapper failed.";
return RET_ERROR;
}
return RET_OK;
}
REGISTER_PRIMITIVE_MAPPER(kNameAddFusion, AddFusionMapper)
} // namespace lite
} // namespace mindspore

View File

@ -1,37 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef ACL_MAPPER_PRIMITIVE_ADDFUSION_MAPPER_H
#define ACL_MAPPER_PRIMITIVE_ADDFUSION_MAPPER_H
#include "tools/converter/adapter/acl/mapper/primitive_mapper.h"
#include "ops/fusion/add_fusion.h"
using mindspore::ops::kNameAddFusion;
namespace mindspore {
namespace lite {
class AddFusionMapper : public PrimitiveMapper {
public:
AddFusionMapper() : PrimitiveMapper(kNameAddFusion) {}
~AddFusionMapper() override = default;
STATUS Mapper(const CNodePtr &cnode) override;
};
} // namespace lite
} // namespace mindspore
#endif // ACL_MAPPER_PRIMITIVE_ADDFUSION_MAPPER_H

View File

@ -0,0 +1,74 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/adapter/acl/mapper/arithmetic_mapper.h"
#include <memory>
#include "tools/converter/adapter/acl/mapper/primitive_mapper_register.h"
namespace mindspore {
namespace lite {
STATUS AddFusionMapper::Mapper(const CNodePtr &cnode) {
auto dst_prim = std::make_shared<ops::Add>();
if (MoveAttrMap(cnode, dst_prim) != RET_OK) {
MS_LOG(ERROR) << "AddFusion mapper failed.";
return RET_ERROR;
}
return RET_OK;
}
STATUS DivFusionMapper::Mapper(const CNodePtr &cnode) {
auto dst_prim = std::make_shared<ops::Div>();
if (MoveAttrMap(cnode, dst_prim) != RET_OK) {
MS_LOG(ERROR) << "DivFusion mapper failed.";
return RET_ERROR;
}
return RET_OK;
}
STATUS MulFusionMapper::Mapper(const CNodePtr &cnode) {
auto dst_prim = std::make_shared<ops::Mul>();
if (MoveAttrMap(cnode, dst_prim) != RET_OK) {
MS_LOG(ERROR) << "MulFusion mapper failed.";
return RET_ERROR;
}
return RET_OK;
}
STATUS PowFusionMapper::Mapper(const CNodePtr &cnode) {
auto dst_prim = std::make_shared<ops::Pow>();
if (MoveAttrMap(cnode, dst_prim) != RET_OK) {
MS_LOG(ERROR) << "PowFusion mapper failed.";
return RET_ERROR;
}
return RET_OK;
}
STATUS SubFusionMapper::Mapper(const CNodePtr &cnode) {
auto dst_prim = std::make_shared<ops::Sub>();
if (MoveAttrMap(cnode, dst_prim) != RET_OK) {
MS_LOG(ERROR) << "SubFusion mapper failed.";
return RET_ERROR;
}
return RET_OK;
}
REGISTER_PRIMITIVE_MAPPER(kNameAddFusion, AddFusionMapper)
REGISTER_PRIMITIVE_MAPPER(kNameDivFusion, DivFusionMapper)
REGISTER_PRIMITIVE_MAPPER(kNameMulFusion, MulFusionMapper)
REGISTER_PRIMITIVE_MAPPER(kNamePowFusion, PowFusionMapper)
REGISTER_PRIMITIVE_MAPPER(kNameSubFusion, SubFusionMapper)
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,82 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef ACL_MAPPER_PRIMITIVE_ARITHMETIC_MAPPER_H
#define ACL_MAPPER_PRIMITIVE_ARITHMETIC_MAPPER_H
#include "tools/converter/adapter/acl/mapper/primitive_mapper.h"
#include "ops/fusion/add_fusion.h"
#include "ops/fusion/div_fusion.h"
#include "ops/fusion/mul_fusion.h"
#include "ops/fusion/pow_fusion.h"
#include "ops/fusion/sub_fusion.h"
namespace mindspore {
namespace lite {
using mindspore::ops::kNameAddFusion;
using mindspore::ops::kNameDivFusion;
using mindspore::ops::kNameMulFusion;
using mindspore::ops::kNamePowFusion;
using mindspore::ops::kNameSubFusion;
class AddFusionMapper : public PrimitiveMapper {
public:
AddFusionMapper() : PrimitiveMapper(kNameAddFusion) {}
~AddFusionMapper() override = default;
STATUS Mapper(const CNodePtr &cnode) override;
};
class DivFusionMapper : public PrimitiveMapper {
public:
DivFusionMapper() : PrimitiveMapper(kNameDivFusion) {}
~DivFusionMapper() override = default;
STATUS Mapper(const CNodePtr &cnode) override;
};
class MulFusionMapper : public PrimitiveMapper {
public:
MulFusionMapper() : PrimitiveMapper(kNameMulFusion) {}
~MulFusionMapper() override = default;
STATUS Mapper(const CNodePtr &cnode) override;
};
class PowFusionMapper : public PrimitiveMapper {
public:
PowFusionMapper() : PrimitiveMapper(kNamePowFusion) {}
~PowFusionMapper() override = default;
STATUS Mapper(const CNodePtr &cnode) override;
};
class SubFusionMapper : public PrimitiveMapper {
public:
SubFusionMapper() : PrimitiveMapper(kNameSubFusion) {}
~SubFusionMapper() override = default;
STATUS Mapper(const CNodePtr &cnode) override;
};
} // namespace lite
} // namespace mindspore
#endif // ACL_MAPPER_PRIMITIVE_ARITHMETIC_MAPPER_H

View File

@ -15,11 +15,25 @@
*/ */
#include "tools/converter/adapter/acl/mapper/conv2d_fusion_mapper.h" #include "tools/converter/adapter/acl/mapper/conv2d_fusion_mapper.h"
#include <vector>
#include <map>
#include <string>
#include "memory" #include "memory"
#include "tools/converter/adapter/acl/mapper/primitive_mapper_register.h" #include "tools/converter/adapter/acl/mapper/primitive_mapper_register.h"
#include "tools/converter/adapter/acl/mapper/tbe_op_def.h"
#include "src/common/log_util.h"
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
static const std::map<int64_t, std::string> kPadModToStrMap = {
{PadMode::PAD, "CALCULATED"},
{PadMode::SAME, "SAME"},
{PadMode::VALID, "VALID"},
};
namespace {
constexpr auto kNamePaddingMode = "padding";
} // namespace
STATUS Conv2DFusionMapper::Mapper(const CNodePtr &cnode) { STATUS Conv2DFusionMapper::Mapper(const CNodePtr &cnode) {
ValueNodePtr value_node = nullptr; ValueNodePtr value_node = nullptr;
PrimitivePtr src_prim = nullptr; PrimitivePtr src_prim = nullptr;
@ -27,10 +41,19 @@ STATUS Conv2DFusionMapper::Mapper(const CNodePtr &cnode) {
MS_LOG(ERROR) << "Get primitive from cnode failed."; MS_LOG(ERROR) << "Get primitive from cnode failed.";
return lite::RET_ERROR; return lite::RET_ERROR;
} }
auto dst_prim = std::make_shared<ops::Conv2D>(); bool is_depth_wise = false;
MS_ASSERT(dst_prim != nullptr); auto depth_wise_ptr = src_prim->GetAttr(ops::kIsDepthWise);
if (depth_wise_ptr != nullptr) {
is_depth_wise = GetValue<bool>(depth_wise_ptr);
}
PrimitivePtr dst_prim = nullptr;
if (!is_depth_wise) {
dst_prim = std::make_shared<ops::Conv2D>();
} else {
dst_prim = std::make_shared<acl::DepthwiseConv2dNative>();
}
CHECK_NULL_RETURN(dst_prim);
dst_prim->SetAttrs(src_prim->attrs()); dst_prim->SetAttrs(src_prim->attrs());
auto status = AttrAdjust(dst_prim, ops::kStride); auto status = AttrAdjust(dst_prim, ops::kStride);
if (status != lite::RET_OK) { if (status != lite::RET_OK) {
MS_LOG(ERROR) << "adjust stride failed."; MS_LOG(ERROR) << "adjust stride failed.";
@ -41,10 +64,34 @@ STATUS Conv2DFusionMapper::Mapper(const CNodePtr &cnode) {
MS_LOG(ERROR) << "adjust dilation failed."; MS_LOG(ERROR) << "adjust dilation failed.";
return status; return status;
} }
status = AdjustAttrPad(dst_prim);
if (status != lite::RET_OK) {
MS_LOG(ERROR) << "adjust pad failed.";
return status;
}
value_node->set_value(dst_prim); value_node->set_value(dst_prim);
return lite::RET_OK; return lite::RET_OK;
} }
STATUS Conv2DFusionMapper::AdjustAttrPad(const PrimitivePtr &prim) {
// attr pad val
auto pad_ptr = prim->GetAttr(ops::kPadList);
if (pad_ptr == nullptr) {
std::vector<int64_t> pad_list = {0, 0, 0, 0};
prim->AddAttr(ops::kPadList, MakeValue(pad_list));
}
// attr pad mode
auto pad_mode_val = prim->GetAttr(ops::kPadMode);
if (pad_mode_val != nullptr) {
auto pad_mode = GetValue<int64_t>(pad_mode_val);
if (kPadModToStrMap.find(pad_mode) != kPadModToStrMap.end()) {
std::string padding_mode = kPadModToStrMap.at(pad_mode);
prim->AddAttr(kNamePaddingMode, MakeValue(padding_mode));
}
}
return lite::RET_OK;
}
REGISTER_PRIMITIVE_MAPPER(kNameConv2DFusion, Conv2DFusionMapper) REGISTER_PRIMITIVE_MAPPER(kNameConv2DFusion, Conv2DFusionMapper)
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

View File

@ -30,6 +30,9 @@ class Conv2DFusionMapper : public PrimitiveMapper {
~Conv2DFusionMapper() override = default; ~Conv2DFusionMapper() override = default;
STATUS Mapper(const CNodePtr &cnode) override; STATUS Mapper(const CNodePtr &cnode) override;
private:
STATUS AdjustAttrPad(const PrimitivePtr &prim);
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

View File

@ -1,35 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/adapter/acl/mapper/div_fusion_mapper.h"
#include <memory>
#include "tools/converter/adapter/acl/mapper/primitive_mapper_register.h"
#include "src/common/log_util.h"
namespace mindspore {
namespace lite {
STATUS DivFusionMapper::Mapper(const CNodePtr &cnode) {
auto dst_prim = std::make_shared<ops::Div>();
if (MoveAttrMap(cnode, dst_prim) != RET_OK) {
MS_LOG(ERROR) << "DivFusion mapper failed.";
return RET_ERROR;
}
return RET_OK;
}
REGISTER_PRIMITIVE_MAPPER(kNameDivFusion, DivFusionMapper)
} // namespace lite
} // namespace mindspore

View File

@ -1,37 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef ACL_MAPPER_PRIMITIVE_DIV_FUSION_MAPPER_H
#define ACL_MAPPER_PRIMITIVE_DIV_FUSION_MAPPER_H
#include "tools/converter/adapter/acl/mapper/primitive_mapper.h"
#include "ops/fusion/div_fusion.h"
namespace mindspore {
namespace lite {
using mindspore::ops::kNameDivFusion;
class DivFusionMapper : public PrimitiveMapper {
public:
DivFusionMapper() : PrimitiveMapper(kNameDivFusion) {}
~DivFusionMapper() override = default;
STATUS Mapper(const CNodePtr &cnode) override;
};
} // namespace lite
} // namespace mindspore
#endif // ACL_MAPPER_PRIMITIVE_DIV_FUSION_MAPPER_H

View File

@ -1,34 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/adapter/acl/mapper/mul_fusion_mapper.h"
#include <memory>
#include "tools/converter/adapter/acl/mapper/primitive_mapper_register.h"
namespace mindspore {
namespace lite {
STATUS MulFusionMapper::Mapper(const CNodePtr &cnode) {
auto dst_prim = std::make_shared<ops::Mul>();
if (MoveAttrMap(cnode, dst_prim) != RET_OK) {
MS_LOG(ERROR) << "MulFusion mapper failed.";
return RET_ERROR;
}
return RET_OK;
}
REGISTER_PRIMITIVE_MAPPER(kNameMulFusion, MulFusionMapper)
} // namespace lite
} // namespace mindspore

View File

@ -1,37 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef ACL_MAPPER_PRIMITIVE_MULFUSION_MAPPER_H
#define ACL_MAPPER_PRIMITIVE_MULFUSION_MAPPER_H
#include "tools/converter/adapter/acl/mapper/primitive_mapper.h"
#include "ops/fusion/mul_fusion.h"
using mindspore::ops::kNameMulFusion;
namespace mindspore {
namespace lite {
class MulFusionMapper : public PrimitiveMapper {
public:
MulFusionMapper() : PrimitiveMapper(kNameMulFusion) {}
~MulFusionMapper() override = default;
STATUS Mapper(const CNodePtr &cnode) override;
};
} // namespace lite
} // namespace mindspore
#endif // ACL_MAPPER_PRIMITIVE_MULFUSION_MAPPER_H

View File

@ -1,35 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/adapter/acl/mapper/pow_fusion_mapper.h"
#include <memory>
#include "tools/converter/adapter/acl/mapper/primitive_mapper_register.h"
#include "src/common/log_util.h"
namespace mindspore {
namespace lite {
STATUS PowFusionMapper::Mapper(const CNodePtr &cnode) {
auto dst_prim = std::make_shared<ops::Pow>();
if (MoveAttrMap(cnode, dst_prim) != RET_OK) {
MS_LOG(ERROR) << "PowFusion mapper failed.";
return RET_ERROR;
}
return RET_OK;
}
REGISTER_PRIMITIVE_MAPPER(kNamePowFusion, PowFusionMapper)
} // namespace lite
} // namespace mindspore

View File

@ -1,37 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef ACL_MAPPER_PRIMITIVE_POW_FUSION_MAPPER_H
#define ACL_MAPPER_PRIMITIVE_POW_FUSION_MAPPER_H
#include "tools/converter/adapter/acl/mapper/primitive_mapper.h"
#include "ops/fusion/pow_fusion.h"
namespace mindspore {
namespace lite {
using mindspore::ops::kNamePowFusion;
class PowFusionMapper : public PrimitiveMapper {
public:
PowFusionMapper() : PrimitiveMapper(kNamePowFusion) {}
~PowFusionMapper() override = default;
STATUS Mapper(const CNodePtr &cnode) override;
};
} // namespace lite
} // namespace mindspore
#endif // ACL_MAPPER_PRIMITIVE_POW_FUSION_MAPPER_H

View File

@ -89,11 +89,15 @@ STATUS PrimitiveMapper::AttrAdjust(const PrimitivePtr &prim, const std::string &
MS_LOG(ERROR) << name << " Value num must be two."; MS_LOG(ERROR) << name << " Value num must be two.";
return lite::RET_ERROR; return lite::RET_ERROR;
} }
std::vector<int64_t> new_value; int64_t format = Format::NCHW;
new_value.push_back(1); if (prim->GetAttr(ops::kFormat) != nullptr) {
new_value.push_back(1); format = GetValue<int64_t>(prim->GetAttr(ops::kFormat));
new_value.push_back(static_cast<int64_t>(origin_value[0])); }
new_value.push_back(static_cast<int64_t>(origin_value[1])); std::vector<int64_t> new_value = {1, 1, static_cast<int64_t>(origin_value[0]), static_cast<int64_t>(origin_value[1])};
if (format == Format::NHWC) {
std::vector<int64_t> tmp = {1, static_cast<int64_t>(origin_value[0]), static_cast<int64_t>(origin_value[1]), 1};
new_value.swap(tmp);
}
prim->AddAttr(name, MakeValue(new_value)); prim->AddAttr(name, MakeValue(new_value));
return lite::RET_OK; return lite::RET_OK;
} }

View File

@ -87,14 +87,14 @@ static STATUS AdapteNodeWithMultiOutputs(const FuncGraphPtr &func_graph, const C
auto input_cnode = input->cast<CNodePtr>(); auto input_cnode = input->cast<CNodePtr>();
std::string input_func_name = GetCNodeFuncName(input_cnode); std::string input_func_name = GetCNodeFuncName(input_cnode);
if (kCNodeWithMultiOutputs.find(input_func_name) != kCNodeWithMultiOutputs.end()) { if (kCNodeWithMultiOutputs.find(input_func_name) != kCNodeWithMultiOutputs.end()) {
MS_LOG(INFO) << "Adapter cnode with multioutputs: " << cnode_func_name; MS_LOG(INFO) << "Input " << input_func_name << " of cnode " << cnode_func_name << " has multioutputs";
CNodePtr get_item_cnode = CreateTupleGetItemNode(func_graph, input_cnode); CNodePtr get_item_cnode = CreateTupleGetItemNode(func_graph, input_cnode);
if (get_item_cnode == nullptr) { if (get_item_cnode == nullptr) {
MS_LOG(ERROR) << "Create tuple item for " << cnode_func_name << " failed."; MS_LOG(ERROR) << "Create tuple item for " << input_func_name << " of " << cnode_func_name << " failed.";
return lite::RET_ERROR; return lite::RET_ERROR;
} }
if (!manager->Replace(input_cnode, get_item_cnode)) { if (!manager->Replace(input_cnode, get_item_cnode)) {
MS_LOG(ERROR) << "Replace " << cnode_func_name << " failed."; MS_LOG(ERROR) << "Replace " << input_func_name << " of " << cnode_func_name << " failed.";
return lite::RET_ERROR; return lite::RET_ERROR;
} }
} }

View File

@ -1,35 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/adapter/acl/mapper/sub_fusion_mapper.h"
#include <memory>
#include "tools/converter/adapter/acl/mapper/primitive_mapper_register.h"
#include "src/common/log_util.h"
namespace mindspore {
namespace lite {
STATUS SubFusionMapper::Mapper(const CNodePtr &cnode) {
auto dst_prim = std::make_shared<ops::Sub>();
if (MoveAttrMap(cnode, dst_prim) != RET_OK) {
MS_LOG(ERROR) << "SubFusion mapper failed.";
return RET_ERROR;
}
return RET_OK;
}
REGISTER_PRIMITIVE_MAPPER(kNameSubFusion, SubFusionMapper)
} // namespace lite
} // namespace mindspore

View File

@ -1,37 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef ACL_MAPPER_PRIMITIVE_SUB_FUSION_MAPPER_H
#define ACL_MAPPER_PRIMITIVE_SUB_FUSION_MAPPER_H
#include "tools/converter/adapter/acl/mapper/primitive_mapper.h"
#include "ops/fusion/sub_fusion.h"
namespace mindspore {
namespace lite {
using mindspore::ops::kNameSubFusion;
class SubFusionMapper : public PrimitiveMapper {
public:
SubFusionMapper() : PrimitiveMapper(kNameSubFusion) {}
~SubFusionMapper() override = default;
STATUS Mapper(const CNodePtr &cnode) override;
};
} // namespace lite
} // namespace mindspore
#endif // ACL_MAPPER_PRIMITIVE_SUB_FUSION_MAPPER_H

View File

@ -41,7 +41,7 @@ ADD_CONVERTER_TBE_OP(BNInference)
ADD_CONVERTER_TBE_OP(Deconvolution) ADD_CONVERTER_TBE_OP(Deconvolution)
ADD_CONVERTER_TBE_OP(Upsample) ADD_CONVERTER_TBE_OP(Upsample)
ADD_CONVERTER_TBE_OP(Conv2DTransposeD) ADD_CONVERTER_TBE_OP(Conv2DTransposeD)
ADD_CONVERTER_TBE_OP(DepthwiseConv2dNative)
} // namespace acl } // namespace acl
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore