forked from mindspore-Ecosystem/mindspore
!25692 310 support ocr model
Merge pull request !25692 from zhengyuanhua/model_br1
This commit is contained in:
commit
027adf9fff
|
@ -910,6 +910,7 @@ DfGraphConvertor &DfGraphConvertor::BuildGraph() {
|
||||||
} else if (vars_[name] != nullptr) {
|
} else if (vars_[name] != nullptr) {
|
||||||
MS_LOG(INFO) << "add var input " << it->ToString();
|
MS_LOG(INFO) << "add var input " << it->ToString();
|
||||||
auto op = Convert(it);
|
auto op = Convert(it);
|
||||||
|
UpdateConstOpDesc(it, vars_[name]);
|
||||||
MS_EXCEPTION_IF_NULL(op);
|
MS_EXCEPTION_IF_NULL(op);
|
||||||
inputs.push_back(*op);
|
inputs.push_back(*op);
|
||||||
}
|
}
|
||||||
|
@ -942,6 +943,40 @@ DfGraphConvertor &DfGraphConvertor::BuildGraph() {
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void DfGraphConvertor::UpdateConstOpDesc(const AnfNodePtr &it, const OperatorPtr &op) const {
|
||||||
|
if (!it->isa<Parameter>()) {
|
||||||
|
MS_LOG(DEBUG) << "It is not parameter, name: " << it->DebugString();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
auto para = it->cast<ParameterPtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(para);
|
||||||
|
std::string format = kOpFormat_NCHW;
|
||||||
|
std::string param_debug_info = para->DebugString();
|
||||||
|
auto param_format = param_format_.find(param_debug_info);
|
||||||
|
if (param_format != param_format_.end()) {
|
||||||
|
format = param_format->second;
|
||||||
|
MS_LOG(DEBUG) << "Parameter debug info: " << param_debug_info << ", format is " << format;
|
||||||
|
}
|
||||||
|
if (format == kOpFormat_NCHW) {
|
||||||
|
MS_LOG(DEBUG) << "Format is not changed, no need to update op desc, name: " << param_debug_info;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (!para->has_default()) {
|
||||||
|
MS_LOG(DEBUG) << "Parameter has no default, no need to update op desc, name: " << param_debug_info;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
auto value = para->default_param();
|
||||||
|
MS_EXCEPTION_IF_NULL(value);
|
||||||
|
auto tensor = value->cast<std::shared_ptr<tensor::Tensor>>();
|
||||||
|
MS_EXCEPTION_IF_NULL(tensor);
|
||||||
|
auto const_op_desc = TransformUtil::GetGeTensorDesc(tensor->shape_c(), tensor->data_type(), format);
|
||||||
|
if (const_op_desc == nullptr) {
|
||||||
|
MS_LOG(WARNING) << "Create parameter " << para->name() << " output descriptor failed!";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
(void)std::static_pointer_cast<Constant>(op)->update_output_desc_y(*const_op_desc);
|
||||||
|
}
|
||||||
|
|
||||||
void DfGraphConvertor::UpdateDataOpDesc(const AnfNodePtr &it, const OperatorPtr &op) const {
|
void DfGraphConvertor::UpdateDataOpDesc(const AnfNodePtr &it, const OperatorPtr &op) const {
|
||||||
auto node = std::static_pointer_cast<AnfNode>(it);
|
auto node = std::static_pointer_cast<AnfNode>(it);
|
||||||
if (node == nullptr) {
|
if (node == nullptr) {
|
||||||
|
@ -1543,6 +1578,27 @@ void DfGraphConvertor::ConvertReshape(const CNodePtr node) {
|
||||||
op_cache_[node.get()] = op;
|
op_cache_[node.get()] = op;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void DfGraphConvertor::ConvertConv2D(const CNodePtr node) {
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
OpAdapterPtr adpt = FindAdapter(node, training_);
|
||||||
|
if (adpt == nullptr) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
auto op = adpt->generate(node);
|
||||||
|
MS_EXCEPTION_IF_NULL(op);
|
||||||
|
auto value_node = node->input(0)->cast<ValueNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(value_node);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_node->value());
|
||||||
|
auto primitive = value_node->value()->cast<PrimitivePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
auto value = primitive->GetAttr("padding");
|
||||||
|
if (value != nullptr) {
|
||||||
|
std::string pad_mode = GetValue<std::string>(value);
|
||||||
|
(void)op->SetAttr("padding", pad_mode);
|
||||||
|
}
|
||||||
|
op_cache_[node.get()] = op;
|
||||||
|
}
|
||||||
|
|
||||||
AnfNodePtr DfGraphConvertor::TraceTupleGetItem(const CNodePtr &node, uint64_t *index) {
|
AnfNodePtr DfGraphConvertor::TraceTupleGetItem(const CNodePtr &node, uint64_t *index) {
|
||||||
const int TUPLE_GET_ITEM_INDEX = 2;
|
const int TUPLE_GET_ITEM_INDEX = 2;
|
||||||
if (node->inputs().size() < 3) { // "tuple_getitem" primitive must have 3 inputs
|
if (node->inputs().size() < 3) { // "tuple_getitem" primitive must have 3 inputs
|
||||||
|
@ -1735,6 +1791,12 @@ bool DfGraphConvertor::CheckCNode(const std::string &name, const CNodePtr node)
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Add attr pad mode to Conv2D
|
||||||
|
if (name == prim::kPrimConv2D->name() || name == prim::kPrimDepthwiseConv2dNative->name()) {
|
||||||
|
ConvertConv2D(node);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
// make_tuple is used for a dynamic_input, convert it to a vector of OutHandlers
|
// make_tuple is used for a dynamic_input, convert it to a vector of OutHandlers
|
||||||
if (name == prim::kPrimMakeTuple->name()) {
|
if (name == prim::kPrimMakeTuple->name()) {
|
||||||
ConvertMakeTuple(node);
|
ConvertMakeTuple(node);
|
||||||
|
@ -1813,7 +1875,7 @@ void DfGraphConvertor::SaveParamFormat(const CNodePtr node) {
|
||||||
CheckAndConvertUtils::ConvertAttrValueToString(prim->name(), "format", &attr.second);
|
CheckAndConvertUtils::ConvertAttrValueToString(prim->name(), "format", &attr.second);
|
||||||
}
|
}
|
||||||
std::string format = attr.second->ToString();
|
std::string format = attr.second->ToString();
|
||||||
if (format != "NCDHW") {
|
if (format != "NCDHW" && format != "NHWC") {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
for (size_t i = 1; i < node->size(); i++) {
|
for (size_t i = 1; i < node->size(); i++) {
|
||||||
|
|
|
@ -164,6 +164,7 @@ class DfGraphConvertor {
|
||||||
void ConvertMakeTuple(const CNodePtr node);
|
void ConvertMakeTuple(const CNodePtr node);
|
||||||
void ConvertTopK(const CNodePtr node);
|
void ConvertTopK(const CNodePtr node);
|
||||||
void ConvertReshape(const CNodePtr node);
|
void ConvertReshape(const CNodePtr node);
|
||||||
|
void ConvertConv2D(const CNodePtr node);
|
||||||
std::vector<int64_t> CastToInt(const ValuePtr &value);
|
std::vector<int64_t> CastToInt(const ValuePtr &value);
|
||||||
bool CheckCNode(const std::string &name, const CNodePtr node);
|
bool CheckCNode(const std::string &name, const CNodePtr node);
|
||||||
void TraceOutput(AnfNodePtr node);
|
void TraceOutput(AnfNodePtr node);
|
||||||
|
@ -177,6 +178,7 @@ class DfGraphConvertor {
|
||||||
void BuildSaveCheckpointGraph();
|
void BuildSaveCheckpointGraph();
|
||||||
void DrawCNode(const CNodePtr node, const OpAdapterPtr adpt);
|
void DrawCNode(const CNodePtr node, const OpAdapterPtr adpt);
|
||||||
void UpdateDataOpDesc(const AnfNodePtr &it, const OperatorPtr &op) const;
|
void UpdateDataOpDesc(const AnfNodePtr &it, const OperatorPtr &op) const;
|
||||||
|
void UpdateConstOpDesc(const AnfNodePtr &it, const OperatorPtr &op) const;
|
||||||
void AddGraphConstInput(const OperatorPtr &op);
|
void AddGraphConstInput(const OperatorPtr &op);
|
||||||
OperatorPtr ToOperatorPtr(const AnfNodePtr &node);
|
OperatorPtr ToOperatorPtr(const AnfNodePtr &node);
|
||||||
bool IsSourceEdgeNode(const AnfNodePtr &node);
|
bool IsSourceEdgeNode(const AnfNodePtr &node);
|
||||||
|
|
|
@ -36,6 +36,7 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
static const std::set<std::string> kDevice = {"Ascend310", "Ascend710"};
|
static const std::set<std::string> kDevice = {"Ascend310", "Ascend710"};
|
||||||
|
static const std::map<int64_t, std::string> kEnumFormatToStrMap = {{Format::NCHW, "NCHW"}, {Format::NHWC, "NHWC"}};
|
||||||
namespace {
|
namespace {
|
||||||
constexpr auto kMakeTuple = "MakeTuple";
|
constexpr auto kMakeTuple = "MakeTuple";
|
||||||
constexpr auto kOutputNames = "outputs_names";
|
constexpr auto kOutputNames = "outputs_names";
|
||||||
|
@ -51,6 +52,50 @@ constexpr auto kDelRedundantTranspose = "DeleteRedundantTranspose";
|
||||||
constexpr size_t kDependInputNum = 3;
|
constexpr size_t kDependInputNum = 3;
|
||||||
constexpr size_t kDependFirstInputIdx = 1;
|
constexpr size_t kDependFirstInputIdx = 1;
|
||||||
constexpr size_t kTupleGetItemFirstInputIdx = 1;
|
constexpr size_t kTupleGetItemFirstInputIdx = 1;
|
||||||
|
|
||||||
|
STATUS PreProcForMindIr(const FuncGraphPtr &func_graph) { return lite::RET_OK; }
|
||||||
|
|
||||||
|
STATUS PreProcForTF(const FuncGraphPtr &func_graph) {
|
||||||
|
if (!lite::RunOptimizerPass(func_graph, {kInferShapePass})) {
|
||||||
|
MS_LOG(ERROR) << "Infer shape pass failed.";
|
||||||
|
return lite::RET_ERROR;
|
||||||
|
}
|
||||||
|
auto node_list = TopoSort(func_graph->get_return());
|
||||||
|
for (auto &node : node_list) {
|
||||||
|
CHECK_NULL_RETURN(node);
|
||||||
|
if (!utils::isa<CNodePtr>(node)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto cnode = node->cast<CNodePtr>();
|
||||||
|
CHECK_NULL_RETURN(cnode);
|
||||||
|
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||||
|
CHECK_NULL_RETURN(prim);
|
||||||
|
if (prim->GetAttr(ops::kFormat) != nullptr) {
|
||||||
|
auto node_format = GetValue<int64_t>(prim->GetAttr(ops::kFormat));
|
||||||
|
if (kEnumFormatToStrMap.find(node_format) != kEnumFormatToStrMap.end()) {
|
||||||
|
std::string format = kEnumFormatToStrMap.at(node_format);
|
||||||
|
prim->AddAttr("io_format", MakeValue(format));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return lite::RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
STATUS PreProcForCaffe(const FuncGraphPtr &func_graph) {
|
||||||
|
if (!lite::RunOptimizerPass(func_graph, {kInferShapePass, kToNCHWFormatPass, kDelRedundantTranspose})) {
|
||||||
|
MS_LOG(ERROR) << "To nchw format failed.";
|
||||||
|
return lite::RET_ERROR;
|
||||||
|
}
|
||||||
|
return lite::RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
STATUS PreProcForOnnx(const FuncGraphPtr &func_graph) {
|
||||||
|
if (!lite::RunOptimizerPass(func_graph, {kInferShapePass, kToNCHWFormatPass, kDelRedundantTranspose})) {
|
||||||
|
MS_LOG(ERROR) << "To nchw format failed.";
|
||||||
|
return lite::RET_ERROR;
|
||||||
|
}
|
||||||
|
return lite::RET_OK;
|
||||||
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
AclPassImpl::AclPassImpl(const converter::Flags &config)
|
AclPassImpl::AclPassImpl(const converter::Flags &config)
|
||||||
|
@ -174,14 +219,20 @@ STATUS AclPassImpl::PreProcGraph(const FuncGraphPtr &func_graph) {
|
||||||
MS_LOG(ERROR) << "Common pass failed.";
|
MS_LOG(ERROR) << "Common pass failed.";
|
||||||
return lite::RET_ERROR;
|
return lite::RET_ERROR;
|
||||||
}
|
}
|
||||||
if (fmk_type_ == converter::kFmkTypeMs) {
|
std::map<converter::FmkType, std::function<STATUS(const FuncGraphPtr &)>> fmk_proc_func = {
|
||||||
MS_LOG(DEBUG) << "MindIr no need to change format.";
|
{converter::kFmkTypeMs, PreProcForMindIr},
|
||||||
return lite::RET_OK;
|
{converter::kFmkTypeTf, PreProcForTF},
|
||||||
}
|
{converter::kFmkTypeCaffe, PreProcForCaffe},
|
||||||
// The format of nodes (cnode, parameter, val) must be nchw due to interface of convert om
|
{converter::kFmkTypeOnnx, PreProcForOnnx},
|
||||||
if (!lite::RunOptimizerPass(func_graph, {kInferShapePass, kToNCHWFormatPass, kDelRedundantTranspose})) {
|
};
|
||||||
MS_LOG(ERROR) << "To nchw format failed.";
|
if (fmk_proc_func.find(fmk_type_) != fmk_proc_func.end()) {
|
||||||
return lite::RET_ERROR;
|
auto func = fmk_proc_func.at(fmk_type_);
|
||||||
|
if (func(func_graph) != lite::RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "Pre proc failed, fmk " << fmk_type_;
|
||||||
|
return lite::RET_ERROR;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
MS_LOG(WARNING) << "Not support fmk type " << fmk_type_;
|
||||||
}
|
}
|
||||||
MS_LOG(DEBUG) << "Pre proc graph success.";
|
MS_LOG(DEBUG) << "Pre proc graph success.";
|
||||||
return lite::RET_OK;
|
return lite::RET_OK;
|
||||||
|
|
|
@ -1,34 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#include "tools/converter/adapter/acl/mapper/add_fusion_mapper.h"
|
|
||||||
#include <memory>
|
|
||||||
#include "tools/converter/adapter/acl/mapper/primitive_mapper_register.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace lite {
|
|
||||||
STATUS AddFusionMapper::Mapper(const CNodePtr &cnode) {
|
|
||||||
auto dst_prim = std::make_shared<ops::Add>();
|
|
||||||
if (MoveAttrMap(cnode, dst_prim) != RET_OK) {
|
|
||||||
MS_LOG(ERROR) << "AddFusion mapper failed.";
|
|
||||||
return RET_ERROR;
|
|
||||||
}
|
|
||||||
return RET_OK;
|
|
||||||
}
|
|
||||||
|
|
||||||
REGISTER_PRIMITIVE_MAPPER(kNameAddFusion, AddFusionMapper)
|
|
||||||
} // namespace lite
|
|
||||||
} // namespace mindspore
|
|
|
@ -1,37 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#ifndef ACL_MAPPER_PRIMITIVE_ADDFUSION_MAPPER_H
|
|
||||||
#define ACL_MAPPER_PRIMITIVE_ADDFUSION_MAPPER_H
|
|
||||||
|
|
||||||
#include "tools/converter/adapter/acl/mapper/primitive_mapper.h"
|
|
||||||
#include "ops/fusion/add_fusion.h"
|
|
||||||
|
|
||||||
using mindspore::ops::kNameAddFusion;
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace lite {
|
|
||||||
class AddFusionMapper : public PrimitiveMapper {
|
|
||||||
public:
|
|
||||||
AddFusionMapper() : PrimitiveMapper(kNameAddFusion) {}
|
|
||||||
|
|
||||||
~AddFusionMapper() override = default;
|
|
||||||
|
|
||||||
STATUS Mapper(const CNodePtr &cnode) override;
|
|
||||||
};
|
|
||||||
} // namespace lite
|
|
||||||
} // namespace mindspore
|
|
||||||
#endif // ACL_MAPPER_PRIMITIVE_ADDFUSION_MAPPER_H
|
|
|
@ -0,0 +1,74 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "tools/converter/adapter/acl/mapper/arithmetic_mapper.h"
|
||||||
|
#include <memory>
|
||||||
|
#include "tools/converter/adapter/acl/mapper/primitive_mapper_register.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace lite {
|
||||||
|
STATUS AddFusionMapper::Mapper(const CNodePtr &cnode) {
|
||||||
|
auto dst_prim = std::make_shared<ops::Add>();
|
||||||
|
if (MoveAttrMap(cnode, dst_prim) != RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "AddFusion mapper failed.";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
STATUS DivFusionMapper::Mapper(const CNodePtr &cnode) {
|
||||||
|
auto dst_prim = std::make_shared<ops::Div>();
|
||||||
|
if (MoveAttrMap(cnode, dst_prim) != RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "DivFusion mapper failed.";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
STATUS MulFusionMapper::Mapper(const CNodePtr &cnode) {
|
||||||
|
auto dst_prim = std::make_shared<ops::Mul>();
|
||||||
|
if (MoveAttrMap(cnode, dst_prim) != RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "MulFusion mapper failed.";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
STATUS PowFusionMapper::Mapper(const CNodePtr &cnode) {
|
||||||
|
auto dst_prim = std::make_shared<ops::Pow>();
|
||||||
|
if (MoveAttrMap(cnode, dst_prim) != RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "PowFusion mapper failed.";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
STATUS SubFusionMapper::Mapper(const CNodePtr &cnode) {
|
||||||
|
auto dst_prim = std::make_shared<ops::Sub>();
|
||||||
|
if (MoveAttrMap(cnode, dst_prim) != RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "SubFusion mapper failed.";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
REGISTER_PRIMITIVE_MAPPER(kNameAddFusion, AddFusionMapper)
|
||||||
|
REGISTER_PRIMITIVE_MAPPER(kNameDivFusion, DivFusionMapper)
|
||||||
|
REGISTER_PRIMITIVE_MAPPER(kNameMulFusion, MulFusionMapper)
|
||||||
|
REGISTER_PRIMITIVE_MAPPER(kNamePowFusion, PowFusionMapper)
|
||||||
|
REGISTER_PRIMITIVE_MAPPER(kNameSubFusion, SubFusionMapper)
|
||||||
|
} // namespace lite
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,82 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef ACL_MAPPER_PRIMITIVE_ARITHMETIC_MAPPER_H
|
||||||
|
#define ACL_MAPPER_PRIMITIVE_ARITHMETIC_MAPPER_H
|
||||||
|
|
||||||
|
#include "tools/converter/adapter/acl/mapper/primitive_mapper.h"
|
||||||
|
#include "ops/fusion/add_fusion.h"
|
||||||
|
#include "ops/fusion/div_fusion.h"
|
||||||
|
#include "ops/fusion/mul_fusion.h"
|
||||||
|
#include "ops/fusion/pow_fusion.h"
|
||||||
|
#include "ops/fusion/sub_fusion.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace lite {
|
||||||
|
using mindspore::ops::kNameAddFusion;
|
||||||
|
using mindspore::ops::kNameDivFusion;
|
||||||
|
using mindspore::ops::kNameMulFusion;
|
||||||
|
using mindspore::ops::kNamePowFusion;
|
||||||
|
using mindspore::ops::kNameSubFusion;
|
||||||
|
|
||||||
|
class AddFusionMapper : public PrimitiveMapper {
|
||||||
|
public:
|
||||||
|
AddFusionMapper() : PrimitiveMapper(kNameAddFusion) {}
|
||||||
|
|
||||||
|
~AddFusionMapper() override = default;
|
||||||
|
|
||||||
|
STATUS Mapper(const CNodePtr &cnode) override;
|
||||||
|
};
|
||||||
|
|
||||||
|
class DivFusionMapper : public PrimitiveMapper {
|
||||||
|
public:
|
||||||
|
DivFusionMapper() : PrimitiveMapper(kNameDivFusion) {}
|
||||||
|
|
||||||
|
~DivFusionMapper() override = default;
|
||||||
|
|
||||||
|
STATUS Mapper(const CNodePtr &cnode) override;
|
||||||
|
};
|
||||||
|
|
||||||
|
class MulFusionMapper : public PrimitiveMapper {
|
||||||
|
public:
|
||||||
|
MulFusionMapper() : PrimitiveMapper(kNameMulFusion) {}
|
||||||
|
|
||||||
|
~MulFusionMapper() override = default;
|
||||||
|
|
||||||
|
STATUS Mapper(const CNodePtr &cnode) override;
|
||||||
|
};
|
||||||
|
|
||||||
|
class PowFusionMapper : public PrimitiveMapper {
|
||||||
|
public:
|
||||||
|
PowFusionMapper() : PrimitiveMapper(kNamePowFusion) {}
|
||||||
|
|
||||||
|
~PowFusionMapper() override = default;
|
||||||
|
|
||||||
|
STATUS Mapper(const CNodePtr &cnode) override;
|
||||||
|
};
|
||||||
|
|
||||||
|
class SubFusionMapper : public PrimitiveMapper {
|
||||||
|
public:
|
||||||
|
SubFusionMapper() : PrimitiveMapper(kNameSubFusion) {}
|
||||||
|
|
||||||
|
~SubFusionMapper() override = default;
|
||||||
|
|
||||||
|
STATUS Mapper(const CNodePtr &cnode) override;
|
||||||
|
};
|
||||||
|
} // namespace lite
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // ACL_MAPPER_PRIMITIVE_ARITHMETIC_MAPPER_H
|
|
@ -15,11 +15,25 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "tools/converter/adapter/acl/mapper/conv2d_fusion_mapper.h"
|
#include "tools/converter/adapter/acl/mapper/conv2d_fusion_mapper.h"
|
||||||
|
#include <vector>
|
||||||
|
#include <map>
|
||||||
|
#include <string>
|
||||||
#include "memory"
|
#include "memory"
|
||||||
#include "tools/converter/adapter/acl/mapper/primitive_mapper_register.h"
|
#include "tools/converter/adapter/acl/mapper/primitive_mapper_register.h"
|
||||||
|
#include "tools/converter/adapter/acl/mapper/tbe_op_def.h"
|
||||||
|
#include "src/common/log_util.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace lite {
|
namespace lite {
|
||||||
|
static const std::map<int64_t, std::string> kPadModToStrMap = {
|
||||||
|
{PadMode::PAD, "CALCULATED"},
|
||||||
|
{PadMode::SAME, "SAME"},
|
||||||
|
{PadMode::VALID, "VALID"},
|
||||||
|
};
|
||||||
|
namespace {
|
||||||
|
constexpr auto kNamePaddingMode = "padding";
|
||||||
|
} // namespace
|
||||||
|
|
||||||
STATUS Conv2DFusionMapper::Mapper(const CNodePtr &cnode) {
|
STATUS Conv2DFusionMapper::Mapper(const CNodePtr &cnode) {
|
||||||
ValueNodePtr value_node = nullptr;
|
ValueNodePtr value_node = nullptr;
|
||||||
PrimitivePtr src_prim = nullptr;
|
PrimitivePtr src_prim = nullptr;
|
||||||
|
@ -27,10 +41,19 @@ STATUS Conv2DFusionMapper::Mapper(const CNodePtr &cnode) {
|
||||||
MS_LOG(ERROR) << "Get primitive from cnode failed.";
|
MS_LOG(ERROR) << "Get primitive from cnode failed.";
|
||||||
return lite::RET_ERROR;
|
return lite::RET_ERROR;
|
||||||
}
|
}
|
||||||
auto dst_prim = std::make_shared<ops::Conv2D>();
|
bool is_depth_wise = false;
|
||||||
MS_ASSERT(dst_prim != nullptr);
|
auto depth_wise_ptr = src_prim->GetAttr(ops::kIsDepthWise);
|
||||||
|
if (depth_wise_ptr != nullptr) {
|
||||||
|
is_depth_wise = GetValue<bool>(depth_wise_ptr);
|
||||||
|
}
|
||||||
|
PrimitivePtr dst_prim = nullptr;
|
||||||
|
if (!is_depth_wise) {
|
||||||
|
dst_prim = std::make_shared<ops::Conv2D>();
|
||||||
|
} else {
|
||||||
|
dst_prim = std::make_shared<acl::DepthwiseConv2dNative>();
|
||||||
|
}
|
||||||
|
CHECK_NULL_RETURN(dst_prim);
|
||||||
dst_prim->SetAttrs(src_prim->attrs());
|
dst_prim->SetAttrs(src_prim->attrs());
|
||||||
|
|
||||||
auto status = AttrAdjust(dst_prim, ops::kStride);
|
auto status = AttrAdjust(dst_prim, ops::kStride);
|
||||||
if (status != lite::RET_OK) {
|
if (status != lite::RET_OK) {
|
||||||
MS_LOG(ERROR) << "adjust stride failed.";
|
MS_LOG(ERROR) << "adjust stride failed.";
|
||||||
|
@ -41,10 +64,34 @@ STATUS Conv2DFusionMapper::Mapper(const CNodePtr &cnode) {
|
||||||
MS_LOG(ERROR) << "adjust dilation failed.";
|
MS_LOG(ERROR) << "adjust dilation failed.";
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
|
status = AdjustAttrPad(dst_prim);
|
||||||
|
if (status != lite::RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "adjust pad failed.";
|
||||||
|
return status;
|
||||||
|
}
|
||||||
value_node->set_value(dst_prim);
|
value_node->set_value(dst_prim);
|
||||||
return lite::RET_OK;
|
return lite::RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
STATUS Conv2DFusionMapper::AdjustAttrPad(const PrimitivePtr &prim) {
|
||||||
|
// attr pad val
|
||||||
|
auto pad_ptr = prim->GetAttr(ops::kPadList);
|
||||||
|
if (pad_ptr == nullptr) {
|
||||||
|
std::vector<int64_t> pad_list = {0, 0, 0, 0};
|
||||||
|
prim->AddAttr(ops::kPadList, MakeValue(pad_list));
|
||||||
|
}
|
||||||
|
// attr pad mode
|
||||||
|
auto pad_mode_val = prim->GetAttr(ops::kPadMode);
|
||||||
|
if (pad_mode_val != nullptr) {
|
||||||
|
auto pad_mode = GetValue<int64_t>(pad_mode_val);
|
||||||
|
if (kPadModToStrMap.find(pad_mode) != kPadModToStrMap.end()) {
|
||||||
|
std::string padding_mode = kPadModToStrMap.at(pad_mode);
|
||||||
|
prim->AddAttr(kNamePaddingMode, MakeValue(padding_mode));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return lite::RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
REGISTER_PRIMITIVE_MAPPER(kNameConv2DFusion, Conv2DFusionMapper)
|
REGISTER_PRIMITIVE_MAPPER(kNameConv2DFusion, Conv2DFusionMapper)
|
||||||
} // namespace lite
|
} // namespace lite
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -30,6 +30,9 @@ class Conv2DFusionMapper : public PrimitiveMapper {
|
||||||
~Conv2DFusionMapper() override = default;
|
~Conv2DFusionMapper() override = default;
|
||||||
|
|
||||||
STATUS Mapper(const CNodePtr &cnode) override;
|
STATUS Mapper(const CNodePtr &cnode) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
STATUS AdjustAttrPad(const PrimitivePtr &prim);
|
||||||
};
|
};
|
||||||
} // namespace lite
|
} // namespace lite
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -1,35 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#include "tools/converter/adapter/acl/mapper/div_fusion_mapper.h"
|
|
||||||
#include <memory>
|
|
||||||
#include "tools/converter/adapter/acl/mapper/primitive_mapper_register.h"
|
|
||||||
#include "src/common/log_util.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace lite {
|
|
||||||
STATUS DivFusionMapper::Mapper(const CNodePtr &cnode) {
|
|
||||||
auto dst_prim = std::make_shared<ops::Div>();
|
|
||||||
if (MoveAttrMap(cnode, dst_prim) != RET_OK) {
|
|
||||||
MS_LOG(ERROR) << "DivFusion mapper failed.";
|
|
||||||
return RET_ERROR;
|
|
||||||
}
|
|
||||||
return RET_OK;
|
|
||||||
}
|
|
||||||
|
|
||||||
REGISTER_PRIMITIVE_MAPPER(kNameDivFusion, DivFusionMapper)
|
|
||||||
} // namespace lite
|
|
||||||
} // namespace mindspore
|
|
|
@ -1,37 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#ifndef ACL_MAPPER_PRIMITIVE_DIV_FUSION_MAPPER_H
|
|
||||||
#define ACL_MAPPER_PRIMITIVE_DIV_FUSION_MAPPER_H
|
|
||||||
|
|
||||||
#include "tools/converter/adapter/acl/mapper/primitive_mapper.h"
|
|
||||||
#include "ops/fusion/div_fusion.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace lite {
|
|
||||||
using mindspore::ops::kNameDivFusion;
|
|
||||||
|
|
||||||
class DivFusionMapper : public PrimitiveMapper {
|
|
||||||
public:
|
|
||||||
DivFusionMapper() : PrimitiveMapper(kNameDivFusion) {}
|
|
||||||
|
|
||||||
~DivFusionMapper() override = default;
|
|
||||||
|
|
||||||
STATUS Mapper(const CNodePtr &cnode) override;
|
|
||||||
};
|
|
||||||
} // namespace lite
|
|
||||||
} // namespace mindspore
|
|
||||||
#endif // ACL_MAPPER_PRIMITIVE_DIV_FUSION_MAPPER_H
|
|
|
@ -1,34 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#include "tools/converter/adapter/acl/mapper/mul_fusion_mapper.h"
|
|
||||||
#include <memory>
|
|
||||||
#include "tools/converter/adapter/acl/mapper/primitive_mapper_register.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace lite {
|
|
||||||
STATUS MulFusionMapper::Mapper(const CNodePtr &cnode) {
|
|
||||||
auto dst_prim = std::make_shared<ops::Mul>();
|
|
||||||
if (MoveAttrMap(cnode, dst_prim) != RET_OK) {
|
|
||||||
MS_LOG(ERROR) << "MulFusion mapper failed.";
|
|
||||||
return RET_ERROR;
|
|
||||||
}
|
|
||||||
return RET_OK;
|
|
||||||
}
|
|
||||||
|
|
||||||
REGISTER_PRIMITIVE_MAPPER(kNameMulFusion, MulFusionMapper)
|
|
||||||
} // namespace lite
|
|
||||||
} // namespace mindspore
|
|
|
@ -1,37 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#ifndef ACL_MAPPER_PRIMITIVE_MULFUSION_MAPPER_H
|
|
||||||
#define ACL_MAPPER_PRIMITIVE_MULFUSION_MAPPER_H
|
|
||||||
|
|
||||||
#include "tools/converter/adapter/acl/mapper/primitive_mapper.h"
|
|
||||||
#include "ops/fusion/mul_fusion.h"
|
|
||||||
|
|
||||||
using mindspore::ops::kNameMulFusion;
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace lite {
|
|
||||||
class MulFusionMapper : public PrimitiveMapper {
|
|
||||||
public:
|
|
||||||
MulFusionMapper() : PrimitiveMapper(kNameMulFusion) {}
|
|
||||||
|
|
||||||
~MulFusionMapper() override = default;
|
|
||||||
|
|
||||||
STATUS Mapper(const CNodePtr &cnode) override;
|
|
||||||
};
|
|
||||||
} // namespace lite
|
|
||||||
} // namespace mindspore
|
|
||||||
#endif // ACL_MAPPER_PRIMITIVE_MULFUSION_MAPPER_H
|
|
|
@ -1,35 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#include "tools/converter/adapter/acl/mapper/pow_fusion_mapper.h"
|
|
||||||
#include <memory>
|
|
||||||
#include "tools/converter/adapter/acl/mapper/primitive_mapper_register.h"
|
|
||||||
#include "src/common/log_util.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace lite {
|
|
||||||
STATUS PowFusionMapper::Mapper(const CNodePtr &cnode) {
|
|
||||||
auto dst_prim = std::make_shared<ops::Pow>();
|
|
||||||
if (MoveAttrMap(cnode, dst_prim) != RET_OK) {
|
|
||||||
MS_LOG(ERROR) << "PowFusion mapper failed.";
|
|
||||||
return RET_ERROR;
|
|
||||||
}
|
|
||||||
return RET_OK;
|
|
||||||
}
|
|
||||||
|
|
||||||
REGISTER_PRIMITIVE_MAPPER(kNamePowFusion, PowFusionMapper)
|
|
||||||
} // namespace lite
|
|
||||||
} // namespace mindspore
|
|
|
@ -1,37 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#ifndef ACL_MAPPER_PRIMITIVE_POW_FUSION_MAPPER_H
|
|
||||||
#define ACL_MAPPER_PRIMITIVE_POW_FUSION_MAPPER_H
|
|
||||||
|
|
||||||
#include "tools/converter/adapter/acl/mapper/primitive_mapper.h"
|
|
||||||
#include "ops/fusion/pow_fusion.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace lite {
|
|
||||||
using mindspore::ops::kNamePowFusion;
|
|
||||||
|
|
||||||
class PowFusionMapper : public PrimitiveMapper {
|
|
||||||
public:
|
|
||||||
PowFusionMapper() : PrimitiveMapper(kNamePowFusion) {}
|
|
||||||
|
|
||||||
~PowFusionMapper() override = default;
|
|
||||||
|
|
||||||
STATUS Mapper(const CNodePtr &cnode) override;
|
|
||||||
};
|
|
||||||
} // namespace lite
|
|
||||||
} // namespace mindspore
|
|
||||||
#endif // ACL_MAPPER_PRIMITIVE_POW_FUSION_MAPPER_H
|
|
|
@ -89,11 +89,15 @@ STATUS PrimitiveMapper::AttrAdjust(const PrimitivePtr &prim, const std::string &
|
||||||
MS_LOG(ERROR) << name << " Value num must be two.";
|
MS_LOG(ERROR) << name << " Value num must be two.";
|
||||||
return lite::RET_ERROR;
|
return lite::RET_ERROR;
|
||||||
}
|
}
|
||||||
std::vector<int64_t> new_value;
|
int64_t format = Format::NCHW;
|
||||||
new_value.push_back(1);
|
if (prim->GetAttr(ops::kFormat) != nullptr) {
|
||||||
new_value.push_back(1);
|
format = GetValue<int64_t>(prim->GetAttr(ops::kFormat));
|
||||||
new_value.push_back(static_cast<int64_t>(origin_value[0]));
|
}
|
||||||
new_value.push_back(static_cast<int64_t>(origin_value[1]));
|
std::vector<int64_t> new_value = {1, 1, static_cast<int64_t>(origin_value[0]), static_cast<int64_t>(origin_value[1])};
|
||||||
|
if (format == Format::NHWC) {
|
||||||
|
std::vector<int64_t> tmp = {1, static_cast<int64_t>(origin_value[0]), static_cast<int64_t>(origin_value[1]), 1};
|
||||||
|
new_value.swap(tmp);
|
||||||
|
}
|
||||||
prim->AddAttr(name, MakeValue(new_value));
|
prim->AddAttr(name, MakeValue(new_value));
|
||||||
return lite::RET_OK;
|
return lite::RET_OK;
|
||||||
}
|
}
|
||||||
|
|
|
@ -87,14 +87,14 @@ static STATUS AdapteNodeWithMultiOutputs(const FuncGraphPtr &func_graph, const C
|
||||||
auto input_cnode = input->cast<CNodePtr>();
|
auto input_cnode = input->cast<CNodePtr>();
|
||||||
std::string input_func_name = GetCNodeFuncName(input_cnode);
|
std::string input_func_name = GetCNodeFuncName(input_cnode);
|
||||||
if (kCNodeWithMultiOutputs.find(input_func_name) != kCNodeWithMultiOutputs.end()) {
|
if (kCNodeWithMultiOutputs.find(input_func_name) != kCNodeWithMultiOutputs.end()) {
|
||||||
MS_LOG(INFO) << "Adapter cnode with multioutputs: " << cnode_func_name;
|
MS_LOG(INFO) << "Input " << input_func_name << " of cnode " << cnode_func_name << " has multioutputs";
|
||||||
CNodePtr get_item_cnode = CreateTupleGetItemNode(func_graph, input_cnode);
|
CNodePtr get_item_cnode = CreateTupleGetItemNode(func_graph, input_cnode);
|
||||||
if (get_item_cnode == nullptr) {
|
if (get_item_cnode == nullptr) {
|
||||||
MS_LOG(ERROR) << "Create tuple item for " << cnode_func_name << " failed.";
|
MS_LOG(ERROR) << "Create tuple item for " << input_func_name << " of " << cnode_func_name << " failed.";
|
||||||
return lite::RET_ERROR;
|
return lite::RET_ERROR;
|
||||||
}
|
}
|
||||||
if (!manager->Replace(input_cnode, get_item_cnode)) {
|
if (!manager->Replace(input_cnode, get_item_cnode)) {
|
||||||
MS_LOG(ERROR) << "Replace " << cnode_func_name << " failed.";
|
MS_LOG(ERROR) << "Replace " << input_func_name << " of " << cnode_func_name << " failed.";
|
||||||
return lite::RET_ERROR;
|
return lite::RET_ERROR;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,35 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#include "tools/converter/adapter/acl/mapper/sub_fusion_mapper.h"
|
|
||||||
#include <memory>
|
|
||||||
#include "tools/converter/adapter/acl/mapper/primitive_mapper_register.h"
|
|
||||||
#include "src/common/log_util.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace lite {
|
|
||||||
STATUS SubFusionMapper::Mapper(const CNodePtr &cnode) {
|
|
||||||
auto dst_prim = std::make_shared<ops::Sub>();
|
|
||||||
if (MoveAttrMap(cnode, dst_prim) != RET_OK) {
|
|
||||||
MS_LOG(ERROR) << "SubFusion mapper failed.";
|
|
||||||
return RET_ERROR;
|
|
||||||
}
|
|
||||||
return RET_OK;
|
|
||||||
}
|
|
||||||
|
|
||||||
REGISTER_PRIMITIVE_MAPPER(kNameSubFusion, SubFusionMapper)
|
|
||||||
} // namespace lite
|
|
||||||
} // namespace mindspore
|
|
|
@ -1,37 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#ifndef ACL_MAPPER_PRIMITIVE_SUB_FUSION_MAPPER_H
|
|
||||||
#define ACL_MAPPER_PRIMITIVE_SUB_FUSION_MAPPER_H
|
|
||||||
|
|
||||||
#include "tools/converter/adapter/acl/mapper/primitive_mapper.h"
|
|
||||||
#include "ops/fusion/sub_fusion.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace lite {
|
|
||||||
using mindspore::ops::kNameSubFusion;
|
|
||||||
|
|
||||||
class SubFusionMapper : public PrimitiveMapper {
|
|
||||||
public:
|
|
||||||
SubFusionMapper() : PrimitiveMapper(kNameSubFusion) {}
|
|
||||||
|
|
||||||
~SubFusionMapper() override = default;
|
|
||||||
|
|
||||||
STATUS Mapper(const CNodePtr &cnode) override;
|
|
||||||
};
|
|
||||||
} // namespace lite
|
|
||||||
} // namespace mindspore
|
|
||||||
#endif // ACL_MAPPER_PRIMITIVE_SUB_FUSION_MAPPER_H
|
|
|
@ -41,7 +41,7 @@ ADD_CONVERTER_TBE_OP(BNInference)
|
||||||
ADD_CONVERTER_TBE_OP(Deconvolution)
|
ADD_CONVERTER_TBE_OP(Deconvolution)
|
||||||
ADD_CONVERTER_TBE_OP(Upsample)
|
ADD_CONVERTER_TBE_OP(Upsample)
|
||||||
ADD_CONVERTER_TBE_OP(Conv2DTransposeD)
|
ADD_CONVERTER_TBE_OP(Conv2DTransposeD)
|
||||||
|
ADD_CONVERTER_TBE_OP(DepthwiseConv2dNative)
|
||||||
} // namespace acl
|
} // namespace acl
|
||||||
} // namespace lite
|
} // namespace lite
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
Loading…
Reference in New Issue