forked from mindspore-Ecosystem/mindspore
Ascend310 infer
This commit is contained in:
parent
111d1a9a61
commit
c8131ef8c4
|
@ -1,5 +1,9 @@
|
|||
message(STATUS "Compiling GraphEngine")
|
||||
set(GE_SOURCE_DIR ${CMAKE_SOURCE_DIR}/graphengine)
|
||||
if(NOT(BUILD_LITE))
|
||||
set(GE_SOURCE_DIR ${CMAKE_SOURCE_DIR}/graphengine)
|
||||
else()
|
||||
set(GE_SOURCE_DIR ${CMAKE_SOURCE_DIR}/../../graphengine)
|
||||
endif()
|
||||
|
||||
message(STATUS "[ME] build_path: ${BUILD_PATH}")
|
||||
|
||||
|
@ -45,7 +49,11 @@ if(ENABLE_TESTCASES)
|
|||
set(ENABLE_GITEE ${_ge_tmp_ENABLE_GITEE})
|
||||
set(CMAKE_CXX_FLAGS ${_ge_tmp_CMAKE_CXX_FLAGS})
|
||||
elseif(MODE_ASCEND_ALL OR MODE_ASCEND_ACL)
|
||||
file(GLOB_RECURSE GE_PROTO_FILE RELATIVE ${CMAKE_SOURCE_DIR} "graphengine/metadef/proto/*.proto")
|
||||
if(NOT(BUILD_LITE))
|
||||
file(GLOB_RECURSE GE_PROTO_FILE RELATIVE ${CMAKE_SOURCE_DIR} "graphengine/metadef/proto/*.proto")
|
||||
else()
|
||||
file(GLOB_RECURSE GE_PROTO_FILE ${TOP_DIR}/graphengine/metadef/proto/*.proto)
|
||||
endif()
|
||||
set(TMP_FILE_NAME_LIST)
|
||||
foreach(file ${GE_PROTO_FILE})
|
||||
get_filename_component(file_name ${file} NAME_WE)
|
||||
|
|
|
@ -7,13 +7,14 @@ if(BUILD_LITE)
|
|||
else()
|
||||
set(glog_CXXFLAGS "-D_FORTIFY_SOURCE=2 -O2 ${SECURE_CXX_FLAGS} -Dgoogle=mindspore_private")
|
||||
set(glog_CFLAGS "-D_FORTIFY_SOURCE=2 -O2")
|
||||
if(NOT ENABLE_GLIBCXX)
|
||||
set(glog_CXXFLAGS "${glog_CXXFLAGS} -D_GLIBCXX_USE_CXX11_ABI=0")
|
||||
endif()
|
||||
set(glog_patch ${CMAKE_SOURCE_DIR}/third_party/patch/glog/glog.patch001)
|
||||
set(glog_lib mindspore_glog)
|
||||
endif()
|
||||
|
||||
if(NOT ENABLE_GLIBCXX)
|
||||
set(glog_CXXFLAGS "${glog_CXXFLAGS} -D_GLIBCXX_USE_CXX11_ABI=0")
|
||||
endif()
|
||||
|
||||
if(ENABLE_GITEE)
|
||||
set(REQ_URL "https://gitee.com/mirrors/glog/repository/archive/v0.4.0.tar.gz")
|
||||
set(MD5 "22fe340ddc231e6c8e46bc295320f8ee")
|
||||
|
|
|
@ -2,6 +2,9 @@ set(protobuf_USE_STATIC_LIBS ON)
|
|||
if(BUILD_LITE)
|
||||
set(protobuf_CXXFLAGS "-fstack-protector-all -Wno-maybe-uninitialized -Wno-unused-parameter \
|
||||
-fPIC -fvisibility=hidden -D_FORTIFY_SOURCE=2 -O2")
|
||||
if(ENABLE_ACL)
|
||||
set(protobuf_CXXFLAGS "${protobuf_CXXFLAGS} -D_GLIBCXX_USE_CXX11_ABI=0")
|
||||
endif()
|
||||
else()
|
||||
if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
|
||||
set(protobuf_CXXFLAGS "-fstack-protector-all -Wno-uninitialized -Wno-unused-parameter -fPIC \
|
||||
|
|
|
@ -515,6 +515,11 @@ else()
|
|||
install(FILES ${glog_LIBPATH}/libglog.so.0.4.0 DESTINATION ${CONVERTER_ROOT_DIR}/lib RENAME libglog.so
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
endif()
|
||||
if(MSLITE_ENABLE_ACL)
|
||||
set(LITE_ACL_DIR ${TOP_DIR}/mindspore/lite/build/tools/converter/acl)
|
||||
install(FILES ${LITE_ACL_DIR}/mindspore_shared_lib/libmindspore_shared_lib.so
|
||||
DESTINATION ${CONVERTER_ROOT_DIR}/lib COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
endif()
|
||||
__install_micro_wrapper()
|
||||
__install_micro_codegen()
|
||||
endif()
|
||||
|
|
|
@ -1,8 +1,15 @@
|
|||
# build mindspore_shared_lib
|
||||
set(LOAD_MINDIR_SRC
|
||||
${CMAKE_SOURCE_DIR}/mindspore/core/load_mindir/load_model.cc
|
||||
${CMAKE_SOURCE_DIR}/mindspore/core/load_mindir/anf_model_parser.cc
|
||||
)
|
||||
if(NOT(BUILD_LITE))
|
||||
set(LOAD_MINDIR_SRC
|
||||
${CMAKE_SOURCE_DIR}/mindspore/core/load_mindir/load_model.cc
|
||||
${CMAKE_SOURCE_DIR}/mindspore/core/load_mindir/anf_model_parser.cc
|
||||
)
|
||||
else()
|
||||
set(MS_UTILS_SRC
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../../mindspore/ccsrc/utils/config_manager.cc
|
||||
)
|
||||
endif()
|
||||
|
||||
file(GLOB_RECURSE API_OPS_SRC ${CMAKE_CURRENT_SOURCE_DIR} "ops/*.cc")
|
||||
|
||||
if(ENABLE_D OR ENABLE_ACL)
|
||||
|
@ -12,11 +19,15 @@ if(ENABLE_D OR ENABLE_ACL)
|
|||
include_directories(${CMAKE_BINARY_DIR}/proto/ge)
|
||||
|
||||
file(GLOB_RECURSE API_ACL_SRC ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
"akg_kernel_register.cc"
|
||||
"model/acl/*.cc"
|
||||
"model/model_converter_utils/*.cc"
|
||||
"graph/acl/*.cc"
|
||||
)
|
||||
|
||||
if(NOT(BUILD_LITE))
|
||||
list(APPEND API_ACL_SRC "akg_kernel_register.cc")
|
||||
endif()
|
||||
|
||||
if(NOT ENABLE_D)
|
||||
list(APPEND API_ACL_SRC $<TARGET_OBJECTS:_mindspore_transform_graph_ir_obj>)
|
||||
endif()
|
||||
|
@ -44,10 +55,13 @@ set(MSLIB_SRC ${CMAKE_CURRENT_SOURCE_DIR}/types.cc
|
|||
${API_MS_INFER_SRC}
|
||||
${API_ACL_SRC}
|
||||
${API_OPS_SRC}
|
||||
${LOAD_MINDIR_SRC})
|
||||
${LOAD_MINDIR_SRC}
|
||||
${MS_UTILS_SRC})
|
||||
|
||||
add_library(mindspore_shared_lib SHARED ${MSLIB_SRC})
|
||||
set_target_properties(mindspore_shared_lib PROPERTIES OUTPUT_NAME mindspore)
|
||||
if(NOT(BUILD_LITE))
|
||||
set_target_properties(mindspore_shared_lib PROPERTIES OUTPUT_NAME mindspore)
|
||||
endif()
|
||||
|
||||
if(CMAKE_SYSTEM_NAME MATCHES "Darwin")
|
||||
target_link_libraries(mindspore_shared_lib PRIVATE ${PYTHON_LIBRARIES} ${SECUREC_LIBRARY}
|
||||
|
@ -58,8 +72,12 @@ else()
|
|||
-Wl,--whole-archive mindspore -Wl,--no-whole-archive mindspore_core proto_input mindspore_gvar
|
||||
mindspore::protobuf)
|
||||
else()
|
||||
target_link_libraries(mindspore_shared_lib PRIVATE ${PYTHON_LIBRARIES} ${SECUREC_LIBRARY}
|
||||
mindspore mindspore_core proto_input mindspore_gvar mindspore::protobuf)
|
||||
if(NOT(BUILD_LITE))
|
||||
target_link_libraries(mindspore_shared_lib PRIVATE ${PYTHON_LIBRARIES} ${SECUREC_LIBRARY}
|
||||
mindspore mindspore_core proto_input mindspore_gvar mindspore::protobuf)
|
||||
else()
|
||||
target_link_libraries(mindspore_shared_lib PRIVATE ${SECUREC_LIBRARY})
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
|
||||
|
|
|
@ -199,7 +199,7 @@ void Ascend310DeviceInfo::SetDumpConfigPath(const std::vector<char> &cfg_path) {
|
|||
}
|
||||
std::vector<char> Ascend310DeviceInfo::GetDumpConfigPathChar() const {
|
||||
MS_EXCEPTION_IF_NULL(data_);
|
||||
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310DeviceID);
|
||||
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310DumpCfgPath);
|
||||
return StringToChar(ref);
|
||||
}
|
||||
|
||||
|
|
|
@ -33,7 +33,7 @@ class AclModelOptions {
|
|||
std::string GenAclOptionsKey() const;
|
||||
uint32_t GetDeviceID() const { return device_id_; }
|
||||
std::string GetDumpCfgPath() const { return dump_cfg_path_; }
|
||||
void RenameInput(const std::vector<std::string> &);
|
||||
void RenameInput(const std::vector<std::string> &name);
|
||||
|
||||
// return tuple<init_options, build_options>
|
||||
std::tuple<std::map<std::string, std::string>, std::map<std::string, std::string>> GenAclOptions() const;
|
||||
|
|
|
@ -1474,6 +1474,65 @@ void DfGraphConvertor::ConvertTopK(const CNodePtr node) {
|
|||
op_cache_[value_ptr.get()] = op;
|
||||
}
|
||||
|
||||
std::vector<int64_t> DfGraphConvertor::CastToInt(const ValuePtr &value) {
|
||||
if (value == nullptr) {
|
||||
MS_LOG(WARNING) << "Value ptr is nullptr.";
|
||||
return {};
|
||||
}
|
||||
std::vector<int64_t> cur_value = {};
|
||||
if (utils::isa<ValueSequeuePtr>(value)) {
|
||||
auto val_seq_ptr = value->cast<ValueSequeuePtr>();
|
||||
MS_EXCEPTION_IF_NULL(val_seq_ptr);
|
||||
if (!val_seq_ptr->value().empty()) {
|
||||
auto first_val = val_seq_ptr->value().front();
|
||||
MS_EXCEPTION_IF_NULL(first_val);
|
||||
MS_EXCEPTION_IF_NULL(first_val->type());
|
||||
if (first_val->type()->number_type() == kNumberTypeInt64) {
|
||||
cur_value = GetValue<std::vector<int64_t>>(value);
|
||||
} else {
|
||||
auto origin_value = GetValue<std::vector<int>>(value);
|
||||
std::transform(origin_value.begin(), origin_value.end(), std::back_inserter(cur_value),
|
||||
[](int index) { return static_cast<int64_t>(index); });
|
||||
}
|
||||
}
|
||||
} else {
|
||||
MS_EXCEPTION_IF_NULL(value->type());
|
||||
if (value->type()->number_type() == kNumberTypeInt64) {
|
||||
cur_value.push_back(GetValue<int64_t>(value));
|
||||
} else {
|
||||
cur_value.push_back(static_cast<int64_t>(GetValue<int>(value)));
|
||||
}
|
||||
}
|
||||
return cur_value;
|
||||
}
|
||||
|
||||
void DfGraphConvertor::ConvertReshape(const CNodePtr node) {
|
||||
MS_LOG(INFO) << "Convert the second input of reshape to op attr.";
|
||||
const auto kInputNum = 3;
|
||||
if (node->size() < kInputNum) {
|
||||
MS_LOG(WARNING) << "Reshape must have two inputs.";
|
||||
return;
|
||||
}
|
||||
OpAdapterPtr adpt = FindAdapter(node, training_);
|
||||
if (adpt == nullptr) {
|
||||
return;
|
||||
}
|
||||
auto op = adpt->generate(node);
|
||||
MS_EXCEPTION_IF_NULL(op);
|
||||
// get shape form attr
|
||||
auto value_node = node->input(0)->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
MS_EXCEPTION_IF_NULL(value_node->value());
|
||||
auto primitive = value_node->value()->cast<PrimitivePtr>();
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto value = primitive->GetAttr("shape");
|
||||
std::vector<int64_t> list;
|
||||
list = CastToInt(value);
|
||||
|
||||
op->SetAttr("shape", list);
|
||||
op_cache_[node.get()] = op;
|
||||
}
|
||||
|
||||
AnfNodePtr DfGraphConvertor::TraceTupleGetItem(const CNodePtr &node, uint64_t *index) {
|
||||
const int TUPLE_GET_ITEM_INDEX = 2;
|
||||
if (node->inputs().size() < 3) { // "tuple_getitem" primitive must have 3 inputs
|
||||
|
@ -1660,6 +1719,12 @@ bool DfGraphConvertor::CheckCNode(const std::string &name, const CNodePtr node)
|
|||
return true;
|
||||
}
|
||||
|
||||
// Convert Reshape add const input to attr(shape)
|
||||
if (name == prim::kPrimReshape->name()) {
|
||||
ConvertReshape(node);
|
||||
return true;
|
||||
}
|
||||
|
||||
// make_tuple is used for a dynamic_input, convert it to a vector of OutHandlers
|
||||
if (name == prim::kPrimMakeTuple->name()) {
|
||||
ConvertMakeTuple(node);
|
||||
|
|
|
@ -158,6 +158,8 @@ class DfGraphConvertor {
|
|||
void ConvertTupleGetItem(const CNodePtr node);
|
||||
void ConvertMakeTuple(const CNodePtr node);
|
||||
void ConvertTopK(const CNodePtr node);
|
||||
void ConvertReshape(const CNodePtr node);
|
||||
std::vector<int64_t> CastToInt(const ValuePtr &value);
|
||||
bool CheckCNode(const std::string &name, const CNodePtr node);
|
||||
void TraceOutput(AnfNodePtr node);
|
||||
void TraceOutputFromParameter(const AnfNodePtr &anf_out);
|
||||
|
|
|
@ -18,8 +18,10 @@
|
|||
|
||||
#include <sstream>
|
||||
|
||||
#ifndef ENABLE_LITE_ACL
|
||||
#include "pipeline/jit/parse/python_adapter.h"
|
||||
#include "pipeline/jit/pipeline.h"
|
||||
#endif
|
||||
#ifndef NO_DLIB
|
||||
#include "tdt/tsd_client.h"
|
||||
#endif
|
||||
|
@ -37,11 +39,13 @@ DfGraphManager::DfGraphManager() {
|
|||
}
|
||||
|
||||
DfGraphManager::~DfGraphManager() {
|
||||
// in python fisrt destroy after atexit but in c++ destoy before atexit
|
||||
// in python first destroy after atexit but in c++ destoy before atexit
|
||||
DeleteGraphRunner();
|
||||
DeleteGeSession();
|
||||
ClearGraph();
|
||||
#ifndef ENABLE_LITE_ACL
|
||||
parse::python_adapter::set_python_env_flag(false);
|
||||
#endif
|
||||
}
|
||||
|
||||
DfGraphManager &DfGraphManager::GetInstance() {
|
||||
|
|
|
@ -19,7 +19,9 @@
|
|||
#include <string>
|
||||
#include <memory>
|
||||
|
||||
#ifndef ENABLE_LITE_ACL
|
||||
#include "pybind11/pybind11.h"
|
||||
#endif
|
||||
#include "utils/log_adapter.h"
|
||||
#include "utils/config_manager.h"
|
||||
#include "sys/time.h"
|
||||
|
@ -40,9 +42,9 @@ Session::Session(const std::map<std::string, std::string> &options) {
|
|||
Session::~Session() {}
|
||||
} // namespace ge
|
||||
#endif
|
||||
|
||||
#ifndef ENABLE_LITE_ACL
|
||||
namespace py = pybind11;
|
||||
|
||||
#endif
|
||||
namespace mindspore {
|
||||
namespace transform {
|
||||
std::shared_ptr<ge::Session> GraphRunner::NewSession(const SessionOptions &sess_options) {
|
||||
|
@ -189,7 +191,9 @@ Status GraphRunner::RunGraph(const RunOptions &options, const std::vector<MeTens
|
|||
Status ret;
|
||||
{
|
||||
// Release GIL before calling into (potentially long-running) C++ code
|
||||
#ifndef ENABLE_LITE_ACL
|
||||
py::gil_scoped_release release;
|
||||
#endif
|
||||
ret = RunGraph(options, ge_inputs, &ge_outputs);
|
||||
}
|
||||
if (ret != Status::SUCCESS) {
|
||||
|
|
|
@ -313,6 +313,21 @@ constexpr const char kNameCTCGreedyDecoder[] = "CTCGreedyDecoder";
|
|||
constexpr const char kNameReverseV2[] = "ReverseV2";
|
||||
constexpr const char kNameLambApplyWeightAssign[] = "LambApplyWeightAssign";
|
||||
constexpr const char kNameLambApplyOptimizerAssign[] = "LambApplyOptimizerAssign";
|
||||
constexpr const char kNameScale[] = "Scale";
|
||||
constexpr const char kNameEltwise[] = "Eltwise";
|
||||
constexpr const char kNameFullConnection[] = "FullConnection";
|
||||
constexpr const char kNameFusedBatchNorm[] = "FusedBatchNorm";
|
||||
constexpr const char kNamePooling[] = "Pooling";
|
||||
constexpr const char kNameMaxPoolV3[] = "MaxPoolV3";
|
||||
constexpr const char kNameAvgPoolV2[] = "AvgPoolV2";
|
||||
constexpr const char kNameShape[] = "Shape";
|
||||
constexpr const char kNameGather[] = "Gather";
|
||||
constexpr const char kNameUnsqueeze[] = "Unsqueeze";
|
||||
constexpr const char kNamePadV3[] = "PadV3";
|
||||
constexpr const char kNameGlobalAvgPool[] = "GlobalAveragePool";
|
||||
constexpr const char kNameStridedSliceV2[] = "StridedSliceV2";
|
||||
constexpr const char kNameBNInference[] = "BNInference";
|
||||
constexpr const char kNameDeconvolution[] = "Deconvolution";
|
||||
|
||||
class OpAdapterMap {
|
||||
public:
|
||||
|
|
|
@ -43,6 +43,12 @@ INPUT_MAP(Data) = EMPTY_INPUT_MAP;
|
|||
ATTR_MAP(Data) = EMPTY_ATTR_MAP;
|
||||
REG_ADPT_DESC(Data, kNameParam, ADPT_DESC(Data))
|
||||
|
||||
// Shape
|
||||
INPUT_MAP(Shape) = {{1, INPUT_DESC(x)}};
|
||||
ATTR_MAP(Shape) = EMPTY_ATTR_MAP;
|
||||
OUTPUT_MAP(Shape) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(Shape, kNameShape, ADPT_DESC(Shape))
|
||||
|
||||
// Reshape
|
||||
INPUT_MAP(Reshape) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(shape)}};
|
||||
ATTR_MAP(Reshape) = EMPTY_ATTR_MAP;
|
||||
|
@ -95,4 +101,10 @@ INPUT_MAP(EditDistance) = {{1, INPUT_DESC(hypothesis_indices)}, {2, INPUT_DESC(h
|
|||
ATTR_MAP(EditDistance) = {{"normalize", ATTR_DESC(normalize, AnyTraits<bool>())}};
|
||||
OUTPUT_MAP(EditDistance) = {{0, OUTPUT_DESC(output)}};
|
||||
REG_ADPT_DESC(EditDistance, kNameEditDistance, ADPT_DESC(EditDistance))
|
||||
|
||||
// Unsqueeze
|
||||
INPUT_MAP(Unsqueeze) = {{1, INPUT_DESC(x)}};
|
||||
ATTR_MAP(Unsqueeze) = {{"axis", ATTR_DESC(axes, AnyTraits<int64_t>(), AnyTraits<std::vector<int64_t>>())}};
|
||||
OUTPUT_MAP(Unsqueeze) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(Unsqueeze, kNameUnsqueeze, ADPT_DESC(Unsqueeze))
|
||||
} // namespace mindspore::transform
|
||||
|
|
|
@ -23,6 +23,9 @@
|
|||
#include "ops/array_ops.h"
|
||||
|
||||
namespace mindspore::transform {
|
||||
DECLARE_OP_ADAPTER(Shape)
|
||||
DECLARE_OP_USE_OUTPUT(Shape)
|
||||
|
||||
DECLARE_OP_ADAPTER(Reshape)
|
||||
DECLARE_OP_USE_OUTPUT(Reshape)
|
||||
|
||||
|
@ -57,5 +60,8 @@ DECLARE_OP_USE_OUTPUT(ReverseSequence)
|
|||
|
||||
DECLARE_OP_ADAPTER(EditDistance)
|
||||
DECLARE_OP_USE_OUTPUT(EditDistance)
|
||||
|
||||
DECLARE_OP_ADAPTER(Unsqueeze)
|
||||
DECLARE_OP_USE_OUTPUT(Unsqueeze)
|
||||
} // namespace mindspore::transform
|
||||
#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_ARRAY_OPS_DECLARE_H_
|
||||
|
|
|
@ -637,4 +637,13 @@ INPUT_MAP(LambApplyWeightAssign) = {{1, INPUT_DESC(input0)},
|
|||
ATTR_MAP(LambApplyWeightAssign) = EMPTY_ATTR_MAP;
|
||||
OUTPUT_MAP(LambApplyWeightAssign) = {{0, OUTPUT_DESC(input_param)}};
|
||||
REG_ADPT_DESC(LambApplyWeightAssign, kNameLambApplyWeightAssign, ADPT_DESC(LambApplyWeightAssign))
|
||||
|
||||
// Eltwise
|
||||
INPUT_MAP(Eltwise) = EMPTY_INPUT_MAP;
|
||||
DYN_INPUT_MAP(Eltwise) = {{1, DYN_INPUT_DESC(x)}};
|
||||
ATTR_MAP(Eltwise) = {{"n", ATTR_DESC(N, AnyTraits<int64_t>())},
|
||||
{"mode", ATTR_DESC(mode, AnyTraits<int64_t>())},
|
||||
{"coeff", ATTR_DESC(coeff, AnyTraits<std::vector<float>>(), AnyTraits<float>())}};
|
||||
OUTPUT_MAP(Eltwise) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(Eltwise, kNameEltwise, ADPT_DESC(Eltwise))
|
||||
} // namespace mindspore::transform
|
||||
|
|
|
@ -316,5 +316,8 @@ DECLARE_OP_USE_OUTPUT(LambApplyOptimizerAssign)
|
|||
|
||||
DECLARE_OP_ADAPTER(LambApplyWeightAssign)
|
||||
DECLARE_OP_USE_OUTPUT(LambApplyWeightAssign)
|
||||
|
||||
DECLARE_OP_ADAPTER(Eltwise)
|
||||
DECLARE_OP_USE_OUTPUT(Eltwise)
|
||||
} // namespace mindspore::transform
|
||||
#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_ELEWISE_CALCULATION_OPS_DECLARE_H_
|
||||
|
|
|
@ -133,4 +133,15 @@ INPUT_MAP(L2Loss) = {{1, INPUT_DESC(x)}};
|
|||
ATTR_MAP(L2Loss) = EMPTY_ATTR_MAP;
|
||||
OUTPUT_MAP(L2Loss) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(L2Loss, kNameL2Loss, ADPT_DESC(L2Loss))
|
||||
|
||||
// FullyConnection
|
||||
INPUT_MAP(FullyConnection) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(w)}, {3, INPUT_DESC(b)}, {4, INPUT_DESC(offset_w)}};
|
||||
|
||||
ATTR_MAP(FullyConnection) = {{"num_output", ATTR_DESC(num_output, AnyTraits<int64_t>())},
|
||||
{"transpose", ATTR_DESC(transpose, AnyTraits<bool>())},
|
||||
{"axis", ATTR_DESC(axis, AnyTraits<int64_t>())},
|
||||
{"offset_x", ATTR_DESC(offset_x, AnyTraits<int64_t>())}};
|
||||
|
||||
OUTPUT_MAP(FullyConnection) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(FullyConnection, kNameFullConnection, ADPT_DESC(FullyConnection))
|
||||
} // namespace mindspore::transform
|
||||
|
|
|
@ -79,5 +79,8 @@ DECLARE_OP_USE_OUTPUT(DiagPart)
|
|||
|
||||
DECLARE_OP_ADAPTER(L2Loss)
|
||||
DECLARE_OP_USE_OUTPUT(L2Loss)
|
||||
|
||||
DECLARE_OP_ADAPTER(FullyConnection)
|
||||
DECLARE_OP_USE_OUTPUT(FullyConnection)
|
||||
} // namespace mindspore::transform
|
||||
#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_MATRIX_CALCULATION_OPS_DECLARE_H_
|
||||
|
|
|
@ -32,7 +32,17 @@ OUTPUT_MAP(BatchNorm) = {{0, OUTPUT_DESC(y)},
|
|||
{2, OUTPUT_DESC(batch_variance)},
|
||||
{3, OUTPUT_DESC(reserve_space_1)},
|
||||
{4, OUTPUT_DESC(reserve_space_2)}};
|
||||
// BNInference is BatchNorm for caffe
|
||||
INPUT_MAP(BNInference) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(mean)}, {3, INPUT_DESC(variance)},
|
||||
{4, INPUT_DESC(momentum)}, {5, INPUT_DESC(scale)}, {6, INPUT_DESC(offset)}};
|
||||
ATTR_MAP(BNInference) = {{"epsilon", ATTR_DESC(epsilon, AnyTraits<float>())},
|
||||
{"use_global_stats", ATTR_DESC(use_global_stats, AnyTraits<bool>())},
|
||||
{"mode", ATTR_DESC(mode, AnyTraits<int64_t>())}};
|
||||
OUTPUT_MAP(BNInference) = {{0, OUTPUT_DESC(y)}};
|
||||
|
||||
REG_ADPT_DESC(BNInference, kNameBNInference, ADPT_DESC(BNInference))
|
||||
REG_ADPT_DESC(BatchNorm, kNameBatchNorm, ADPT_DESC(BatchNorm))
|
||||
REG_ADPT_DESC(FusedBatchNorm, kNameFusedBatchNorm, ADPT_DESC(BatchNorm))
|
||||
|
||||
// BatchNormGrad
|
||||
INPUT_MAP(BatchNormGrad) = {{1, INPUT_DESC(y_backprop)},
|
||||
|
@ -65,4 +75,5 @@ ATTR_MAP(L2Normalize) = {
|
|||
{"epsilon", ATTR_DESC(eps, AnyTraits<float>())}};
|
||||
OUTPUT_MAP(L2Normalize) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(L2Normalize, kNameL2Normalize, ADPT_DESC(L2Normalize))
|
||||
|
||||
} // namespace mindspore::transform
|
||||
|
|
|
@ -26,6 +26,9 @@ namespace mindspore::transform {
|
|||
DECLARE_OP_ADAPTER(BatchNorm)
|
||||
DECLARE_OP_USE_OUTPUT(BatchNorm)
|
||||
|
||||
DECLARE_OP_ADAPTER(BNInference)
|
||||
DECLARE_OP_USE_OUTPUT(BNInference)
|
||||
|
||||
DECLARE_OP_ADAPTER(BatchNormGrad)
|
||||
DECLARE_OP_USE_OUTPUT(BatchNormGrad)
|
||||
|
||||
|
|
|
@ -49,6 +49,19 @@ ATTR_MAP(Conv2DBackpropInputD) = {
|
|||
};
|
||||
OUTPUT_MAP(Conv2DBackpropInputD) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(Conv2DBackpropInputD, prim::kPrimConv2DBackpropInput->name(), ADPT_DESC(Conv2DBackpropInputD))
|
||||
|
||||
// Deconvolution for caffe inference
|
||||
INPUT_MAP(Deconvolution) = {
|
||||
{1, INPUT_DESC(x)}, {2, INPUT_DESC(filter)}, {3, INPUT_DESC(bias)}, {4, INPUT_DESC(offset_w)}};
|
||||
ATTR_MAP(Deconvolution) = {
|
||||
{"stride", ATTR_DESC(strides, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())},
|
||||
{"pad_list", ATTR_DESC(pads, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())},
|
||||
{"dilation", ATTR_DESC(dilations, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())},
|
||||
{"group", ATTR_DESC(groups, AnyTraits<int64_t>())},
|
||||
{"format", ATTR_DESC(data_format, AnyTraits<string>())},
|
||||
{"offset", ATTR_DESC(offset_x, AnyTraits<int64_t>())}};
|
||||
OUTPUT_MAP(Deconvolution) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(Deconvolution, kNameDeconvolution, ADPT_DESC(Deconvolution))
|
||||
REG_ADPT_DESC(Conv2DTranspose, kConv2DTransposeOpName, ADPT_DESC(Conv2DBackpropInputD))
|
||||
|
||||
// Conv2DBackpropFilterD
|
||||
|
|
|
@ -69,5 +69,8 @@ DECLARE_OP_USE_OUTPUT(DepthwiseConv2DBackpropFilterD)
|
|||
DECLARE_OP_ADAPTER(DepthwiseConv2DBackpropInputD)
|
||||
DECLARE_OP_USE_INPUT_ATTR(DepthwiseConv2DBackpropInputD)
|
||||
DECLARE_OP_USE_OUTPUT(DepthwiseConv2DBackpropInputD)
|
||||
|
||||
DECLARE_OP_ADAPTER(Deconvolution)
|
||||
DECLARE_OP_USE_OUTPUT(Deconvolution)
|
||||
} // namespace mindspore::transform
|
||||
#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_NN_CALCULATION_OPS_DECLARE_H_
|
||||
|
|
|
@ -146,4 +146,13 @@ INPUT_MAP(Centralization) = {{1, INPUT_DESC(x)}};
|
|||
ATTR_MAP(Centralization) = {{"axes", ATTR_DESC(axes, AnyTraits<std::vector<int64_t>>())}};
|
||||
OUTPUT_MAP(Centralization) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(Centralization, kNameCentralization, ADPT_DESC(Centralization))
|
||||
|
||||
// Scale
|
||||
INPUT_MAP(Scale) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(scale)}, {3, INPUT_DESC(bias)}};
|
||||
ATTR_MAP(Scale) = {{"axis", ATTR_DESC(axis, AnyTraits<int64_t>())},
|
||||
{"num_axes", ATTR_DESC(num_axes, AnyTraits<int64_t>())},
|
||||
{"scale_from_blob", ATTR_DESC(scale_from_blob, AnyTraits<bool>())}};
|
||||
|
||||
OUTPUT_MAP(Scale) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(Scale, kNameScale, ADPT_DESC(Scale))
|
||||
} // namespace mindspore::transform
|
||||
|
|
|
@ -76,5 +76,8 @@ DECLARE_OP_USE_OUTPUT(BinaryCrossEntropyGrad)
|
|||
|
||||
DECLARE_OP_ADAPTER(Centralization)
|
||||
DECLARE_OP_USE_OUTPUT(Centralization)
|
||||
|
||||
DECLARE_OP_ADAPTER(Scale)
|
||||
DECLARE_OP_USE_OUTPUT(Scale)
|
||||
} // namespace mindspore::transform
|
||||
#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_IMAGE_OPS_DECLARE_H_
|
||||
|
|
|
@ -120,4 +120,47 @@ ATTR_MAP(MaxPoolGradGradWithArgmax) = {
|
|||
{"pad_mode", ATTR_DESC(padding, AnyTraits<std::string>())}};
|
||||
OUTPUT_MAP(MaxPoolGradGradWithArgmax) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(MaxPoolGradGradWithArgmax, kNameMaxPoolGradGradWithArgmax, ADPT_DESC(MaxPoolGradGradWithArgmax))
|
||||
|
||||
// Pooling
|
||||
INPUT_MAP(Pooling) = {{1, INPUT_DESC(x)}};
|
||||
ATTR_MAP(Pooling) = {{"mode", ATTR_DESC(mode, AnyTraits<int64_t>())},
|
||||
{"global", ATTR_DESC(global_pooling, AnyTraits<bool>())},
|
||||
{"kernel_size", ATTR_DESC(window, AnyTraits<int64_t>(), AnyTraits<std::vector<int64_t>>())},
|
||||
{"strides", ATTR_DESC(stride, AnyTraits<int64_t>(), AnyTraits<std::vector<int64_t>>())},
|
||||
{"pad", ATTR_DESC(pad, AnyTraits<int64_t>(), AnyTraits<std::vector<int64_t>>())},
|
||||
{"dilation", ATTR_DESC(dilation, AnyTraits<int64_t>(), AnyTraits<std::vector<int64_t>>())},
|
||||
{"round_mode", ATTR_DESC(ceil_mode, AnyTraits<int64_t>())},
|
||||
{"format", ATTR_DESC(data_format, AnyTraits<std::string>())}};
|
||||
OUTPUT_MAP(Pooling) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(Pooling, kNamePooling, ADPT_DESC(Pooling))
|
||||
|
||||
// MaxPoolV3
|
||||
INPUT_MAP(MaxPoolV3) = {{1, INPUT_DESC(x)}};
|
||||
ATTR_MAP(MaxPoolV3) = {{"kernel_size", ATTR_DESC(ksize, AnyTraits<int64_t>(), AnyTraits<std::vector<int64_t>>())},
|
||||
{"strides", ATTR_DESC(strides, AnyTraits<int64_t>(), AnyTraits<std::vector<int64_t>>())},
|
||||
{"padding_mode", ATTR_DESC(padding_mode, AnyTraits<std::string>())},
|
||||
{"pad", ATTR_DESC(pads, AnyTraits<int64_t>(), AnyTraits<std::vector<int64_t>>())},
|
||||
{"format", ATTR_DESC(data_format, AnyTraits<std::string>())},
|
||||
{"global", ATTR_DESC(global_pooling, AnyTraits<bool>())},
|
||||
{"ceil_mode", ATTR_DESC(ceil_mode, AnyTraits<bool>())}};
|
||||
OUTPUT_MAP(MaxPoolV3) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(MaxPoolV3, kNameMaxPoolV3, ADPT_DESC(MaxPoolV3))
|
||||
|
||||
// AvgPoolV2
|
||||
INPUT_MAP(AvgPoolV2) = {{1, INPUT_DESC(x)}};
|
||||
ATTR_MAP(AvgPoolV2) = {{"kernel_size", ATTR_DESC(ksize, AnyTraits<int64_t>(), AnyTraits<std::vector<int64_t>>())},
|
||||
{"strides", ATTR_DESC(strides, AnyTraits<int64_t>(), AnyTraits<std::vector<int64_t>>())},
|
||||
{"padding_mode", ATTR_DESC(padding_mode, AnyTraits<std::string>())},
|
||||
{"pad", ATTR_DESC(pads, AnyTraits<int64_t>(), AnyTraits<std::vector<int64_t>>())},
|
||||
{"format", ATTR_DESC(data_format, AnyTraits<std::string>())},
|
||||
{"global", ATTR_DESC(global_pooling, AnyTraits<bool>())},
|
||||
{"ceil_mode", ATTR_DESC(ceil_mode, AnyTraits<bool>())}};
|
||||
OUTPUT_MAP(AvgPoolV2) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(AvgPoolV2, kNameAvgPoolV2, ADPT_DESC(AvgPoolV2))
|
||||
|
||||
// GlobalAveragePool
|
||||
INPUT_MAP(GlobalAveragePool) = {{1, INPUT_DESC(x)}};
|
||||
ATTR_MAP(GlobalAveragePool) = EMPTY_ATTR_MAP;
|
||||
OUTPUT_MAP(GlobalAveragePool) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(GlobalAveragePool, kNameGlobalAvgPool, ADPT_DESC(GlobalAveragePool))
|
||||
} // namespace mindspore::transform
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#include <unordered_map>
|
||||
#include "transform/graph_ir/op_declare/op_declare_macro.h"
|
||||
#include "ops/nn_ops.h"
|
||||
#include "ops/nn_pooling_ops.h"
|
||||
|
||||
namespace mindspore::transform {
|
||||
DECLARE_OP_ADAPTER(MaxPoolWithArgmax)
|
||||
|
@ -55,5 +56,17 @@ DECLARE_OP_USE_OUTPUT(AvgPool)
|
|||
|
||||
DECLARE_OP_ADAPTER(AvgPoolGrad)
|
||||
DECLARE_OP_USE_OUTPUT(AvgPoolGrad)
|
||||
|
||||
DECLARE_OP_ADAPTER(Pooling)
|
||||
DECLARE_OP_USE_OUTPUT(Pooling)
|
||||
|
||||
DECLARE_OP_ADAPTER(MaxPoolV3)
|
||||
DECLARE_OP_USE_OUTPUT(MaxPoolV3)
|
||||
|
||||
DECLARE_OP_ADAPTER(AvgPoolV2)
|
||||
DECLARE_OP_USE_OUTPUT(AvgPoolV2)
|
||||
|
||||
DECLARE_OP_ADAPTER(GlobalAveragePool)
|
||||
DECLARE_OP_USE_OUTPUT(GlobalAveragePool)
|
||||
} // namespace mindspore::transform
|
||||
#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_NN_POOLING_OPS_DECLARE_H_
|
||||
|
|
|
@ -154,4 +154,10 @@ INPUT_MAP(FastGeluGrad) = {{1, INPUT_DESC(dy)}, {2, INPUT_DESC(x)}};
|
|||
ATTR_MAP(FastGeluGrad) = EMPTY_ATTR_MAP;
|
||||
OUTPUT_MAP(FastGeluGrad) = {{0, OUTPUT_DESC(z)}};
|
||||
REG_ADPT_DESC(FastGeluGrad, prim::kPrimFastGeLUGrad->name(), ADPT_DESC(FastGeluGrad))
|
||||
|
||||
// LeakyRelu
|
||||
INPUT_MAP(LeakyRelu) = {{1, INPUT_DESC(x)}};
|
||||
ATTR_MAP(LeakyRelu) = {{"alpha", ATTR_DESC(negative_slope, AnyTraits<float>())}};
|
||||
OUTPUT_MAP(LeakyRelu) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(LeakyRelu, prim::kPrimLeakyRelu->name(), ADPT_DESC(LeakyRelu))
|
||||
} // namespace mindspore::transform
|
||||
|
|
|
@ -91,5 +91,8 @@ DECLARE_OP_USE_OUTPUT(Sigmoid)
|
|||
|
||||
DECLARE_OP_ADAPTER(SigmoidGrad)
|
||||
DECLARE_OP_USE_OUTPUT(SigmoidGrad)
|
||||
|
||||
DECLARE_OP_ADAPTER(LeakyRelu)
|
||||
DECLARE_OP_USE_OUTPUT(LeakyRelu)
|
||||
} // namespace mindspore::transform
|
||||
#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_NONLINEAR_FUC_OPS_DECLARE_H_
|
||||
|
|
|
@ -41,4 +41,11 @@ INPUT_MAP(FillD) = {{1, INPUT_DESC(value)}};
|
|||
ATTR_MAP(FillD) = {{"dims", ATTR_DESC(dims, AnyTraits<std::vector<int64_t>>())}};
|
||||
OUTPUT_MAP(FillD) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(FillD, kNameFillD, ADPT_DESC(FillD))
|
||||
|
||||
// PadV3
|
||||
INPUT_MAP(PadV3) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(paddings)}, {3, INPUT_DESC(constant_values)}};
|
||||
ATTR_MAP(PadV3) = {{"mode", ATTR_DESC(mode, AnyTraits<std::string>())},
|
||||
{"pad_contiguous", ATTR_DESC(paddings_contiguous, AnyTraits<bool>())}};
|
||||
OUTPUT_MAP(PadV3) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(PadV3, kNamePadV3, ADPT_DESC(PadV3))
|
||||
} // namespace mindspore::transform
|
||||
|
|
|
@ -34,5 +34,8 @@ DECLARE_OP_USE_OUTPUT(Diag)
|
|||
|
||||
DECLARE_OP_ADAPTER(FillD)
|
||||
DECLARE_OP_USE_OUTPUT(FillD)
|
||||
|
||||
DECLARE_OP_ADAPTER(PadV3)
|
||||
DECLARE_OP_USE_OUTPUT(PadV3)
|
||||
} // namespace mindspore::transform
|
||||
#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_PAD_OPS_DECLARE_H_
|
||||
|
|
|
@ -77,6 +77,7 @@ INPUT_ATTR_MAP(GatherV2D) = {{3, ATTR_DESC(axis, AnyTraits<int64_t>())}};
|
|||
ATTR_MAP(GatherV2D) = EMPTY_ATTR_MAP;
|
||||
OUTPUT_MAP(GatherV2D) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(GatherV2D, prim::kPrimGather->name(), ADPT_DESC(GatherV2D))
|
||||
REG_ADPT_DESC(Gather, kNameGather, ADPT_DESC(GatherV2D))
|
||||
|
||||
// ScatterNdD
|
||||
INPUT_MAP(ScatterNdD) = {{1, INPUT_DESC(indices)}, {2, INPUT_DESC(x)}};
|
||||
|
@ -151,6 +152,17 @@ ATTR_MAP(StridedSlice) = {{"begin_mask", ATTR_DESC(begin_mask, AnyTraits<int64_t
|
|||
OUTPUT_MAP(StridedSlice) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(StridedSlice, kNameStridedSlice, ADPT_DESC(StridedSlice))
|
||||
|
||||
// StridedSliceV2
|
||||
INPUT_MAP(StridedSliceV2) = {
|
||||
{1, INPUT_DESC(x)}, {2, INPUT_DESC(begin)}, {3, INPUT_DESC(end)}, {4, INPUT_DESC(axes)}, {5, INPUT_DESC(strides)}};
|
||||
ATTR_MAP(StridedSliceV2) = {{"begin_mask", ATTR_DESC(begin_mask, AnyTraits<int64_t>())},
|
||||
{"end_mask", ATTR_DESC(end_mask, AnyTraits<int64_t>())},
|
||||
{"ellipsis_mask", ATTR_DESC(ellipsis_mask, AnyTraits<int64_t>())},
|
||||
{"new_axis_mask", ATTR_DESC(new_axis_mask, AnyTraits<int64_t>())},
|
||||
{"shrink_axis_mask", ATTR_DESC(shrink_axis_mask, AnyTraits<int64_t>())}};
|
||||
OUTPUT_MAP(StridedSliceV2) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(StridedSliceV2, kNameStridedSliceV2, ADPT_DESC(StridedSliceV2))
|
||||
|
||||
// UnsortedSegmentSum
|
||||
INPUT_MAP(UnsortedSegmentSumD) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(segment_ids)}};
|
||||
INPUT_ATTR_MAP(UnsortedSegmentSumD) = {{3, ATTR_DESC(num_segments, AnyTraits<int64_t>())}};
|
||||
|
|
|
@ -52,6 +52,9 @@ DECLARE_OP_USE_OUTPUT(StridedSliceGrad)
|
|||
DECLARE_OP_ADAPTER(StridedSlice)
|
||||
DECLARE_OP_USE_OUTPUT(StridedSlice)
|
||||
|
||||
DECLARE_OP_ADAPTER(StridedSliceV2)
|
||||
DECLARE_OP_USE_OUTPUT(StridedSliceV2)
|
||||
|
||||
DECLARE_OP_ADAPTER(UnsortedSegmentSumD)
|
||||
DECLARE_OP_USE_INPUT_ATTR(UnsortedSegmentSumD)
|
||||
DECLARE_OP_USE_OUTPUT(UnsortedSegmentSumD)
|
||||
|
|
|
@ -219,7 +219,14 @@ void FuncGraph::AddNode(const AnfNodePtr &node) { nodes_.add(node); }
|
|||
|
||||
void FuncGraph::DropNode(const AnfNodePtr &node) {
|
||||
nodes_.erase(node);
|
||||
if (node == nullptr) {
|
||||
MS_LOG(ERROR) << "Node is nullptr";
|
||||
return;
|
||||
}
|
||||
auto graph = node->func_graph();
|
||||
if (node->isa<Parameter>()) {
|
||||
parameters_.erase(std::remove(parameters_.begin(), parameters_.end(), node), parameters_.end());
|
||||
}
|
||||
// Remove the node from order list.
|
||||
if (graph) {
|
||||
graph->EraseUnusedNodeInOrder(node);
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
#include "ops/fusion/scale_fusion.h"
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "ops/op_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
|
|
@ -254,6 +254,10 @@ constexpr auto kSplitDim = "split_dim";
|
|||
constexpr auto kPadTop = "pad_top";
|
||||
constexpr auto kTransFormat = "trans_format";
|
||||
constexpr auto kApproximate = "approximate";
|
||||
constexpr auto kNumOutput = "num_output";
|
||||
constexpr auto kUseGlobalStats = "use_global_stats";
|
||||
constexpr auto kFmkType = "fmk_type";
|
||||
|
||||
const std::set<TypePtr> common_valid_types = {kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16,
|
||||
kUInt32, kUInt64, kFloat16, kFloat32, kFloat64};
|
||||
|
||||
|
|
|
@ -138,7 +138,11 @@ static std::map<std::string, std::map<std::string, AttrConverterPair>> PrimAttrC
|
|||
{"BinaryCrossEntropyGrad", ReductionMap},
|
||||
{"NLLLoss", ReductionMap},
|
||||
{"DepthToSpace", DataFormatMap},
|
||||
};
|
||||
{"Pooling", DataFormatMap},
|
||||
{"Deconvolution", DataFormatMap},
|
||||
{"AvgPoolV2", DataFormatMap},
|
||||
{"MaxPoolV3", DataFormatMap},
|
||||
{"FusedBatchNorm", DataFormatMap}};
|
||||
|
||||
bool CheckAndConvertUtils::GetDataFormatEnumValue(const ValuePtr &value, int64_t *enum_value) {
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
|
|
|
@ -41,6 +41,7 @@ option(MSLITE_ENABLE_MINDRT "enable mindrt use" on)
|
|||
option(MSLITE_DELEGATE_USE "enable delegate use" on)
|
||||
option(MSLITE_ENABLE_V0 "support v0 schema" on)
|
||||
option(MSLITE_ENABLE_FP16 "Whether to compile Fp16 operator" off)
|
||||
option(MSLITE_ENABLE_ACL "enable ACL" off)
|
||||
|
||||
#Option that can be configured through manually
|
||||
option(ENABLE_VERBOSE "" off)
|
||||
|
@ -119,6 +120,10 @@ if(DEFINED ENV{MSLITE_ENABLE_FP16})
|
|||
set(MSLITE_ENABLE_FP16 $ENV{MSLITE_ENABLE_FP16})
|
||||
endif()
|
||||
|
||||
if(DEFINED ENV{MSLITE_ENABLE_ACL})
|
||||
set(MSLITE_ENABLE_ACL $ENV{MSLITE_ENABLE_ACL})
|
||||
endif()
|
||||
|
||||
if(PLATFORM_ARM64)
|
||||
if(MSLITE_GPU_BACKEND STREQUAL "")
|
||||
set(MSLITE_GPU_BACKEND "opencl")
|
||||
|
@ -240,6 +245,21 @@ if(ENABLE_ASAN)
|
|||
add_link_options(-fsanitize=address)
|
||||
endif()
|
||||
|
||||
if(MSLITE_ENABLE_ACL)
|
||||
set(ENABLE_ACL on)
|
||||
add_definitions(-D ENABLE_LITE_ACL)
|
||||
add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,--allow-shlib-undefined")
|
||||
|
||||
if(DEFINED ENV{ASCEND_CUSTOM_PATH})
|
||||
set(ASCEND_PATH $ENV{ASCEND_CUSTOM_PATH})
|
||||
else()
|
||||
set(ASCEND_PATH /usr/local/Ascend)
|
||||
endif()
|
||||
set(ASCEND_RUNTIME_PATH ${ASCEND_PATH}/fwkacllib/lib64)
|
||||
set(ASCEND_TOOLKIT_RUNTIME_PATH ${ASCEND_PATH}/ascend-toolkit/latest/fwkacllib/lib64)
|
||||
endif()
|
||||
|
||||
set(PKG_NAME_PREFIX mindspore-lite-${MS_VERSION_MAJOR}.${MS_VERSION_MINOR}.${MS_VERSION_REVISION})
|
||||
|
||||
if(SUPPORT_NPU)
|
||||
|
|
|
@ -381,9 +381,18 @@ build_aar() {
|
|||
sha256sum mindspore-lite-maven-${VERSION_STR}.zip > mindspore-lite-maven-${VERSION_STR}.zip.sha256
|
||||
}
|
||||
|
||||
update_submodule()
|
||||
{
|
||||
git submodule update --init graphengine
|
||||
cd "${BASEPATH}/graphengine"
|
||||
git submodule update --init metadef
|
||||
}
|
||||
|
||||
LITE_JAVA_PATH=${BASEPATH}/mindspore/lite/java
|
||||
LITE_BUILD_TYPE="Release"
|
||||
if [[ "${MSLITE_ENABLE_ACL}" == "on" ]]; then
|
||||
update_submodule
|
||||
fi
|
||||
if [[ "${DEBUG_MODE}" == "on" ]]; then
|
||||
LITE_BUILD_TYPE="Debug"
|
||||
fi
|
||||
|
|
|
@ -233,6 +233,11 @@ add_subdirectory(runtime/kernel/arm)
|
|||
add_library(lite_src_mid OBJECT ${LITE_SRC})
|
||||
add_dependencies(lite_src_mid fbs_src)
|
||||
|
||||
if(MSLITE_ENABLE_ACL)
|
||||
add_subdirectory(runtime/kernel/ascend310)
|
||||
link_directories(${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
|
||||
endif()
|
||||
|
||||
add_library(mindspore-lite SHARED $<TARGET_OBJECTS:lite_src_mid>)
|
||||
set_target_properties(mindspore-lite PROPERTIES CLEAN_DIRECT_OUTPUT 1)
|
||||
|
||||
|
@ -387,3 +392,8 @@ if(ENABLE_MODEL_OBF)
|
|||
target_link_libraries(mindspore-lite ${OBF_LIB_DIR}/libmsdeobfuscator-lite.so)
|
||||
target_link_libraries(mindspore-lite_static ${OBF_LIB_DIR}/libmsdeobfuscator-lite.so)
|
||||
endif()
|
||||
|
||||
if(MSLITE_ENABLE_ACL)
|
||||
target_link_libraries(mindspore-lite ascend310_kernel_mid)
|
||||
target_link_libraries(mindspore-lite_static ascend310_kernel_mid)
|
||||
endif()
|
||||
|
|
|
@ -33,6 +33,19 @@ constexpr auto kModelOptionGPUEnableFP16 = "mindspore.option.gpu.enable_fp16";
|
|||
constexpr auto kModelOptionKirinNpuFrequency = "mindspore.option.kirin_npu.frequency";
|
||||
constexpr auto kModelOptionProvider = "mindspore.option.provider";
|
||||
constexpr auto kModelOptionProviderDevice = "mindspore.option.provider.device";
|
||||
constexpr auto kModelOptionDeviceID = "mindspore.option.device_id";
|
||||
constexpr auto kModelOptionAscend310DeviceID = kModelOptionDeviceID;
|
||||
constexpr auto kModelOptionAscend310DumpCfgPath = "mindspore.option.ascend310.dump_config_file_path";
|
||||
constexpr auto kModelOptionAscend310InsertOpCfgPath = "mindspore.option.ascend310.insert_op_config_file_path";
|
||||
constexpr auto kModelOptionAscend310InputFormat = "mindspore.option.ascend310.input_format";
|
||||
constexpr auto kModelOptionAscend310InputShapeMap = "mindspore.option.ascend310.input_shape_map";
|
||||
constexpr auto kModelOptionAscend310InputShape = "mindspore.option.ascend310.input_shape";
|
||||
constexpr auto kModelOptionAscend310OutputType = "mindspore.option.ascend310.output_type";
|
||||
constexpr auto kModelOptionAscend310PrecisionMode = "mindspore.option.ascend310.precision_mode";
|
||||
constexpr auto kModelOptionAscend310OpSelectImplMode = "mindspore.option.ascend310.op_select_impl_mode";
|
||||
constexpr auto KModelOptionAscend310FusionSwitchCfgPath = "mindspore.option.ascend310.fusion_switch_config_file_path";
|
||||
constexpr auto kModelOptionAscend310DynamicBatchSize = "mindspore.option.ascend310.dynamic_batch_size";
|
||||
constexpr auto kModelOptionAscend310BufferOptimize = "mindspore.option.ascend310.buffer_optimize";
|
||||
|
||||
struct Context::Data {
|
||||
std::vector<std::shared_ptr<DeviceInfoContext>> device_info_list;
|
||||
|
@ -290,101 +303,208 @@ uint32_t Ascend910DeviceInfo::GetDeviceID() const {
|
|||
return 0;
|
||||
}
|
||||
|
||||
void Ascend310DeviceInfo::SetDeviceID(uint32_t device_id) { MS_LOG(ERROR) << "Unsupported Feature."; }
|
||||
void Ascend310DeviceInfo::SetDeviceID(uint32_t device_id) {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return;
|
||||
}
|
||||
data_->params[kModelOptionAscend310DeviceID] = device_id;
|
||||
}
|
||||
|
||||
uint32_t Ascend310DeviceInfo::GetDeviceID() const {
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
return 0;
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return 0;
|
||||
}
|
||||
return GetValue<uint32_t>(data_, kModelOptionAscend310DeviceID);
|
||||
}
|
||||
|
||||
void Ascend310DeviceInfo::SetDumpConfigPath(const std::vector<char> &cfg_path) {
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return;
|
||||
}
|
||||
data_->params[kModelOptionAscend310DumpCfgPath] = CharToString(cfg_path);
|
||||
}
|
||||
|
||||
std::vector<char> Ascend310DeviceInfo::GetDumpConfigPathChar() const {
|
||||
std::vector<char> empty;
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
return empty;
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return std::vector<char>();
|
||||
}
|
||||
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310DumpCfgPath);
|
||||
return StringToChar(ref);
|
||||
}
|
||||
|
||||
void Ascend310DeviceInfo::SetInsertOpConfigPath(const std::vector<char> &cfg_path) {
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return;
|
||||
}
|
||||
data_->params[kModelOptionAscend310InsertOpCfgPath] = CharToString(cfg_path);
|
||||
}
|
||||
std::vector<char> Ascend310DeviceInfo::GetInsertOpConfigPathChar() const {
|
||||
std::vector<char> empty;
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
return empty;
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return std::vector<char>();
|
||||
}
|
||||
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310InsertOpCfgPath);
|
||||
return StringToChar(ref);
|
||||
}
|
||||
|
||||
void Ascend310DeviceInfo::SetInputFormat(const std::vector<char> &format) {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return;
|
||||
}
|
||||
data_->params[kModelOptionAscend310InputFormat] = CharToString(format);
|
||||
}
|
||||
|
||||
void Ascend310DeviceInfo::SetInputFormat(const std::vector<char> &format) { MS_LOG(ERROR) << "Unsupported Feature."; }
|
||||
std::vector<char> Ascend310DeviceInfo::GetInputFormatChar() const {
|
||||
std::vector<char> empty;
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
return empty;
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return std::vector<char>();
|
||||
}
|
||||
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310InputFormat);
|
||||
return StringToChar(ref);
|
||||
}
|
||||
|
||||
void Ascend310DeviceInfo::SetInputShape(const std::vector<char> &shape) { MS_LOG(ERROR) << "Unsupported Feature."; }
|
||||
void Ascend310DeviceInfo::SetInputShape(const std::vector<char> &shape) {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return;
|
||||
}
|
||||
data_->params[kModelOptionAscend310InputShape] = CharToString(shape);
|
||||
}
|
||||
std::vector<char> Ascend310DeviceInfo::GetInputShapeChar() const {
|
||||
std::vector<char> empty;
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
return empty;
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return std::vector<char>();
|
||||
}
|
||||
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310InputShape);
|
||||
return StringToChar(ref);
|
||||
}
|
||||
|
||||
void Ascend310DeviceInfo::SetDynamicBatchSize(const std::vector<size_t> &dynamic_batch_size) {
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return;
|
||||
}
|
||||
std::string batchs;
|
||||
for (size_t i = 0; i < dynamic_batch_size.size(); ++i) {
|
||||
if (i != 0) {
|
||||
batchs.push_back(',');
|
||||
}
|
||||
batchs += std::to_string(dynamic_batch_size[i]);
|
||||
}
|
||||
data_->params[kModelOptionAscend310DynamicBatchSize] = batchs;
|
||||
}
|
||||
|
||||
std::vector<char> Ascend310DeviceInfo::GetDynamicBatchSizeChar() const {
|
||||
std::vector<char> empty;
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
return empty;
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return std::vector<char>();
|
||||
}
|
||||
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310DynamicBatchSize);
|
||||
return StringToChar(ref);
|
||||
}
|
||||
|
||||
void Ascend310DeviceInfo::SetPrecisionMode(const std::vector<char> &precision_mode) {
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return;
|
||||
}
|
||||
data_->params[kModelOptionAscend310PrecisionMode] = CharToString(precision_mode);
|
||||
}
|
||||
|
||||
std::vector<char> Ascend310DeviceInfo::GetPrecisionModeChar() const {
|
||||
std::vector<char> empty;
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
return empty;
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return std::vector<char>();
|
||||
}
|
||||
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310PrecisionMode);
|
||||
return StringToChar(ref);
|
||||
}
|
||||
|
||||
void Ascend310DeviceInfo::SetOpSelectImplMode(const std::vector<char> &op_select_impl_mode) {
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return;
|
||||
}
|
||||
data_->params[kModelOptionAscend310OpSelectImplMode] = CharToString(op_select_impl_mode);
|
||||
}
|
||||
|
||||
std::vector<char> Ascend310DeviceInfo::GetOpSelectImplModeChar() const {
|
||||
std::vector<char> empty;
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
return empty;
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return std::vector<char>();
|
||||
}
|
||||
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310OpSelectImplMode);
|
||||
return StringToChar(ref);
|
||||
}
|
||||
|
||||
void Ascend310DeviceInfo::SetFusionSwitchConfigPath(const std::vector<char> &cfg_path) {
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return;
|
||||
}
|
||||
data_->params[KModelOptionAscend310FusionSwitchCfgPath] = CharToString(cfg_path);
|
||||
}
|
||||
std::vector<char> Ascend310DeviceInfo::GetFusionSwitchConfigPathChar() const {
|
||||
std::vector<char> empty;
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
return empty;
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return std::vector<char>();
|
||||
}
|
||||
const std::string &ref = GetValue<std::string>(data_, KModelOptionAscend310FusionSwitchCfgPath);
|
||||
return StringToChar(ref);
|
||||
}
|
||||
|
||||
void Ascend310DeviceInfo::SetInputShapeMap(const std::map<int, std::vector<int>> &shape) {
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
}
|
||||
std::map<int, std::vector<int>> Ascend310DeviceInfo::GetInputShapeMap() const {
|
||||
std::map<int, std::vector<int>> empty;
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
return empty;
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return;
|
||||
}
|
||||
data_->params[kModelOptionAscend310InputShapeMap] = shape;
|
||||
}
|
||||
|
||||
std::map<int, std::vector<int>> Ascend310DeviceInfo::GetInputShapeMap() const {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return std::map<int, std::vector<int>>();
|
||||
}
|
||||
return GetValue<std::map<int, std::vector<int>>>(data_, kModelOptionAscend310InputShapeMap);
|
||||
}
|
||||
|
||||
void Ascend310DeviceInfo::SetOutputType(enum DataType output_type) {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return;
|
||||
}
|
||||
data_->params[kModelOptionAscend310OutputType] = output_type;
|
||||
}
|
||||
|
||||
void Ascend310DeviceInfo::SetOutputType(enum DataType output_type) { MS_LOG(ERROR) << "Unsupported Feature."; }
|
||||
enum DataType Ascend310DeviceInfo::GetOutputType() const {
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
return DataType::kTypeUnknown;
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return DataType::kTypeUnknown;
|
||||
}
|
||||
return GetValue<enum DataType>(data_, kModelOptionAscend310OutputType);
|
||||
}
|
||||
|
||||
void Ascend310DeviceInfo::SetBufferOptimizeMode(const std::vector<char> &buffer_optimize_mode) {
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return;
|
||||
}
|
||||
data_->params[kModelOptionAscend310BufferOptimize] = CharToString(buffer_optimize_mode);
|
||||
}
|
||||
|
||||
std::vector<char> Ascend310DeviceInfo::GetBufferOptimizeModeChar() const {
|
||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||
std::vector<char> ret;
|
||||
return ret;
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
return std::vector<char>();
|
||||
}
|
||||
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310BufferOptimize);
|
||||
return StringToChar(ref);
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -27,31 +27,44 @@
|
|||
namespace mindspore {
|
||||
class Buffer::Impl {
|
||||
public:
|
||||
Impl() : data_() { MS_LOG(ERROR) << "Unsupported feature."; }
|
||||
Impl() : data_() {}
|
||||
~Impl() = default;
|
||||
Impl(const void *data, size_t data_len) { MS_LOG(ERROR) << "Unsupported feature."; }
|
||||
Impl(const void *data, size_t data_len) {
|
||||
if (data != nullptr) {
|
||||
(void)SetData(data, data_len);
|
||||
} else {
|
||||
ResizeData(data_len);
|
||||
}
|
||||
}
|
||||
|
||||
const void *Data() const {
|
||||
MS_LOG(ERROR) << "Unsupported feature.";
|
||||
return nullptr;
|
||||
}
|
||||
void *MutableData() {
|
||||
MS_LOG(ERROR) << "Unsupported feature.";
|
||||
return nullptr;
|
||||
}
|
||||
size_t DataSize() const {
|
||||
MS_LOG(ERROR) << "Unsupported feature.";
|
||||
return 0;
|
||||
}
|
||||
const void *Data() const { return data_.data(); }
|
||||
void *MutableData() { return data_.data(); }
|
||||
size_t DataSize() const { return data_.size(); }
|
||||
|
||||
bool ResizeData(size_t data_len) {
|
||||
MS_LOG(ERROR) << "Unsupported feature.";
|
||||
return false;
|
||||
data_.resize(data_len);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool SetData(const void *data, size_t data_len) {
|
||||
MS_LOG(ERROR) << "Unsupported feature.";
|
||||
return false;
|
||||
ResizeData(data_len);
|
||||
if (DataSize() != data_len) {
|
||||
MS_LOG(ERROR) << "Set data failed, tensor current data size " << DataSize() << " not match data len " << data_len;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (data == nullptr) {
|
||||
return data_len == 0;
|
||||
}
|
||||
|
||||
if (MutableData() == nullptr) {
|
||||
MS_LOG(ERROR) << "Set data failed, data len " << data_len;
|
||||
return false;
|
||||
}
|
||||
|
||||
memcpy(MutableData(), data, data_len);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
protected:
|
||||
|
@ -343,38 +356,58 @@ void MSTensor::SetQuantParams(std::vector<QuantParam> quant_params) {
|
|||
return impl_->SetQuantParams(quant_params);
|
||||
}
|
||||
|
||||
Buffer::Buffer() : impl_(nullptr) { MS_LOG(ERROR) << "Unsupported feature."; }
|
||||
Buffer::Buffer(const void *data, size_t data_len) : impl_(nullptr) { MS_LOG(ERROR) << "Unsupported feature."; }
|
||||
Buffer::Buffer() : impl_(std::make_shared<Impl>()) {}
|
||||
Buffer::Buffer(const void *data, size_t data_len) : impl_(std::make_shared<Impl>(data, data_len)) {}
|
||||
Buffer::~Buffer() = default;
|
||||
|
||||
Buffer Buffer::Clone() const {
|
||||
MS_LOG(ERROR) << "Unsupported feature.";
|
||||
return Buffer();
|
||||
Buffer ret;
|
||||
if (impl_ == nullptr) {
|
||||
MS_LOG(ERROR) << "impl is nullptr.";
|
||||
return ret;
|
||||
}
|
||||
ret.impl_ = std::make_shared<Impl>(*impl_);
|
||||
return ret;
|
||||
}
|
||||
|
||||
const void *Buffer::Data() const {
|
||||
MS_LOG(ERROR) << "Unsupported feature.";
|
||||
return nullptr;
|
||||
if (impl_ == nullptr) {
|
||||
MS_LOG(ERROR) << "impl is nullptr.";
|
||||
return nullptr;
|
||||
}
|
||||
return impl_->Data();
|
||||
}
|
||||
|
||||
void *Buffer::MutableData() {
|
||||
MS_LOG(ERROR) << "Unsupported feature.";
|
||||
return nullptr;
|
||||
if (impl_ == nullptr) {
|
||||
MS_LOG(ERROR) << "impl is nullptr.";
|
||||
return nullptr;
|
||||
}
|
||||
return impl_->MutableData();
|
||||
}
|
||||
|
||||
size_t Buffer::DataSize() const {
|
||||
MS_LOG(ERROR) << "Unsupported feature.";
|
||||
return 0;
|
||||
if (impl_ == nullptr) {
|
||||
MS_LOG(ERROR) << "impl is nullptr.";
|
||||
return 0;
|
||||
}
|
||||
return impl_->DataSize();
|
||||
}
|
||||
|
||||
bool Buffer::ResizeData(size_t data_len) {
|
||||
MS_LOG(ERROR) << "Unsupported feature.";
|
||||
return false;
|
||||
if (impl_ == nullptr) {
|
||||
MS_LOG(ERROR) << "impl is nullptr.";
|
||||
return false;
|
||||
}
|
||||
return impl_->ResizeData(data_len);
|
||||
}
|
||||
|
||||
bool Buffer::SetData(const void *data, size_t data_len) {
|
||||
MS_LOG(ERROR) << "Unsupported feature.";
|
||||
return false;
|
||||
if (impl_ == nullptr) {
|
||||
MS_LOG(ERROR) << "impl is nullptr.";
|
||||
return false;
|
||||
}
|
||||
return impl_->SetData(data, data_len);
|
||||
}
|
||||
|
||||
std::vector<char> CharVersion() { return StringToChar(lite::Version()); }
|
||||
|
|
|
@ -0,0 +1,12 @@
|
|||
include_directories(${TOP_DIR}/graphengine/inc/external)
|
||||
|
||||
find_library(ge_graph libgraph.so ${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
|
||||
|
||||
aux_source_directory(src ACL_SRC)
|
||||
add_library(ascend310_kernel_mid OBJECT ${ACL_SRC})
|
||||
|
||||
add_dependencies(ascend310_kernel_mid fbs_inner_src)
|
||||
|
||||
target_link_libraries(ascend310_kernel_mid ${ge_graph} ${ge_compiler}
|
||||
${acl_retr} ${acl_cblas} ${acl_dvpp} ${acl_runtime} ${libplatform}
|
||||
${libcompress} ${libopskernel} ${libaicore_utils} ${libaicpu_engine_common} ${acl})
|
|
@ -0,0 +1,66 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/runtime/kernel/ascend310/src/acl_env_guard.h"
|
||||
#include "common/log_adapter.h"
|
||||
#include "acl/acl.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace acl {
|
||||
std::shared_ptr<AclEnvGuard> AclEnvGuard::global_acl_env_ = nullptr;
|
||||
std::mutex AclEnvGuard::global_acl_env_mutex_;
|
||||
|
||||
AclEnvGuard::AclEnvGuard(std::string_view cfg_file) : errno_(ACL_ERROR_NONE) {
|
||||
errno_ = aclInit(cfg_file.data());
|
||||
if (errno_ != ACL_ERROR_NONE && errno_ != ACL_ERROR_REPEAT_INITIALIZE) {
|
||||
MS_LOG(ERROR) << "Execute aclInit Failed";
|
||||
return;
|
||||
}
|
||||
MS_LOG(INFO) << "Acl init success";
|
||||
}
|
||||
|
||||
AclEnvGuard::~AclEnvGuard() {
|
||||
errno_ = aclFinalize();
|
||||
if (errno_ != ACL_ERROR_NONE && errno_ != ACL_ERROR_REPEAT_FINALIZE) {
|
||||
MS_LOG(ERROR) << "Finalize acl failed";
|
||||
}
|
||||
MS_LOG(INFO) << "Acl finalize success";
|
||||
}
|
||||
|
||||
std::shared_ptr<AclEnvGuard> AclEnvGuard::GetAclEnv(std::string_view cfg_file) {
|
||||
std::shared_ptr<AclEnvGuard> acl_env;
|
||||
|
||||
std::lock_guard<std::mutex> lock(global_acl_env_mutex_);
|
||||
acl_env = global_acl_env_;
|
||||
if (acl_env != nullptr) {
|
||||
MS_LOG(INFO) << "Acl has been initialized, skip.";
|
||||
if (!cfg_file.empty()) {
|
||||
MS_LOG(WARNING) << "Dump config file option " << cfg_file << " is ignored.";
|
||||
}
|
||||
} else {
|
||||
acl_env = std::make_shared<AclEnvGuard>(cfg_file);
|
||||
aclError ret = acl_env->GetErrno();
|
||||
if (ret != ACL_ERROR_NONE && ret != ACL_ERROR_REPEAT_INITIALIZE) {
|
||||
MS_LOG(ERROR) << "Execute aclInit Failed";
|
||||
return nullptr;
|
||||
}
|
||||
global_acl_env_ = acl_env;
|
||||
MS_LOG(INFO) << "Acl init success";
|
||||
}
|
||||
return acl_env;
|
||||
}
|
||||
} // namespace acl
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,42 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_AGENT_ACL_ACL_ENV_GUARD_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_AGENT_ACL_ACL_ENV_GUARD_H_
|
||||
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include "acl/acl_base.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace acl {
|
||||
class __attribute__((visibility("default"))) AclEnvGuard {
|
||||
public:
|
||||
explicit AclEnvGuard(std::string_view cfg_file);
|
||||
~AclEnvGuard();
|
||||
aclError GetErrno() const { return errno_; }
|
||||
static std::shared_ptr<AclEnvGuard> GetAclEnv(std::string_view cfg_file);
|
||||
|
||||
private:
|
||||
static std::shared_ptr<AclEnvGuard> global_acl_env_;
|
||||
static std::mutex global_acl_env_mutex_;
|
||||
|
||||
aclError errno_;
|
||||
};
|
||||
} // namespace acl
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // LITE_ACL_ENV_GUARD_H
|
|
@ -0,0 +1,60 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/runtime/kernel/ascend310/src/custom_interface.h"
|
||||
#include <memory>
|
||||
#include "include/errorcode.h"
|
||||
#include "include/registry/register_kernel_interface.h"
|
||||
#include "common/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace acl {
|
||||
Status CustomInterface::Infer(std::vector<mindspore::MSTensor> *inputs, std::vector<mindspore::MSTensor> *outputs,
|
||||
const mindspore::schema::Primitive *primitive) {
|
||||
if (inputs == nullptr || (*inputs).empty()) {
|
||||
MS_LOG(ERROR) << "Inputs is invalid.";
|
||||
return kLiteError;
|
||||
}
|
||||
if (outputs == nullptr || (*outputs).empty()) {
|
||||
MS_LOG(ERROR) << "Outputs is invalid.";
|
||||
return kLiteError;
|
||||
}
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "Primitive is nullptr.";
|
||||
return kLiteError;
|
||||
}
|
||||
if (primitive->value_type() != schema::PrimitiveType_Custom) {
|
||||
MS_LOG(ERROR) << "Primitive type is not PrimitiveType_Custom.";
|
||||
return kLiteError;
|
||||
}
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
std::shared_ptr<mindspore::kernel::KernelInterface> CustomInferCreater() {
|
||||
auto infer = new (std::nothrow) CustomInterface();
|
||||
if (infer == nullptr) {
|
||||
MS_LOG(ERROR) << "New custom infer is nullptr";
|
||||
return nullptr;
|
||||
}
|
||||
return std::shared_ptr<mindspore::kernel::KernelInterface>(infer);
|
||||
}
|
||||
} // namespace acl
|
||||
} // namespace mindspore
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
REGISTER_CUSTOM_KERNEL_INTERFACE(ACL, ACL, acl::CustomInferCreater);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,34 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_ACL_CUSTOM_INTERFACE_H_
|
||||
#define MINDSPORE_LITE_ACL_CUSTOM_INTERFACE_H_
|
||||
|
||||
#include <vector>
|
||||
#include "include/kernel_interface.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace acl {
|
||||
class CustomInterface : public mindspore::kernel::KernelInterface {
|
||||
public:
|
||||
CustomInterface() {}
|
||||
~CustomInterface() = default;
|
||||
|
||||
Status Infer(std::vector<mindspore::MSTensor> *inputs, std::vector<mindspore::MSTensor> *outputs,
|
||||
const mindspore::schema::Primitive *primitive) override;
|
||||
};
|
||||
} // namespace acl
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_ACL_CUSTOM_INTERFACE_H_
|
|
@ -0,0 +1,134 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/runtime/kernel/ascend310/src/custom_kernel.h"
|
||||
#include "include/registry/register_kernel.h"
|
||||
#include "include/api/types.h"
|
||||
#include "include/api/data_type.h"
|
||||
#include "src/runtime/kernel/ascend310/src/model_infer.h"
|
||||
#include "common/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace acl {
|
||||
CustomAscend310Kernel::CustomAscend310Kernel(const std::vector<mindspore::MSTensor> &inputs,
|
||||
const std::vector<mindspore::MSTensor> &outputs,
|
||||
const schema::Primitive *primitive, const mindspore::Context *ctx)
|
||||
: Kernel(inputs, outputs, primitive, ctx), load_model_(false), model_infer_(nullptr) {}
|
||||
|
||||
CustomAscend310Kernel::~CustomAscend310Kernel() {
|
||||
if (load_model_) {
|
||||
int ret = model_infer_->Finalize();
|
||||
if (ret != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Model finalize failed.";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
STATUS CustomAscend310Kernel::PrepareModelInfer() {
|
||||
if (inputs_.size() < 1) {
|
||||
MS_LOG(ERROR) << "Inputs size should not less than 1.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
// last input is om data tensor
|
||||
int idx = inputs_.size() - 1;
|
||||
Buffer om_data(inputs_[idx].Data().get(), inputs_[idx].DataSize());
|
||||
if (model_infer_ == nullptr) {
|
||||
model_infer_.reset(new ModelInfer(om_data, 0));
|
||||
}
|
||||
int ret = model_infer_->Init();
|
||||
if (ret != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Model infer init failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
ret = model_infer_->Load();
|
||||
if (ret != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Load om data failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
MS_LOG(INFO) << "Load om data success.";
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
STATUS CustomAscend310Kernel::Prepare() {
|
||||
if (load_model_) {
|
||||
MS_LOG(INFO) << "Custom kernel has been prepared.";
|
||||
return lite::RET_OK;
|
||||
}
|
||||
if (PrepareModelInfer() != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Model infer prepare is not ok.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
load_model_ = true;
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
STATUS CustomAscend310Kernel::ReSize() {
|
||||
if (load_model_) {
|
||||
int ret = model_infer_->Finalize();
|
||||
if (ret != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Model finalize failed.";
|
||||
}
|
||||
load_model_ = false;
|
||||
}
|
||||
return Prepare();
|
||||
}
|
||||
|
||||
STATUS CustomAscend310Kernel::Execute() {
|
||||
if (!load_model_) {
|
||||
MS_LOG(WARNING) << "Custom kernel has not been prepared.";
|
||||
return lite::RET_OK;
|
||||
}
|
||||
std::vector<mindspore::MSTensor> inputs(inputs_.begin(), inputs_.end() - 1);
|
||||
if (model_infer_->Inference(inputs, &outputs_) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Custom kernel execute failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
std::shared_ptr<kernel::Kernel> CustomCreateKernel(const std::vector<mindspore::MSTensor> &inputs,
|
||||
const std::vector<mindspore::MSTensor> &outputs,
|
||||
const schema::Primitive *primitive, const mindspore::Context *ctx) {
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "Primitive is nullptr.";
|
||||
return nullptr;
|
||||
}
|
||||
if (primitive->value_type() != schema::PrimitiveType_Custom) {
|
||||
MS_LOG(ERROR) << "Primitive type is not PrimitiveType_Custom";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto kernel = std::make_shared<CustomAscend310Kernel>(inputs, outputs, primitive, ctx);
|
||||
if (kernel == nullptr) {
|
||||
MS_LOG(ERROR) << "New custom kernel is nullptr";
|
||||
return nullptr;
|
||||
}
|
||||
return kernel;
|
||||
}
|
||||
} // namespace acl
|
||||
} // namespace mindspore
|
||||
namespace mindspore {
|
||||
namespace registry {
|
||||
namespace {
|
||||
const auto kFloat32 = DataType::kNumberTypeFloat32;
|
||||
const auto kInt8 = DataType::kNumberTypeInt8;
|
||||
const auto kUInt8 = DataType::kNumberTypeUInt8;
|
||||
} // namespace
|
||||
REGISTER_CUSTOM_KERNEL(ASCEND310, ACL, kFloat32, ACL, acl::CustomCreateKernel)
|
||||
REGISTER_CUSTOM_KERNEL(ASCEND310, ACL, kInt8, ACL, acl::CustomCreateKernel)
|
||||
REGISTER_CUSTOM_KERNEL(ASCEND310, ACL, kUInt8, ACL, acl::CustomCreateKernel)
|
||||
} // namespace registry
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,51 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ASCEND310_KERNEL_CUSTOM_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ASCEND310_KERNEL_CUSTOM_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "include/api/types.h"
|
||||
#include "include/api/context.h"
|
||||
#include "include/api/kernel.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "src/runtime/kernel/ascend310/src/model_infer.h"
|
||||
|
||||
using mindspore::lite::STATUS;
|
||||
|
||||
namespace mindspore {
|
||||
namespace acl {
|
||||
class CustomAscend310Kernel : public kernel::Kernel {
|
||||
public:
|
||||
CustomAscend310Kernel(const std::vector<mindspore::MSTensor> &inputs, const std::vector<mindspore::MSTensor> &outputs,
|
||||
const mindspore::schema::Primitive *primitive, const mindspore::Context *ctx);
|
||||
~CustomAscend310Kernel() override;
|
||||
|
||||
STATUS Prepare() override;
|
||||
STATUS ReSize() override;
|
||||
STATUS Execute() override;
|
||||
|
||||
private:
|
||||
STATUS PrepareModelInfer();
|
||||
|
||||
bool load_model_;
|
||||
std::shared_ptr<ModelInfer> model_infer_;
|
||||
};
|
||||
} // namespace acl
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ASCEND310_FP32_CUSTOM_H_
|
|
@ -0,0 +1,162 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/runtime/kernel/ascend310/src/model_infer.h"
|
||||
#include "common/log_adapter.h"
|
||||
#include "acl/acl.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace acl {
|
||||
ModelInfer::ModelInfer(const Buffer &om_data, int32_t device_id)
|
||||
: init_flag_(false),
|
||||
load_flag_(false),
|
||||
device_type_("AscendCL"),
|
||||
device_id_(device_id),
|
||||
context_(nullptr),
|
||||
om_data_(om_data),
|
||||
model_process_(),
|
||||
acl_env_(nullptr) {}
|
||||
|
||||
STATUS ModelInfer::Init() {
|
||||
if (init_flag_) {
|
||||
MS_LOG(INFO) << "Acl has been initialized, skip.";
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
acl_env_ = AclEnvGuard::GetAclEnv("");
|
||||
if (acl_env_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Acl init failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
|
||||
aclError ret = aclrtSetDevice(device_id_);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Acl open device " << device_id_ << " failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
MS_LOG(INFO) << "Open device " << device_id_ << " success.";
|
||||
|
||||
ret = aclrtCreateContext(&context_, device_id_);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Acl create context failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
MS_LOG(INFO) << "Create context success.";
|
||||
|
||||
aclrtRunMode run_mode;
|
||||
ret = aclrtGetRunMode(&run_mode);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Acl get run mode failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
bool is_device = (run_mode == ACL_DEVICE);
|
||||
model_process_.SetIsDevice(is_device);
|
||||
MS_LOG(INFO) << "Get run mode success is device input/output " << is_device;
|
||||
|
||||
MS_LOG(INFO) << "Init acl success, device id " << device_id_;
|
||||
init_flag_ = true;
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
STATUS ModelInfer::Finalize() {
|
||||
if (!init_flag_) {
|
||||
MS_LOG(WARNING) << "Init is not ok, no need to finalize.";
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
aclError rt_ret = aclrtSetCurrentContext(context_);
|
||||
if (rt_ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Set the ascend device context failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
|
||||
int ret = model_process_.UnLoad();
|
||||
if (ret != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Unload model inner failed.";
|
||||
return ret;
|
||||
}
|
||||
|
||||
if (context_ != nullptr) {
|
||||
rt_ret = aclrtDestroyContext(context_);
|
||||
if (rt_ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Destroy context failed.";
|
||||
}
|
||||
context_ = nullptr;
|
||||
}
|
||||
MS_LOG(INFO) << "End to destroy context.";
|
||||
|
||||
rt_ret = aclrtResetDevice(device_id_);
|
||||
if (rt_ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Reset device " << device_id_ << " failed.";
|
||||
}
|
||||
MS_LOG(INFO) << "End to reset device " << device_id_;
|
||||
init_flag_ = false;
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
STATUS ModelInfer::Load() {
|
||||
if (!load_flag_) {
|
||||
int ret = LoadAclModel(om_data_);
|
||||
if (ret != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Load acl model failed.";
|
||||
return ret;
|
||||
}
|
||||
load_flag_ = true;
|
||||
}
|
||||
|
||||
aclError rt_ret = aclrtSetCurrentContext(context_);
|
||||
if (rt_ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Set the ascend device context failed, ret = " << rt_ret;
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
STATUS ModelInfer::LoadAclModel(const Buffer &om_data) {
|
||||
MS_LOG(INFO) << "Start load acl model.";
|
||||
// acl load model
|
||||
uint32_t acl_model_id;
|
||||
auto acl_ret = aclmdlLoadFromMem(om_data.Data(), om_data.DataSize(), &acl_model_id);
|
||||
if (acl_ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Call aclmdlLoadFromMem failed, ret = " << acl_ret;
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
|
||||
// acl init model resource
|
||||
model_process_.set_model_id(acl_model_id);
|
||||
int ret = model_process_.PreInitModelResource();
|
||||
if (ret != lite::RET_OK) {
|
||||
(void)aclmdlUnload(acl_model_id);
|
||||
MS_LOG(ERROR) << "Pre init model resource failed.";
|
||||
return ret;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Load acl model success.";
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
STATUS ModelInfer::Inference(const std::vector<mindspore::MSTensor> &inputs,
|
||||
std::vector<mindspore::MSTensor> *outputs) {
|
||||
if (Load() != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Prepare model resource failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
|
||||
return model_process_.PredictFromHost(inputs, outputs);
|
||||
}
|
||||
} // namespace acl
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,56 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_AGENT_ACL_MODEL_INFER_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_AGENT_ACL_MODEL_INFER_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "src/runtime/kernel/ascend310/src/model_process.h"
|
||||
#include "src/runtime/kernel/ascend310/src/acl_env_guard.h"
|
||||
#include "include/api/types.h"
|
||||
#include "include/errorcode.h"
|
||||
|
||||
using mindspore::lite::STATUS;
|
||||
|
||||
namespace mindspore {
|
||||
namespace acl {
|
||||
class ModelInfer {
|
||||
public:
|
||||
ModelInfer(const Buffer &om_data, int32_t device_id);
|
||||
~ModelInfer() = default;
|
||||
|
||||
STATUS Init();
|
||||
STATUS Finalize();
|
||||
STATUS Load();
|
||||
STATUS Inference(const std::vector<mindspore::MSTensor> &inputs, std::vector<mindspore::MSTensor> *outputs);
|
||||
|
||||
private:
|
||||
STATUS LoadAclModel(const Buffer &om_data);
|
||||
|
||||
bool init_flag_;
|
||||
bool load_flag_;
|
||||
std::string device_type_;
|
||||
int32_t device_id_;
|
||||
aclrtContext context_;
|
||||
Buffer om_data_;
|
||||
ModelProcess model_process_;
|
||||
std::shared_ptr<AclEnvGuard> acl_env_;
|
||||
};
|
||||
} // namespace acl
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_AGENT_ACL_MODEL_INFER_H_
|
|
@ -0,0 +1,576 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/runtime/kernel/ascend310/src/model_process.h"
|
||||
#include <utility>
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include "common/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace acl {
|
||||
namespace {
|
||||
constexpr size_t kDynamicBatchSize = 1;
|
||||
constexpr size_t kDynamicImageSize = 2;
|
||||
} // namespace
|
||||
static DataType TransToDataType(aclDataType data_type) {
|
||||
static const std::map<aclDataType, enum DataType> data_type_map = {
|
||||
{ACL_FLOAT16, DataType::kNumberTypeFloat16}, {ACL_FLOAT, DataType::kNumberTypeFloat32},
|
||||
{ACL_DOUBLE, DataType::kNumberTypeFloat64}, {ACL_INT8, DataType::kNumberTypeInt8},
|
||||
{ACL_INT16, DataType::kNumberTypeInt16}, {ACL_INT32, DataType::kNumberTypeInt32},
|
||||
{ACL_INT64, DataType::kNumberTypeInt64}, {ACL_UINT8, DataType::kNumberTypeUInt8},
|
||||
{ACL_UINT16, DataType::kNumberTypeUInt16}, {ACL_UINT32, DataType::kNumberTypeUInt32},
|
||||
{ACL_UINT64, DataType::kNumberTypeUInt64}, {ACL_BOOL, DataType::kNumberTypeBool},
|
||||
};
|
||||
auto it = data_type_map.find(data_type);
|
||||
if (it == data_type_map.end()) {
|
||||
return DataType::kNumberTypeEnd;
|
||||
} else {
|
||||
return it->second;
|
||||
}
|
||||
}
|
||||
|
||||
template <class T>
|
||||
inline static void ClearIfNotNull(T *vec) {
|
||||
if (vec != nullptr) {
|
||||
vec->clear();
|
||||
}
|
||||
}
|
||||
|
||||
template <class T, class U = std::vector<T>>
|
||||
inline static void PushbackIfNotNull(U *vec, T &&item) {
|
||||
if (vec != nullptr) {
|
||||
vec->emplace_back(item);
|
||||
}
|
||||
}
|
||||
|
||||
static STATUS ConstructTensorDesc(const std::vector<AclTensorInfo> &acl_tensor_list, std::vector<std::string> *names,
|
||||
std::vector<std::vector<int64_t>> *shapes, std::vector<enum DataType> *data_types,
|
||||
std::vector<size_t> *mem_sizes) {
|
||||
ClearIfNotNull(names);
|
||||
ClearIfNotNull(shapes);
|
||||
ClearIfNotNull(data_types);
|
||||
ClearIfNotNull(mem_sizes);
|
||||
for (size_t i = 0; i < acl_tensor_list.size(); ++i) {
|
||||
const auto &info = acl_tensor_list[i];
|
||||
PushbackIfNotNull(names, info.name);
|
||||
PushbackIfNotNull(shapes, info.dims);
|
||||
PushbackIfNotNull(data_types, TransToDataType(info.data_type));
|
||||
PushbackIfNotNull(mem_sizes, info.buffer_size);
|
||||
}
|
||||
|
||||
if (names->size() != acl_tensor_list.size() || shapes->size() != acl_tensor_list.size() ||
|
||||
data_types->size() != acl_tensor_list.size() || mem_sizes->size() != acl_tensor_list.size()) {
|
||||
MS_LOG(ERROR) << "Inner error, size do not match: names size " << names->size() << " shapes size " << shapes->size()
|
||||
<< " data types size " << data_types->size() << " mem sizes size " << mem_sizes->size()
|
||||
<< " acl_tensor_list size " << acl_tensor_list.size();
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
static std::string ShapeToString(const std::vector<int64_t> &shape) {
|
||||
std::string result = "[";
|
||||
for (size_t i = 0; i < shape.size(); ++i) {
|
||||
result += std::to_string(shape[i]);
|
||||
if (i + 1 < shape.size()) {
|
||||
result += ", ";
|
||||
}
|
||||
}
|
||||
result += "]";
|
||||
return result;
|
||||
}
|
||||
|
||||
STATUS ModelProcess::PreInitModelResource() {
|
||||
model_desc_ = aclmdlCreateDesc();
|
||||
aclError acl_ret = aclmdlGetDesc(model_desc_, model_id_);
|
||||
if (acl_ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Read model desc failed, ret = " << acl_ret;
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
STATUS ret = InitInputsBuffer();
|
||||
if (ret != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Create input buffer failed.";
|
||||
return ret;
|
||||
}
|
||||
ret = InitOutputsBuffer();
|
||||
if (ret != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Create output buffer failed.";
|
||||
return ret;
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
STATUS ModelProcess::InitInputsBuffer() {
|
||||
aclError ret;
|
||||
size_t input_size = aclmdlGetNumInputs(model_desc_);
|
||||
MS_LOG(INFO) << "input_size = " << input_size;
|
||||
for (size_t i = 0; i < input_size; ++i) {
|
||||
auto buffer_size = aclmdlGetInputSizeByIndex(model_desc_, i);
|
||||
void *data_mem_buffer = nullptr;
|
||||
if (!is_run_on_device_) { // need to copy input/output to/from device
|
||||
ret = aclrtMalloc(&data_mem_buffer, buffer_size, ACL_MEM_MALLOC_NORMAL_ONLY);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Malloc device input buffer failed , input size " << buffer_size;
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
}
|
||||
|
||||
aclmdlIODims dims;
|
||||
ret = aclmdlGetInputDims(model_desc_, i, &dims);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Get input shape failed, ret = " << ret;
|
||||
if (!is_run_on_device_) {
|
||||
aclrtFree(data_mem_buffer);
|
||||
}
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
aclDataType data_type = aclmdlGetInputDataType(model_desc_, i);
|
||||
std::vector<int64_t> shape(dims.dims, dims.dims + dims.dimCount);
|
||||
std::string input_name = aclmdlGetInputNameByIndex(model_desc_, i);
|
||||
if (input_name.empty()) {
|
||||
MS_LOG(WARNING) << "Get name of input " << i << " failed.";
|
||||
}
|
||||
MS_LOG(INFO) << "Name of input " << i << " is " << input_name;
|
||||
input_infos_.emplace_back(
|
||||
AclTensorInfo{data_mem_buffer, data_mem_buffer, buffer_size, data_type, shape, input_name});
|
||||
}
|
||||
MS_LOG(INFO) << "Create model inputs success";
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
STATUS ModelProcess::CreateDataBuffer(void **data_mem_buffer, size_t buffer_size, aclmdlDataset *dataset) {
|
||||
if (data_mem_buffer == nullptr) {
|
||||
MS_LOG(ERROR) << "Data mem buffer is nullptr.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
aclError ret;
|
||||
auto free_data_buffer = [this](void *dataMemBuffer) {
|
||||
if (!is_run_on_device_) {
|
||||
aclrtFree(dataMemBuffer);
|
||||
} else {
|
||||
aclrtFreeHost(dataMemBuffer);
|
||||
}
|
||||
};
|
||||
|
||||
if (!is_run_on_device_) {
|
||||
ret = aclrtMalloc(data_mem_buffer, buffer_size, ACL_MEM_MALLOC_NORMAL_ONLY);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Malloc device buffer failed , buffer size " << buffer_size;
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
} else {
|
||||
ret = aclrtMallocHost(data_mem_buffer, buffer_size);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Malloc host buffer failed , buffer size " << buffer_size;
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
}
|
||||
|
||||
auto data_buffer = aclCreateDataBuffer(*data_mem_buffer, buffer_size);
|
||||
if (data_buffer == nullptr) {
|
||||
MS_LOG(ERROR) << "Create Data Buffer failed";
|
||||
free_data_buffer(*data_mem_buffer);
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
ret = aclmdlAddDatasetBuffer(dataset, data_buffer);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "add data buffer failed";
|
||||
free_data_buffer(*data_mem_buffer);
|
||||
aclDestroyDataBuffer(data_buffer);
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
STATUS ModelProcess::InitOutputsBuffer() {
|
||||
aclError ret;
|
||||
outputs_ = aclmdlCreateDataset();
|
||||
if (outputs_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Create input dataset failed";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
size_t output_size = aclmdlGetNumOutputs(model_desc_);
|
||||
MS_LOG(INFO) << "Output_size = " << output_size;
|
||||
for (size_t i = 0; i < output_size; ++i) {
|
||||
auto buffer_size = aclmdlGetOutputSizeByIndex(model_desc_, i);
|
||||
|
||||
void *data_mem_buffer = nullptr;
|
||||
if (CreateDataBuffer(&data_mem_buffer, buffer_size, outputs_) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Add output data buffer failed, buffer size " << buffer_size;
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
aclmdlIODims dims;
|
||||
ret = aclmdlGetOutputDims(model_desc_, i, &dims);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Get input shape failed";
|
||||
if (!is_run_on_device_) {
|
||||
aclrtFree(data_mem_buffer);
|
||||
} else {
|
||||
aclrtFreeHost(data_mem_buffer);
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
aclFormat format = aclmdlGetOutputFormat(model_desc_, i);
|
||||
if (format != aclFormat::ACL_FORMAT_NCHW) {
|
||||
MS_LOG(WARNING) << "The output format of om should be nchw, but now is " << format;
|
||||
}
|
||||
aclDataType data_type = aclmdlGetOutputDataType(model_desc_, i);
|
||||
std::vector<int64_t> shape(dims.dims, dims.dims + dims.dimCount);
|
||||
std::string output_name = aclmdlGetOutputNameByIndex(model_desc_, i);
|
||||
if (output_name.empty()) {
|
||||
MS_LOG(WARNING) << "Get name of output " << i << " failed.";
|
||||
}
|
||||
MS_LOG(INFO) << "Name of input " << i << " is " << output_name;
|
||||
output_infos_.emplace_back(
|
||||
AclTensorInfo{data_mem_buffer, data_mem_buffer, buffer_size, data_type, shape, output_name});
|
||||
}
|
||||
MS_LOG(INFO) << "Create model output success.";
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
void ModelProcess::DestroyInputsDataset() {
|
||||
if (inputs_ == nullptr) {
|
||||
return;
|
||||
}
|
||||
for (size_t i = 0; i < aclmdlGetDatasetNumBuffers(inputs_); i++) {
|
||||
auto dataBuffer = aclmdlGetDatasetBuffer(inputs_, i);
|
||||
aclDestroyDataBuffer(dataBuffer);
|
||||
}
|
||||
aclmdlDestroyDataset(inputs_);
|
||||
inputs_ = nullptr;
|
||||
}
|
||||
|
||||
void ModelProcess::DestroyInputsDataMem() {
|
||||
if (!is_run_on_device_) {
|
||||
for (const auto &item : input_infos_) {
|
||||
aclrtFree(item.device_data);
|
||||
}
|
||||
}
|
||||
input_infos_.clear();
|
||||
}
|
||||
|
||||
void ModelProcess::DestroyInputsBuffer() {
|
||||
DestroyInputsDataMem();
|
||||
DestroyInputsDataset();
|
||||
}
|
||||
|
||||
void ModelProcess::DestroyOutputsBuffer() {
|
||||
for (const auto &item : output_infos_) {
|
||||
if (!is_run_on_device_) {
|
||||
aclrtFree(item.device_data);
|
||||
} else {
|
||||
aclrtFreeHost(item.device_data);
|
||||
}
|
||||
}
|
||||
output_infos_.clear();
|
||||
|
||||
if (outputs_ == nullptr) {
|
||||
return;
|
||||
}
|
||||
for (size_t i = 0; i < aclmdlGetDatasetNumBuffers(outputs_); i++) {
|
||||
auto dataBuffer = aclmdlGetDatasetBuffer(outputs_, i);
|
||||
aclDestroyDataBuffer(dataBuffer);
|
||||
}
|
||||
aclmdlDestroyDataset(outputs_);
|
||||
outputs_ = nullptr;
|
||||
}
|
||||
|
||||
STATUS ModelProcess::UnLoad() {
|
||||
auto ret = aclmdlUnload(model_id_);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Unload model failed, ret = " << ret;
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
if (model_desc_ != nullptr) {
|
||||
ret = aclmdlDestroyDesc(model_desc_);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Unload model failed, ret = " << ret;
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
model_desc_ = nullptr;
|
||||
}
|
||||
DestroyInputsBuffer();
|
||||
DestroyOutputsBuffer();
|
||||
MS_LOG(INFO) << "End unload model " << model_id_;
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
size_t ModelProcess::GetDynamicDims(const std::vector<AclTensorInfo> &inputs) {
|
||||
size_t max_num = 0;
|
||||
for (auto input : inputs) {
|
||||
size_t cur_num = std::count(input.dims.begin(), input.dims.end(), -1);
|
||||
if (cur_num > max_num) {
|
||||
max_num = cur_num;
|
||||
}
|
||||
}
|
||||
return max_num;
|
||||
}
|
||||
|
||||
STATUS ModelProcess::SetBatchSize(const std::vector<mindspore::MSTensor> &inputs) {
|
||||
size_t index;
|
||||
aclError ret;
|
||||
for (size_t i = 0; i < inputs.size(); i++) {
|
||||
input_infos_[i].buffer_size = inputs[i].DataSize();
|
||||
}
|
||||
auto *p = reinterpret_cast<const float *>(inputs[inputs.size() - 1].Data().get());
|
||||
if (p == nullptr) {
|
||||
MS_LOG(ERROR) << "Pointer is nullptr.";
|
||||
return lite::RET_OK;
|
||||
}
|
||||
auto dynamicBatchSize = p[0];
|
||||
ret = aclmdlGetInputIndexByName(model_desc_, ACL_DYNAMIC_TENSOR_NAME, &index);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Get index failed";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
ret = aclmdlSetDynamicBatchSize(model_id_, inputs_, index, dynamicBatchSize);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Set dynamic batch size failed, model_id is " << model_id_;
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
STATUS ModelProcess::CheckTensorByTensorInfo(const std::vector<mindspore::MSTensor> &tensor,
|
||||
const std::vector<AclTensorInfo> &tensor_info, size_t dynamic_nums) {
|
||||
if (dynamic_nums == 0) {
|
||||
for (size_t i = 0; i < tensor_info.size(); ++i) {
|
||||
if (tensor[i].Shape() != tensor_info[i].dims) {
|
||||
MS_LOG(ERROR) << "Note: input " << i << " shape not match, required " << ShapeToString(tensor_info[i].dims)
|
||||
<< ", given " << ShapeToString(tensor[i].Shape());
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
if (tensor[i].DataType() != TransToDataType(tensor_info[i].data_type)) {
|
||||
MS_LOG(ERROR) << "Note: input " << i << " data type not match, required "
|
||||
<< static_cast<int>(TransToDataType(tensor_info[i].data_type)) << ", given "
|
||||
<< static_cast<int>(tensor[i].DataType());
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
if (tensor[i].DataSize() != tensor_info[i].buffer_size) {
|
||||
MS_LOG(ERROR) << "Input " << i << " data size not match, required size " << tensor_info[i].buffer_size
|
||||
<< ", given count " << tensor[i].DataSize();
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
}
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
STATUS ModelProcess::ProcDynamicShape(const std::vector<mindspore::MSTensor> &inputs, size_t dynamic_nums) {
|
||||
if (dynamic_nums == kDynamicBatchSize) {
|
||||
if (SetBatchSize(inputs) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Failed to convert dynamic batch size";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
if (ResetOutputSize() != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Reset output size failed";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
} else if (dynamic_nums == kDynamicImageSize) {
|
||||
MS_LOG(ERROR) << "Only dynamic batch size is supported";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
STATUS ModelProcess::CheckAndInitInput(const std::vector<mindspore::MSTensor> &inputs) {
|
||||
aclError ret;
|
||||
inputs_ = aclmdlCreateDataset();
|
||||
size_t dynamic_nums = GetDynamicDims(input_infos_);
|
||||
// check inputs
|
||||
if (CheckTensorByTensorInfo(inputs, input_infos_, dynamic_nums) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Check input tensor failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
// copy inputs
|
||||
for (size_t i = 0; i < input_infos_.size(); ++i) {
|
||||
auto &info = input_infos_[i];
|
||||
auto input = inputs[i];
|
||||
void *data = input.MutableData();
|
||||
void *input_buffer = nullptr;
|
||||
if (!is_run_on_device_) {
|
||||
info.cur_device_data = info.device_data;
|
||||
ret = aclrtMemcpy(info.cur_device_data, info.buffer_size, data, input.DataSize(), ACL_MEMCPY_HOST_TO_DEVICE);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Acl memcpy input " << i << " data to device failed, buffer size " << input.DataSize();
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
input_buffer = info.cur_device_data;
|
||||
} else {
|
||||
input_buffer = data;
|
||||
}
|
||||
auto data_buffer = aclCreateDataBuffer(input_buffer, info.buffer_size);
|
||||
if (data_buffer == nullptr) {
|
||||
MS_LOG(ERROR) << "Create Data Buffer failed";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
ret = aclmdlAddDatasetBuffer(inputs_, data_buffer);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Add data buffer failed";
|
||||
aclDestroyDataBuffer(data_buffer);
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
}
|
||||
if (ProcDynamicShape(inputs, dynamic_nums) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Proc input dynamic shape failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
STATUS ModelProcess::ResetOutputSize() {
|
||||
aclDataType output_type;
|
||||
aclError ret;
|
||||
size_t output_size = aclmdlGetNumOutputs(model_desc_);
|
||||
for (size_t index = 0; index < output_size; index++) {
|
||||
size_t dims = 1;
|
||||
struct aclmdlIODims output_dims;
|
||||
ret = aclmdlGetCurOutputDims(model_desc_, index, &output_dims);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "get output dim error.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
for (size_t i = 0; i < output_dims.dimCount; i++) {
|
||||
dims *= output_dims.dims[i];
|
||||
}
|
||||
output_type = aclmdlGetOutputDataType(model_desc_, index);
|
||||
output_infos_[index].buffer_size = dims * aclDataTypeSize(output_type);
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
STATUS ModelProcess::SortTensorInfoByName(const std::vector<mindspore::MSTensor> &tensor,
|
||||
std::vector<AclTensorInfo> *tensor_info) {
|
||||
if (tensor_info == nullptr) {
|
||||
MS_LOG(ERROR) << "Tensor info is nullptr.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
if (tensor.size() != tensor_info->size()) {
|
||||
MS_LOG(ERROR) << "Actual tensor count not match, required count " << tensor_info->size() << ", given count "
|
||||
<< tensor.size();
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
size_t size = tensor.size();
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
std::string name = tensor[i].Name();
|
||||
size_t j;
|
||||
for (j = 0; j < size; j++) {
|
||||
if (name.find((*tensor_info)[j].name) != std::string::npos) {
|
||||
std::swap((*tensor_info)[i], (*tensor_info)[j]);
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (j == size) {
|
||||
MS_LOG(ERROR) << "Input[" << i << "] " << name << " can't be found in acl om.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
STATUS ModelProcess::PredictFromHost(const std::vector<mindspore::MSTensor> &inputs,
|
||||
std::vector<mindspore::MSTensor> *outputs) {
|
||||
if (SortTensorInfoByName(inputs, &input_infos_) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Sort input tensor info failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
STATUS ret = CheckAndInitInput(inputs);
|
||||
if (ret != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Check or init input failed";
|
||||
DestroyInputsDataset();
|
||||
return ret; // forward status error
|
||||
}
|
||||
aclError acl_ret = aclmdlExecute(model_id_, inputs_, outputs_);
|
||||
DestroyInputsDataset();
|
||||
if (acl_ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Execute Model Failed, ret = " << acl_ret;
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
ret = GetOutputs(outputs);
|
||||
if (ret != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Build outputs failed";
|
||||
return ret;
|
||||
}
|
||||
MS_LOG(INFO) << "Execute model success";
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
STATUS ModelProcess::GetOutputs(std::vector<mindspore::MSTensor> *outputs) {
|
||||
if (outputs == nullptr) {
|
||||
MS_LOG(ERROR) << "Ms tensor output is nullptr.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
|
||||
if (ConstructTensor(outputs) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Construct ms tensor failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
STATUS ModelProcess::ConstructTensor(std::vector<mindspore::MSTensor> *outputs) {
|
||||
if (outputs == nullptr) {
|
||||
MS_LOG(ERROR) << "Ms tensor output is nullptr.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
if (outputs->size() != output_infos_.size()) {
|
||||
MS_LOG(ERROR) << "Actual tensor count not match, required count " << output_infos_.size() << ", given count "
|
||||
<< outputs->size();
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
std::vector<std::string> names;
|
||||
std::vector<std::vector<int64_t>> shapes;
|
||||
std::vector<enum DataType> data_types;
|
||||
std::vector<size_t> mem_sizes;
|
||||
if (ConstructTensorDesc(output_infos_, &names, &shapes, &data_types, &mem_sizes) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Construct tensor desc failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
// set output info and malloc data size
|
||||
for (size_t i = 0; i < output_infos_.size(); ++i) {
|
||||
std::string lite_output_name = (*outputs)[i].Name();
|
||||
if (lite_output_name != names[i]) {
|
||||
MS_LOG(INFO) << "Lite output name: " << lite_output_name << "; Om output name: " << names[i];
|
||||
}
|
||||
(*outputs)[i].SetFormat(Format::NCHW);
|
||||
(*outputs)[i].SetDataType(data_types[i]);
|
||||
(*outputs)[i].SetShape(shapes[i]);
|
||||
(*outputs)[i].MutableData();
|
||||
if ((*outputs)[i].DataSize() != mem_sizes[i]) {
|
||||
MS_LOG(ERROR) << "Ms tensor size " << (*outputs)[i].DataSize() << " not match acl tensor size " << mem_sizes[i];
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
}
|
||||
aclrtMemcpyKind kind = is_run_on_device_ ? ACL_MEMCPY_HOST_TO_HOST : ACL_MEMCPY_DEVICE_TO_HOST;
|
||||
for (size_t i = 0; i < output_infos_.size(); ++i) {
|
||||
if (output_infos_[i].cur_device_data == nullptr) {
|
||||
// when run on device, cur_device_data is nullptr before first execute
|
||||
continue;
|
||||
}
|
||||
auto ret = aclrtMemcpy((*outputs)[i].MutableData(), (*outputs)[i].DataSize(), output_infos_[i].cur_device_data,
|
||||
output_infos_[i].buffer_size, kind);
|
||||
if (ret != ACL_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Memcpy input " << i << " from " << (is_run_on_device_ ? "host" : "device")
|
||||
<< " to host failed, memory size " << output_infos_[i].buffer_size;
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
} // namespace acl
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,95 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_AGENT_ACL_MODEL_PROCESS_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_AGENT_ACL_MODEL_PROCESS_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include "acl/acl.h"
|
||||
#include "acl/acl_mdl.h"
|
||||
#include "acl/acl_rt.h"
|
||||
#include "include/api/types.h"
|
||||
#include "include/errorcode.h"
|
||||
|
||||
using mindspore::lite::STATUS;
|
||||
|
||||
namespace mindspore {
|
||||
namespace acl {
|
||||
struct AclTensorInfo {
|
||||
void *cur_device_data;
|
||||
void *device_data;
|
||||
size_t buffer_size;
|
||||
aclDataType data_type;
|
||||
std::vector<int64_t> dims;
|
||||
std::string name;
|
||||
};
|
||||
|
||||
class ModelProcess {
|
||||
public:
|
||||
ModelProcess()
|
||||
: model_id_(0xffffffff),
|
||||
is_run_on_device_(false),
|
||||
model_desc_(nullptr),
|
||||
inputs_(nullptr),
|
||||
outputs_(nullptr),
|
||||
input_infos_(),
|
||||
output_infos_() {}
|
||||
~ModelProcess() {}
|
||||
|
||||
STATUS UnLoad();
|
||||
STATUS PredictFromHost(const std::vector<mindspore::MSTensor> &inputs, std::vector<mindspore::MSTensor> *outputs);
|
||||
STATUS PreInitModelResource();
|
||||
|
||||
// override this method to avoid request/reply data copy
|
||||
void SetIsDevice(bool is_device) { is_run_on_device_ = is_device; }
|
||||
|
||||
void set_model_id(uint32_t model_id) { model_id_ = model_id; }
|
||||
uint32_t model_id() const { return model_id_; }
|
||||
|
||||
private:
|
||||
STATUS CreateDataBuffer(void **data_mem_buffer, size_t buffer_size, aclmdlDataset *dataset);
|
||||
STATUS CheckAndInitInput(const std::vector<mindspore::MSTensor> &inputs);
|
||||
STATUS SortTensorInfoByName(const std::vector<mindspore::MSTensor> &tensor, std::vector<AclTensorInfo> *tensor_info);
|
||||
STATUS CheckTensorByTensorInfo(const std::vector<mindspore::MSTensor> &tensor,
|
||||
const std::vector<AclTensorInfo> &tensor_info, size_t dynamic_nums);
|
||||
STATUS GetOutputs(std::vector<mindspore::MSTensor> *outputs);
|
||||
STATUS ConstructTensor(std::vector<mindspore::MSTensor> *outputs);
|
||||
STATUS SetBatchSize(const std::vector<mindspore::MSTensor> &inputs);
|
||||
STATUS InitInputsBuffer();
|
||||
STATUS InitOutputsBuffer();
|
||||
STATUS ResetOutputSize();
|
||||
size_t GetDynamicDims(const std::vector<AclTensorInfo> &);
|
||||
STATUS ProcDynamicShape(const std::vector<mindspore::MSTensor> &inputs, size_t dynamic_nums);
|
||||
|
||||
void DestroyInputsDataset();
|
||||
void DestroyInputsDataMem();
|
||||
void DestroyInputsBuffer();
|
||||
void DestroyOutputsBuffer();
|
||||
|
||||
uint32_t model_id_;
|
||||
// if run one device(AICPU), there is no need to alloc device memory and copy inputs to(/outputs from) device
|
||||
bool is_run_on_device_;
|
||||
aclmdlDesc *model_desc_;
|
||||
aclmdlDataset *inputs_;
|
||||
aclmdlDataset *outputs_;
|
||||
std::vector<AclTensorInfo> input_infos_;
|
||||
std::vector<AclTensorInfo> output_infos_;
|
||||
};
|
||||
} // namespace acl
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_AGENT_ACL_MODEL_PROCESS_H_
|
|
@ -6,7 +6,10 @@ set(CCSRC_SRC
|
|||
${CCSRC_DIR}/backend/optimizer/common/visit.cc
|
||||
${CCSRC_DIR}/backend/optimizer/common/optimizer.cc
|
||||
)
|
||||
set(ENABLE_GLIBCXX ON)
|
||||
if(NOT MSLITE_ENABLE_ACL)
|
||||
set(ENABLE_GLIBCXX ON)
|
||||
endif()
|
||||
|
||||
include(${TOP_DIR}/cmake/external_libs/opencv.cmake)
|
||||
include(${TOP_DIR}/cmake/external_libs/glog.cmake)
|
||||
include_directories(${TOP_DIR}/mindspore/ccsrc/backend/kernel_compiler/cpu)
|
||||
|
@ -136,6 +139,14 @@ add_subdirectory(registry)
|
|||
add_subdirectory(preprocess)
|
||||
add_subdirectory(${CORE_DIR} mindspore_core)
|
||||
|
||||
if(MSLITE_ENABLE_ACL)
|
||||
set(MODE_ASCEND_ACL ON)
|
||||
include(${TOP_DIR}/cmake/dependency_graphengine.cmake)
|
||||
|
||||
add_subdirectory(acl)
|
||||
link_directories(${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
|
||||
endif()
|
||||
|
||||
set(SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../src)
|
||||
set(API_SRC ${SRC_DIR}/cxx_api/context.cc)
|
||||
set(LITE_SRC
|
||||
|
@ -222,6 +233,12 @@ target_link_libraries(converter_lite PRIVATE
|
|||
preprocess_mid
|
||||
)
|
||||
|
||||
if(MSLITE_ENABLE_ACL)
|
||||
target_link_libraries(converter_lite PRIVATE
|
||||
lite_acl_mid
|
||||
mindspore_shared_lib)
|
||||
endif()
|
||||
|
||||
if(NOT MSVC)
|
||||
target_link_libraries(converter_lite PRIVATE pthread)
|
||||
endif()
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
include_directories(${TOP_DIR}/graphengine/metadef/inc/external)
|
||||
include_directories(${TOP_DIR}/graphengine/inc)
|
||||
include_directories(${TOP_DIR}/graphengine/inc/external)
|
||||
include_directories(${TOP_DIR}/graphengine/ge)
|
||||
include_directories(${TOP_DIR}/graphengine/metadef/inc)
|
||||
include_directories(${TOP_DIR}/graphengine/inc/framework)
|
||||
include_directories(${TOP_DIR}/graphengine/third_party/fwkacllib/inc)
|
||||
include_directories(${TOP_DIR}/graphengine/third_party/fwkacllib/inc/toolchain)
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
|
||||
|
||||
file(GLOB ACL_SRC
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/*.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/common/*.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/deparser/*.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/infer/*.cc
|
||||
)
|
||||
|
||||
add_subdirectory(${TOP_DIR}/mindspore/ccsrc/transform/graph_ir _mindspore_transform_graph_ir_obj)
|
||||
add_subdirectory(${TOP_DIR}/mindspore/ccsrc/cxx_api mindspore_shared_lib)
|
||||
|
||||
set_property(SOURCE ${ACL_SRC} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_LITE)
|
||||
add_library(lite_acl_mid OBJECT ${ACL_SRC})
|
||||
|
||||
add_dependencies(lite_acl_mid fbs_inner_src)
|
|
@ -0,0 +1,405 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/acl/acl_pass.h"
|
||||
#include <set>
|
||||
#include "tools/converter/ops/ops_def.h"
|
||||
#include "tools/common/graph_util.h"
|
||||
#include "tools/common/tensor_util.h"
|
||||
#include "tools/converter/acl/deparser/primitive_deparser_register.h"
|
||||
#include "tools/converter/acl/deparser/spatial_node_adapter.h"
|
||||
#include "tools/converter/parser/parser_utils.h"
|
||||
#include "tools/converter/optimizer_manager.h"
|
||||
#include "include/registry/pass_registry.h"
|
||||
#include "common/utils.h"
|
||||
#include "ops/custom.h"
|
||||
#include "base/core_ops.h"
|
||||
#include "cxx_api/model/acl/model_converter.h"
|
||||
#include "backend/kernel_compiler/cpu/nnacl/op_base.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
constexpr auto kMakeTuple = "MakeTuple";
|
||||
constexpr auto kOutputNames = "outputs_names";
|
||||
constexpr auto kCustomPrimTypeACL = "ACL";
|
||||
constexpr auto kCustomNodeName = "Custom";
|
||||
} // namespace
|
||||
|
||||
ParameterPtr AclPass::CreateOmParameter(const FuncGraphPtr &func_graph, const Buffer &om_data) {
|
||||
ParameterPtr om_parameter = func_graph->add_parameter();
|
||||
om_parameter->set_name("ACL_om_data");
|
||||
|
||||
auto type_ptr = TypeIdToType(kNumberTypeUInt8);
|
||||
ShapeVector shape_vector = {static_cast<int64_t>(om_data.DataSize())};
|
||||
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector);
|
||||
om_parameter->set_abstract(abstract_tensor);
|
||||
|
||||
auto param_value =
|
||||
std::make_shared<tensor::Tensor>(kNumberTypeUInt8, ShapeVector({static_cast<int64_t>(om_data.DataSize())}));
|
||||
auto tensor_data = param_value->data_c();
|
||||
if (tensor_data == nullptr) {
|
||||
MS_LOG(ERROR) << "New Tensor failed.";
|
||||
return nullptr;
|
||||
}
|
||||
if (param_value->Size() < om_data.DataSize()) {
|
||||
MS_LOG(ERROR) << "Dst buff size " << param_value->Size() << " should be greater than src buff size "
|
||||
<< om_data.DataSize();
|
||||
return nullptr;
|
||||
}
|
||||
if (memcpy_s(tensor_data, param_value->Size(), om_data.Data(), om_data.DataSize()) != EOK) {
|
||||
MS_LOG(ERROR) << "Memcpy om data failed.";
|
||||
return nullptr;
|
||||
}
|
||||
om_parameter->set_default_param(param_value);
|
||||
return om_parameter;
|
||||
}
|
||||
|
||||
// now build the whole graph, not split
|
||||
STATUS AclPass::BuildGraph(const FuncGraphPtr &func_graph) {
|
||||
Buffer om_data;
|
||||
if (ConvertGraphToOm(func_graph, &om_data) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Convert graph to om failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
om_parameter_ = CreateOmParameter(func_graph, om_data);
|
||||
if (om_parameter_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Convert graph to om failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
STATUS AclPass::RunPrimitiveDeparser(const FuncGraphPtr &func_graph) {
|
||||
MS_LOG(INFO) << "Deparser graph start.";
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
std::set<FuncGraphPtr> all_func_graphs = {};
|
||||
lite::GetAllFuncGraph(func_graph, &all_func_graphs);
|
||||
for (auto graph : all_func_graphs) {
|
||||
auto node_list = TopoSort(graph->get_return());
|
||||
for (auto &node : node_list) {
|
||||
if (!utils::isa<CNodePtr>(node)) {
|
||||
continue;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
auto prim = GetCNodePrimitive(cnode);
|
||||
if (prim == nullptr) {
|
||||
MS_LOG(ERROR) << "Prim is nullptr.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
auto name = prim->name();
|
||||
auto deparser = lite::PrimitiveDeparserRegister::GetInstance().GetPrimitiveDeparser(name);
|
||||
if (deparser == nullptr) {
|
||||
MS_LOG(DEBUG) << "Name: " << name << " not need to deparser.";
|
||||
continue;
|
||||
}
|
||||
MS_LOG(INFO) << "Deparser cnode: " << name;
|
||||
auto status = deparser->Deparser(cnode);
|
||||
if (status != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Deparser primitive failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
}
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
STATUS AclPass::DeparseGraph(const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &manager) {
|
||||
if (fmk_type_ == converter::kFmkTypeMs) {
|
||||
MS_LOG(INFO) << "MindIr no need to deparser graph";
|
||||
return lite::RET_OK;
|
||||
}
|
||||
if (RunPrimitiveDeparser(func_graph) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Run deparser primitive failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
|
||||
if (lite::AdapteSpatialNode(func_graph, manager) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Adapter spatial node failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
STATUS AclPass::PreProcGraph(const FuncGraphPtr &func_graph) {
|
||||
if (fmk_type_ == converter::kFmkTypeMs) {
|
||||
MS_LOG(INFO) << "MindIr no need to pre proc graph";
|
||||
return lite::RET_OK;
|
||||
}
|
||||
// The format of nodes (cnode, parameter, val) must be nchw due to interface of convert om
|
||||
if (!lite::RunOptimizerPass(func_graph, {"ToNCHWFormat", "DeleteRedundantTranspose"})) {
|
||||
MS_LOG(ERROR) << "To nchw format success.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
STATUS AclPass::PostProcGraph(const FuncGraphPtr &func_graph) {
|
||||
// The format must be nhwc due to ms model
|
||||
if (!lite::RunOptimizerPass(func_graph, {"ToNHWCFormat"})) {
|
||||
MS_LOG(ERROR) << "To NHWC Format failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
bool AclPass::Run(const FuncGraphPtr &func_graph) {
|
||||
MS_LOG(INFO) << "Acl pass run start.";
|
||||
if (func_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "Func_graph is nullptr.";
|
||||
return false;
|
||||
}
|
||||
auto manager = Manage(func_graph, true);
|
||||
if (manager == nullptr) {
|
||||
MS_LOG(ERROR) << "Manager is nullptr.";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (PreProcGraph(func_graph) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Pre proc graph failed.";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (DeparseGraph(func_graph, manager) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Deparse graph failed.";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (BuildGraph(func_graph) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Build graph failed.";
|
||||
return false;
|
||||
}
|
||||
|
||||
custom_node_ = CreateCustomNode(func_graph);
|
||||
if (custom_node_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Create custom node failed.";
|
||||
return false;
|
||||
}
|
||||
// prepare graph for export create
|
||||
if (ModifyGraphByCustomNode(func_graph, manager, custom_node_) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Modify func graph by custom failed.";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (PostProcGraph(func_graph) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Post proc graph failed.";
|
||||
return false;
|
||||
}
|
||||
MS_LOG(INFO) << "Acl pass run end.";
|
||||
return true;
|
||||
}
|
||||
|
||||
STATUS AclPass::ConvertGraphToOm(const FuncGraphPtr &func_graph, Buffer *om_data) {
|
||||
if (om_data == nullptr) {
|
||||
MS_LOG(ERROR) << "Om data is nullptr.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
SetAclModelOptions(func_graph);
|
||||
// call interface of cloud
|
||||
ModelConverter model_converter;
|
||||
model_converter.set_options(options_.get());
|
||||
*om_data = model_converter.LoadMindIR(func_graph);
|
||||
if (om_data->Data() == nullptr || om_data->DataSize() == 0) {
|
||||
MS_LOG(ERROR) << "Model converter load mindir failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
void AclPass::SetAclModelOptions(const FuncGraphPtr &func_graph) {
|
||||
MS_LOG(INFO) << "Set acl model options start.";
|
||||
auto model_context = std::make_shared<mindspore::Context>();
|
||||
auto ascend310_info = std::make_shared<Ascend310DeviceInfo>();
|
||||
ascend310_info->SetDeviceID(0);
|
||||
model_context->MutableDeviceInfo().emplace_back(ascend310_info);
|
||||
// set options
|
||||
options_ = std::make_unique<AclModelOptions>(model_context);
|
||||
if (options_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Acl option make shared failed.";
|
||||
return;
|
||||
}
|
||||
auto inputs = func_graph->get_inputs();
|
||||
std::vector<std::string> input_names;
|
||||
for (auto node : inputs) {
|
||||
if (node == nullptr) {
|
||||
MS_LOG(ERROR) << "Node is nullptr.";
|
||||
return;
|
||||
}
|
||||
auto para = node->cast<ParameterPtr>();
|
||||
if (para == nullptr) {
|
||||
MS_LOG(ERROR) << "Parameter is nullptr.";
|
||||
return;
|
||||
}
|
||||
std::string name = para->name();
|
||||
for (auto pos = name.find(':'); pos != std::string::npos; pos = name.find(':')) {
|
||||
name = name.substr(0, pos) + "_" + name.substr(pos + 1);
|
||||
MS_LOG(INFO) << name;
|
||||
}
|
||||
para->set_name(name);
|
||||
input_names.push_back(name);
|
||||
}
|
||||
options_->RenameInput(input_names);
|
||||
MS_LOG(INFO) << "Set acl model options end.";
|
||||
}
|
||||
|
||||
STATUS AclPass::GetFuncGraphOutputInfo(const FuncGraphPtr &func_graph, AnfNodePtrList *graph_outputs,
|
||||
std::vector<std::string> *graph_output_names,
|
||||
std::vector<std::vector<int64_t>> *graph_output_dims) {
|
||||
CHECK_NULL_RETURN(graph_outputs);
|
||||
CHECK_NULL_RETURN(graph_output_names);
|
||||
CHECK_NULL_RETURN(graph_output_dims);
|
||||
AnfNodePtr return_input = func_graph->output();
|
||||
CHECK_NULL_RETURN(return_input);
|
||||
auto input_cnode = return_input->cast<CNodePtr>();
|
||||
CHECK_NULL_RETURN(input_cnode);
|
||||
auto primitive = mindspore::GetValueNode<PrimitivePtr>(input_cnode->input(0));
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "Primitive is nullptr, node: " << input_cnode->fullname_with_scope();
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
// not consider custom op
|
||||
std::string primitive_type = primitive->name();
|
||||
if (primitive_type == kMakeTuple) {
|
||||
for (size_t j = 1; j < input_cnode->inputs().size(); j++) {
|
||||
auto item = input_cnode->input(j);
|
||||
MS_ASSERT(item != nullptr);
|
||||
graph_outputs->emplace_back(item);
|
||||
graph_output_names->emplace_back(item->fullname_with_scope());
|
||||
auto item_cnode = item->cast<CNodePtr>();
|
||||
if (item_cnode == nullptr) {
|
||||
MS_LOG(ERROR) << "Input of MakeTuple is not a cnode for input_id: " << j;
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
std::vector<int64_t> dims;
|
||||
if (lite::acl::GetShapeVectorFromCNode(item_cnode, &dims) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Get node shape failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
graph_output_dims->emplace_back(dims);
|
||||
}
|
||||
} else {
|
||||
graph_outputs->emplace_back(input_cnode);
|
||||
graph_output_names->emplace_back(input_cnode->fullname_with_scope());
|
||||
std::vector<int64_t> dims;
|
||||
if (lite::acl::GetShapeVectorFromCNode(input_cnode, &dims) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Get node shape failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
graph_output_dims->emplace_back(dims);
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
STATUS AclPass::SetMultiOutputs(const CNodePtr &new_cnode, TypeId data_type) {
|
||||
AbstractBasePtrList abstract_list;
|
||||
for (size_t j = 0; j < graph_outputs_.size(); j++) {
|
||||
auto abstract_tensor = lite::CreateTensorAbstract(graph_outputs_dims_[j], data_type);
|
||||
if (abstract_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "Abstract tensor is nullptr for output " << j;
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
abstract_list.emplace_back(abstract_tensor);
|
||||
}
|
||||
new_cnode->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
STATUS AclPass::SetCustomOutputs(const FuncGraphPtr &func_graph, const CNodePtr &custom_node) {
|
||||
STATUS ret = GetFuncGraphOutputInfo(func_graph, &graph_outputs_, &graph_output_names_, &graph_outputs_dims_);
|
||||
if (ret != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Get output info of graph failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
if (graph_outputs_.empty() || graph_outputs_.size() != graph_outputs_dims_.size()) {
|
||||
MS_LOG(ERROR) << "Graph output size is error, num size: " << graph_outputs_.size()
|
||||
<< " dim size: " << graph_outputs_dims_.size();
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
custom_node->AddAttr(kOutputNames, MakeValue(graph_output_names_));
|
||||
|
||||
TypeId type = lite::acl::GetTypeFromNode(graph_outputs_[0]);
|
||||
if (graph_outputs_.size() == 1) {
|
||||
auto abstract_tensor = lite::CreateTensorAbstract(graph_outputs_dims_[0], type);
|
||||
if (abstract_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "Abstract_tensor is nullptr.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
custom_node->set_abstract(abstract_tensor);
|
||||
return lite::RET_OK;
|
||||
}
|
||||
if (SetMultiOutputs(custom_node, type) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Set multi graph output failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
CNodePtr AclPass::CreateCustomNode(const FuncGraphPtr &func_graph) {
|
||||
auto prim = std::make_unique<mindspore::ops::Custom>();
|
||||
if (prim == nullptr) {
|
||||
MS_LOG(ERROR) << "New custom op failed.";
|
||||
return nullptr;
|
||||
}
|
||||
prim->set_type(kCustomPrimTypeACL);
|
||||
auto graph_input = func_graph->get_inputs();
|
||||
CNodePtr custom_node = func_graph->NewCNode(std::shared_ptr<ops::PrimitiveC>(prim.release()), graph_input);
|
||||
if (custom_node == nullptr) {
|
||||
MS_LOG(ERROR) << "Custom cnode failed.";
|
||||
return nullptr;
|
||||
}
|
||||
custom_node->set_fullname_with_scope(kCustomNodeName);
|
||||
custom_node->add_input(om_parameter_);
|
||||
|
||||
if (SetCustomOutputs(func_graph, custom_node) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Set custom outputs failed.";
|
||||
return nullptr;
|
||||
}
|
||||
return custom_node;
|
||||
}
|
||||
|
||||
STATUS AclPass::ModifyGraphByCustomNode(const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &manager,
|
||||
const CNodePtr &custom_node) {
|
||||
if (graph_outputs_.size() == 1) {
|
||||
if (!manager->Replace(graph_outputs_[0], custom_node)) {
|
||||
MS_LOG(ERROR) << "Replace node failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
} else {
|
||||
for (size_t j = 0; j < graph_outputs_.size(); ++j) {
|
||||
auto tuple_get_item_prim_ptr = std::make_shared<lite::TupleGetItem>();
|
||||
if (tuple_get_item_prim_ptr == nullptr) {
|
||||
MS_LOG(ERROR) << "New TupleGetItem failed for output " << j;
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
auto tuple_get_item_prim = NewValueNode(tuple_get_item_prim_ptr);
|
||||
auto get_item_value = NewValueNode(MakeValue<int>(j));
|
||||
AnfNodePtrList inputs{tuple_get_item_prim, custom_node, get_item_value};
|
||||
CNodePtr get_item_cnode = func_graph->NewCNode(inputs);
|
||||
if (get_item_cnode == nullptr) {
|
||||
MS_LOG(ERROR) << "New get item cnode failed for output " << j;
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
get_item_cnode->set_fullname_with_scope(custom_node->fullname_with_scope() + "_getitem_" + std::to_string(j));
|
||||
if (!manager->Replace(graph_outputs_[j], get_item_cnode)) {
|
||||
MS_LOG(ERROR) << "Replace node failed for output " << j;
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
}
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,69 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_ACL_ACL_PASS_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_ACL_ACL_PASS_H
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "backend/optimizer/common/pass.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "include/api/types.h"
|
||||
#include "include/registry/parser_context.h"
|
||||
#include "cxx_api/model/acl/acl_model_options.h"
|
||||
|
||||
using mindspore::converter::FmkType;
|
||||
using mindspore::lite::STATUS;
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class AclPass : public Pass {
|
||||
public:
|
||||
explicit AclPass(FmkType fmk_type) : Pass("Acl"), fmk_type_(fmk_type) {}
|
||||
~AclPass() override = default;
|
||||
|
||||
bool Run(const FuncGraphPtr &func_graph) override;
|
||||
|
||||
private:
|
||||
STATUS PreProcGraph(const FuncGraphPtr &func_graph);
|
||||
STATUS PostProcGraph(const FuncGraphPtr &func_graph);
|
||||
STATUS DeparseGraph(const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &manager);
|
||||
STATUS RunPrimitiveDeparser(const FuncGraphPtr &func_graph);
|
||||
STATUS BuildGraph(const FuncGraphPtr &func_graph);
|
||||
STATUS ConvertGraphToOm(const FuncGraphPtr &func_graph, Buffer *om_data);
|
||||
ParameterPtr CreateOmParameter(const FuncGraphPtr &func_graph, const Buffer &om);
|
||||
CNodePtr CreateCustomNode(const FuncGraphPtr &func_graph);
|
||||
STATUS SetCustomOutputs(const FuncGraphPtr &func_graph, const CNodePtr &custom_node);
|
||||
STATUS SetMultiOutputs(const CNodePtr &new_cnode, TypeId data_type);
|
||||
STATUS ModifyGraphByCustomNode(const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &manager,
|
||||
const CNodePtr &custom_node);
|
||||
void SetAclModelOptions(const FuncGraphPtr &func_graph);
|
||||
STATUS GetFuncGraphOutputInfo(const FuncGraphPtr &func_graph, AnfNodePtrList *graph_outputs,
|
||||
std::vector<std::string> *graph_output_names,
|
||||
std::vector<std::vector<int64_t>> *graph_output_dims);
|
||||
|
||||
FmkType fmk_type_;
|
||||
ParameterPtr om_parameter_ = nullptr;
|
||||
CNodePtr custom_node_ = nullptr;
|
||||
std::unique_ptr<AclModelOptions> options_;
|
||||
AnfNodePtrList graph_outputs_;
|
||||
std::vector<std::string> graph_output_names_;
|
||||
std::vector<std::vector<int64_t>> graph_outputs_dims_;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_ACL_ACL_PASS_H
|
|
@ -0,0 +1,125 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/acl/common/utils.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "base/base_ref.h"
|
||||
#include "base/core_ops.h"
|
||||
#include "abstract/dshape.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
namespace acl {
|
||||
namespace {
|
||||
constexpr size_t kTupleGetItemInputSize = 3;
|
||||
constexpr size_t kSecondIndex = 1;
|
||||
constexpr size_t kInvalidSize = SIZE_MAX;
|
||||
} // namespace
|
||||
|
||||
static size_t GetTupleGetItemOutIndex(const mindspore::CNodePtr &tuple_get_item) {
|
||||
MS_ASSERT(tuple_get_item != nullptr);
|
||||
if (tuple_get_item->size() != mindspore::kTupleGetItemInputSize) {
|
||||
MS_LOG(ERROR) << "The node tuple_get_item must have 2 inputs!";
|
||||
return kInvalidSize;
|
||||
}
|
||||
auto output_index_value_node = tuple_get_item->input(mindspore::kInputNodeOutputIndexInTupleGetItem);
|
||||
MS_ASSERT(output_index_value_node != nullptr);
|
||||
auto value_node = output_index_value_node->cast<mindspore::ValueNodePtr>();
|
||||
MS_ASSERT(value_node != nullptr);
|
||||
return IntToSize(opt::CastToInt(value_node->value()).front());
|
||||
}
|
||||
|
||||
static bool CheckPrimitiveType(const mindspore::AnfNodePtr &node, const mindspore::PrimitivePtr &primitive_type) {
|
||||
if (node == nullptr) {
|
||||
return false;
|
||||
}
|
||||
if (node->isa<mindspore::CNode>()) {
|
||||
auto cnode = node->cast<mindspore::CNodePtr>();
|
||||
return IsPrimitive(cnode->input(0), primitive_type);
|
||||
} else if (node->isa<mindspore::ValueNode>()) {
|
||||
return IsPrimitive(node, primitive_type);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
STATUS GetShapeVectorFromCNode(const mindspore::CNodePtr &cnode, std::vector<int64_t> *shape_vector) {
|
||||
mindspore::AbstractBasePtr cnode_abstract;
|
||||
if (CheckPrimitiveType(cnode, mindspore::prim::kPrimTupleGetItem)) {
|
||||
auto tuple_inputs = cnode->inputs();
|
||||
MS_ASSERT(tuple_inputs.size() == kTupleGetItemInputSize);
|
||||
auto get_item_input_cnode = tuple_inputs.at(kSecondIndex);
|
||||
MS_ASSERT(get_item_input_cnode != nullptr);
|
||||
auto idx = GetTupleGetItemOutIndex(cnode);
|
||||
if (!mindspore::utils::isa<mindspore::abstract::AbstractTuplePtr>(get_item_input_cnode->abstract())) {
|
||||
MS_LOG(ERROR) << "TupleGetItem's abstract is not AbstractTuple";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
auto abstract_tuple =
|
||||
mindspore::utils::cast<mindspore::abstract::AbstractTuplePtr>(get_item_input_cnode->abstract());
|
||||
auto abstract_list = abstract_tuple->elements();
|
||||
if (abstract_list.size() <= idx) {
|
||||
MS_LOG(ERROR) << "AbstractTuple's size is smaller than expect";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
cnode_abstract = abstract_list[idx];
|
||||
} else {
|
||||
cnode_abstract = cnode->abstract();
|
||||
}
|
||||
if (cnode_abstract == nullptr) {
|
||||
MS_LOG(ERROR) << "Abstract cnode is nullptr. " << cnode->fullname_with_scope();
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
if (!mindspore::utils::isa<mindspore::abstract::AbstractTensorPtr>(cnode_abstract)) {
|
||||
MS_LOG(ERROR) << "Abstract is not abstract tensor. " << cnode->fullname_with_scope();
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
auto cnode_abstract_tensor = cnode_abstract->cast<mindspore::abstract::AbstractTensorPtr>();
|
||||
if (!mindspore::utils::isa<mindspore::abstract::ShapePtr>(cnode_abstract_tensor->BuildShape())) {
|
||||
MS_LOG(ERROR) << "Shape of abstract tensor should be ShapePtr. " << cnode->fullname_with_scope();
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
auto shape_ptr = mindspore::utils::cast<mindspore::abstract::ShapePtr>(cnode_abstract_tensor->BuildShape());
|
||||
|
||||
if (shape_ptr->shape().empty()) {
|
||||
MS_LOG(WARNING) << "Shape is empty " << cnode->fullname_with_scope();
|
||||
}
|
||||
|
||||
*shape_vector = shape_ptr->shape();
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
TypeId GetTypeFromNode(const AnfNodePtr &node) {
|
||||
TypeId type = kNumberTypeFloat32;
|
||||
if (utils::isa<CNodePtr>(node)) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (utils::isa<abstract::AbstractTensorPtr>(cnode->abstract())) {
|
||||
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(cnode->abstract());
|
||||
if (abstract_tensor == nullptr || abstract_tensor->element() == nullptr) {
|
||||
MS_LOG(WARNING) << "Abstract_tensor or abstract_tensor->element() is nullptr.";
|
||||
return type;
|
||||
}
|
||||
auto type_ptr = abstract_tensor->element()->GetTypeTrack();
|
||||
type = type_ptr->type_id();
|
||||
}
|
||||
MS_LOG(INFO) << "node type id is " << type;
|
||||
}
|
||||
return type;
|
||||
}
|
||||
} // namespace acl
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,34 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_ACL_COMMON_UTILS_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_ACL_COMMON_UTILS_H
|
||||
|
||||
#include <vector>
|
||||
#include "include/errorcode.h"
|
||||
#include "ir/anf.h"
|
||||
#include "ir/dtype/type_id.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
namespace acl {
|
||||
STATUS GetShapeVectorFromCNode(const mindspore::CNodePtr &cnode, std::vector<int64_t> *shape_vector);
|
||||
|
||||
TypeId GetTypeFromNode(const AnfNodePtr &node);
|
||||
} // namespace acl
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_ACL_ACL_PASS_H
|
|
@ -0,0 +1,67 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/acl/deparser/activation_deparser.h"
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include "tools/converter/acl/deparser/primitive_deparser_register.h"
|
||||
#include "ops/elu.h"
|
||||
#include "ops/gelu.h"
|
||||
#include "ops/leaky_relu.h"
|
||||
#include "ops/relu.h"
|
||||
#include "ops/relu6.h"
|
||||
#include "ops/sigmoid.h"
|
||||
#include "ops/tanh.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS ActivationDeparser::Deparser(const CNodePtr &cnode) {
|
||||
static std::map<ActivationType, PrimitivePtr> activation_type_map = {
|
||||
{mindspore::ELU, std::make_shared<ops::Elu>()},
|
||||
{mindspore::GELU, std::make_shared<ops::GeLU>()},
|
||||
{mindspore::RELU, std::make_shared<ops::ReLU>()},
|
||||
{mindspore::RELU6, std::make_shared<ops::ReLU6>()},
|
||||
{mindspore::SIGMOID, std::make_shared<ops::Sigmoid>()},
|
||||
{mindspore::TANH, std::make_shared<ops::Tanh>()},
|
||||
{mindspore::LEAKY_RELU, std::make_shared<ops::LeakyRelu>()}};
|
||||
ValueNodePtr value_node = nullptr;
|
||||
PrimitivePtr src_prim = nullptr;
|
||||
if (GetValueNodeAndPrimFromCnode(cnode, &value_node, &src_prim) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Get primitive from cnode failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
auto activate_prim = dynamic_cast<ops::Activation *>(src_prim.get());
|
||||
if (activate_prim == nullptr) {
|
||||
MS_LOG(ERROR) << "Dynamic cast activation failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
PrimitivePtr dst_prim = nullptr;
|
||||
ActivationType type = activate_prim->get_activation_type();
|
||||
if (activation_type_map.find(type) != activation_type_map.end()) {
|
||||
dst_prim = activation_type_map[type];
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Type " << static_cast<int>(type) << " is unsupported.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
MS_ASSERT(dst_prim != nullptr);
|
||||
dst_prim->SetAttrs(src_prim->attrs());
|
||||
value_node->set_value(dst_prim);
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
REGISTER_PRIMITIVE_DEPARSER(kNameActivation, ActivationDeparser)
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,37 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef ACL_DEPARSER_PRIMITIVE_ACTIVATION_DEPARSER_H
|
||||
#define ACL_DEPARSER_PRIMITIVE_ACTIVATION_DEPARSER_H
|
||||
|
||||
#include "tools/converter/acl/deparser/primitive_deparser.h"
|
||||
#include "ops/fusion/activation.h"
|
||||
|
||||
using mindspore::ops::kNameActivation;
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class ActivationDeparser : public PrimitiveDeparser {
|
||||
public:
|
||||
ActivationDeparser() : PrimitiveDeparser(kNameActivation) {}
|
||||
|
||||
~ActivationDeparser() override = default;
|
||||
|
||||
STATUS Deparser(const CNodePtr &cnode) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // ACL_DEPARSER_PRIMITIVE_ACTIVATION_DEPARSER_H
|
|
@ -0,0 +1,34 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/acl/deparser/add_fusion_deparser.h"
|
||||
#include <memory>
|
||||
#include "tools/converter/acl/deparser/primitive_deparser_register.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS AddFusionDeparser::Deparser(const CNodePtr &cnode) {
|
||||
auto dst_prim = std::make_shared<ops::Add>();
|
||||
if (MoveAttrMap(cnode, dst_prim) != RET_OK) {
|
||||
MS_LOG(ERROR) << "AddFusion deparser failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
REGISTER_PRIMITIVE_DEPARSER(kNameAddFusion, AddFusionDeparser)
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,37 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef ACL_DEPARSER_PRIMITIVE_ADDFUSION_DEPARSER_H
|
||||
#define ACL_DEPARSER_PRIMITIVE_ADDFUSION_DEPARSER_H
|
||||
|
||||
#include "tools/converter/acl/deparser/primitive_deparser.h"
|
||||
#include "ops/fusion/add_fusion.h"
|
||||
|
||||
using mindspore::ops::kNameAddFusion;
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class AddFusionDeparser : public PrimitiveDeparser {
|
||||
public:
|
||||
AddFusionDeparser() : PrimitiveDeparser(kNameAddFusion) {}
|
||||
|
||||
~AddFusionDeparser() override = default;
|
||||
|
||||
STATUS Deparser(const CNodePtr &cnode) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // ACL_DEPARSER_PRIMITIVE_ADDFUSION_DEPARSER_H
|
|
@ -0,0 +1,63 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/acl/deparser/avgpool_fusion_deparser.h"
|
||||
#include <memory>
|
||||
#include "tools/converter/acl/deparser/primitive_deparser_register.h"
|
||||
#include "tools/converter/acl/deparser/tbe_op_def.h"
|
||||
#include "include/registry/parser_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS AvgPoolFusionDeparser::Deparser(const CNodePtr &cnode) {
|
||||
ValueNodePtr value_node = nullptr;
|
||||
PrimitivePtr src_prim = nullptr;
|
||||
if (GetValueNodeAndPrimFromCnode(cnode, &value_node, &src_prim) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Get value node and primitive from cnode failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
|
||||
auto attr_val = src_prim->GetAttr(ops::kFmkType);
|
||||
int fmk_type = attr_val != nullptr ? GetValue<int>(attr_val) : converter::kFmkTypeTf;
|
||||
PrimitivePtr dst_prim = nullptr;
|
||||
if (fmk_type == converter::kFmkTypeCaffe) {
|
||||
dst_prim = std::make_shared<acl::Pooling>();
|
||||
} else if (fmk_type == converter::kFmkTypeOnnx) {
|
||||
ValuePtr val_ptr = src_prim->GetAttr(ops::kKernelSize);
|
||||
if (val_ptr == nullptr) {
|
||||
dst_prim = std::make_shared<acl::GlobalAveragePool>();
|
||||
} else {
|
||||
dst_prim = std::make_shared<acl::AvgPoolV2>();
|
||||
}
|
||||
} else {
|
||||
dst_prim = std::make_shared<ops::AvgPool>();
|
||||
}
|
||||
if (dst_prim == nullptr) {
|
||||
MS_LOG(ERROR) << "Get primitive by fmk type failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
dst_prim->SetAttrs(src_prim->attrs());
|
||||
if (AdjustPoolAttr(fmk_type, kNameAvgPoolFusion, dst_prim) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Adjust pool attr failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
value_node->set_value(dst_prim);
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
REGISTER_PRIMITIVE_DEPARSER(kNameAvgPoolFusion, AvgPoolFusionDeparser)
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,36 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef ACL_DEPARSER_PRIMITIVE_AVGPOOLFUSION_DEPARSER_H
|
||||
#define ACL_DEPARSER_PRIMITIVE_AVGPOOLFUSION_DEPARSER_H
|
||||
|
||||
#include "tools/converter/acl/deparser/primitive_deparser.h"
|
||||
#include "ops/fusion/avg_pool_fusion.h"
|
||||
|
||||
using mindspore::ops::kNameAvgPoolFusion;
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class AvgPoolFusionDeparser : public PrimitiveDeparser {
|
||||
public:
|
||||
AvgPoolFusionDeparser() : PrimitiveDeparser(kNameAvgPoolFusion) {}
|
||||
~AvgPoolFusionDeparser() override = default;
|
||||
|
||||
STATUS Deparser(const CNodePtr &cnode) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // ACL_DEPARSER_PRIMITIVE_AVGPOOLFUSION_DEPARSER_H
|
|
@ -0,0 +1,47 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/acl/deparser/batchnorm_deparser.h"
|
||||
#include <memory>
|
||||
#include "tools/converter/acl/deparser/primitive_deparser_register.h"
|
||||
#include "tools/converter/acl/deparser/tbe_op_def.h"
|
||||
#include "include/registry/parser_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS BatchNormDeparser::Deparser(const CNodePtr &cnode) {
|
||||
ValueNodePtr value_node = nullptr;
|
||||
PrimitivePtr src_prim = nullptr;
|
||||
if (GetValueNodeAndPrimFromCnode(cnode, &value_node, &src_prim) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Get value node and primitive from cnode failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
|
||||
auto attr_val = src_prim->GetAttr(ops::kFmkType);
|
||||
int fmk_type = attr_val != nullptr ? GetValue<int>(attr_val) : converter::kFmkTypeTf;
|
||||
if (fmk_type == converter::kFmkTypeCaffe) {
|
||||
auto dst_prim = std::make_shared<acl::BNInference>();
|
||||
if (MoveAttrMap(cnode, dst_prim) != RET_OK) {
|
||||
MS_LOG(ERROR) << "BatchNorm deparser failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
REGISTER_PRIMITIVE_DEPARSER(kNameBatchNorm, BatchNormDeparser)
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,37 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef ACL_DEPARSER_PRIMITIVE_BATCHNORM_DEPARSER_H
|
||||
#define ACL_DEPARSER_PRIMITIVE_BATCHNORM_DEPARSER_H
|
||||
|
||||
#include "tools/converter/acl/deparser/primitive_deparser.h"
|
||||
#include "ops/batch_norm.h"
|
||||
|
||||
using mindspore::ops::kNameBatchNorm;
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class BatchNormDeparser : public PrimitiveDeparser {
|
||||
public:
|
||||
BatchNormDeparser() : PrimitiveDeparser(kNameBatchNorm) {}
|
||||
|
||||
~BatchNormDeparser() override = default;
|
||||
|
||||
STATUS Deparser(const CNodePtr &cnode) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // ACL_DEPARSER_PRIMITIVE_BATCHNORM_DEPARSER_H
|
|
@ -0,0 +1,61 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/acl/deparser/cast_deparser.h"
|
||||
#include "tools/converter/acl/deparser/primitive_deparser_register.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
namespace {
|
||||
constexpr size_t kNameCastInputNum = 3;
|
||||
} // namespace
|
||||
|
||||
STATUS CastDeparser::Deparser(const CNodePtr &cnode) {
|
||||
if (cnode == nullptr) {
|
||||
MS_LOG(ERROR) << "Cnode is nullptr.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
if (cnode->size() != kNameCastInputNum) {
|
||||
MS_LOG(ERROR) << "Input size of gather must be three.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
// convert last parameter to const value node
|
||||
auto to_input = cnode->input(kNameCastInputNum - 1);
|
||||
if (!utils::isa<ParameterPtr>(to_input)) {
|
||||
MS_LOG(ERROR) << "The to node is not parameter.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
ParameterPtr to_param = to_input->cast<ParameterPtr>();
|
||||
auto data = opt::GetIntParameterData(to_param);
|
||||
int dst_type = data.empty() ? kNumberTypeInt32 : data.front();
|
||||
TypePtr type_ptr = TypeIdToType(TypeId(dst_type));
|
||||
if (type_ptr == nullptr) {
|
||||
MS_LOG(ERROR) << "New type ptr failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
ValueNodePtr value_node = NewValueNode(type_ptr);
|
||||
if (value_node == nullptr) {
|
||||
MS_LOG(ERROR) << "New value node failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
cnode->set_input(kNameCastInputNum - 1, value_node);
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
REGISTER_PRIMITIVE_DEPARSER(kNameCast, CastDeparser)
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,37 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef ACL_DEPARSER_PRIMITIVE_CAST_DEPARSER_H
|
||||
#define ACL_DEPARSER_PRIMITIVE_CAST_DEPARSER_H
|
||||
|
||||
#include "tools/converter/acl/deparser/primitive_deparser.h"
|
||||
#include "ops/cast.h"
|
||||
|
||||
using mindspore::ops::kNameCast;
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class CastDeparser : public PrimitiveDeparser {
|
||||
public:
|
||||
CastDeparser() : PrimitiveDeparser(kNameCast) {}
|
||||
|
||||
~CastDeparser() override = default;
|
||||
|
||||
STATUS Deparser(const CNodePtr &cnode) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // ACL_DEPARSER_PRIMITIVE_CAST_DEPARSER_H
|
|
@ -0,0 +1,53 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/acl/deparser/concat_deparser.h"
|
||||
#include "tools/converter/acl/deparser/primitive_deparser_register.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
namespace {
|
||||
constexpr auto kNameInputNums = "inputNums";
|
||||
}
|
||||
|
||||
STATUS ConcatDeparser::Deparser(const CNodePtr &cnode) {
|
||||
if (AddAttrForDynInputPrimitive(cnode) != RET_OK) {
|
||||
MS_LOG(ERROR) << "Concat deparser failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS ConcatDeparser::AddAttrForDynInputPrimitive(const CNodePtr &cnode) {
|
||||
MS_ASSERT(cnode != nullptr);
|
||||
auto value_node = cnode->input(0)->cast<ValueNodePtr>();
|
||||
MS_ASSERT(value_node != nullptr);
|
||||
auto prim = GetValueNode<PrimitivePtr>(value_node);
|
||||
if (prim == nullptr) {
|
||||
MS_LOG(ERROR) << "Value node is invalid.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
// add attr input num for dynamic input op
|
||||
int64_t num = static_cast<int64_t>(cnode->size());
|
||||
if (num > 1) {
|
||||
prim->AddAttr(kNameInputNums, MakeValue(num - 1));
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
REGISTER_PRIMITIVE_DEPARSER(kNameConcat, ConcatDeparser)
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,40 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef ACL_DEPARSER_PRIMITIVE_CONCAT_DEPARSER_H
|
||||
#define ACL_DEPARSER_PRIMITIVE_CONCAT_DEPARSER_H
|
||||
|
||||
#include "tools/converter/acl/deparser/primitive_deparser.h"
|
||||
#include "ops/concat.h"
|
||||
|
||||
using mindspore::ops::kNameConcat;
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class ConcatDeparser : public PrimitiveDeparser {
|
||||
public:
|
||||
ConcatDeparser() : PrimitiveDeparser(kNameConcat) {}
|
||||
|
||||
~ConcatDeparser() override = default;
|
||||
|
||||
STATUS Deparser(const CNodePtr &cnode) override;
|
||||
|
||||
private:
|
||||
STATUS AddAttrForDynInputPrimitive(const CNodePtr &cnode);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // ACL_DEPARSER_PRIMITIVE_CONCAT_DEPARSER_H
|
|
@ -0,0 +1,50 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/acl/deparser/conv2d_fusion_deparser.h"
|
||||
#include "memory"
|
||||
#include "tools/converter/acl/deparser/primitive_deparser_register.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS Conv2DFusionDeparser::Deparser(const CNodePtr &cnode) {
|
||||
ValueNodePtr value_node = nullptr;
|
||||
PrimitivePtr src_prim = nullptr;
|
||||
if (GetValueNodeAndPrimFromCnode(cnode, &value_node, &src_prim) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Get primitive from cnode failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
auto dst_prim = std::make_shared<ops::Conv2D>();
|
||||
MS_ASSERT(dst_prim != nullptr);
|
||||
dst_prim->SetAttrs(src_prim->attrs());
|
||||
|
||||
auto status = AttrAdjust(dst_prim, ops::kStride);
|
||||
if (status != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "adjust stride failed.";
|
||||
return status;
|
||||
}
|
||||
status = AttrAdjust(dst_prim, ops::kDilation);
|
||||
if (status != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "adjust dilation failed.";
|
||||
return status;
|
||||
}
|
||||
value_node->set_value(dst_prim);
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
REGISTER_PRIMITIVE_DEPARSER(kNameConv2DFusion, Conv2DFusionDeparser)
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,36 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef ACL_DEPARSER_PRIMITIVE_CONV2DFUSION_DEPARSER_H
|
||||
#define ACL_DEPARSER_PRIMITIVE_CONV2DFUSION_DEPARSER_H
|
||||
|
||||
#include "tools/converter/acl/deparser/primitive_deparser.h"
|
||||
#include "ops/fusion/conv2d_fusion.h"
|
||||
|
||||
using mindspore::ops::kNameConv2DFusion;
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class Conv2DFusionDeparser : public PrimitiveDeparser {
|
||||
public:
|
||||
Conv2DFusionDeparser() : PrimitiveDeparser(kNameConv2DFusion) {}
|
||||
~Conv2DFusionDeparser() override = default;
|
||||
|
||||
STATUS Deparser(const CNodePtr &cnode) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // ACL_DEPARSER_PRIMITIVE_CONV2DFUSION_DEPARSER_H
|
|
@ -0,0 +1,54 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/acl/deparser/conv2d_transpose_fusion_deparser.h"
|
||||
#include <memory>
|
||||
#include "tools/converter/acl/deparser/primitive_deparser_register.h"
|
||||
#include "include/registry/parser_context.h"
|
||||
#include "tools/converter/acl/deparser/tbe_op_def.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS Conv2dTransposeFusionDeparser::Deparser(const CNodePtr &cnode) {
|
||||
ValueNodePtr value_node = nullptr;
|
||||
PrimitivePtr src_prim = nullptr;
|
||||
if (GetValueNodeAndPrimFromCnode(cnode, &value_node, &src_prim) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Get primitive from cnode failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
|
||||
auto attr_val = src_prim->GetAttr(ops::kFmkType);
|
||||
int fmk_type = attr_val != nullptr ? GetValue<int>(attr_val) : converter::kFmkTypeTf;
|
||||
PrimitivePtr dst_prim = nullptr;
|
||||
if (fmk_type == converter::kFmkTypeCaffe) {
|
||||
dst_prim = std::make_shared<acl::Deconvolution>();
|
||||
} else {
|
||||
dst_prim = std::make_shared<ops::Conv2DTranspose>();
|
||||
}
|
||||
MS_ASSERT(dst_prim != nullptr);
|
||||
dst_prim->SetAttrs(src_prim->attrs());
|
||||
auto status = AttrAdjust(dst_prim, ops::kDilation);
|
||||
if (status != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "adjust failed.";
|
||||
return status;
|
||||
}
|
||||
value_node->set_value(dst_prim);
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
REGISTER_PRIMITIVE_DEPARSER(kNameConv2dTransposeFusion, Conv2dTransposeFusionDeparser)
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,36 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef ACL_DEPARSER_PRIMITIVE_CONV2DTRANSPOSEFUSION_DEPARSER_H
|
||||
#define ACL_DEPARSER_PRIMITIVE_CONV2DTRANSPOSEFUSION_DEPARSER_H
|
||||
|
||||
#include "tools/converter/acl/deparser/primitive_deparser.h"
|
||||
#include "ops/fusion/conv2d_transpose_fusion.h"
|
||||
|
||||
using mindspore::ops::kNameConv2dTransposeFusion;
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class Conv2dTransposeFusionDeparser : public PrimitiveDeparser {
|
||||
public:
|
||||
Conv2dTransposeFusionDeparser() : PrimitiveDeparser(kNameConv2dTransposeFusion) {}
|
||||
~Conv2dTransposeFusionDeparser() override = default;
|
||||
|
||||
STATUS Deparser(const CNodePtr &cnode) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // ACL_DEPARSER_PRIMITIVE_CONV2DTRANSPOSEFUSION_DEPARSER_H
|
|
@ -0,0 +1,50 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/acl/deparser/eltwise_deparser.h"
|
||||
#include "tools/converter/acl/deparser/primitive_deparser_register.h"
|
||||
#include "ops/op_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS EltWiseDeparser::Deparser(const CNodePtr &cnode) {
|
||||
if (AddAttrForDynInputPrimitive(cnode) != RET_OK) {
|
||||
MS_LOG(ERROR) << "EltWise deparser failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS EltWiseDeparser::AddAttrForDynInputPrimitive(const CNodePtr &cnode) {
|
||||
MS_ASSERT(cnode != nullptr);
|
||||
auto value_node = cnode->input(0)->cast<ValueNodePtr>();
|
||||
MS_ASSERT(value_node != nullptr);
|
||||
auto prim = GetValueNode<PrimitivePtr>(value_node);
|
||||
if (prim == nullptr) {
|
||||
MS_LOG(ERROR) << "Value node is invalid.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
// add attr input num for dynamic input op
|
||||
int64_t num = static_cast<int64_t>(cnode->size());
|
||||
if (num > 1) {
|
||||
prim->AddAttr(ops::kN, MakeValue(num - 1));
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
REGISTER_PRIMITIVE_DEPARSER(kNameEltwise, EltWiseDeparser)
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,40 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef ACL_DEPARSER_PRIMITIVE_ELTWISE_DEPARSER_H
|
||||
#define ACL_DEPARSER_PRIMITIVE_ELTWISE_DEPARSER_H
|
||||
|
||||
#include "tools/converter/acl/deparser/primitive_deparser.h"
|
||||
#include "ops/eltwise.h"
|
||||
|
||||
using mindspore::ops::kNameEltwise;
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class EltWiseDeparser : public PrimitiveDeparser {
|
||||
public:
|
||||
EltWiseDeparser() : PrimitiveDeparser(kNameEltwise) {}
|
||||
|
||||
~EltWiseDeparser() override = default;
|
||||
|
||||
STATUS Deparser(const CNodePtr &cnode) override;
|
||||
|
||||
private:
|
||||
STATUS AddAttrForDynInputPrimitive(const CNodePtr &cnode);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // ACL_DEPARSER_PRIMITIVE_ELTWISE_DEPARSER_H
|
|
@ -0,0 +1,36 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/acl/deparser/fused_batchnorm_deparser.h"
|
||||
#include "tools/converter/acl/deparser/primitive_deparser_register.h"
|
||||
#include "ops/op_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS FusedBatchNormDeparser::Deparser(const CNodePtr &cnode) {
|
||||
ValueNodePtr value_node = nullptr;
|
||||
PrimitivePtr src_prim = nullptr;
|
||||
if (GetValueNodeAndPrimFromCnode(cnode, &value_node, &src_prim) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Get primitive from cnode failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
src_prim->AddAttr(ops::kIsTraining, MakeValue(false));
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
REGISTER_PRIMITIVE_DEPARSER(kNameFusedBatchNorm, FusedBatchNormDeparser)
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,36 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef ACL_DEPARSER_PRIMITIVE_FUSEDBATCHNORM_DEPARSER_H
|
||||
#define ACL_DEPARSER_PRIMITIVE_FUSEDBATCHNORM_DEPARSER_H
|
||||
|
||||
#include "tools/converter/acl/deparser/primitive_deparser.h"
|
||||
#include "ops/fused_batch_norm.h"
|
||||
|
||||
using mindspore::ops::kNameFusedBatchNorm;
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class FusedBatchNormDeparser : public PrimitiveDeparser {
|
||||
public:
|
||||
FusedBatchNormDeparser() : PrimitiveDeparser(kNameFusedBatchNorm) {}
|
||||
~FusedBatchNormDeparser() override = default;
|
||||
|
||||
STATUS Deparser(const CNodePtr &cnode) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // ACL_DEPARSER_PRIMITIVE_FUSEDBATCHNORM_DEPARSER_H
|
|
@ -0,0 +1,56 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/acl/deparser/gather_fusion_deparser.h"
|
||||
#include "tools/converter/acl/deparser/primitive_deparser_register.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
namespace {
|
||||
constexpr size_t kNameGatherInputNum = 4;
|
||||
}
|
||||
|
||||
STATUS GatherDeparser::Deparser(const CNodePtr &cnode) {
|
||||
if (cnode == nullptr) {
|
||||
MS_LOG(ERROR) << "Cnode is nullptr.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
if (cnode->size() != kNameGatherInputNum) {
|
||||
MS_LOG(ERROR) << "Input size of gather must be four.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
// convert last parameter to const value node
|
||||
auto axis_input = cnode->input(kNameGatherInputNum - 1);
|
||||
if (!utils::isa<ParameterPtr>(axis_input)) {
|
||||
MS_LOG(ERROR) << "The axis node is not parameter.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
ParameterPtr axis_param = axis_input->cast<ParameterPtr>();
|
||||
auto data = opt::GetIntParameterData(axis_param);
|
||||
int64_t axis = data.empty() ? 0 : static_cast<int64_t>(data.front());
|
||||
ValueNodePtr value_node = NewValueNode<int64_t>(axis);
|
||||
if (value_node == nullptr) {
|
||||
MS_LOG(ERROR) << "New value node failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
cnode->set_input(kNameGatherInputNum - 1, value_node);
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
REGISTER_PRIMITIVE_DEPARSER(kNameGather, GatherDeparser)
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,36 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef ACL_DEPARSER_PRIMITIVE_GATHER_DEPARSER_H
|
||||
#define ACL_DEPARSER_PRIMITIVE_GATHER_DEPARSER_H
|
||||
|
||||
#include "tools/converter/acl/deparser/primitive_deparser.h"
|
||||
#include "ops/gather.h"
|
||||
|
||||
using mindspore::ops::kNameGather;
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class GatherDeparser : public PrimitiveDeparser {
|
||||
public:
|
||||
GatherDeparser() : PrimitiveDeparser(kNameGather) {}
|
||||
~GatherDeparser() override = default;
|
||||
|
||||
STATUS Deparser(const CNodePtr &cnode) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // ACL_DEPARSER_PRIMITIVE_GATHER_DEPARSER_H
|
|
@ -0,0 +1,58 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/acl/deparser/maxpool_fusion_deparser.h"
|
||||
#include <memory>
|
||||
#include "tools/converter/acl/deparser/primitive_deparser_register.h"
|
||||
#include "tools/converter/acl/deparser/tbe_op_def.h"
|
||||
#include "include/registry/parser_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS MaxPoolFusionDeparser::Deparser(const CNodePtr &cnode) {
|
||||
ValueNodePtr value_node = nullptr;
|
||||
PrimitivePtr src_prim = nullptr;
|
||||
if (GetValueNodeAndPrimFromCnode(cnode, &value_node, &src_prim) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Get value node and primitive from cnode failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
|
||||
auto attr_val = src_prim->GetAttr(ops::kFmkType);
|
||||
int fmk_type = attr_val != nullptr ? GetValue<int>(attr_val) : converter::kFmkTypeTf;
|
||||
PrimitivePtr dst_prim = nullptr;
|
||||
if (fmk_type == converter::kFmkTypeCaffe) {
|
||||
dst_prim = std::make_shared<acl::Pooling>();
|
||||
} else if (fmk_type == converter::kFmkTypeOnnx) {
|
||||
dst_prim = std::make_shared<acl::MaxPoolV3>();
|
||||
} else {
|
||||
dst_prim = std::make_shared<ops::MaxPool>();
|
||||
}
|
||||
if (dst_prim == nullptr) {
|
||||
MS_LOG(ERROR) << "Get primitive by fmk type failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
dst_prim->SetAttrs(src_prim->attrs());
|
||||
if (AdjustPoolAttr(fmk_type, kNameMaxPoolFusion, dst_prim) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Adjust pool attr failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
value_node->set_value(dst_prim);
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
REGISTER_PRIMITIVE_DEPARSER(kNameMaxPoolFusion, MaxPoolFusionDeparser)
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,36 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef ACL_DEPARSER_PRIMITIVE_MAXPOOLFUSION_DEPARSER_H
|
||||
#define ACL_DEPARSER_PRIMITIVE_MAXPOOLFUSION_DEPARSER_H
|
||||
|
||||
#include "tools/converter/acl/deparser/primitive_deparser.h"
|
||||
#include "ops/fusion/max_pool_fusion.h"
|
||||
|
||||
using mindspore::ops::kNameMaxPoolFusion;
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class MaxPoolFusionDeparser : public PrimitiveDeparser {
|
||||
public:
|
||||
MaxPoolFusionDeparser() : PrimitiveDeparser(kNameMaxPoolFusion) {}
|
||||
~MaxPoolFusionDeparser() override = default;
|
||||
|
||||
STATUS Deparser(const CNodePtr &cnode) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // ACL_DEPARSER_PRIMITIVE_MAXPOOLFUSION_DEPARSER_H
|
|
@ -0,0 +1,34 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/acl/deparser/mul_fusion_deparser.h"
|
||||
#include <memory>
|
||||
#include "tools/converter/acl/deparser/primitive_deparser_register.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS MulFusionDeparser::Deparser(const CNodePtr &cnode) {
|
||||
auto dst_prim = std::make_shared<ops::Mul>();
|
||||
if (MoveAttrMap(cnode, dst_prim) != RET_OK) {
|
||||
MS_LOG(ERROR) << "MulFusion deparser failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
REGISTER_PRIMITIVE_DEPARSER(kNameMulFusion, MulFusionDeparser)
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,37 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef ACL_DEPARSER_PRIMITIVE_MULFUSION_DEPARSER_H
|
||||
#define ACL_DEPARSER_PRIMITIVE_MULFUSION_DEPARSER_H
|
||||
|
||||
#include "tools/converter/acl/deparser/primitive_deparser.h"
|
||||
#include "ops/fusion/mul_fusion.h"
|
||||
|
||||
using mindspore::ops::kNameMulFusion;
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class MulFusionDeparser : public PrimitiveDeparser {
|
||||
public:
|
||||
MulFusionDeparser() : PrimitiveDeparser(kNameMulFusion) {}
|
||||
|
||||
~MulFusionDeparser() override = default;
|
||||
|
||||
STATUS Deparser(const CNodePtr &cnode) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // ACL_DEPARSER_PRIMITIVE_MULFUSION_DEPARSER_H
|
|
@ -0,0 +1,83 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/acl/deparser/pad_fusion_deparser.h"
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include "tools/converter/acl/deparser/primitive_deparser_register.h"
|
||||
#include "tools/converter/acl/deparser/tbe_op_def.h"
|
||||
#include "ops/op_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
namespace {
|
||||
constexpr size_t kNamePadInputNum = 3;
|
||||
constexpr auto kNamePadContiguous = "pad_contiguous";
|
||||
} // namespace
|
||||
|
||||
STATUS PadFusionDeparser::Deparser(const CNodePtr &cnode) {
|
||||
ValueNodePtr value_node = nullptr;
|
||||
PrimitivePtr src_prim = nullptr;
|
||||
if (GetValueNodeAndPrimFromCnode(cnode, &value_node, &src_prim) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Get primitive from cnode failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
auto dst_prim = std::make_shared<acl::PadV3>();
|
||||
MS_ASSERT(dst_prim != nullptr);
|
||||
dst_prim->SetAttrs(src_prim->attrs());
|
||||
AdjustPadAttr(dst_prim);
|
||||
|
||||
if (cnode->size() != kNamePadInputNum) {
|
||||
MS_LOG(INFO) << "No need to add attr to input, input num: " << cnode->size();
|
||||
value_node->set_value(dst_prim);
|
||||
return lite::RET_OK;
|
||||
}
|
||||
auto func_graph = cnode->func_graph();
|
||||
if (func_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "Func graph is nullptr.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
int status = AddAttrToInput(func_graph, cnode, dst_prim, ops::kConstantValue, 2);
|
||||
if (status != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Add constant value to input failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
value_node->set_value(dst_prim);
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
void PadFusionDeparser::AdjustPadAttr(const PrimitivePtr &dst_prim) {
|
||||
static std::map<int64_t, std::string> kPadModeToStrMap = {
|
||||
{PaddingMode::CONSTANT, "constant"},
|
||||
{PaddingMode::REFLECT, "reflect"},
|
||||
{PaddingMode::SYMMETRIC, "edge"},
|
||||
};
|
||||
auto pad_mode_value = dst_prim->GetAttr(ops::kPaddingMode);
|
||||
if (pad_mode_value != nullptr) {
|
||||
auto pad_mode = GetValue<int64_t>(pad_mode_value);
|
||||
if (kPadModeToStrMap.find(pad_mode) != kPadModeToStrMap.end()) {
|
||||
dst_prim->AddAttr(ops::kMode, MakeValue(kPadModeToStrMap[pad_mode]));
|
||||
dst_prim->DelAttr(ops::kPaddingMode);
|
||||
}
|
||||
}
|
||||
|
||||
dst_prim->AddAttr(kNamePadContiguous, MakeValue(true));
|
||||
}
|
||||
|
||||
REGISTER_PRIMITIVE_DEPARSER(kNamePadFusion, PadFusionDeparser)
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,39 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef ACL_DEPARSER_PRIMITIVE_PADFUSION_DEPARSER_H
|
||||
#define ACL_DEPARSER_PRIMITIVE_PADFUSION_DEPARSER_H
|
||||
|
||||
#include "tools/converter/acl/deparser/primitive_deparser.h"
|
||||
#include "ops/fusion/pad_fusion.h"
|
||||
|
||||
using mindspore::ops::kNamePadFusion;
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class PadFusionDeparser : public PrimitiveDeparser {
|
||||
public:
|
||||
PadFusionDeparser() : PrimitiveDeparser(kNamePadFusion) {}
|
||||
~PadFusionDeparser() override = default;
|
||||
|
||||
STATUS Deparser(const CNodePtr &cnode) override;
|
||||
|
||||
private:
|
||||
void AdjustPadAttr(const PrimitivePtr &dst_prim);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // ACL_DEPARSER_PRIMITIVE_PADFUSION_DEPARSER_H
|
|
@ -0,0 +1,34 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/acl/deparser/prelu_fusion_deparser.h"
|
||||
#include <memory>
|
||||
#include "tools/converter/acl/deparser/primitive_deparser_register.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS PReluFusionDeparser::Deparser(const CNodePtr &cnode) {
|
||||
auto dst_prim = std::make_shared<ops::PReLU>();
|
||||
if (MoveAttrMap(cnode, dst_prim) != RET_OK) {
|
||||
MS_LOG(ERROR) << "PReluFusion deparser failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
REGISTER_PRIMITIVE_DEPARSER(kNamePReLUFusion, PReluFusionDeparser)
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,36 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef ACL_DEPARSER_PRIMITIVE_PRELUFUSION_DEPARSER_H
|
||||
#define ACL_DEPARSER_PRIMITIVE_PRELUFUSION_DEPARSER_H
|
||||
|
||||
#include "tools/converter/acl/deparser/primitive_deparser.h"
|
||||
#include "ops/fusion/prelu_fusion.h"
|
||||
|
||||
using mindspore::ops::kNamePReLUFusion;
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class PReluFusionDeparser : public PrimitiveDeparser {
|
||||
public:
|
||||
PReluFusionDeparser() : PrimitiveDeparser(kNamePReLUFusion) {}
|
||||
~PReluFusionDeparser() override = default;
|
||||
|
||||
STATUS Deparser(const CNodePtr &cnode) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // ACL_DEPARSER_PRIMITIVE_PRELUFUSION_DEPARSER_H
|
|
@ -0,0 +1,198 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/acl/deparser/primitive_deparser.h"
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include "tools/converter/acl/common/utils.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "ir/graph_utils.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "include/registry/parser_context.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "ops/fusion/avg_pool_fusion.h"
|
||||
#include "backend/kernel_compiler/cpu/nnacl/op_base.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
namespace {
|
||||
constexpr auto kCommonAttrValueNum = 2;
|
||||
constexpr auto kNamePaddingMode = "padding_mode";
|
||||
constexpr auto kNameCeilMode = "ceil_mode";
|
||||
} // namespace
|
||||
|
||||
STATUS PrimitiveDeparser::Deparser(const CNodePtr &cnode) { return lite::RET_OK; }
|
||||
|
||||
STATUS PrimitiveDeparser::GetValueNodeAndPrimFromCnode(const CNodePtr &cnode, ValueNodePtr *value_node,
|
||||
PrimitivePtr *prim_ptr) {
|
||||
CHECK_NULL_RETURN(cnode);
|
||||
CHECK_NULL_RETURN(value_node);
|
||||
CHECK_NULL_RETURN(prim_ptr);
|
||||
|
||||
*value_node = cnode->input(0)->cast<ValueNodePtr>();
|
||||
if (*value_node == nullptr) {
|
||||
MS_LOG(ERROR) << "Value node[" << cnode->fullname_with_scope() << "] is nullptr.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
*prim_ptr = GetValueNode<PrimitivePtr>(*value_node);
|
||||
if (*prim_ptr == nullptr) {
|
||||
MS_LOG(ERROR) << "Value node[" << cnode->fullname_with_scope() << "] cast to primitive failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
STATUS PrimitiveDeparser::AttrAdjust(const PrimitivePtr &prim, const std::string &name) {
|
||||
auto value_ptr = prim->GetAttr(name);
|
||||
if (value_ptr == nullptr) {
|
||||
MS_LOG(WARNING) << prim->name() << " has no attr " << name;
|
||||
return lite::RET_OK;
|
||||
}
|
||||
if (utils::isa<ValueSequeuePtr>(value_ptr)) {
|
||||
auto val_seq_ptr = value_ptr->cast<ValueSequeuePtr>();
|
||||
CHECK_NULL_RETURN(val_seq_ptr);
|
||||
ValuePtr first_val = nullptr;
|
||||
if (!val_seq_ptr->value().empty()) {
|
||||
first_val = val_seq_ptr->value().front();
|
||||
}
|
||||
CHECK_NULL_RETURN(first_val);
|
||||
CHECK_NULL_RETURN(first_val->type());
|
||||
if (first_val->type()->number_type() != kNumberTypeInt64) {
|
||||
MS_LOG(ERROR) << "Value number type of name: " << prim->name() << " ,please check the attr name: " << name;
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
} else {
|
||||
CHECK_NULL_RETURN(value_ptr->type());
|
||||
if (value_ptr->type()->number_type() != kNumberTypeInt64) {
|
||||
MS_LOG(ERROR) << "Value number type of name: " << prim->name() << " ,please check the attr name: " << name;
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
}
|
||||
auto origin_value = opt::CastToInt(value_ptr);
|
||||
if (origin_value.size() != kCommonAttrValueNum) {
|
||||
MS_LOG(ERROR) << name << " Value num must be two.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
std::vector<int64_t> new_value;
|
||||
new_value.push_back(1);
|
||||
new_value.push_back(1);
|
||||
new_value.push_back(static_cast<int64_t>(origin_value[0]));
|
||||
new_value.push_back(static_cast<int64_t>(origin_value[1]));
|
||||
prim->AddAttr(name, MakeValue(new_value));
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
void PrimitiveDeparser::AdjustCaffePoolAttr(const std::string &src_prim_name, const PrimitivePtr &dst_prim) {
|
||||
int64_t mode = src_prim_name == ops::kNameAvgPoolFusion ? 1 : 0;
|
||||
dst_prim->AddAttr(ops::kMode, MakeValue(mode));
|
||||
|
||||
auto run_mode_val = dst_prim->GetAttr(ops::kRoundMode);
|
||||
auto run_mode = GetValue<int64_t>(run_mode_val);
|
||||
int64_t run_mode_ge = run_mode == RoundMode::FLOOR ? 1 : 0;
|
||||
dst_prim->set_attr(ops::kRoundMode, MakeValue(run_mode_ge));
|
||||
}
|
||||
|
||||
void PrimitiveDeparser::AdjustOnnxPoolAttr(const PrimitivePtr &dst_prim) {
|
||||
static std::map<int64_t, std::string> kPadModToStrMap = {
|
||||
{PadMode::PAD, "CALCULATED"},
|
||||
{PadMode::SAME, "SAME"},
|
||||
{PadMode::VALID, "VALID"},
|
||||
};
|
||||
auto pad_mode_val = dst_prim->GetAttr(ops::kPadMode);
|
||||
auto pad_mode = GetValue<int64_t>(pad_mode_val);
|
||||
std::string padding_mode = "CALCULATED";
|
||||
if (kPadModToStrMap.find(pad_mode) != kPadModToStrMap.end()) {
|
||||
padding_mode = kPadModToStrMap[pad_mode];
|
||||
}
|
||||
dst_prim->AddAttr(kNamePaddingMode, MakeValue(padding_mode));
|
||||
|
||||
auto run_mode_val = dst_prim->GetAttr(ops::kRoundMode);
|
||||
int64_t run_mode = GetValue<int64_t>(run_mode_val);
|
||||
bool ceil_mode = run_mode == RoundMode::CEIL;
|
||||
dst_prim->AddAttr(kNameCeilMode, MakeValue(ceil_mode));
|
||||
}
|
||||
|
||||
STATUS PrimitiveDeparser::AdjustPoolAttr(int fmk_type, const std::string &src_prim_name, const PrimitivePtr &dst_prim) {
|
||||
if (fmk_type == converter::kFmkTypeCaffe) {
|
||||
AdjustCaffePoolAttr(src_prim_name, dst_prim);
|
||||
return lite::RET_OK;
|
||||
} else if (fmk_type == converter::kFmkTypeOnnx) {
|
||||
AdjustOnnxPoolAttr(dst_prim);
|
||||
}
|
||||
// adjust common attr
|
||||
auto status = AttrAdjust(dst_prim, ops::kKernelSize);
|
||||
if (status != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Adjust kernel size failed.";
|
||||
return status;
|
||||
}
|
||||
status = AttrAdjust(dst_prim, ops::kStrides);
|
||||
if (status != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "adjust strides failed.";
|
||||
return status;
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
STATUS PrimitiveDeparser::MoveAttrMap(const CNodePtr &cnode, const PrimitivePtr &dst_prim) {
|
||||
ValueNodePtr value_node = nullptr;
|
||||
PrimitivePtr src_prim = nullptr;
|
||||
if (GetValueNodeAndPrimFromCnode(cnode, &value_node, &src_prim) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Get primitive from cnode failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
if (dst_prim == nullptr) {
|
||||
MS_LOG(ERROR) << "Primitive is nullptr.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
dst_prim->SetAttrs(src_prim->attrs());
|
||||
value_node->set_value(dst_prim);
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
STATUS PrimitiveDeparser::AddAttrToInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
||||
const PrimitivePtr &dst_prim, const std::string &attr_name, int flag) {
|
||||
auto attr_val = dst_prim->GetAttr(attr_name);
|
||||
if (attr_val == nullptr) {
|
||||
MS_LOG(INFO) << "There is no attr: " << attr_name;
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
auto inputs = cnode->inputs();
|
||||
switch (flag) {
|
||||
case (1): {
|
||||
auto value_data = opt::CastToVec2DInt(attr_val);
|
||||
auto param_node =
|
||||
opt::BuildIntVec2DParameterNode(func_graph, value_data, cnode->fullname_with_scope() + "_" + attr_name);
|
||||
inputs.push_back(param_node);
|
||||
break;
|
||||
}
|
||||
case (2): {
|
||||
auto value_data = GetValue<float>(attr_val);
|
||||
auto param_node =
|
||||
opt::BuildFloatValueParameterNode(func_graph, value_data, cnode->fullname_with_scope() + "_" + attr_name);
|
||||
inputs.push_back(param_node);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
MS_LOG(ERROR) << "Invalid flag for attr: " << flag;
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
cnode->set_inputs(inputs);
|
||||
return lite::RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,59 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef ACL_DEPARSER_PRIMITIVE_DEPARSER_H
|
||||
#define ACL_DEPARSER_PRIMITIVE_DEPARSER_H
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "base/base.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "ir/anf.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class PrimitiveDeparser {
|
||||
public:
|
||||
explicit PrimitiveDeparser(const std::string &name) : name_(name) {}
|
||||
|
||||
virtual ~PrimitiveDeparser() = default;
|
||||
|
||||
virtual STATUS Deparser(const CNodePtr &cnode);
|
||||
|
||||
protected:
|
||||
STATUS AttrAdjust(const PrimitivePtr &prim, const std::string &name);
|
||||
|
||||
STATUS MoveAttrMap(const CNodePtr &cnode, const PrimitivePtr &dst_prim);
|
||||
|
||||
STATUS GetValueNodeAndPrimFromCnode(const CNodePtr &cnode, ValueNodePtr *value_node, PrimitivePtr *prim_ptr);
|
||||
|
||||
STATUS AdjustPoolAttr(int fmk_type, const std::string &src_prim_name, const PrimitivePtr &dst_prim);
|
||||
|
||||
STATUS AddAttrToInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const PrimitivePtr &dst_prim,
|
||||
const std::string &attr_name, int flag);
|
||||
|
||||
private:
|
||||
void AdjustCaffePoolAttr(const std::string &src_prim_name, const PrimitivePtr &dst_prim);
|
||||
|
||||
void AdjustOnnxPoolAttr(const PrimitivePtr &dst_prim);
|
||||
|
||||
std::string name_;
|
||||
};
|
||||
|
||||
using PrimitiveDeparserPtr = std::shared_ptr<PrimitiveDeparser>;
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // ACL_DEPARSER_PRIMITIVE_DEPARSER_H
|
|
@ -0,0 +1,44 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/acl/deparser/primitive_deparser_register.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
PrimitiveDeparserRegister &PrimitiveDeparserRegister::GetInstance() {
|
||||
static PrimitiveDeparserRegister instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
void PrimitiveDeparserRegister::InsertPrimitiveDeparser(const std::string &name, const PrimitiveDeparserPtr &deparser) {
|
||||
deparser_[name] = deparser;
|
||||
}
|
||||
|
||||
PrimitiveDeparserPtr PrimitiveDeparserRegister::GetPrimitiveDeparser(const std::string &name) {
|
||||
if (deparser_.find(name) != deparser_.end()) {
|
||||
return deparser_[name];
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "Unsupported primitive name : " << name;
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
RegisterPrimitiveDeparser::RegisterPrimitiveDeparser(const std::string &name, const PrimitiveDeparserPtr &deparser) {
|
||||
PrimitiveDeparserRegister::GetInstance().InsertPrimitiveDeparser(name, deparser);
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,54 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef ACL_DEPARSER_PRIMITIVE_DEPARSER_REGISTER_H
|
||||
#define ACL_DEPARSER_PRIMITIVE_DEPARSER_REGISTER_H
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "ir/anf.h"
|
||||
#include "tools/converter/acl/deparser/primitive_deparser.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class PrimitiveDeparserRegister {
|
||||
public:
|
||||
static PrimitiveDeparserRegister &GetInstance();
|
||||
|
||||
void InsertPrimitiveDeparser(const std::string &name, const PrimitiveDeparserPtr &deparser);
|
||||
|
||||
PrimitiveDeparserPtr GetPrimitiveDeparser(const std::string &name);
|
||||
|
||||
private:
|
||||
PrimitiveDeparserRegister() = default;
|
||||
~PrimitiveDeparserRegister() = default;
|
||||
|
||||
std::map<std::string, PrimitiveDeparserPtr> deparser_;
|
||||
};
|
||||
|
||||
class RegisterPrimitiveDeparser {
|
||||
public:
|
||||
RegisterPrimitiveDeparser(const std::string &name, const PrimitiveDeparserPtr &deparser);
|
||||
|
||||
~RegisterPrimitiveDeparser() = default;
|
||||
};
|
||||
|
||||
#define REGISTER_PRIMITIVE_DEPARSER(name, deparser) \
|
||||
static RegisterPrimitiveDeparser g_##name##PrimDeparser(name, std::make_shared<deparser>());
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // ACL_DEPARSER_PRIMITIVE_DEPARSER_REGISTER_H
|
|
@ -0,0 +1,34 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/acl/deparser/scale_fusion_deparser.h"
|
||||
#include <memory>
|
||||
#include "tools/converter/acl/deparser/primitive_deparser_register.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS ScaleFusionDeparser::Deparser(const CNodePtr &cnode) {
|
||||
auto dst_prim = std::make_shared<ops::Scale>();
|
||||
if (MoveAttrMap(cnode, dst_prim) != RET_OK) {
|
||||
MS_LOG(ERROR) << "ScaleFusion deparser failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
REGISTER_PRIMITIVE_DEPARSER(kNameScaleFusion, ScaleFusionDeparser)
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,36 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef ACL_DEPARSER_PRIMITIVE_SCALEFUSION_DEPARSER_H
|
||||
#define ACL_DEPARSER_PRIMITIVE_SCALEFUSION_DEPARSER_H
|
||||
|
||||
#include "tools/converter/acl/deparser/primitive_deparser.h"
|
||||
#include "ops/fusion/scale_fusion.h"
|
||||
|
||||
using mindspore::ops::kNameScaleFusion;
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class ScaleFusionDeparser : public PrimitiveDeparser {
|
||||
public:
|
||||
ScaleFusionDeparser() : PrimitiveDeparser(kNameScaleFusion) {}
|
||||
~ScaleFusionDeparser() override = default;
|
||||
|
||||
STATUS Deparser(const CNodePtr &cnode) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // ACL_DEPARSER_PRIMITIVE_SCALEFUSION_DEPARSER_H
|
|
@ -0,0 +1,152 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/acl/deparser/spatial_node_adapter.h"
|
||||
#include <vector>
|
||||
#include <set>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "tools/converter/acl/common/utils.h"
|
||||
#include "tools/converter/ops/ops_def.h"
|
||||
#include "tools/common/tensor_util.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "base/base.h"
|
||||
#include "base/core_ops.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "ops/concat.h"
|
||||
#include "ops/batch_norm.h"
|
||||
#include "ops/fused_batch_norm.h"
|
||||
#include "ops/stack.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
namespace {
|
||||
constexpr size_t kCnodeInputMinNum = 2;
|
||||
constexpr auto kAnfPrimitiveIndex = 0;
|
||||
constexpr auto kNamewiEltwise = "Eltwise";
|
||||
const std::set<std::string> kCNodeWithMultiOutputs = {ops::kNameBatchNorm, ops::kNameFusedBatchNorm};
|
||||
const std::set<std::string> kCNodeWithDynamicInput = {kNamewiEltwise, ops::kNameConcat, ops::kNameStack};
|
||||
} // namespace
|
||||
|
||||
CNodePtr CreateTupleGetItemNode(const FuncGraphPtr &func_graph, const CNodePtr &input_cnode) {
|
||||
CNodePtr get_item_cnode = nullptr;
|
||||
auto tuple_get_item_prim_ptr = std::make_shared<lite::TupleGetItem>();
|
||||
if (tuple_get_item_prim_ptr == nullptr) {
|
||||
MS_LOG(ERROR) << "New TupleGetItem failed";
|
||||
return nullptr;
|
||||
}
|
||||
auto tuple_get_item_prim = NewValueNode(tuple_get_item_prim_ptr);
|
||||
auto get_item_value = NewValueNode(MakeValue<int64_t>(0));
|
||||
AnfNodePtrList inputs{tuple_get_item_prim, input_cnode, get_item_value};
|
||||
get_item_cnode = func_graph->NewCNode(inputs);
|
||||
if (get_item_cnode == nullptr) {
|
||||
MS_LOG(ERROR) << "New get item cnode failed.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::vector<int64_t> shape;
|
||||
if (acl::GetShapeVectorFromCNode(input_cnode, &shape) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Get shape failed.";
|
||||
return nullptr;
|
||||
}
|
||||
TypeId type = acl::GetTypeFromNode(input_cnode);
|
||||
auto get_item_abstract = CreateTensorAbstract(shape, type);
|
||||
if (get_item_abstract == nullptr) {
|
||||
MS_LOG(ERROR) << "Create tensor abstract failed.";
|
||||
return nullptr;
|
||||
}
|
||||
get_item_cnode->set_abstract(get_item_abstract);
|
||||
get_item_cnode->set_fullname_with_scope(input_cnode->fullname_with_scope() + "_getitem");
|
||||
return get_item_cnode;
|
||||
}
|
||||
|
||||
static STATUS AdapteNodeWithMultiOutputs(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
||||
const FuncGraphManagerPtr &manager) {
|
||||
std::string cnode_func_name = GetCNodeFuncName(cnode);
|
||||
if (cnode_func_name == prim::kTupleGetItem || cnode_func_name == kNameReturn) {
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
|
||||
auto input = cnode->input(i);
|
||||
if (!utils::isa<CNode>(input)) {
|
||||
continue;
|
||||
}
|
||||
auto input_cnode = input->cast<CNodePtr>();
|
||||
std::string input_func_name = GetCNodeFuncName(input_cnode);
|
||||
if (kCNodeWithMultiOutputs.find(input_func_name) != kCNodeWithMultiOutputs.end()) {
|
||||
MS_LOG(INFO) << "Adapter cnode with multioutputs: " << cnode_func_name;
|
||||
CNodePtr get_item_cnode = CreateTupleGetItemNode(func_graph, input_cnode);
|
||||
if (get_item_cnode == nullptr) {
|
||||
MS_LOG(ERROR) << "Create tuple item for " << cnode_func_name << " failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
if (!manager->Replace(input_cnode, get_item_cnode)) {
|
||||
MS_LOG(ERROR) << "Replace " << cnode_func_name << " failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
}
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
static STATUS AdapteNodeWithDynamicInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
||||
std::string cnode_func_name = GetCNodeFuncName(cnode);
|
||||
if (kCNodeWithDynamicInput.find(cnode_func_name) == kCNodeWithDynamicInput.end()) {
|
||||
return lite::RET_OK;
|
||||
}
|
||||
MS_LOG(INFO) << "Adapter cnode with dynamic input: " << cnode_func_name;
|
||||
auto make_tuple_val_node = NewValueNode(prim::kPrimMakeTuple);
|
||||
if (make_tuple_val_node == nullptr) {
|
||||
MS_LOG(ERROR) << "New make tuple val node failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
AnfNodePtrList new_inputs = {make_tuple_val_node};
|
||||
auto cnode_inputs = cnode->inputs();
|
||||
if (cnode_inputs.size() >= kCnodeInputMinNum) {
|
||||
new_inputs.insert(new_inputs.end(), cnode_inputs.begin() + 1, cnode_inputs.end());
|
||||
}
|
||||
auto make_tuple_cnode = func_graph->NewCNode(new_inputs);
|
||||
if (make_tuple_cnode == nullptr) {
|
||||
MS_LOG(ERROR) << "New make tuple cnode failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
|
||||
const std::vector<AnfNodePtr> replace_node = {cnode_inputs[0], make_tuple_cnode};
|
||||
cnode->set_inputs(replace_node);
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
STATUS AdapteSpatialNode(const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &manager) {
|
||||
auto cnodes = func_graph->GetOrderedCnodes();
|
||||
for (const auto &cnode : cnodes) {
|
||||
if (cnode == nullptr) {
|
||||
MS_LOG(ERROR) << "Cnode is nullptr.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
if (AdapteNodeWithMultiOutputs(func_graph, cnode, manager) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Adapter node with multioutput failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
if (AdapteNodeWithDynamicInput(func_graph, cnode) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Adapter node with dynamic input failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,29 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef ACL_DEPARSER_SPATIAL_NODE_ADAPTER_PASS_H
|
||||
#define ACL_DEPARSER_SPATIAL_NODE_ADAPTER_PASS_H
|
||||
|
||||
#include "ir/func_graph.h"
|
||||
#include "include/errorcode.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS AdapteSpatialNode(const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &manager);
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // ACL_DEPARSER_SPATIAL_NODE_ADAPTER_PASS_H
|
|
@ -0,0 +1,53 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/acl/deparser/stack_deparser.h"
|
||||
#include "tools/converter/acl/deparser/primitive_deparser_register.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
namespace {
|
||||
constexpr auto kNameNum = "num";
|
||||
}
|
||||
|
||||
STATUS StackDeparser::Deparser(const CNodePtr &cnode) {
|
||||
if (AddAttrForDynInputPrimitive(cnode) != RET_OK) {
|
||||
MS_LOG(ERROR) << "Stack deparser failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS StackDeparser::AddAttrForDynInputPrimitive(const CNodePtr &cnode) {
|
||||
MS_ASSERT(cnode != nullptr);
|
||||
auto value_node = cnode->input(0)->cast<ValueNodePtr>();
|
||||
MS_ASSERT(value_node != nullptr);
|
||||
auto prim = GetValueNode<PrimitivePtr>(value_node);
|
||||
if (prim == nullptr) {
|
||||
MS_LOG(ERROR) << "Value node is invalid.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
// add attr input num for dynamic input op
|
||||
int64_t num = static_cast<int64_t>(cnode->size());
|
||||
if (num > 1) {
|
||||
prim->AddAttr(kNameNum, MakeValue(num - 1));
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
REGISTER_PRIMITIVE_DEPARSER(kNameStack, StackDeparser)
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,40 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef ACL_DEPARSER_PRIMITIVE_STACK_DEPARSER_H
|
||||
#define ACL_DEPARSER_PRIMITIVE_STACK_DEPARSER_H
|
||||
|
||||
#include "tools/converter/acl/deparser/primitive_deparser.h"
|
||||
#include "ops/stack.h"
|
||||
|
||||
using mindspore::ops::kNameStack;
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class StackDeparser : public PrimitiveDeparser {
|
||||
public:
|
||||
StackDeparser() : PrimitiveDeparser(kNameStack) {}
|
||||
|
||||
~StackDeparser() override = default;
|
||||
|
||||
STATUS Deparser(const CNodePtr &cnode) override;
|
||||
|
||||
private:
|
||||
STATUS AddAttrForDynInputPrimitive(const CNodePtr &cnode);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // ACL_DEPARSER_PRIMITIVE_STACK_DEPARSER_H
|
|
@ -0,0 +1,47 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/acl/deparser/stridedslice_deparser.h"
|
||||
#include <memory>
|
||||
#include "tools/converter/acl/deparser/primitive_deparser_register.h"
|
||||
#include "tools/converter/acl/deparser/tbe_op_def.h"
|
||||
#include "include/registry/parser_context.h"
|
||||
#include "ops/op_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS StridedSliceDeparser::Deparser(const CNodePtr &cnode) {
|
||||
ValueNodePtr value_node = nullptr;
|
||||
PrimitivePtr src_prim = nullptr;
|
||||
if (GetValueNodeAndPrimFromCnode(cnode, &value_node, &src_prim) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Get value node and primitive from cnode failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
|
||||
auto attr_val = src_prim->GetAttr(ops::kFmkType);
|
||||
int fmk_type = attr_val != nullptr ? GetValue<int>(attr_val) : converter::kFmkTypeTf;
|
||||
if (fmk_type == converter::kFmkTypeOnnx) {
|
||||
auto dst_prim = std::make_shared<acl::StridedSliceV2>();
|
||||
MS_ASSERT(dst_prim != nullptr);
|
||||
dst_prim->SetAttrs(src_prim->attrs());
|
||||
value_node->set_value(dst_prim);
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
REGISTER_PRIMITIVE_DEPARSER(kNameStridedSlice, StridedSliceDeparser)
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,36 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef ACL_DEPARSER_PRIMITIVE_STRIDEDSLICE_DEPARSER_H
|
||||
#define ACL_DEPARSER_PRIMITIVE_STRIDEDSLICE_DEPARSER_H
|
||||
|
||||
#include "tools/converter/acl/deparser/primitive_deparser.h"
|
||||
#include "ops/strided_slice.h"
|
||||
|
||||
using mindspore::ops::kNameStridedSlice;
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class StridedSliceDeparser : public PrimitiveDeparser {
|
||||
public:
|
||||
StridedSliceDeparser() : PrimitiveDeparser(kNameStridedSlice) {}
|
||||
~StridedSliceDeparser() override = default;
|
||||
|
||||
STATUS Deparser(const CNodePtr &cnode) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // ACL_DEPARSER_PRIMITIVE_STRIDEDSLICE_DEPARSER_H
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue