forked from mindspore-Ecosystem/mindspore
!24084 support yolo network
Merge pull request !24084 from zhengyuanhua/br1
This commit is contained in:
commit
e4b6008d1e
|
@ -59,6 +59,7 @@ if(BUILD_LITE)
|
|||
"${CMAKE_CURRENT_SOURCE_DIR}/model/acl/acl_model_multi.cc"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/model/acl/acl_model.cc"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/serialization.cc"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/types.cc"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/model/model.cc"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/model/model_impl.cc"
|
||||
$<TARGET_OBJECTS:_mindspore_vm_obj>)
|
||||
|
|
|
@ -331,6 +331,7 @@ constexpr const char kNameGlobalAvgPool[] = "GlobalAveragePool";
|
|||
constexpr const char kNameStridedSliceV2[] = "StridedSliceV2";
|
||||
constexpr const char kNameBNInference[] = "BNInference";
|
||||
constexpr const char kNameDeconvolution[] = "Deconvolution";
|
||||
constexpr const char kNameUpsample[] = "Upsample";
|
||||
|
||||
class OpAdapterMap {
|
||||
public:
|
||||
|
|
|
@ -163,4 +163,12 @@ INPUT_MAP(GlobalAveragePool) = {{1, INPUT_DESC(x)}};
|
|||
ATTR_MAP(GlobalAveragePool) = EMPTY_ATTR_MAP;
|
||||
OUTPUT_MAP(GlobalAveragePool) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(GlobalAveragePool, kNameGlobalAvgPool, ADPT_DESC(GlobalAveragePool))
|
||||
|
||||
// Upsample
|
||||
INPUT_MAP(Upsample) = {{1, INPUT_DESC(x)}};
|
||||
ATTR_MAP(Upsample) = {{"scale", ATTR_DESC(scale, AnyTraits<float>())},
|
||||
{"h", ATTR_DESC(stride_h, AnyTraits<int64_t>())},
|
||||
{"w", ATTR_DESC(stride_w, AnyTraits<int64_t>())}};
|
||||
OUTPUT_MAP(Upsample) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(Upsample, kNameUpsample, ADPT_DESC(Upsample))
|
||||
} // namespace mindspore::transform
|
||||
|
|
|
@ -68,5 +68,8 @@ DECLARE_OP_USE_OUTPUT(AvgPoolV2)
|
|||
|
||||
DECLARE_OP_ADAPTER(GlobalAveragePool)
|
||||
DECLARE_OP_USE_OUTPUT(GlobalAveragePool)
|
||||
|
||||
DECLARE_OP_ADAPTER(Upsample)
|
||||
DECLARE_OP_USE_OUTPUT(Upsample)
|
||||
} // namespace mindspore::transform
|
||||
#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_NN_POOLING_OPS_DECLARE_H_
|
||||
|
|
|
@ -218,11 +218,11 @@ void FuncGraph::ClearNodes() { nodes_.clear(); }
|
|||
void FuncGraph::AddNode(const AnfNodePtr &node) { nodes_.add(node); }
|
||||
|
||||
void FuncGraph::DropNode(const AnfNodePtr &node) {
|
||||
nodes_.erase(node);
|
||||
if (node == nullptr) {
|
||||
MS_LOG(ERROR) << "Node is nullptr";
|
||||
return;
|
||||
}
|
||||
nodes_.erase(node);
|
||||
auto graph = node->func_graph();
|
||||
if (node->isa<Parameter>()) {
|
||||
(void)parameters_.erase(std::remove(parameters_.begin(), parameters_.end(), node), parameters_.end());
|
||||
|
|
|
@ -545,7 +545,8 @@ STATUS ModelProcess::ConstructTensor(std::vector<mindspore::MSTensor> *outputs)
|
|||
for (size_t i = 0; i < output_infos_.size(); ++i) {
|
||||
std::string lite_output_name = (*outputs)[i].Name();
|
||||
if (lite_output_name != names[i]) {
|
||||
MS_LOG(INFO) << "Lite output name: " << lite_output_name << "; Om output name: " << names[i];
|
||||
MS_LOG(DEBUG) << "Lite output name: " << lite_output_name << "; model output name: " << names[i]
|
||||
<< "shape: " << VectorToString(shapes[i]);
|
||||
}
|
||||
(*outputs)[i].SetFormat(Format::NCHW);
|
||||
(*outputs)[i].SetDataType(data_types[i]);
|
||||
|
@ -572,5 +573,17 @@ STATUS ModelProcess::ConstructTensor(std::vector<mindspore::MSTensor> *outputs)
|
|||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
std::string ModelProcess::VectorToString(const std::vector<int64_t> &val) {
|
||||
std::string str;
|
||||
auto size = val.size();
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
str += std::to_string(val[i]);
|
||||
if (i != size - 1) {
|
||||
str += ",";
|
||||
}
|
||||
}
|
||||
return str;
|
||||
}
|
||||
} // namespace acl
|
||||
} // namespace mindspore::kernel
|
||||
|
|
|
@ -74,6 +74,7 @@ class ModelProcess {
|
|||
STATUS ResetOutputSize();
|
||||
size_t GetDynamicDims(const std::vector<AclTensorInfo> &);
|
||||
STATUS ProcDynamicShape(const std::vector<mindspore::MSTensor> &inputs, size_t dynamic_nums);
|
||||
std::string VectorToString(const std::vector<int64_t> &);
|
||||
|
||||
void DestroyInputsDataset();
|
||||
void DestroyInputsDataMem();
|
||||
|
|
|
@ -9,7 +9,7 @@ set(CCSRC_SRC
|
|||
|
||||
include_directories(${TOP_DIR}/mindspore/ccsrc/backend/kernel_compiler/cpu)
|
||||
|
||||
if(NOT WIN32)
|
||||
if(NOT WIN32 AND NOT MSLITE_ENABLE_ACL)
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -rdynamic -fvisibility=hidden")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -rdynamic -fvisibility=hidden")
|
||||
endif()
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
#include "tools/converter/acl/acl_pass.h"
|
||||
#include <set>
|
||||
#include <map>
|
||||
#include "tools/converter/ops/ops_def.h"
|
||||
#include "tools/common/graph_util.h"
|
||||
#include "tools/common/tensor_util.h"
|
||||
|
@ -42,6 +43,7 @@ constexpr auto kCustomNodeName = "custom_0";
|
|||
constexpr auto kNCHWFormat = "NCHW";
|
||||
constexpr auto kToNHWCFormatPass = "ToNHWCFormat";
|
||||
constexpr auto kToNCHWFormatPass = "ToNCHWFormat";
|
||||
constexpr auto kInferShapePass = "InferShapePass";
|
||||
constexpr auto kDelRedundantTranspose = "DeleteRedundantTranspose";
|
||||
constexpr size_t kDependInputNum = 3;
|
||||
constexpr size_t kDependFirstInputIdx = 1;
|
||||
|
@ -150,7 +152,7 @@ STATUS AclPass::PreProcGraph(const FuncGraphPtr &func_graph) {
|
|||
return lite::RET_OK;
|
||||
}
|
||||
// The format of nodes (cnode, parameter, val) must be nchw due to interface of convert om
|
||||
if (!lite::RunOptimizerPass(func_graph, {kToNCHWFormatPass, kDelRedundantTranspose})) {
|
||||
if (!lite::RunOptimizerPass(func_graph, {kInferShapePass, kToNCHWFormatPass, kDelRedundantTranspose})) {
|
||||
MS_LOG(ERROR) << "To nchw format success.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
|
@ -399,7 +401,6 @@ STATUS AclPass::SetCustomOutputs(const FuncGraphPtr &func_graph, const CNodePtr
|
|||
return lite::RET_ERROR;
|
||||
}
|
||||
custom_node->AddAttr(kOutputNames, MakeValue(graph_output_names_));
|
||||
|
||||
TypeId type = lite::acl::GetTypeFromNode(graph_outputs_[0]);
|
||||
if (graph_outputs_.size() == 1) {
|
||||
auto abstract_tensor = lite::CreateTensorAbstract(graph_output_dims_[0], type);
|
||||
|
@ -417,15 +418,29 @@ STATUS AclPass::SetCustomOutputs(const FuncGraphPtr &func_graph, const CNodePtr
|
|||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
void AclPass::SetCustomAttrs(const std::shared_ptr<ops::Custom> &prim) {
|
||||
// add output_shape attr
|
||||
std::string output_dim_str;
|
||||
for (const auto &item : graph_output_dims_) {
|
||||
output_dim_str += std::to_string(item.size()) + ",";
|
||||
for (const auto &val : item) {
|
||||
output_dim_str += std::to_string(val) + ",";
|
||||
}
|
||||
}
|
||||
std::vector<uint8_t> output_dim_char(output_dim_str.begin(), output_dim_str.end());
|
||||
std::map<std::string, std::vector<uint8_t>> attrs = {{lite::acl::kOutputShapes, output_dim_char}};
|
||||
prim->set_attr(attrs);
|
||||
}
|
||||
|
||||
CNodePtr AclPass::CreateCustomNode(const FuncGraphPtr &func_graph) {
|
||||
auto prim = std::make_unique<mindspore::ops::Custom>();
|
||||
auto prim = std::make_shared<mindspore::ops::Custom>();
|
||||
if (prim == nullptr) {
|
||||
MS_LOG(ERROR) << "New custom op failed.";
|
||||
return nullptr;
|
||||
}
|
||||
prim->set_type(kCustomPrimTypeACL);
|
||||
auto graph_input = func_graph->get_inputs();
|
||||
CNodePtr custom_node = func_graph->NewCNode(std::shared_ptr<ops::PrimitiveC>(prim.release()), graph_input);
|
||||
CNodePtr custom_node = func_graph->NewCNode(prim, graph_input);
|
||||
if (custom_node == nullptr) {
|
||||
MS_LOG(ERROR) << "Custom cnode failed.";
|
||||
return nullptr;
|
||||
|
@ -437,6 +452,7 @@ CNodePtr AclPass::CreateCustomNode(const FuncGraphPtr &func_graph) {
|
|||
MS_LOG(ERROR) << "Set custom outputs failed.";
|
||||
return nullptr;
|
||||
}
|
||||
SetCustomAttrs(prim);
|
||||
return custom_node;
|
||||
}
|
||||
|
||||
|
|
|
@ -26,7 +26,8 @@
|
|||
#include "include/api/types.h"
|
||||
#include "include/registry/parser_context.h"
|
||||
#include "cxx_api/model/acl/acl_model_options.h"
|
||||
#include "tools/converter/acl/common/acl_option_cfg.h"
|
||||
#include "tools/converter/acl/common/acl_types.h"
|
||||
#include "ops/custom.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -53,6 +54,7 @@ class AclPass : public Pass {
|
|||
STATUS ConvertGraphToOm(const FuncGraphPtr &func_graph, Buffer *om_data);
|
||||
ParameterPtr CreateOmParameter(const FuncGraphPtr &func_graph, const Buffer &om);
|
||||
CNodePtr CreateCustomNode(const FuncGraphPtr &func_graph);
|
||||
void SetCustomAttrs(const std::shared_ptr<ops::Custom> &prim);
|
||||
STATUS SetCustomOutputs(const FuncGraphPtr &func_graph, const CNodePtr &custom_node);
|
||||
STATUS SetMultiOutputs(const CNodePtr &new_cnode, TypeId data_type);
|
||||
STATUS ModifyGraphByCustomNode(const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &manager,
|
||||
|
|
|
@ -38,6 +38,8 @@ struct AclModelOptionCfg {
|
|||
std::string buffer_optimize;
|
||||
std::string insert_op_config_file_path;
|
||||
};
|
||||
|
||||
constexpr auto kOutputShapes = "outputs_shape";
|
||||
} // namespace acl
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -19,9 +19,13 @@
|
|||
#include "include/registry/register_kernel_interface.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "common/log_adapter.h"
|
||||
#include "tools/converter/acl/common/acl_types.h"
|
||||
#include "backend/kernel_compiler/cpu/nnacl/op_base.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
constexpr auto kBufMaxSize = 1024;
|
||||
|
||||
Status CustomInterface::Infer(std::vector<mindspore::MSTensor> *inputs, std::vector<mindspore::MSTensor> *outputs,
|
||||
const mindspore::schema::Primitive *primitive) {
|
||||
if (inputs == nullptr || (*inputs).empty()) {
|
||||
|
@ -32,17 +36,64 @@ Status CustomInterface::Infer(std::vector<mindspore::MSTensor> *inputs, std::vec
|
|||
MS_LOG(ERROR) << "Outputs is invalid.";
|
||||
return kLiteError;
|
||||
}
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "Primitive is nullptr.";
|
||||
return kLiteError;
|
||||
}
|
||||
MS_CHECK_TRUE_MSG(primitive != nullptr, kLiteNullptr, "Primitive is nullptr.");
|
||||
if (primitive->value_type() != schema::PrimitiveType_Custom) {
|
||||
MS_LOG(ERROR) << "Primitive type is not PrimitiveType_Custom.";
|
||||
return kLiteError;
|
||||
}
|
||||
auto op = primitive->value_as_Custom();
|
||||
char buf[kBufMaxSize];
|
||||
if (GetCustomAttr(buf, kBufMaxSize, op, acl::kOutputShapes) != kSuccess) {
|
||||
MS_LOG(ERROR) << "Get custom attr output shape failed.";
|
||||
return kLiteError;
|
||||
}
|
||||
uint32_t id = 0;
|
||||
char delims[] = ",";
|
||||
char *res = nullptr;
|
||||
char *save_ptr = nullptr;
|
||||
res = strtok_r(buf, delims, &save_ptr);
|
||||
while (res != nullptr && id < outputs->size()) {
|
||||
int64_t dims_num = strtol(res, &res, 10);
|
||||
std::vector<int64_t> shape(dims_num);
|
||||
for (int64_t j = 0; j < dims_num; j++) {
|
||||
res = strtok_r(nullptr, delims, &save_ptr);
|
||||
shape[j] = static_cast<int64_t>(strtol(res, &res, 10));
|
||||
}
|
||||
(*outputs)[id].SetShape(shape);
|
||||
id += 1;
|
||||
res = strtok_r(nullptr, delims, &save_ptr);
|
||||
}
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
Status CustomInterface::GetCustomAttr(char *buf, uint32_t buf_size, const mindspore::schema::Custom *op,
|
||||
const std::string &attr_name) {
|
||||
MS_CHECK_TRUE_MSG(buf != nullptr, kLiteNullptr, "Buf is nullptr");
|
||||
MS_CHECK_TRUE_MSG(op != nullptr, kLiteNullptr, "Op is nullptr.");
|
||||
auto attr_ptr = op->attr();
|
||||
MS_CHECK_TRUE_MSG(attr_ptr != nullptr, kLiteNullptr, "Attr ptr is nullptr.");
|
||||
for (uint32_t i = 0; i < attr_ptr->size(); i++) {
|
||||
auto val = attr_ptr->Get(i);
|
||||
MS_CHECK_TRUE_MSG(val != nullptr, kLiteNullptr, "Attr val is nullptr.");
|
||||
MS_CHECK_TRUE_MSG(val->name() != nullptr, kLiteNullptr, "Attr val name is nullptr.");
|
||||
if (val->name()->str() == attr_name) {
|
||||
auto output_info = val->data();
|
||||
MS_CHECK_TRUE_MSG(output_info != nullptr, kLiteNullptr, "Output info is nullptr.");
|
||||
auto attr_size = output_info->size();
|
||||
if (attr_size >= buf_size) {
|
||||
MS_LOG(ERROR) << "Attr size[" << attr_size << "] is large than max size[" << buf_size << "]";
|
||||
return kLiteError;
|
||||
}
|
||||
for (uint32_t j = 0; j < attr_size; j++) {
|
||||
buf[j] = static_cast<char>(output_info->Get(j));
|
||||
}
|
||||
buf[attr_size] = 0;
|
||||
return kSuccess;
|
||||
}
|
||||
}
|
||||
return kLiteError;
|
||||
}
|
||||
|
||||
std::shared_ptr<mindspore::kernel::KernelInterface> CustomInferCreater() {
|
||||
auto infer = new (std::nothrow) CustomInterface();
|
||||
if (infer == nullptr) {
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#define MINDSPORE_LITE_TOOLS_CONVERTER_ACL_CUSTOM_INFER_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "include/kernel_interface.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -29,6 +30,9 @@ class CustomInterface : public mindspore::kernel::KernelInterface {
|
|||
|
||||
Status Infer(std::vector<mindspore::MSTensor> *inputs, std::vector<mindspore::MSTensor> *outputs,
|
||||
const mindspore::schema::Primitive *primitive) override;
|
||||
|
||||
private:
|
||||
Status GetCustomAttr(char *buf, uint32_t buf_size, const mindspore::schema::Custom *op, const std::string &attr_name);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -15,7 +15,9 @@
|
|||
*/
|
||||
|
||||
#include "tools/converter/acl/mapper/concat_mapper.h"
|
||||
#include <string>
|
||||
#include "tools/converter/acl/mapper/primitive_mapper_register.h"
|
||||
#include "src/common/log_util.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
@ -24,6 +26,11 @@ constexpr auto kNameInputNums = "inputNums";
|
|||
}
|
||||
|
||||
STATUS ConcatMapper::Mapper(const CNodePtr &cnode) {
|
||||
CHECK_NULL_RETURN(cnode);
|
||||
if (RenameNode(cnode) != RET_OK) {
|
||||
MS_LOG(ERROR) << "Concat rename failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (AddAttrForDynInputPrimitive(cnode) != RET_OK) {
|
||||
MS_LOG(ERROR) << "Concat mapper failed.";
|
||||
return RET_ERROR;
|
||||
|
@ -32,7 +39,6 @@ STATUS ConcatMapper::Mapper(const CNodePtr &cnode) {
|
|||
}
|
||||
|
||||
STATUS ConcatMapper::AddAttrForDynInputPrimitive(const CNodePtr &cnode) {
|
||||
MS_ASSERT(cnode != nullptr);
|
||||
auto value_node = cnode->input(0)->cast<ValueNodePtr>();
|
||||
MS_ASSERT(value_node != nullptr);
|
||||
auto prim = GetValueNode<PrimitivePtr>(value_node);
|
||||
|
@ -48,6 +54,17 @@ STATUS ConcatMapper::AddAttrForDynInputPrimitive(const CNodePtr &cnode) {
|
|||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
STATUS ConcatMapper::RenameNode(const CNodePtr &cnode) {
|
||||
const std::string kNamePercent = "%";
|
||||
std::string name = cnode->fullname_with_scope();
|
||||
std::string::size_type pos = 0;
|
||||
while ((pos = name.find(kNamePercent)) != name.npos) {
|
||||
name = name.replace(pos, kNamePercent.size(), "");
|
||||
}
|
||||
cnode->set_fullname_with_scope(name);
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
REGISTER_PRIMITIVE_MAPPER(kNameConcat, ConcatMapper)
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -34,6 +34,7 @@ class ConcatMapper : public PrimitiveMapper {
|
|||
|
||||
private:
|
||||
STATUS AddAttrForDynInputPrimitive(const CNodePtr &cnode);
|
||||
STATUS RenameNode(const CNodePtr &cnode);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -46,6 +46,8 @@ ADD_CONVERTER_TBE_OP(GlobalAveragePool)
|
|||
ADD_CONVERTER_TBE_OP(BNInference)
|
||||
|
||||
ADD_CONVERTER_TBE_OP(Deconvolution)
|
||||
|
||||
ADD_CONVERTER_TBE_OP(Upsample)
|
||||
} // namespace acl
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -0,0 +1,89 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/acl/mapper/upsample_mapper.h"
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "tools/converter/acl/mapper/primitive_mapper_register.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "src/common/log_util.h"
|
||||
#include "ops/op_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
namespace {
|
||||
constexpr size_t kScaleMinNum = 2;
|
||||
constexpr size_t kInputNum = 3;
|
||||
} // namespace
|
||||
|
||||
STATUS UpsampleMapper::Mapper(const CNodePtr &cnode) {
|
||||
ValueNodePtr value_node = nullptr;
|
||||
PrimitivePtr src_prim = nullptr;
|
||||
if (GetValueNodeAndPrimFromCnode(cnode, &value_node, &src_prim) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Get value node and primitive from cnode failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
if (cnode->inputs().size() != kInputNum) {
|
||||
MS_LOG(ERROR) << "Upsample input num should be three, real size: " << cnode->inputs().size();
|
||||
return RET_ERROR;
|
||||
}
|
||||
TypeId type_id;
|
||||
if (opt::GetDataTypeFromAnfNode(cnode->inputs()[kInputNum - 1], &type_id) != RET_OK) {
|
||||
MS_LOG(ERROR) << "Get data type failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (type_id == kNumberTypeFloat32) {
|
||||
if (AttrAdjust(src_prim, value_node) != RET_OK) {
|
||||
MS_LOG(ERROR) << "Upsample attr adjust failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (RemoveConstInput(cnode) != RET_OK) {
|
||||
MS_LOG(ERROR) << "Upsample remove const input failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS UpsampleMapper::AttrAdjust(const PrimitivePtr &src_prim, const ValueNodePtr &val_node) {
|
||||
auto attr_val = src_prim->GetAttr("scale");
|
||||
CHECK_NULL_RETURN(attr_val);
|
||||
std::vector<float> scale = opt::CastToFloat(attr_val);
|
||||
if (scale.size() < kScaleMinNum) {
|
||||
MS_LOG(ERROR) << "Scale size must not be less than two, real size: " << scale.size();
|
||||
return RET_ERROR;
|
||||
}
|
||||
MS_LOG(DEBUG) << "The scale value: " << scale[1];
|
||||
auto dst_prim = std::make_shared<acl::Upsample>();
|
||||
CHECK_NULL_RETURN(dst_prim);
|
||||
dst_prim->AddAttr("scale", MakeValue(scale[1]));
|
||||
val_node->set_value(dst_prim);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS UpsampleMapper::RemoveConstInput(const CNodePtr &cnode) {
|
||||
std::vector<AnfNodePtr> inputs{cnode->inputs().begin(), cnode->inputs().end() - 1};
|
||||
cnode->set_inputs(inputs);
|
||||
auto redundant_input = cnode->inputs()[kInputNum - 1];
|
||||
auto graph = cnode->func_graph();
|
||||
CHECK_NULL_RETURN(graph);
|
||||
graph->DropNode(redundant_input);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
REGISTER_PRIMITIVE_MAPPER(kNameResize, UpsampleMapper)
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,42 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef ACL_MAPPER_PRIMITIVE_UPSAMPLE_MAPPER_H
|
||||
#define ACL_MAPPER_PRIMITIVE_UPSAMPLE_MAPPER_H
|
||||
|
||||
#include "tools/converter/acl/mapper/primitive_mapper.h"
|
||||
#include "tools/converter/acl/mapper/tbe_op_def.h"
|
||||
#include "ops/resize.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
using mindspore::ops::kNameResize;
|
||||
|
||||
class UpsampleMapper : public PrimitiveMapper {
|
||||
public:
|
||||
UpsampleMapper() : PrimitiveMapper(acl::kNameUpsample) {}
|
||||
|
||||
~UpsampleMapper() override = default;
|
||||
|
||||
STATUS Mapper(const CNodePtr &cnode) override;
|
||||
|
||||
private:
|
||||
STATUS AttrAdjust(const PrimitivePtr &src_prim, const ValueNodePtr &val_node);
|
||||
STATUS RemoveConstInput(const CNodePtr &cnode);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // ACL_MAPPER_PRIMITIVE_TRANSPOSE_MAPPER_H
|
|
@ -517,6 +517,7 @@ bool AnfTransform::StoreBuiltinPass(const converter::Flags *config) {
|
|||
{"ToNCHWFormat", std::make_shared<opt::ToNCHWFormat>(fmk, is_train)},
|
||||
{"ToNHWCFormat", std::make_shared<opt::ToNHWCFormat>(fmk, is_train)},
|
||||
{"InferShapePass", std::make_shared<opt::InferShapePass>(fmk, is_train)},
|
||||
{"DeleteRedundantTranspose", std::make_shared<opt::DeleteRedundantTranspose>()},
|
||||
{"DecreaseTransposeAlgo", std::make_shared<opt::DecreaseTransposeAlgo>(fmk, is_train)},
|
||||
{"SpecifyGraphInputFormat", std::make_shared<opt::SpecifyGraphInputFormat>(config->graphInputFormat)}};
|
||||
bool succeed_store = true;
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
#define MINDSPORE_LITE_TOOLS_CONVERTER_CONFIG_PARSER_ACL_OPTION_PARAM_PARSER_H
|
||||
#include <string>
|
||||
#include "tools/converter/config_parser/config_file_parser.h"
|
||||
#include "tools/converter/acl/common/acl_option_cfg.h"
|
||||
#include "tools/converter/acl/common/acl_types.h"
|
||||
#include "include/errorcode.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
|
|
@ -26,7 +26,7 @@
|
|||
#include "schema/inner/model_generated.h"
|
||||
#include "tools/converter/preprocess/preprocess_param.h"
|
||||
#include "tools/converter/quantizer/quant_params.h"
|
||||
#include "tools/converter/acl/common/acl_option_cfg.h"
|
||||
#include "tools/converter/acl/common/acl_types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
|
|
@ -17,15 +17,16 @@
|
|||
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FORMAT_DELETE_REDUNDANT_TRANSPOSE_H_
|
||||
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FORMAT_DELETE_REDUNDANT_TRANSPOSE_H_
|
||||
|
||||
#include "backend/optimizer/common/pass.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class DeleteRedundantTranspose {
|
||||
class DeleteRedundantTranspose : public Pass {
|
||||
public:
|
||||
DeleteRedundantTranspose() = default;
|
||||
DeleteRedundantTranspose() : Pass("DeleteRedundantTranspose") {}
|
||||
~DeleteRedundantTranspose() = default;
|
||||
bool Run(const FuncGraphPtr &func_graph);
|
||||
bool Run(const FuncGraphPtr &func_graph) override;
|
||||
|
||||
private:
|
||||
STATUS DeleteNot4DTranspose(const FuncGraphPtr &func_graph);
|
||||
|
|
Loading…
Reference in New Issue