Ascend310 infer

This commit is contained in:
zhengyuanhua 2021-08-17 16:17:58 +08:00
parent 111d1a9a61
commit c8131ef8c4
114 changed files with 4973 additions and 102 deletions

View File

@ -1,5 +1,9 @@
message(STATUS "Compiling GraphEngine") message(STATUS "Compiling GraphEngine")
set(GE_SOURCE_DIR ${CMAKE_SOURCE_DIR}/graphengine) if(NOT(BUILD_LITE))
set(GE_SOURCE_DIR ${CMAKE_SOURCE_DIR}/graphengine)
else()
set(GE_SOURCE_DIR ${CMAKE_SOURCE_DIR}/../../graphengine)
endif()
message(STATUS "[ME] build_path: ${BUILD_PATH}") message(STATUS "[ME] build_path: ${BUILD_PATH}")
@ -45,7 +49,11 @@ if(ENABLE_TESTCASES)
set(ENABLE_GITEE ${_ge_tmp_ENABLE_GITEE}) set(ENABLE_GITEE ${_ge_tmp_ENABLE_GITEE})
set(CMAKE_CXX_FLAGS ${_ge_tmp_CMAKE_CXX_FLAGS}) set(CMAKE_CXX_FLAGS ${_ge_tmp_CMAKE_CXX_FLAGS})
elseif(MODE_ASCEND_ALL OR MODE_ASCEND_ACL) elseif(MODE_ASCEND_ALL OR MODE_ASCEND_ACL)
file(GLOB_RECURSE GE_PROTO_FILE RELATIVE ${CMAKE_SOURCE_DIR} "graphengine/metadef/proto/*.proto") if(NOT(BUILD_LITE))
file(GLOB_RECURSE GE_PROTO_FILE RELATIVE ${CMAKE_SOURCE_DIR} "graphengine/metadef/proto/*.proto")
else()
file(GLOB_RECURSE GE_PROTO_FILE ${TOP_DIR}/graphengine/metadef/proto/*.proto)
endif()
set(TMP_FILE_NAME_LIST) set(TMP_FILE_NAME_LIST)
foreach(file ${GE_PROTO_FILE}) foreach(file ${GE_PROTO_FILE})
get_filename_component(file_name ${file} NAME_WE) get_filename_component(file_name ${file} NAME_WE)

View File

@ -7,13 +7,14 @@ if(BUILD_LITE)
else() else()
set(glog_CXXFLAGS "-D_FORTIFY_SOURCE=2 -O2 ${SECURE_CXX_FLAGS} -Dgoogle=mindspore_private") set(glog_CXXFLAGS "-D_FORTIFY_SOURCE=2 -O2 ${SECURE_CXX_FLAGS} -Dgoogle=mindspore_private")
set(glog_CFLAGS "-D_FORTIFY_SOURCE=2 -O2") set(glog_CFLAGS "-D_FORTIFY_SOURCE=2 -O2")
if(NOT ENABLE_GLIBCXX)
set(glog_CXXFLAGS "${glog_CXXFLAGS} -D_GLIBCXX_USE_CXX11_ABI=0")
endif()
set(glog_patch ${CMAKE_SOURCE_DIR}/third_party/patch/glog/glog.patch001) set(glog_patch ${CMAKE_SOURCE_DIR}/third_party/patch/glog/glog.patch001)
set(glog_lib mindspore_glog) set(glog_lib mindspore_glog)
endif() endif()
if(NOT ENABLE_GLIBCXX)
set(glog_CXXFLAGS "${glog_CXXFLAGS} -D_GLIBCXX_USE_CXX11_ABI=0")
endif()
if(ENABLE_GITEE) if(ENABLE_GITEE)
set(REQ_URL "https://gitee.com/mirrors/glog/repository/archive/v0.4.0.tar.gz") set(REQ_URL "https://gitee.com/mirrors/glog/repository/archive/v0.4.0.tar.gz")
set(MD5 "22fe340ddc231e6c8e46bc295320f8ee") set(MD5 "22fe340ddc231e6c8e46bc295320f8ee")

View File

@ -2,6 +2,9 @@ set(protobuf_USE_STATIC_LIBS ON)
if(BUILD_LITE) if(BUILD_LITE)
set(protobuf_CXXFLAGS "-fstack-protector-all -Wno-maybe-uninitialized -Wno-unused-parameter \ set(protobuf_CXXFLAGS "-fstack-protector-all -Wno-maybe-uninitialized -Wno-unused-parameter \
-fPIC -fvisibility=hidden -D_FORTIFY_SOURCE=2 -O2") -fPIC -fvisibility=hidden -D_FORTIFY_SOURCE=2 -O2")
if(ENABLE_ACL)
set(protobuf_CXXFLAGS "${protobuf_CXXFLAGS} -D_GLIBCXX_USE_CXX11_ABI=0")
endif()
else() else()
if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin") if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
set(protobuf_CXXFLAGS "-fstack-protector-all -Wno-uninitialized -Wno-unused-parameter -fPIC \ set(protobuf_CXXFLAGS "-fstack-protector-all -Wno-uninitialized -Wno-unused-parameter -fPIC \

View File

@ -515,6 +515,11 @@ else()
install(FILES ${glog_LIBPATH}/libglog.so.0.4.0 DESTINATION ${CONVERTER_ROOT_DIR}/lib RENAME libglog.so install(FILES ${glog_LIBPATH}/libglog.so.0.4.0 DESTINATION ${CONVERTER_ROOT_DIR}/lib RENAME libglog.so
COMPONENT ${RUNTIME_COMPONENT_NAME}) COMPONENT ${RUNTIME_COMPONENT_NAME})
endif() endif()
if(MSLITE_ENABLE_ACL)
set(LITE_ACL_DIR ${TOP_DIR}/mindspore/lite/build/tools/converter/acl)
install(FILES ${LITE_ACL_DIR}/mindspore_shared_lib/libmindspore_shared_lib.so
DESTINATION ${CONVERTER_ROOT_DIR}/lib COMPONENT ${RUNTIME_COMPONENT_NAME})
endif()
__install_micro_wrapper() __install_micro_wrapper()
__install_micro_codegen() __install_micro_codegen()
endif() endif()

View File

@ -1,8 +1,15 @@
# build mindspore_shared_lib # build mindspore_shared_lib
set(LOAD_MINDIR_SRC if(NOT(BUILD_LITE))
${CMAKE_SOURCE_DIR}/mindspore/core/load_mindir/load_model.cc set(LOAD_MINDIR_SRC
${CMAKE_SOURCE_DIR}/mindspore/core/load_mindir/anf_model_parser.cc ${CMAKE_SOURCE_DIR}/mindspore/core/load_mindir/load_model.cc
) ${CMAKE_SOURCE_DIR}/mindspore/core/load_mindir/anf_model_parser.cc
)
else()
set(MS_UTILS_SRC
${CMAKE_CURRENT_SOURCE_DIR}/../../../mindspore/ccsrc/utils/config_manager.cc
)
endif()
file(GLOB_RECURSE API_OPS_SRC ${CMAKE_CURRENT_SOURCE_DIR} "ops/*.cc") file(GLOB_RECURSE API_OPS_SRC ${CMAKE_CURRENT_SOURCE_DIR} "ops/*.cc")
if(ENABLE_D OR ENABLE_ACL) if(ENABLE_D OR ENABLE_ACL)
@ -12,11 +19,15 @@ if(ENABLE_D OR ENABLE_ACL)
include_directories(${CMAKE_BINARY_DIR}/proto/ge) include_directories(${CMAKE_BINARY_DIR}/proto/ge)
file(GLOB_RECURSE API_ACL_SRC ${CMAKE_CURRENT_SOURCE_DIR} file(GLOB_RECURSE API_ACL_SRC ${CMAKE_CURRENT_SOURCE_DIR}
"akg_kernel_register.cc"
"model/acl/*.cc" "model/acl/*.cc"
"model/model_converter_utils/*.cc" "model/model_converter_utils/*.cc"
"graph/acl/*.cc" "graph/acl/*.cc"
) )
if(NOT(BUILD_LITE))
list(APPEND API_ACL_SRC "akg_kernel_register.cc")
endif()
if(NOT ENABLE_D) if(NOT ENABLE_D)
list(APPEND API_ACL_SRC $<TARGET_OBJECTS:_mindspore_transform_graph_ir_obj>) list(APPEND API_ACL_SRC $<TARGET_OBJECTS:_mindspore_transform_graph_ir_obj>)
endif() endif()
@ -44,10 +55,13 @@ set(MSLIB_SRC ${CMAKE_CURRENT_SOURCE_DIR}/types.cc
${API_MS_INFER_SRC} ${API_MS_INFER_SRC}
${API_ACL_SRC} ${API_ACL_SRC}
${API_OPS_SRC} ${API_OPS_SRC}
${LOAD_MINDIR_SRC}) ${LOAD_MINDIR_SRC}
${MS_UTILS_SRC})
add_library(mindspore_shared_lib SHARED ${MSLIB_SRC}) add_library(mindspore_shared_lib SHARED ${MSLIB_SRC})
set_target_properties(mindspore_shared_lib PROPERTIES OUTPUT_NAME mindspore) if(NOT(BUILD_LITE))
set_target_properties(mindspore_shared_lib PROPERTIES OUTPUT_NAME mindspore)
endif()
if(CMAKE_SYSTEM_NAME MATCHES "Darwin") if(CMAKE_SYSTEM_NAME MATCHES "Darwin")
target_link_libraries(mindspore_shared_lib PRIVATE ${PYTHON_LIBRARIES} ${SECUREC_LIBRARY} target_link_libraries(mindspore_shared_lib PRIVATE ${PYTHON_LIBRARIES} ${SECUREC_LIBRARY}
@ -58,8 +72,12 @@ else()
-Wl,--whole-archive mindspore -Wl,--no-whole-archive mindspore_core proto_input mindspore_gvar -Wl,--whole-archive mindspore -Wl,--no-whole-archive mindspore_core proto_input mindspore_gvar
mindspore::protobuf) mindspore::protobuf)
else() else()
target_link_libraries(mindspore_shared_lib PRIVATE ${PYTHON_LIBRARIES} ${SECUREC_LIBRARY} if(NOT(BUILD_LITE))
mindspore mindspore_core proto_input mindspore_gvar mindspore::protobuf) target_link_libraries(mindspore_shared_lib PRIVATE ${PYTHON_LIBRARIES} ${SECUREC_LIBRARY}
mindspore mindspore_core proto_input mindspore_gvar mindspore::protobuf)
else()
target_link_libraries(mindspore_shared_lib PRIVATE ${SECUREC_LIBRARY})
endif()
endif() endif()
endif() endif()

View File

@ -199,7 +199,7 @@ void Ascend310DeviceInfo::SetDumpConfigPath(const std::vector<char> &cfg_path) {
} }
std::vector<char> Ascend310DeviceInfo::GetDumpConfigPathChar() const { std::vector<char> Ascend310DeviceInfo::GetDumpConfigPathChar() const {
MS_EXCEPTION_IF_NULL(data_); MS_EXCEPTION_IF_NULL(data_);
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310DeviceID); const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310DumpCfgPath);
return StringToChar(ref); return StringToChar(ref);
} }

View File

@ -33,7 +33,7 @@ class AclModelOptions {
std::string GenAclOptionsKey() const; std::string GenAclOptionsKey() const;
uint32_t GetDeviceID() const { return device_id_; } uint32_t GetDeviceID() const { return device_id_; }
std::string GetDumpCfgPath() const { return dump_cfg_path_; } std::string GetDumpCfgPath() const { return dump_cfg_path_; }
void RenameInput(const std::vector<std::string> &); void RenameInput(const std::vector<std::string> &name);
// return tuple<init_options, build_options> // return tuple<init_options, build_options>
std::tuple<std::map<std::string, std::string>, std::map<std::string, std::string>> GenAclOptions() const; std::tuple<std::map<std::string, std::string>, std::map<std::string, std::string>> GenAclOptions() const;

View File

@ -1474,6 +1474,65 @@ void DfGraphConvertor::ConvertTopK(const CNodePtr node) {
op_cache_[value_ptr.get()] = op; op_cache_[value_ptr.get()] = op;
} }
std::vector<int64_t> DfGraphConvertor::CastToInt(const ValuePtr &value) {
if (value == nullptr) {
MS_LOG(WARNING) << "Value ptr is nullptr.";
return {};
}
std::vector<int64_t> cur_value = {};
if (utils::isa<ValueSequeuePtr>(value)) {
auto val_seq_ptr = value->cast<ValueSequeuePtr>();
MS_EXCEPTION_IF_NULL(val_seq_ptr);
if (!val_seq_ptr->value().empty()) {
auto first_val = val_seq_ptr->value().front();
MS_EXCEPTION_IF_NULL(first_val);
MS_EXCEPTION_IF_NULL(first_val->type());
if (first_val->type()->number_type() == kNumberTypeInt64) {
cur_value = GetValue<std::vector<int64_t>>(value);
} else {
auto origin_value = GetValue<std::vector<int>>(value);
std::transform(origin_value.begin(), origin_value.end(), std::back_inserter(cur_value),
[](int index) { return static_cast<int64_t>(index); });
}
}
} else {
MS_EXCEPTION_IF_NULL(value->type());
if (value->type()->number_type() == kNumberTypeInt64) {
cur_value.push_back(GetValue<int64_t>(value));
} else {
cur_value.push_back(static_cast<int64_t>(GetValue<int>(value)));
}
}
return cur_value;
}
void DfGraphConvertor::ConvertReshape(const CNodePtr node) {
MS_LOG(INFO) << "Convert the second input of reshape to op attr.";
const auto kInputNum = 3;
if (node->size() < kInputNum) {
MS_LOG(WARNING) << "Reshape must have two inputs.";
return;
}
OpAdapterPtr adpt = FindAdapter(node, training_);
if (adpt == nullptr) {
return;
}
auto op = adpt->generate(node);
MS_EXCEPTION_IF_NULL(op);
// get shape form attr
auto value_node = node->input(0)->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
MS_EXCEPTION_IF_NULL(value_node->value());
auto primitive = value_node->value()->cast<PrimitivePtr>();
MS_EXCEPTION_IF_NULL(primitive);
auto value = primitive->GetAttr("shape");
std::vector<int64_t> list;
list = CastToInt(value);
op->SetAttr("shape", list);
op_cache_[node.get()] = op;
}
AnfNodePtr DfGraphConvertor::TraceTupleGetItem(const CNodePtr &node, uint64_t *index) { AnfNodePtr DfGraphConvertor::TraceTupleGetItem(const CNodePtr &node, uint64_t *index) {
const int TUPLE_GET_ITEM_INDEX = 2; const int TUPLE_GET_ITEM_INDEX = 2;
if (node->inputs().size() < 3) { // "tuple_getitem" primitive must have 3 inputs if (node->inputs().size() < 3) { // "tuple_getitem" primitive must have 3 inputs
@ -1660,6 +1719,12 @@ bool DfGraphConvertor::CheckCNode(const std::string &name, const CNodePtr node)
return true; return true;
} }
// Convert Reshape add const input to attr(shape)
if (name == prim::kPrimReshape->name()) {
ConvertReshape(node);
return true;
}
// make_tuple is used for a dynamic_input, convert it to a vector of OutHandlers // make_tuple is used for a dynamic_input, convert it to a vector of OutHandlers
if (name == prim::kPrimMakeTuple->name()) { if (name == prim::kPrimMakeTuple->name()) {
ConvertMakeTuple(node); ConvertMakeTuple(node);

View File

@ -158,6 +158,8 @@ class DfGraphConvertor {
void ConvertTupleGetItem(const CNodePtr node); void ConvertTupleGetItem(const CNodePtr node);
void ConvertMakeTuple(const CNodePtr node); void ConvertMakeTuple(const CNodePtr node);
void ConvertTopK(const CNodePtr node); void ConvertTopK(const CNodePtr node);
void ConvertReshape(const CNodePtr node);
std::vector<int64_t> CastToInt(const ValuePtr &value);
bool CheckCNode(const std::string &name, const CNodePtr node); bool CheckCNode(const std::string &name, const CNodePtr node);
void TraceOutput(AnfNodePtr node); void TraceOutput(AnfNodePtr node);
void TraceOutputFromParameter(const AnfNodePtr &anf_out); void TraceOutputFromParameter(const AnfNodePtr &anf_out);

View File

@ -18,8 +18,10 @@
#include <sstream> #include <sstream>
#ifndef ENABLE_LITE_ACL
#include "pipeline/jit/parse/python_adapter.h" #include "pipeline/jit/parse/python_adapter.h"
#include "pipeline/jit/pipeline.h" #include "pipeline/jit/pipeline.h"
#endif
#ifndef NO_DLIB #ifndef NO_DLIB
#include "tdt/tsd_client.h" #include "tdt/tsd_client.h"
#endif #endif
@ -37,11 +39,13 @@ DfGraphManager::DfGraphManager() {
} }
DfGraphManager::~DfGraphManager() { DfGraphManager::~DfGraphManager() {
// in python fisrt destroy after atexit but in c++ destoy before atexit // in python first destroy after atexit but in c++ destoy before atexit
DeleteGraphRunner(); DeleteGraphRunner();
DeleteGeSession(); DeleteGeSession();
ClearGraph(); ClearGraph();
#ifndef ENABLE_LITE_ACL
parse::python_adapter::set_python_env_flag(false); parse::python_adapter::set_python_env_flag(false);
#endif
} }
DfGraphManager &DfGraphManager::GetInstance() { DfGraphManager &DfGraphManager::GetInstance() {

View File

@ -19,7 +19,9 @@
#include <string> #include <string>
#include <memory> #include <memory>
#ifndef ENABLE_LITE_ACL
#include "pybind11/pybind11.h" #include "pybind11/pybind11.h"
#endif
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
#include "utils/config_manager.h" #include "utils/config_manager.h"
#include "sys/time.h" #include "sys/time.h"
@ -40,9 +42,9 @@ Session::Session(const std::map<std::string, std::string> &options) {
Session::~Session() {} Session::~Session() {}
} // namespace ge } // namespace ge
#endif #endif
#ifndef ENABLE_LITE_ACL
namespace py = pybind11; namespace py = pybind11;
#endif
namespace mindspore { namespace mindspore {
namespace transform { namespace transform {
std::shared_ptr<ge::Session> GraphRunner::NewSession(const SessionOptions &sess_options) { std::shared_ptr<ge::Session> GraphRunner::NewSession(const SessionOptions &sess_options) {
@ -189,7 +191,9 @@ Status GraphRunner::RunGraph(const RunOptions &options, const std::vector<MeTens
Status ret; Status ret;
{ {
// Release GIL before calling into (potentially long-running) C++ code // Release GIL before calling into (potentially long-running) C++ code
#ifndef ENABLE_LITE_ACL
py::gil_scoped_release release; py::gil_scoped_release release;
#endif
ret = RunGraph(options, ge_inputs, &ge_outputs); ret = RunGraph(options, ge_inputs, &ge_outputs);
} }
if (ret != Status::SUCCESS) { if (ret != Status::SUCCESS) {

View File

@ -313,6 +313,21 @@ constexpr const char kNameCTCGreedyDecoder[] = "CTCGreedyDecoder";
constexpr const char kNameReverseV2[] = "ReverseV2"; constexpr const char kNameReverseV2[] = "ReverseV2";
constexpr const char kNameLambApplyWeightAssign[] = "LambApplyWeightAssign"; constexpr const char kNameLambApplyWeightAssign[] = "LambApplyWeightAssign";
constexpr const char kNameLambApplyOptimizerAssign[] = "LambApplyOptimizerAssign"; constexpr const char kNameLambApplyOptimizerAssign[] = "LambApplyOptimizerAssign";
constexpr const char kNameScale[] = "Scale";
constexpr const char kNameEltwise[] = "Eltwise";
constexpr const char kNameFullConnection[] = "FullConnection";
constexpr const char kNameFusedBatchNorm[] = "FusedBatchNorm";
constexpr const char kNamePooling[] = "Pooling";
constexpr const char kNameMaxPoolV3[] = "MaxPoolV3";
constexpr const char kNameAvgPoolV2[] = "AvgPoolV2";
constexpr const char kNameShape[] = "Shape";
constexpr const char kNameGather[] = "Gather";
constexpr const char kNameUnsqueeze[] = "Unsqueeze";
constexpr const char kNamePadV3[] = "PadV3";
constexpr const char kNameGlobalAvgPool[] = "GlobalAveragePool";
constexpr const char kNameStridedSliceV2[] = "StridedSliceV2";
constexpr const char kNameBNInference[] = "BNInference";
constexpr const char kNameDeconvolution[] = "Deconvolution";
class OpAdapterMap { class OpAdapterMap {
public: public:

View File

@ -43,6 +43,12 @@ INPUT_MAP(Data) = EMPTY_INPUT_MAP;
ATTR_MAP(Data) = EMPTY_ATTR_MAP; ATTR_MAP(Data) = EMPTY_ATTR_MAP;
REG_ADPT_DESC(Data, kNameParam, ADPT_DESC(Data)) REG_ADPT_DESC(Data, kNameParam, ADPT_DESC(Data))
// Shape
INPUT_MAP(Shape) = {{1, INPUT_DESC(x)}};
ATTR_MAP(Shape) = EMPTY_ATTR_MAP;
OUTPUT_MAP(Shape) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(Shape, kNameShape, ADPT_DESC(Shape))
// Reshape // Reshape
INPUT_MAP(Reshape) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(shape)}}; INPUT_MAP(Reshape) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(shape)}};
ATTR_MAP(Reshape) = EMPTY_ATTR_MAP; ATTR_MAP(Reshape) = EMPTY_ATTR_MAP;
@ -95,4 +101,10 @@ INPUT_MAP(EditDistance) = {{1, INPUT_DESC(hypothesis_indices)}, {2, INPUT_DESC(h
ATTR_MAP(EditDistance) = {{"normalize", ATTR_DESC(normalize, AnyTraits<bool>())}}; ATTR_MAP(EditDistance) = {{"normalize", ATTR_DESC(normalize, AnyTraits<bool>())}};
OUTPUT_MAP(EditDistance) = {{0, OUTPUT_DESC(output)}}; OUTPUT_MAP(EditDistance) = {{0, OUTPUT_DESC(output)}};
REG_ADPT_DESC(EditDistance, kNameEditDistance, ADPT_DESC(EditDistance)) REG_ADPT_DESC(EditDistance, kNameEditDistance, ADPT_DESC(EditDistance))
// Unsqueeze
INPUT_MAP(Unsqueeze) = {{1, INPUT_DESC(x)}};
ATTR_MAP(Unsqueeze) = {{"axis", ATTR_DESC(axes, AnyTraits<int64_t>(), AnyTraits<std::vector<int64_t>>())}};
OUTPUT_MAP(Unsqueeze) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(Unsqueeze, kNameUnsqueeze, ADPT_DESC(Unsqueeze))
} // namespace mindspore::transform } // namespace mindspore::transform

View File

@ -23,6 +23,9 @@
#include "ops/array_ops.h" #include "ops/array_ops.h"
namespace mindspore::transform { namespace mindspore::transform {
DECLARE_OP_ADAPTER(Shape)
DECLARE_OP_USE_OUTPUT(Shape)
DECLARE_OP_ADAPTER(Reshape) DECLARE_OP_ADAPTER(Reshape)
DECLARE_OP_USE_OUTPUT(Reshape) DECLARE_OP_USE_OUTPUT(Reshape)
@ -57,5 +60,8 @@ DECLARE_OP_USE_OUTPUT(ReverseSequence)
DECLARE_OP_ADAPTER(EditDistance) DECLARE_OP_ADAPTER(EditDistance)
DECLARE_OP_USE_OUTPUT(EditDistance) DECLARE_OP_USE_OUTPUT(EditDistance)
DECLARE_OP_ADAPTER(Unsqueeze)
DECLARE_OP_USE_OUTPUT(Unsqueeze)
} // namespace mindspore::transform } // namespace mindspore::transform
#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_ARRAY_OPS_DECLARE_H_ #endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_ARRAY_OPS_DECLARE_H_

View File

@ -637,4 +637,13 @@ INPUT_MAP(LambApplyWeightAssign) = {{1, INPUT_DESC(input0)},
ATTR_MAP(LambApplyWeightAssign) = EMPTY_ATTR_MAP; ATTR_MAP(LambApplyWeightAssign) = EMPTY_ATTR_MAP;
OUTPUT_MAP(LambApplyWeightAssign) = {{0, OUTPUT_DESC(input_param)}}; OUTPUT_MAP(LambApplyWeightAssign) = {{0, OUTPUT_DESC(input_param)}};
REG_ADPT_DESC(LambApplyWeightAssign, kNameLambApplyWeightAssign, ADPT_DESC(LambApplyWeightAssign)) REG_ADPT_DESC(LambApplyWeightAssign, kNameLambApplyWeightAssign, ADPT_DESC(LambApplyWeightAssign))
// Eltwise
INPUT_MAP(Eltwise) = EMPTY_INPUT_MAP;
DYN_INPUT_MAP(Eltwise) = {{1, DYN_INPUT_DESC(x)}};
ATTR_MAP(Eltwise) = {{"n", ATTR_DESC(N, AnyTraits<int64_t>())},
{"mode", ATTR_DESC(mode, AnyTraits<int64_t>())},
{"coeff", ATTR_DESC(coeff, AnyTraits<std::vector<float>>(), AnyTraits<float>())}};
OUTPUT_MAP(Eltwise) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(Eltwise, kNameEltwise, ADPT_DESC(Eltwise))
} // namespace mindspore::transform } // namespace mindspore::transform

View File

@ -316,5 +316,8 @@ DECLARE_OP_USE_OUTPUT(LambApplyOptimizerAssign)
DECLARE_OP_ADAPTER(LambApplyWeightAssign) DECLARE_OP_ADAPTER(LambApplyWeightAssign)
DECLARE_OP_USE_OUTPUT(LambApplyWeightAssign) DECLARE_OP_USE_OUTPUT(LambApplyWeightAssign)
DECLARE_OP_ADAPTER(Eltwise)
DECLARE_OP_USE_OUTPUT(Eltwise)
} // namespace mindspore::transform } // namespace mindspore::transform
#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_ELEWISE_CALCULATION_OPS_DECLARE_H_ #endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_ELEWISE_CALCULATION_OPS_DECLARE_H_

View File

@ -133,4 +133,15 @@ INPUT_MAP(L2Loss) = {{1, INPUT_DESC(x)}};
ATTR_MAP(L2Loss) = EMPTY_ATTR_MAP; ATTR_MAP(L2Loss) = EMPTY_ATTR_MAP;
OUTPUT_MAP(L2Loss) = {{0, OUTPUT_DESC(y)}}; OUTPUT_MAP(L2Loss) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(L2Loss, kNameL2Loss, ADPT_DESC(L2Loss)) REG_ADPT_DESC(L2Loss, kNameL2Loss, ADPT_DESC(L2Loss))
// FullyConnection
INPUT_MAP(FullyConnection) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(w)}, {3, INPUT_DESC(b)}, {4, INPUT_DESC(offset_w)}};
ATTR_MAP(FullyConnection) = {{"num_output", ATTR_DESC(num_output, AnyTraits<int64_t>())},
{"transpose", ATTR_DESC(transpose, AnyTraits<bool>())},
{"axis", ATTR_DESC(axis, AnyTraits<int64_t>())},
{"offset_x", ATTR_DESC(offset_x, AnyTraits<int64_t>())}};
OUTPUT_MAP(FullyConnection) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(FullyConnection, kNameFullConnection, ADPT_DESC(FullyConnection))
} // namespace mindspore::transform } // namespace mindspore::transform

View File

@ -79,5 +79,8 @@ DECLARE_OP_USE_OUTPUT(DiagPart)
DECLARE_OP_ADAPTER(L2Loss) DECLARE_OP_ADAPTER(L2Loss)
DECLARE_OP_USE_OUTPUT(L2Loss) DECLARE_OP_USE_OUTPUT(L2Loss)
DECLARE_OP_ADAPTER(FullyConnection)
DECLARE_OP_USE_OUTPUT(FullyConnection)
} // namespace mindspore::transform } // namespace mindspore::transform
#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_MATRIX_CALCULATION_OPS_DECLARE_H_ #endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_MATRIX_CALCULATION_OPS_DECLARE_H_

View File

@ -32,7 +32,17 @@ OUTPUT_MAP(BatchNorm) = {{0, OUTPUT_DESC(y)},
{2, OUTPUT_DESC(batch_variance)}, {2, OUTPUT_DESC(batch_variance)},
{3, OUTPUT_DESC(reserve_space_1)}, {3, OUTPUT_DESC(reserve_space_1)},
{4, OUTPUT_DESC(reserve_space_2)}}; {4, OUTPUT_DESC(reserve_space_2)}};
// BNInference is BatchNorm for caffe
INPUT_MAP(BNInference) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(mean)}, {3, INPUT_DESC(variance)},
{4, INPUT_DESC(momentum)}, {5, INPUT_DESC(scale)}, {6, INPUT_DESC(offset)}};
ATTR_MAP(BNInference) = {{"epsilon", ATTR_DESC(epsilon, AnyTraits<float>())},
{"use_global_stats", ATTR_DESC(use_global_stats, AnyTraits<bool>())},
{"mode", ATTR_DESC(mode, AnyTraits<int64_t>())}};
OUTPUT_MAP(BNInference) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(BNInference, kNameBNInference, ADPT_DESC(BNInference))
REG_ADPT_DESC(BatchNorm, kNameBatchNorm, ADPT_DESC(BatchNorm)) REG_ADPT_DESC(BatchNorm, kNameBatchNorm, ADPT_DESC(BatchNorm))
REG_ADPT_DESC(FusedBatchNorm, kNameFusedBatchNorm, ADPT_DESC(BatchNorm))
// BatchNormGrad // BatchNormGrad
INPUT_MAP(BatchNormGrad) = {{1, INPUT_DESC(y_backprop)}, INPUT_MAP(BatchNormGrad) = {{1, INPUT_DESC(y_backprop)},
@ -65,4 +75,5 @@ ATTR_MAP(L2Normalize) = {
{"epsilon", ATTR_DESC(eps, AnyTraits<float>())}}; {"epsilon", ATTR_DESC(eps, AnyTraits<float>())}};
OUTPUT_MAP(L2Normalize) = {{0, OUTPUT_DESC(y)}}; OUTPUT_MAP(L2Normalize) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(L2Normalize, kNameL2Normalize, ADPT_DESC(L2Normalize)) REG_ADPT_DESC(L2Normalize, kNameL2Normalize, ADPT_DESC(L2Normalize))
} // namespace mindspore::transform } // namespace mindspore::transform

View File

@ -26,6 +26,9 @@ namespace mindspore::transform {
DECLARE_OP_ADAPTER(BatchNorm) DECLARE_OP_ADAPTER(BatchNorm)
DECLARE_OP_USE_OUTPUT(BatchNorm) DECLARE_OP_USE_OUTPUT(BatchNorm)
DECLARE_OP_ADAPTER(BNInference)
DECLARE_OP_USE_OUTPUT(BNInference)
DECLARE_OP_ADAPTER(BatchNormGrad) DECLARE_OP_ADAPTER(BatchNormGrad)
DECLARE_OP_USE_OUTPUT(BatchNormGrad) DECLARE_OP_USE_OUTPUT(BatchNormGrad)

View File

@ -49,6 +49,19 @@ ATTR_MAP(Conv2DBackpropInputD) = {
}; };
OUTPUT_MAP(Conv2DBackpropInputD) = {{0, OUTPUT_DESC(y)}}; OUTPUT_MAP(Conv2DBackpropInputD) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(Conv2DBackpropInputD, prim::kPrimConv2DBackpropInput->name(), ADPT_DESC(Conv2DBackpropInputD)) REG_ADPT_DESC(Conv2DBackpropInputD, prim::kPrimConv2DBackpropInput->name(), ADPT_DESC(Conv2DBackpropInputD))
// Deconvolution for caffe inference
INPUT_MAP(Deconvolution) = {
{1, INPUT_DESC(x)}, {2, INPUT_DESC(filter)}, {3, INPUT_DESC(bias)}, {4, INPUT_DESC(offset_w)}};
ATTR_MAP(Deconvolution) = {
{"stride", ATTR_DESC(strides, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())},
{"pad_list", ATTR_DESC(pads, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())},
{"dilation", ATTR_DESC(dilations, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())},
{"group", ATTR_DESC(groups, AnyTraits<int64_t>())},
{"format", ATTR_DESC(data_format, AnyTraits<string>())},
{"offset", ATTR_DESC(offset_x, AnyTraits<int64_t>())}};
OUTPUT_MAP(Deconvolution) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(Deconvolution, kNameDeconvolution, ADPT_DESC(Deconvolution))
REG_ADPT_DESC(Conv2DTranspose, kConv2DTransposeOpName, ADPT_DESC(Conv2DBackpropInputD)) REG_ADPT_DESC(Conv2DTranspose, kConv2DTransposeOpName, ADPT_DESC(Conv2DBackpropInputD))
// Conv2DBackpropFilterD // Conv2DBackpropFilterD

View File

@ -69,5 +69,8 @@ DECLARE_OP_USE_OUTPUT(DepthwiseConv2DBackpropFilterD)
DECLARE_OP_ADAPTER(DepthwiseConv2DBackpropInputD) DECLARE_OP_ADAPTER(DepthwiseConv2DBackpropInputD)
DECLARE_OP_USE_INPUT_ATTR(DepthwiseConv2DBackpropInputD) DECLARE_OP_USE_INPUT_ATTR(DepthwiseConv2DBackpropInputD)
DECLARE_OP_USE_OUTPUT(DepthwiseConv2DBackpropInputD) DECLARE_OP_USE_OUTPUT(DepthwiseConv2DBackpropInputD)
DECLARE_OP_ADAPTER(Deconvolution)
DECLARE_OP_USE_OUTPUT(Deconvolution)
} // namespace mindspore::transform } // namespace mindspore::transform
#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_NN_CALCULATION_OPS_DECLARE_H_ #endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_NN_CALCULATION_OPS_DECLARE_H_

View File

@ -146,4 +146,13 @@ INPUT_MAP(Centralization) = {{1, INPUT_DESC(x)}};
ATTR_MAP(Centralization) = {{"axes", ATTR_DESC(axes, AnyTraits<std::vector<int64_t>>())}}; ATTR_MAP(Centralization) = {{"axes", ATTR_DESC(axes, AnyTraits<std::vector<int64_t>>())}};
OUTPUT_MAP(Centralization) = {{0, OUTPUT_DESC(y)}}; OUTPUT_MAP(Centralization) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(Centralization, kNameCentralization, ADPT_DESC(Centralization)) REG_ADPT_DESC(Centralization, kNameCentralization, ADPT_DESC(Centralization))
// Scale
INPUT_MAP(Scale) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(scale)}, {3, INPUT_DESC(bias)}};
ATTR_MAP(Scale) = {{"axis", ATTR_DESC(axis, AnyTraits<int64_t>())},
{"num_axes", ATTR_DESC(num_axes, AnyTraits<int64_t>())},
{"scale_from_blob", ATTR_DESC(scale_from_blob, AnyTraits<bool>())}};
OUTPUT_MAP(Scale) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(Scale, kNameScale, ADPT_DESC(Scale))
} // namespace mindspore::transform } // namespace mindspore::transform

View File

@ -76,5 +76,8 @@ DECLARE_OP_USE_OUTPUT(BinaryCrossEntropyGrad)
DECLARE_OP_ADAPTER(Centralization) DECLARE_OP_ADAPTER(Centralization)
DECLARE_OP_USE_OUTPUT(Centralization) DECLARE_OP_USE_OUTPUT(Centralization)
DECLARE_OP_ADAPTER(Scale)
DECLARE_OP_USE_OUTPUT(Scale)
} // namespace mindspore::transform } // namespace mindspore::transform
#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_IMAGE_OPS_DECLARE_H_ #endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_IMAGE_OPS_DECLARE_H_

View File

@ -120,4 +120,47 @@ ATTR_MAP(MaxPoolGradGradWithArgmax) = {
{"pad_mode", ATTR_DESC(padding, AnyTraits<std::string>())}}; {"pad_mode", ATTR_DESC(padding, AnyTraits<std::string>())}};
OUTPUT_MAP(MaxPoolGradGradWithArgmax) = {{0, OUTPUT_DESC(y)}}; OUTPUT_MAP(MaxPoolGradGradWithArgmax) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(MaxPoolGradGradWithArgmax, kNameMaxPoolGradGradWithArgmax, ADPT_DESC(MaxPoolGradGradWithArgmax)) REG_ADPT_DESC(MaxPoolGradGradWithArgmax, kNameMaxPoolGradGradWithArgmax, ADPT_DESC(MaxPoolGradGradWithArgmax))
// Pooling
INPUT_MAP(Pooling) = {{1, INPUT_DESC(x)}};
ATTR_MAP(Pooling) = {{"mode", ATTR_DESC(mode, AnyTraits<int64_t>())},
{"global", ATTR_DESC(global_pooling, AnyTraits<bool>())},
{"kernel_size", ATTR_DESC(window, AnyTraits<int64_t>(), AnyTraits<std::vector<int64_t>>())},
{"strides", ATTR_DESC(stride, AnyTraits<int64_t>(), AnyTraits<std::vector<int64_t>>())},
{"pad", ATTR_DESC(pad, AnyTraits<int64_t>(), AnyTraits<std::vector<int64_t>>())},
{"dilation", ATTR_DESC(dilation, AnyTraits<int64_t>(), AnyTraits<std::vector<int64_t>>())},
{"round_mode", ATTR_DESC(ceil_mode, AnyTraits<int64_t>())},
{"format", ATTR_DESC(data_format, AnyTraits<std::string>())}};
OUTPUT_MAP(Pooling) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(Pooling, kNamePooling, ADPT_DESC(Pooling))
// MaxPoolV3
INPUT_MAP(MaxPoolV3) = {{1, INPUT_DESC(x)}};
ATTR_MAP(MaxPoolV3) = {{"kernel_size", ATTR_DESC(ksize, AnyTraits<int64_t>(), AnyTraits<std::vector<int64_t>>())},
{"strides", ATTR_DESC(strides, AnyTraits<int64_t>(), AnyTraits<std::vector<int64_t>>())},
{"padding_mode", ATTR_DESC(padding_mode, AnyTraits<std::string>())},
{"pad", ATTR_DESC(pads, AnyTraits<int64_t>(), AnyTraits<std::vector<int64_t>>())},
{"format", ATTR_DESC(data_format, AnyTraits<std::string>())},
{"global", ATTR_DESC(global_pooling, AnyTraits<bool>())},
{"ceil_mode", ATTR_DESC(ceil_mode, AnyTraits<bool>())}};
OUTPUT_MAP(MaxPoolV3) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(MaxPoolV3, kNameMaxPoolV3, ADPT_DESC(MaxPoolV3))
// AvgPoolV2
INPUT_MAP(AvgPoolV2) = {{1, INPUT_DESC(x)}};
ATTR_MAP(AvgPoolV2) = {{"kernel_size", ATTR_DESC(ksize, AnyTraits<int64_t>(), AnyTraits<std::vector<int64_t>>())},
{"strides", ATTR_DESC(strides, AnyTraits<int64_t>(), AnyTraits<std::vector<int64_t>>())},
{"padding_mode", ATTR_DESC(padding_mode, AnyTraits<std::string>())},
{"pad", ATTR_DESC(pads, AnyTraits<int64_t>(), AnyTraits<std::vector<int64_t>>())},
{"format", ATTR_DESC(data_format, AnyTraits<std::string>())},
{"global", ATTR_DESC(global_pooling, AnyTraits<bool>())},
{"ceil_mode", ATTR_DESC(ceil_mode, AnyTraits<bool>())}};
OUTPUT_MAP(AvgPoolV2) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(AvgPoolV2, kNameAvgPoolV2, ADPT_DESC(AvgPoolV2))
// GlobalAveragePool
INPUT_MAP(GlobalAveragePool) = {{1, INPUT_DESC(x)}};
ATTR_MAP(GlobalAveragePool) = EMPTY_ATTR_MAP;
OUTPUT_MAP(GlobalAveragePool) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(GlobalAveragePool, kNameGlobalAvgPool, ADPT_DESC(GlobalAveragePool))
} // namespace mindspore::transform } // namespace mindspore::transform

View File

@ -21,6 +21,7 @@
#include <unordered_map> #include <unordered_map>
#include "transform/graph_ir/op_declare/op_declare_macro.h" #include "transform/graph_ir/op_declare/op_declare_macro.h"
#include "ops/nn_ops.h" #include "ops/nn_ops.h"
#include "ops/nn_pooling_ops.h"
namespace mindspore::transform { namespace mindspore::transform {
DECLARE_OP_ADAPTER(MaxPoolWithArgmax) DECLARE_OP_ADAPTER(MaxPoolWithArgmax)
@ -55,5 +56,17 @@ DECLARE_OP_USE_OUTPUT(AvgPool)
DECLARE_OP_ADAPTER(AvgPoolGrad) DECLARE_OP_ADAPTER(AvgPoolGrad)
DECLARE_OP_USE_OUTPUT(AvgPoolGrad) DECLARE_OP_USE_OUTPUT(AvgPoolGrad)
DECLARE_OP_ADAPTER(Pooling)
DECLARE_OP_USE_OUTPUT(Pooling)
DECLARE_OP_ADAPTER(MaxPoolV3)
DECLARE_OP_USE_OUTPUT(MaxPoolV3)
DECLARE_OP_ADAPTER(AvgPoolV2)
DECLARE_OP_USE_OUTPUT(AvgPoolV2)
DECLARE_OP_ADAPTER(GlobalAveragePool)
DECLARE_OP_USE_OUTPUT(GlobalAveragePool)
} // namespace mindspore::transform } // namespace mindspore::transform
#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_NN_POOLING_OPS_DECLARE_H_ #endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_NN_POOLING_OPS_DECLARE_H_

View File

@ -154,4 +154,10 @@ INPUT_MAP(FastGeluGrad) = {{1, INPUT_DESC(dy)}, {2, INPUT_DESC(x)}};
ATTR_MAP(FastGeluGrad) = EMPTY_ATTR_MAP; ATTR_MAP(FastGeluGrad) = EMPTY_ATTR_MAP;
OUTPUT_MAP(FastGeluGrad) = {{0, OUTPUT_DESC(z)}}; OUTPUT_MAP(FastGeluGrad) = {{0, OUTPUT_DESC(z)}};
REG_ADPT_DESC(FastGeluGrad, prim::kPrimFastGeLUGrad->name(), ADPT_DESC(FastGeluGrad)) REG_ADPT_DESC(FastGeluGrad, prim::kPrimFastGeLUGrad->name(), ADPT_DESC(FastGeluGrad))
// LeakyRelu
INPUT_MAP(LeakyRelu) = {{1, INPUT_DESC(x)}};
ATTR_MAP(LeakyRelu) = {{"alpha", ATTR_DESC(negative_slope, AnyTraits<float>())}};
OUTPUT_MAP(LeakyRelu) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(LeakyRelu, prim::kPrimLeakyRelu->name(), ADPT_DESC(LeakyRelu))
} // namespace mindspore::transform } // namespace mindspore::transform

View File

@ -91,5 +91,8 @@ DECLARE_OP_USE_OUTPUT(Sigmoid)
DECLARE_OP_ADAPTER(SigmoidGrad) DECLARE_OP_ADAPTER(SigmoidGrad)
DECLARE_OP_USE_OUTPUT(SigmoidGrad) DECLARE_OP_USE_OUTPUT(SigmoidGrad)
DECLARE_OP_ADAPTER(LeakyRelu)
DECLARE_OP_USE_OUTPUT(LeakyRelu)
} // namespace mindspore::transform } // namespace mindspore::transform
#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_NONLINEAR_FUC_OPS_DECLARE_H_ #endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_NONLINEAR_FUC_OPS_DECLARE_H_

View File

@ -41,4 +41,11 @@ INPUT_MAP(FillD) = {{1, INPUT_DESC(value)}};
ATTR_MAP(FillD) = {{"dims", ATTR_DESC(dims, AnyTraits<std::vector<int64_t>>())}}; ATTR_MAP(FillD) = {{"dims", ATTR_DESC(dims, AnyTraits<std::vector<int64_t>>())}};
OUTPUT_MAP(FillD) = {{0, OUTPUT_DESC(y)}}; OUTPUT_MAP(FillD) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(FillD, kNameFillD, ADPT_DESC(FillD)) REG_ADPT_DESC(FillD, kNameFillD, ADPT_DESC(FillD))
// PadV3
INPUT_MAP(PadV3) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(paddings)}, {3, INPUT_DESC(constant_values)}};
ATTR_MAP(PadV3) = {{"mode", ATTR_DESC(mode, AnyTraits<std::string>())},
{"pad_contiguous", ATTR_DESC(paddings_contiguous, AnyTraits<bool>())}};
OUTPUT_MAP(PadV3) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(PadV3, kNamePadV3, ADPT_DESC(PadV3))
} // namespace mindspore::transform } // namespace mindspore::transform

View File

@ -34,5 +34,8 @@ DECLARE_OP_USE_OUTPUT(Diag)
DECLARE_OP_ADAPTER(FillD) DECLARE_OP_ADAPTER(FillD)
DECLARE_OP_USE_OUTPUT(FillD) DECLARE_OP_USE_OUTPUT(FillD)
DECLARE_OP_ADAPTER(PadV3)
DECLARE_OP_USE_OUTPUT(PadV3)
} // namespace mindspore::transform } // namespace mindspore::transform
#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_PAD_OPS_DECLARE_H_ #endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_PAD_OPS_DECLARE_H_

View File

@ -77,6 +77,7 @@ INPUT_ATTR_MAP(GatherV2D) = {{3, ATTR_DESC(axis, AnyTraits<int64_t>())}};
ATTR_MAP(GatherV2D) = EMPTY_ATTR_MAP; ATTR_MAP(GatherV2D) = EMPTY_ATTR_MAP;
OUTPUT_MAP(GatherV2D) = {{0, OUTPUT_DESC(y)}}; OUTPUT_MAP(GatherV2D) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(GatherV2D, prim::kPrimGather->name(), ADPT_DESC(GatherV2D)) REG_ADPT_DESC(GatherV2D, prim::kPrimGather->name(), ADPT_DESC(GatherV2D))
REG_ADPT_DESC(Gather, kNameGather, ADPT_DESC(GatherV2D))
// ScatterNdD // ScatterNdD
INPUT_MAP(ScatterNdD) = {{1, INPUT_DESC(indices)}, {2, INPUT_DESC(x)}}; INPUT_MAP(ScatterNdD) = {{1, INPUT_DESC(indices)}, {2, INPUT_DESC(x)}};
@ -151,6 +152,17 @@ ATTR_MAP(StridedSlice) = {{"begin_mask", ATTR_DESC(begin_mask, AnyTraits<int64_t
OUTPUT_MAP(StridedSlice) = {{0, OUTPUT_DESC(y)}}; OUTPUT_MAP(StridedSlice) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(StridedSlice, kNameStridedSlice, ADPT_DESC(StridedSlice)) REG_ADPT_DESC(StridedSlice, kNameStridedSlice, ADPT_DESC(StridedSlice))
// StridedSliceV2
INPUT_MAP(StridedSliceV2) = {
{1, INPUT_DESC(x)}, {2, INPUT_DESC(begin)}, {3, INPUT_DESC(end)}, {4, INPUT_DESC(axes)}, {5, INPUT_DESC(strides)}};
ATTR_MAP(StridedSliceV2) = {{"begin_mask", ATTR_DESC(begin_mask, AnyTraits<int64_t>())},
{"end_mask", ATTR_DESC(end_mask, AnyTraits<int64_t>())},
{"ellipsis_mask", ATTR_DESC(ellipsis_mask, AnyTraits<int64_t>())},
{"new_axis_mask", ATTR_DESC(new_axis_mask, AnyTraits<int64_t>())},
{"shrink_axis_mask", ATTR_DESC(shrink_axis_mask, AnyTraits<int64_t>())}};
OUTPUT_MAP(StridedSliceV2) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(StridedSliceV2, kNameStridedSliceV2, ADPT_DESC(StridedSliceV2))
// UnsortedSegmentSum // UnsortedSegmentSum
INPUT_MAP(UnsortedSegmentSumD) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(segment_ids)}}; INPUT_MAP(UnsortedSegmentSumD) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(segment_ids)}};
INPUT_ATTR_MAP(UnsortedSegmentSumD) = {{3, ATTR_DESC(num_segments, AnyTraits<int64_t>())}}; INPUT_ATTR_MAP(UnsortedSegmentSumD) = {{3, ATTR_DESC(num_segments, AnyTraits<int64_t>())}};

View File

@ -52,6 +52,9 @@ DECLARE_OP_USE_OUTPUT(StridedSliceGrad)
DECLARE_OP_ADAPTER(StridedSlice) DECLARE_OP_ADAPTER(StridedSlice)
DECLARE_OP_USE_OUTPUT(StridedSlice) DECLARE_OP_USE_OUTPUT(StridedSlice)
DECLARE_OP_ADAPTER(StridedSliceV2)
DECLARE_OP_USE_OUTPUT(StridedSliceV2)
DECLARE_OP_ADAPTER(UnsortedSegmentSumD) DECLARE_OP_ADAPTER(UnsortedSegmentSumD)
DECLARE_OP_USE_INPUT_ATTR(UnsortedSegmentSumD) DECLARE_OP_USE_INPUT_ATTR(UnsortedSegmentSumD)
DECLARE_OP_USE_OUTPUT(UnsortedSegmentSumD) DECLARE_OP_USE_OUTPUT(UnsortedSegmentSumD)

View File

@ -219,7 +219,14 @@ void FuncGraph::AddNode(const AnfNodePtr &node) { nodes_.add(node); }
void FuncGraph::DropNode(const AnfNodePtr &node) { void FuncGraph::DropNode(const AnfNodePtr &node) {
nodes_.erase(node); nodes_.erase(node);
if (node == nullptr) {
MS_LOG(ERROR) << "Node is nullptr";
return;
}
auto graph = node->func_graph(); auto graph = node->func_graph();
if (node->isa<Parameter>()) {
parameters_.erase(std::remove(parameters_.begin(), parameters_.end(), node), parameters_.end());
}
// Remove the node from order list. // Remove the node from order list.
if (graph) { if (graph) {
graph->EraseUnusedNodeInOrder(node); graph->EraseUnusedNodeInOrder(node);

View File

@ -16,6 +16,7 @@
#include "ops/fusion/scale_fusion.h" #include "ops/fusion/scale_fusion.h"
#include <string> #include <string>
#include <memory>
#include "ops/op_utils.h" #include "ops/op_utils.h"
namespace mindspore { namespace mindspore {

View File

@ -254,6 +254,10 @@ constexpr auto kSplitDim = "split_dim";
constexpr auto kPadTop = "pad_top"; constexpr auto kPadTop = "pad_top";
constexpr auto kTransFormat = "trans_format"; constexpr auto kTransFormat = "trans_format";
constexpr auto kApproximate = "approximate"; constexpr auto kApproximate = "approximate";
constexpr auto kNumOutput = "num_output";
constexpr auto kUseGlobalStats = "use_global_stats";
constexpr auto kFmkType = "fmk_type";
const std::set<TypePtr> common_valid_types = {kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16, const std::set<TypePtr> common_valid_types = {kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16,
kUInt32, kUInt64, kFloat16, kFloat32, kFloat64}; kUInt32, kUInt64, kFloat16, kFloat32, kFloat64};

View File

@ -138,7 +138,11 @@ static std::map<std::string, std::map<std::string, AttrConverterPair>> PrimAttrC
{"BinaryCrossEntropyGrad", ReductionMap}, {"BinaryCrossEntropyGrad", ReductionMap},
{"NLLLoss", ReductionMap}, {"NLLLoss", ReductionMap},
{"DepthToSpace", DataFormatMap}, {"DepthToSpace", DataFormatMap},
}; {"Pooling", DataFormatMap},
{"Deconvolution", DataFormatMap},
{"AvgPoolV2", DataFormatMap},
{"MaxPoolV3", DataFormatMap},
{"FusedBatchNorm", DataFormatMap}};
bool CheckAndConvertUtils::GetDataFormatEnumValue(const ValuePtr &value, int64_t *enum_value) { bool CheckAndConvertUtils::GetDataFormatEnumValue(const ValuePtr &value, int64_t *enum_value) {
MS_EXCEPTION_IF_NULL(value); MS_EXCEPTION_IF_NULL(value);

View File

@ -41,6 +41,7 @@ option(MSLITE_ENABLE_MINDRT "enable mindrt use" on)
option(MSLITE_DELEGATE_USE "enable delegate use" on) option(MSLITE_DELEGATE_USE "enable delegate use" on)
option(MSLITE_ENABLE_V0 "support v0 schema" on) option(MSLITE_ENABLE_V0 "support v0 schema" on)
option(MSLITE_ENABLE_FP16 "Whether to compile Fp16 operator" off) option(MSLITE_ENABLE_FP16 "Whether to compile Fp16 operator" off)
option(MSLITE_ENABLE_ACL "enable ACL" off)
#Option that can be configured through manually #Option that can be configured through manually
option(ENABLE_VERBOSE "" off) option(ENABLE_VERBOSE "" off)
@ -119,6 +120,10 @@ if(DEFINED ENV{MSLITE_ENABLE_FP16})
set(MSLITE_ENABLE_FP16 $ENV{MSLITE_ENABLE_FP16}) set(MSLITE_ENABLE_FP16 $ENV{MSLITE_ENABLE_FP16})
endif() endif()
if(DEFINED ENV{MSLITE_ENABLE_ACL})
set(MSLITE_ENABLE_ACL $ENV{MSLITE_ENABLE_ACL})
endif()
if(PLATFORM_ARM64) if(PLATFORM_ARM64)
if(MSLITE_GPU_BACKEND STREQUAL "") if(MSLITE_GPU_BACKEND STREQUAL "")
set(MSLITE_GPU_BACKEND "opencl") set(MSLITE_GPU_BACKEND "opencl")
@ -240,6 +245,21 @@ if(ENABLE_ASAN)
add_link_options(-fsanitize=address) add_link_options(-fsanitize=address)
endif() endif()
if(MSLITE_ENABLE_ACL)
set(ENABLE_ACL on)
add_definitions(-D ENABLE_LITE_ACL)
add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,--allow-shlib-undefined")
if(DEFINED ENV{ASCEND_CUSTOM_PATH})
set(ASCEND_PATH $ENV{ASCEND_CUSTOM_PATH})
else()
set(ASCEND_PATH /usr/local/Ascend)
endif()
set(ASCEND_RUNTIME_PATH ${ASCEND_PATH}/fwkacllib/lib64)
set(ASCEND_TOOLKIT_RUNTIME_PATH ${ASCEND_PATH}/ascend-toolkit/latest/fwkacllib/lib64)
endif()
set(PKG_NAME_PREFIX mindspore-lite-${MS_VERSION_MAJOR}.${MS_VERSION_MINOR}.${MS_VERSION_REVISION}) set(PKG_NAME_PREFIX mindspore-lite-${MS_VERSION_MAJOR}.${MS_VERSION_MINOR}.${MS_VERSION_REVISION})
if(SUPPORT_NPU) if(SUPPORT_NPU)

View File

@ -381,9 +381,18 @@ build_aar() {
sha256sum mindspore-lite-maven-${VERSION_STR}.zip > mindspore-lite-maven-${VERSION_STR}.zip.sha256 sha256sum mindspore-lite-maven-${VERSION_STR}.zip > mindspore-lite-maven-${VERSION_STR}.zip.sha256
} }
update_submodule()
{
git submodule update --init graphengine
cd "${BASEPATH}/graphengine"
git submodule update --init metadef
}
LITE_JAVA_PATH=${BASEPATH}/mindspore/lite/java LITE_JAVA_PATH=${BASEPATH}/mindspore/lite/java
LITE_BUILD_TYPE="Release" LITE_BUILD_TYPE="Release"
if [[ "${MSLITE_ENABLE_ACL}" == "on" ]]; then
update_submodule
fi
if [[ "${DEBUG_MODE}" == "on" ]]; then if [[ "${DEBUG_MODE}" == "on" ]]; then
LITE_BUILD_TYPE="Debug" LITE_BUILD_TYPE="Debug"
fi fi

View File

@ -233,6 +233,11 @@ add_subdirectory(runtime/kernel/arm)
add_library(lite_src_mid OBJECT ${LITE_SRC}) add_library(lite_src_mid OBJECT ${LITE_SRC})
add_dependencies(lite_src_mid fbs_src) add_dependencies(lite_src_mid fbs_src)
if(MSLITE_ENABLE_ACL)
add_subdirectory(runtime/kernel/ascend310)
link_directories(${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
endif()
add_library(mindspore-lite SHARED $<TARGET_OBJECTS:lite_src_mid>) add_library(mindspore-lite SHARED $<TARGET_OBJECTS:lite_src_mid>)
set_target_properties(mindspore-lite PROPERTIES CLEAN_DIRECT_OUTPUT 1) set_target_properties(mindspore-lite PROPERTIES CLEAN_DIRECT_OUTPUT 1)
@ -387,3 +392,8 @@ if(ENABLE_MODEL_OBF)
target_link_libraries(mindspore-lite ${OBF_LIB_DIR}/libmsdeobfuscator-lite.so) target_link_libraries(mindspore-lite ${OBF_LIB_DIR}/libmsdeobfuscator-lite.so)
target_link_libraries(mindspore-lite_static ${OBF_LIB_DIR}/libmsdeobfuscator-lite.so) target_link_libraries(mindspore-lite_static ${OBF_LIB_DIR}/libmsdeobfuscator-lite.so)
endif() endif()
if(MSLITE_ENABLE_ACL)
target_link_libraries(mindspore-lite ascend310_kernel_mid)
target_link_libraries(mindspore-lite_static ascend310_kernel_mid)
endif()

View File

@ -33,6 +33,19 @@ constexpr auto kModelOptionGPUEnableFP16 = "mindspore.option.gpu.enable_fp16";
constexpr auto kModelOptionKirinNpuFrequency = "mindspore.option.kirin_npu.frequency"; constexpr auto kModelOptionKirinNpuFrequency = "mindspore.option.kirin_npu.frequency";
constexpr auto kModelOptionProvider = "mindspore.option.provider"; constexpr auto kModelOptionProvider = "mindspore.option.provider";
constexpr auto kModelOptionProviderDevice = "mindspore.option.provider.device"; constexpr auto kModelOptionProviderDevice = "mindspore.option.provider.device";
constexpr auto kModelOptionDeviceID = "mindspore.option.device_id";
constexpr auto kModelOptionAscend310DeviceID = kModelOptionDeviceID;
constexpr auto kModelOptionAscend310DumpCfgPath = "mindspore.option.ascend310.dump_config_file_path";
constexpr auto kModelOptionAscend310InsertOpCfgPath = "mindspore.option.ascend310.insert_op_config_file_path";
constexpr auto kModelOptionAscend310InputFormat = "mindspore.option.ascend310.input_format";
constexpr auto kModelOptionAscend310InputShapeMap = "mindspore.option.ascend310.input_shape_map";
constexpr auto kModelOptionAscend310InputShape = "mindspore.option.ascend310.input_shape";
constexpr auto kModelOptionAscend310OutputType = "mindspore.option.ascend310.output_type";
constexpr auto kModelOptionAscend310PrecisionMode = "mindspore.option.ascend310.precision_mode";
constexpr auto kModelOptionAscend310OpSelectImplMode = "mindspore.option.ascend310.op_select_impl_mode";
constexpr auto KModelOptionAscend310FusionSwitchCfgPath = "mindspore.option.ascend310.fusion_switch_config_file_path";
constexpr auto kModelOptionAscend310DynamicBatchSize = "mindspore.option.ascend310.dynamic_batch_size";
constexpr auto kModelOptionAscend310BufferOptimize = "mindspore.option.ascend310.buffer_optimize";
struct Context::Data { struct Context::Data {
std::vector<std::shared_ptr<DeviceInfoContext>> device_info_list; std::vector<std::shared_ptr<DeviceInfoContext>> device_info_list;
@ -290,101 +303,208 @@ uint32_t Ascend910DeviceInfo::GetDeviceID() const {
return 0; return 0;
} }
void Ascend310DeviceInfo::SetDeviceID(uint32_t device_id) { MS_LOG(ERROR) << "Unsupported Feature."; } void Ascend310DeviceInfo::SetDeviceID(uint32_t device_id) {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return;
}
data_->params[kModelOptionAscend310DeviceID] = device_id;
}
uint32_t Ascend310DeviceInfo::GetDeviceID() const { uint32_t Ascend310DeviceInfo::GetDeviceID() const {
MS_LOG(ERROR) << "Unsupported Feature."; if (data_ == nullptr) {
return 0; MS_LOG(ERROR) << "Invalid context.";
return 0;
}
return GetValue<uint32_t>(data_, kModelOptionAscend310DeviceID);
} }
void Ascend310DeviceInfo::SetDumpConfigPath(const std::vector<char> &cfg_path) { void Ascend310DeviceInfo::SetDumpConfigPath(const std::vector<char> &cfg_path) {
MS_LOG(ERROR) << "Unsupported Feature."; if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return;
}
data_->params[kModelOptionAscend310DumpCfgPath] = CharToString(cfg_path);
} }
std::vector<char> Ascend310DeviceInfo::GetDumpConfigPathChar() const { std::vector<char> Ascend310DeviceInfo::GetDumpConfigPathChar() const {
std::vector<char> empty; if (data_ == nullptr) {
MS_LOG(ERROR) << "Unsupported Feature."; MS_LOG(ERROR) << "Invalid context.";
return empty; return std::vector<char>();
}
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310DumpCfgPath);
return StringToChar(ref);
} }
void Ascend310DeviceInfo::SetInsertOpConfigPath(const std::vector<char> &cfg_path) { void Ascend310DeviceInfo::SetInsertOpConfigPath(const std::vector<char> &cfg_path) {
MS_LOG(ERROR) << "Unsupported Feature."; if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return;
}
data_->params[kModelOptionAscend310InsertOpCfgPath] = CharToString(cfg_path);
} }
std::vector<char> Ascend310DeviceInfo::GetInsertOpConfigPathChar() const { std::vector<char> Ascend310DeviceInfo::GetInsertOpConfigPathChar() const {
std::vector<char> empty; if (data_ == nullptr) {
MS_LOG(ERROR) << "Unsupported Feature."; MS_LOG(ERROR) << "Invalid context.";
return empty; return std::vector<char>();
}
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310InsertOpCfgPath);
return StringToChar(ref);
}
void Ascend310DeviceInfo::SetInputFormat(const std::vector<char> &format) {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return;
}
data_->params[kModelOptionAscend310InputFormat] = CharToString(format);
} }
void Ascend310DeviceInfo::SetInputFormat(const std::vector<char> &format) { MS_LOG(ERROR) << "Unsupported Feature."; }
std::vector<char> Ascend310DeviceInfo::GetInputFormatChar() const { std::vector<char> Ascend310DeviceInfo::GetInputFormatChar() const {
std::vector<char> empty; if (data_ == nullptr) {
MS_LOG(ERROR) << "Unsupported Feature."; MS_LOG(ERROR) << "Invalid context.";
return empty; return std::vector<char>();
}
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310InputFormat);
return StringToChar(ref);
} }
void Ascend310DeviceInfo::SetInputShape(const std::vector<char> &shape) { MS_LOG(ERROR) << "Unsupported Feature."; } void Ascend310DeviceInfo::SetInputShape(const std::vector<char> &shape) {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return;
}
data_->params[kModelOptionAscend310InputShape] = CharToString(shape);
}
std::vector<char> Ascend310DeviceInfo::GetInputShapeChar() const { std::vector<char> Ascend310DeviceInfo::GetInputShapeChar() const {
std::vector<char> empty; if (data_ == nullptr) {
MS_LOG(ERROR) << "Unsupported Feature."; MS_LOG(ERROR) << "Invalid context.";
return empty; return std::vector<char>();
}
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310InputShape);
return StringToChar(ref);
} }
void Ascend310DeviceInfo::SetDynamicBatchSize(const std::vector<size_t> &dynamic_batch_size) { void Ascend310DeviceInfo::SetDynamicBatchSize(const std::vector<size_t> &dynamic_batch_size) {
MS_LOG(ERROR) << "Unsupported Feature."; if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return;
}
std::string batchs;
for (size_t i = 0; i < dynamic_batch_size.size(); ++i) {
if (i != 0) {
batchs.push_back(',');
}
batchs += std::to_string(dynamic_batch_size[i]);
}
data_->params[kModelOptionAscend310DynamicBatchSize] = batchs;
} }
std::vector<char> Ascend310DeviceInfo::GetDynamicBatchSizeChar() const { std::vector<char> Ascend310DeviceInfo::GetDynamicBatchSizeChar() const {
std::vector<char> empty; if (data_ == nullptr) {
MS_LOG(ERROR) << "Unsupported Feature."; MS_LOG(ERROR) << "Invalid context.";
return empty; return std::vector<char>();
}
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310DynamicBatchSize);
return StringToChar(ref);
} }
void Ascend310DeviceInfo::SetPrecisionMode(const std::vector<char> &precision_mode) { void Ascend310DeviceInfo::SetPrecisionMode(const std::vector<char> &precision_mode) {
MS_LOG(ERROR) << "Unsupported Feature."; if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return;
}
data_->params[kModelOptionAscend310PrecisionMode] = CharToString(precision_mode);
} }
std::vector<char> Ascend310DeviceInfo::GetPrecisionModeChar() const { std::vector<char> Ascend310DeviceInfo::GetPrecisionModeChar() const {
std::vector<char> empty; if (data_ == nullptr) {
MS_LOG(ERROR) << "Unsupported Feature."; MS_LOG(ERROR) << "Invalid context.";
return empty; return std::vector<char>();
}
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310PrecisionMode);
return StringToChar(ref);
} }
void Ascend310DeviceInfo::SetOpSelectImplMode(const std::vector<char> &op_select_impl_mode) { void Ascend310DeviceInfo::SetOpSelectImplMode(const std::vector<char> &op_select_impl_mode) {
MS_LOG(ERROR) << "Unsupported Feature."; if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return;
}
data_->params[kModelOptionAscend310OpSelectImplMode] = CharToString(op_select_impl_mode);
} }
std::vector<char> Ascend310DeviceInfo::GetOpSelectImplModeChar() const { std::vector<char> Ascend310DeviceInfo::GetOpSelectImplModeChar() const {
std::vector<char> empty; if (data_ == nullptr) {
MS_LOG(ERROR) << "Unsupported Feature."; MS_LOG(ERROR) << "Invalid context.";
return empty; return std::vector<char>();
}
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310OpSelectImplMode);
return StringToChar(ref);
} }
void Ascend310DeviceInfo::SetFusionSwitchConfigPath(const std::vector<char> &cfg_path) { void Ascend310DeviceInfo::SetFusionSwitchConfigPath(const std::vector<char> &cfg_path) {
MS_LOG(ERROR) << "Unsupported Feature."; if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return;
}
data_->params[KModelOptionAscend310FusionSwitchCfgPath] = CharToString(cfg_path);
} }
std::vector<char> Ascend310DeviceInfo::GetFusionSwitchConfigPathChar() const { std::vector<char> Ascend310DeviceInfo::GetFusionSwitchConfigPathChar() const {
std::vector<char> empty; if (data_ == nullptr) {
MS_LOG(ERROR) << "Unsupported Feature."; MS_LOG(ERROR) << "Invalid context.";
return empty; return std::vector<char>();
}
const std::string &ref = GetValue<std::string>(data_, KModelOptionAscend310FusionSwitchCfgPath);
return StringToChar(ref);
} }
void Ascend310DeviceInfo::SetInputShapeMap(const std::map<int, std::vector<int>> &shape) { void Ascend310DeviceInfo::SetInputShapeMap(const std::map<int, std::vector<int>> &shape) {
MS_LOG(ERROR) << "Unsupported Feature."; if (data_ == nullptr) {
} MS_LOG(ERROR) << "Invalid context.";
std::map<int, std::vector<int>> Ascend310DeviceInfo::GetInputShapeMap() const { return;
std::map<int, std::vector<int>> empty; }
MS_LOG(ERROR) << "Unsupported Feature."; data_->params[kModelOptionAscend310InputShapeMap] = shape;
return empty; }
std::map<int, std::vector<int>> Ascend310DeviceInfo::GetInputShapeMap() const {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return std::map<int, std::vector<int>>();
}
return GetValue<std::map<int, std::vector<int>>>(data_, kModelOptionAscend310InputShapeMap);
}
void Ascend310DeviceInfo::SetOutputType(enum DataType output_type) {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return;
}
data_->params[kModelOptionAscend310OutputType] = output_type;
} }
void Ascend310DeviceInfo::SetOutputType(enum DataType output_type) { MS_LOG(ERROR) << "Unsupported Feature."; }
enum DataType Ascend310DeviceInfo::GetOutputType() const { enum DataType Ascend310DeviceInfo::GetOutputType() const {
MS_LOG(ERROR) << "Unsupported Feature."; if (data_ == nullptr) {
return DataType::kTypeUnknown; MS_LOG(ERROR) << "Invalid context.";
return DataType::kTypeUnknown;
}
return GetValue<enum DataType>(data_, kModelOptionAscend310OutputType);
} }
void Ascend310DeviceInfo::SetBufferOptimizeMode(const std::vector<char> &buffer_optimize_mode) { void Ascend310DeviceInfo::SetBufferOptimizeMode(const std::vector<char> &buffer_optimize_mode) {
MS_LOG(ERROR) << "Unsupported Feature."; if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return;
}
data_->params[kModelOptionAscend310BufferOptimize] = CharToString(buffer_optimize_mode);
} }
std::vector<char> Ascend310DeviceInfo::GetBufferOptimizeModeChar() const { std::vector<char> Ascend310DeviceInfo::GetBufferOptimizeModeChar() const {
MS_LOG(ERROR) << "Unsupported Feature."; if (data_ == nullptr) {
std::vector<char> ret; MS_LOG(ERROR) << "Invalid context.";
return ret; return std::vector<char>();
}
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310BufferOptimize);
return StringToChar(ref);
} }
} // namespace mindspore } // namespace mindspore

View File

@ -27,31 +27,44 @@
namespace mindspore { namespace mindspore {
class Buffer::Impl { class Buffer::Impl {
public: public:
Impl() : data_() { MS_LOG(ERROR) << "Unsupported feature."; } Impl() : data_() {}
~Impl() = default; ~Impl() = default;
Impl(const void *data, size_t data_len) { MS_LOG(ERROR) << "Unsupported feature."; } Impl(const void *data, size_t data_len) {
if (data != nullptr) {
(void)SetData(data, data_len);
} else {
ResizeData(data_len);
}
}
const void *Data() const { const void *Data() const { return data_.data(); }
MS_LOG(ERROR) << "Unsupported feature."; void *MutableData() { return data_.data(); }
return nullptr; size_t DataSize() const { return data_.size(); }
}
void *MutableData() {
MS_LOG(ERROR) << "Unsupported feature.";
return nullptr;
}
size_t DataSize() const {
MS_LOG(ERROR) << "Unsupported feature.";
return 0;
}
bool ResizeData(size_t data_len) { bool ResizeData(size_t data_len) {
MS_LOG(ERROR) << "Unsupported feature."; data_.resize(data_len);
return false; return true;
} }
bool SetData(const void *data, size_t data_len) { bool SetData(const void *data, size_t data_len) {
MS_LOG(ERROR) << "Unsupported feature."; ResizeData(data_len);
return false; if (DataSize() != data_len) {
MS_LOG(ERROR) << "Set data failed, tensor current data size " << DataSize() << " not match data len " << data_len;
return false;
}
if (data == nullptr) {
return data_len == 0;
}
if (MutableData() == nullptr) {
MS_LOG(ERROR) << "Set data failed, data len " << data_len;
return false;
}
memcpy(MutableData(), data, data_len);
return true;
} }
protected: protected:
@ -343,38 +356,58 @@ void MSTensor::SetQuantParams(std::vector<QuantParam> quant_params) {
return impl_->SetQuantParams(quant_params); return impl_->SetQuantParams(quant_params);
} }
Buffer::Buffer() : impl_(nullptr) { MS_LOG(ERROR) << "Unsupported feature."; } Buffer::Buffer() : impl_(std::make_shared<Impl>()) {}
Buffer::Buffer(const void *data, size_t data_len) : impl_(nullptr) { MS_LOG(ERROR) << "Unsupported feature."; } Buffer::Buffer(const void *data, size_t data_len) : impl_(std::make_shared<Impl>(data, data_len)) {}
Buffer::~Buffer() = default; Buffer::~Buffer() = default;
Buffer Buffer::Clone() const { Buffer Buffer::Clone() const {
MS_LOG(ERROR) << "Unsupported feature."; Buffer ret;
return Buffer(); if (impl_ == nullptr) {
MS_LOG(ERROR) << "impl is nullptr.";
return ret;
}
ret.impl_ = std::make_shared<Impl>(*impl_);
return ret;
} }
const void *Buffer::Data() const { const void *Buffer::Data() const {
MS_LOG(ERROR) << "Unsupported feature."; if (impl_ == nullptr) {
return nullptr; MS_LOG(ERROR) << "impl is nullptr.";
return nullptr;
}
return impl_->Data();
} }
void *Buffer::MutableData() { void *Buffer::MutableData() {
MS_LOG(ERROR) << "Unsupported feature."; if (impl_ == nullptr) {
return nullptr; MS_LOG(ERROR) << "impl is nullptr.";
return nullptr;
}
return impl_->MutableData();
} }
size_t Buffer::DataSize() const { size_t Buffer::DataSize() const {
MS_LOG(ERROR) << "Unsupported feature."; if (impl_ == nullptr) {
return 0; MS_LOG(ERROR) << "impl is nullptr.";
return 0;
}
return impl_->DataSize();
} }
bool Buffer::ResizeData(size_t data_len) { bool Buffer::ResizeData(size_t data_len) {
MS_LOG(ERROR) << "Unsupported feature."; if (impl_ == nullptr) {
return false; MS_LOG(ERROR) << "impl is nullptr.";
return false;
}
return impl_->ResizeData(data_len);
} }
bool Buffer::SetData(const void *data, size_t data_len) { bool Buffer::SetData(const void *data, size_t data_len) {
MS_LOG(ERROR) << "Unsupported feature."; if (impl_ == nullptr) {
return false; MS_LOG(ERROR) << "impl is nullptr.";
return false;
}
return impl_->SetData(data, data_len);
} }
std::vector<char> CharVersion() { return StringToChar(lite::Version()); } std::vector<char> CharVersion() { return StringToChar(lite::Version()); }

View File

@ -0,0 +1,12 @@
include_directories(${TOP_DIR}/graphengine/inc/external)
find_library(ge_graph libgraph.so ${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
aux_source_directory(src ACL_SRC)
add_library(ascend310_kernel_mid OBJECT ${ACL_SRC})
add_dependencies(ascend310_kernel_mid fbs_inner_src)
target_link_libraries(ascend310_kernel_mid ${ge_graph} ${ge_compiler}
${acl_retr} ${acl_cblas} ${acl_dvpp} ${acl_runtime} ${libplatform}
${libcompress} ${libopskernel} ${libaicore_utils} ${libaicpu_engine_common} ${acl})

View File

@ -0,0 +1,66 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/ascend310/src/acl_env_guard.h"
#include "common/log_adapter.h"
#include "acl/acl.h"
namespace mindspore {
namespace acl {
std::shared_ptr<AclEnvGuard> AclEnvGuard::global_acl_env_ = nullptr;
std::mutex AclEnvGuard::global_acl_env_mutex_;
AclEnvGuard::AclEnvGuard(std::string_view cfg_file) : errno_(ACL_ERROR_NONE) {
errno_ = aclInit(cfg_file.data());
if (errno_ != ACL_ERROR_NONE && errno_ != ACL_ERROR_REPEAT_INITIALIZE) {
MS_LOG(ERROR) << "Execute aclInit Failed";
return;
}
MS_LOG(INFO) << "Acl init success";
}
AclEnvGuard::~AclEnvGuard() {
errno_ = aclFinalize();
if (errno_ != ACL_ERROR_NONE && errno_ != ACL_ERROR_REPEAT_FINALIZE) {
MS_LOG(ERROR) << "Finalize acl failed";
}
MS_LOG(INFO) << "Acl finalize success";
}
std::shared_ptr<AclEnvGuard> AclEnvGuard::GetAclEnv(std::string_view cfg_file) {
std::shared_ptr<AclEnvGuard> acl_env;
std::lock_guard<std::mutex> lock(global_acl_env_mutex_);
acl_env = global_acl_env_;
if (acl_env != nullptr) {
MS_LOG(INFO) << "Acl has been initialized, skip.";
if (!cfg_file.empty()) {
MS_LOG(WARNING) << "Dump config file option " << cfg_file << " is ignored.";
}
} else {
acl_env = std::make_shared<AclEnvGuard>(cfg_file);
aclError ret = acl_env->GetErrno();
if (ret != ACL_ERROR_NONE && ret != ACL_ERROR_REPEAT_INITIALIZE) {
MS_LOG(ERROR) << "Execute aclInit Failed";
return nullptr;
}
global_acl_env_ = acl_env;
MS_LOG(INFO) << "Acl init success";
}
return acl_env;
}
} // namespace acl
} // namespace mindspore

View File

@ -0,0 +1,42 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_AGENT_ACL_ACL_ENV_GUARD_H_
#define MINDSPORE_LITE_SRC_RUNTIME_AGENT_ACL_ACL_ENV_GUARD_H_
#include <memory>
#include <mutex>
#include "acl/acl_base.h"
namespace mindspore {
namespace acl {
class __attribute__((visibility("default"))) AclEnvGuard {
public:
explicit AclEnvGuard(std::string_view cfg_file);
~AclEnvGuard();
aclError GetErrno() const { return errno_; }
static std::shared_ptr<AclEnvGuard> GetAclEnv(std::string_view cfg_file);
private:
static std::shared_ptr<AclEnvGuard> global_acl_env_;
static std::mutex global_acl_env_mutex_;
aclError errno_;
};
} // namespace acl
} // namespace mindspore
#endif // LITE_ACL_ENV_GUARD_H

View File

@ -0,0 +1,60 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/ascend310/src/custom_interface.h"
#include <memory>
#include "include/errorcode.h"
#include "include/registry/register_kernel_interface.h"
#include "common/log_adapter.h"
namespace mindspore {
namespace acl {
Status CustomInterface::Infer(std::vector<mindspore::MSTensor> *inputs, std::vector<mindspore::MSTensor> *outputs,
const mindspore::schema::Primitive *primitive) {
if (inputs == nullptr || (*inputs).empty()) {
MS_LOG(ERROR) << "Inputs is invalid.";
return kLiteError;
}
if (outputs == nullptr || (*outputs).empty()) {
MS_LOG(ERROR) << "Outputs is invalid.";
return kLiteError;
}
if (primitive == nullptr) {
MS_LOG(ERROR) << "Primitive is nullptr.";
return kLiteError;
}
if (primitive->value_type() != schema::PrimitiveType_Custom) {
MS_LOG(ERROR) << "Primitive type is not PrimitiveType_Custom.";
return kLiteError;
}
return kSuccess;
}
std::shared_ptr<mindspore::kernel::KernelInterface> CustomInferCreater() {
auto infer = new (std::nothrow) CustomInterface();
if (infer == nullptr) {
MS_LOG(ERROR) << "New custom infer is nullptr";
return nullptr;
}
return std::shared_ptr<mindspore::kernel::KernelInterface>(infer);
}
} // namespace acl
} // namespace mindspore
namespace mindspore {
namespace kernel {
REGISTER_CUSTOM_KERNEL_INTERFACE(ACL, ACL, acl::CustomInferCreater);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,34 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_ACL_CUSTOM_INTERFACE_H_
#define MINDSPORE_LITE_ACL_CUSTOM_INTERFACE_H_
#include <vector>
#include "include/kernel_interface.h"
namespace mindspore {
namespace acl {
class CustomInterface : public mindspore::kernel::KernelInterface {
public:
CustomInterface() {}
~CustomInterface() = default;
Status Infer(std::vector<mindspore::MSTensor> *inputs, std::vector<mindspore::MSTensor> *outputs,
const mindspore::schema::Primitive *primitive) override;
};
} // namespace acl
} // namespace mindspore
#endif // MINDSPORE_LITE_ACL_CUSTOM_INTERFACE_H_

View File

@ -0,0 +1,134 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/ascend310/src/custom_kernel.h"
#include "include/registry/register_kernel.h"
#include "include/api/types.h"
#include "include/api/data_type.h"
#include "src/runtime/kernel/ascend310/src/model_infer.h"
#include "common/log_adapter.h"
namespace mindspore {
namespace acl {
CustomAscend310Kernel::CustomAscend310Kernel(const std::vector<mindspore::MSTensor> &inputs,
const std::vector<mindspore::MSTensor> &outputs,
const schema::Primitive *primitive, const mindspore::Context *ctx)
: Kernel(inputs, outputs, primitive, ctx), load_model_(false), model_infer_(nullptr) {}
CustomAscend310Kernel::~CustomAscend310Kernel() {
if (load_model_) {
int ret = model_infer_->Finalize();
if (ret != lite::RET_OK) {
MS_LOG(ERROR) << "Model finalize failed.";
}
}
}
STATUS CustomAscend310Kernel::PrepareModelInfer() {
if (inputs_.size() < 1) {
MS_LOG(ERROR) << "Inputs size should not less than 1.";
return lite::RET_ERROR;
}
// last input is om data tensor
int idx = inputs_.size() - 1;
Buffer om_data(inputs_[idx].Data().get(), inputs_[idx].DataSize());
if (model_infer_ == nullptr) {
model_infer_.reset(new ModelInfer(om_data, 0));
}
int ret = model_infer_->Init();
if (ret != lite::RET_OK) {
MS_LOG(ERROR) << "Model infer init failed.";
return lite::RET_ERROR;
}
ret = model_infer_->Load();
if (ret != lite::RET_OK) {
MS_LOG(ERROR) << "Load om data failed.";
return lite::RET_ERROR;
}
MS_LOG(INFO) << "Load om data success.";
return lite::RET_OK;
}
STATUS CustomAscend310Kernel::Prepare() {
if (load_model_) {
MS_LOG(INFO) << "Custom kernel has been prepared.";
return lite::RET_OK;
}
if (PrepareModelInfer() != lite::RET_OK) {
MS_LOG(ERROR) << "Model infer prepare is not ok.";
return lite::RET_ERROR;
}
load_model_ = true;
return lite::RET_OK;
}
STATUS CustomAscend310Kernel::ReSize() {
if (load_model_) {
int ret = model_infer_->Finalize();
if (ret != lite::RET_OK) {
MS_LOG(ERROR) << "Model finalize failed.";
}
load_model_ = false;
}
return Prepare();
}
STATUS CustomAscend310Kernel::Execute() {
if (!load_model_) {
MS_LOG(WARNING) << "Custom kernel has not been prepared.";
return lite::RET_OK;
}
std::vector<mindspore::MSTensor> inputs(inputs_.begin(), inputs_.end() - 1);
if (model_infer_->Inference(inputs, &outputs_) != lite::RET_OK) {
MS_LOG(ERROR) << "Custom kernel execute failed.";
return lite::RET_ERROR;
}
return lite::RET_OK;
}
std::shared_ptr<kernel::Kernel> CustomCreateKernel(const std::vector<mindspore::MSTensor> &inputs,
const std::vector<mindspore::MSTensor> &outputs,
const schema::Primitive *primitive, const mindspore::Context *ctx) {
if (primitive == nullptr) {
MS_LOG(ERROR) << "Primitive is nullptr.";
return nullptr;
}
if (primitive->value_type() != schema::PrimitiveType_Custom) {
MS_LOG(ERROR) << "Primitive type is not PrimitiveType_Custom";
return nullptr;
}
auto kernel = std::make_shared<CustomAscend310Kernel>(inputs, outputs, primitive, ctx);
if (kernel == nullptr) {
MS_LOG(ERROR) << "New custom kernel is nullptr";
return nullptr;
}
return kernel;
}
} // namespace acl
} // namespace mindspore
namespace mindspore {
namespace registry {
namespace {
const auto kFloat32 = DataType::kNumberTypeFloat32;
const auto kInt8 = DataType::kNumberTypeInt8;
const auto kUInt8 = DataType::kNumberTypeUInt8;
} // namespace
REGISTER_CUSTOM_KERNEL(ASCEND310, ACL, kFloat32, ACL, acl::CustomCreateKernel)
REGISTER_CUSTOM_KERNEL(ASCEND310, ACL, kInt8, ACL, acl::CustomCreateKernel)
REGISTER_CUSTOM_KERNEL(ASCEND310, ACL, kUInt8, ACL, acl::CustomCreateKernel)
} // namespace registry
} // namespace mindspore

View File

@ -0,0 +1,51 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ASCEND310_KERNEL_CUSTOM_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ASCEND310_KERNEL_CUSTOM_H_
#include <vector>
#include <memory>
#include "include/api/types.h"
#include "include/api/context.h"
#include "include/api/kernel.h"
#include "include/errorcode.h"
#include "src/runtime/kernel/ascend310/src/model_infer.h"
using mindspore::lite::STATUS;
namespace mindspore {
namespace acl {
class CustomAscend310Kernel : public kernel::Kernel {
public:
CustomAscend310Kernel(const std::vector<mindspore::MSTensor> &inputs, const std::vector<mindspore::MSTensor> &outputs,
const mindspore::schema::Primitive *primitive, const mindspore::Context *ctx);
~CustomAscend310Kernel() override;
STATUS Prepare() override;
STATUS ReSize() override;
STATUS Execute() override;
private:
STATUS PrepareModelInfer();
bool load_model_;
std::shared_ptr<ModelInfer> model_infer_;
};
} // namespace acl
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ASCEND310_FP32_CUSTOM_H_

View File

@ -0,0 +1,162 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/ascend310/src/model_infer.h"
#include "common/log_adapter.h"
#include "acl/acl.h"
namespace mindspore {
namespace acl {
ModelInfer::ModelInfer(const Buffer &om_data, int32_t device_id)
: init_flag_(false),
load_flag_(false),
device_type_("AscendCL"),
device_id_(device_id),
context_(nullptr),
om_data_(om_data),
model_process_(),
acl_env_(nullptr) {}
STATUS ModelInfer::Init() {
if (init_flag_) {
MS_LOG(INFO) << "Acl has been initialized, skip.";
return lite::RET_OK;
}
acl_env_ = AclEnvGuard::GetAclEnv("");
if (acl_env_ == nullptr) {
MS_LOG(ERROR) << "Acl init failed.";
return lite::RET_ERROR;
}
aclError ret = aclrtSetDevice(device_id_);
if (ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "Acl open device " << device_id_ << " failed.";
return lite::RET_ERROR;
}
MS_LOG(INFO) << "Open device " << device_id_ << " success.";
ret = aclrtCreateContext(&context_, device_id_);
if (ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "Acl create context failed.";
return lite::RET_ERROR;
}
MS_LOG(INFO) << "Create context success.";
aclrtRunMode run_mode;
ret = aclrtGetRunMode(&run_mode);
if (ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "Acl get run mode failed.";
return lite::RET_ERROR;
}
bool is_device = (run_mode == ACL_DEVICE);
model_process_.SetIsDevice(is_device);
MS_LOG(INFO) << "Get run mode success is device input/output " << is_device;
MS_LOG(INFO) << "Init acl success, device id " << device_id_;
init_flag_ = true;
return lite::RET_OK;
}
STATUS ModelInfer::Finalize() {
if (!init_flag_) {
MS_LOG(WARNING) << "Init is not ok, no need to finalize.";
return lite::RET_OK;
}
aclError rt_ret = aclrtSetCurrentContext(context_);
if (rt_ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "Set the ascend device context failed.";
return lite::RET_ERROR;
}
int ret = model_process_.UnLoad();
if (ret != lite::RET_OK) {
MS_LOG(ERROR) << "Unload model inner failed.";
return ret;
}
if (context_ != nullptr) {
rt_ret = aclrtDestroyContext(context_);
if (rt_ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "Destroy context failed.";
}
context_ = nullptr;
}
MS_LOG(INFO) << "End to destroy context.";
rt_ret = aclrtResetDevice(device_id_);
if (rt_ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "Reset device " << device_id_ << " failed.";
}
MS_LOG(INFO) << "End to reset device " << device_id_;
init_flag_ = false;
return lite::RET_OK;
}
STATUS ModelInfer::Load() {
if (!load_flag_) {
int ret = LoadAclModel(om_data_);
if (ret != lite::RET_OK) {
MS_LOG(ERROR) << "Load acl model failed.";
return ret;
}
load_flag_ = true;
}
aclError rt_ret = aclrtSetCurrentContext(context_);
if (rt_ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "Set the ascend device context failed, ret = " << rt_ret;
return lite::RET_ERROR;
}
return lite::RET_OK;
}
STATUS ModelInfer::LoadAclModel(const Buffer &om_data) {
MS_LOG(INFO) << "Start load acl model.";
// acl load model
uint32_t acl_model_id;
auto acl_ret = aclmdlLoadFromMem(om_data.Data(), om_data.DataSize(), &acl_model_id);
if (acl_ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "Call aclmdlLoadFromMem failed, ret = " << acl_ret;
return lite::RET_ERROR;
}
// acl init model resource
model_process_.set_model_id(acl_model_id);
int ret = model_process_.PreInitModelResource();
if (ret != lite::RET_OK) {
(void)aclmdlUnload(acl_model_id);
MS_LOG(ERROR) << "Pre init model resource failed.";
return ret;
}
MS_LOG(INFO) << "Load acl model success.";
return lite::RET_OK;
}
STATUS ModelInfer::Inference(const std::vector<mindspore::MSTensor> &inputs,
std::vector<mindspore::MSTensor> *outputs) {
if (Load() != lite::RET_OK) {
MS_LOG(ERROR) << "Prepare model resource failed.";
return lite::RET_ERROR;
}
return model_process_.PredictFromHost(inputs, outputs);
}
} // namespace acl
} // namespace mindspore

View File

@ -0,0 +1,56 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_AGENT_ACL_MODEL_INFER_H_
#define MINDSPORE_LITE_SRC_RUNTIME_AGENT_ACL_MODEL_INFER_H_
#include <vector>
#include <memory>
#include <string>
#include "src/runtime/kernel/ascend310/src/model_process.h"
#include "src/runtime/kernel/ascend310/src/acl_env_guard.h"
#include "include/api/types.h"
#include "include/errorcode.h"
using mindspore::lite::STATUS;
namespace mindspore {
namespace acl {
class ModelInfer {
public:
ModelInfer(const Buffer &om_data, int32_t device_id);
~ModelInfer() = default;
STATUS Init();
STATUS Finalize();
STATUS Load();
STATUS Inference(const std::vector<mindspore::MSTensor> &inputs, std::vector<mindspore::MSTensor> *outputs);
private:
STATUS LoadAclModel(const Buffer &om_data);
bool init_flag_;
bool load_flag_;
std::string device_type_;
int32_t device_id_;
aclrtContext context_;
Buffer om_data_;
ModelProcess model_process_;
std::shared_ptr<AclEnvGuard> acl_env_;
};
} // namespace acl
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_RUNTIME_AGENT_ACL_MODEL_INFER_H_

View File

@ -0,0 +1,576 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/ascend310/src/model_process.h"
#include <utility>
#include <algorithm>
#include <map>
#include "common/log_adapter.h"
namespace mindspore {
namespace acl {
namespace {
constexpr size_t kDynamicBatchSize = 1;
constexpr size_t kDynamicImageSize = 2;
} // namespace
static DataType TransToDataType(aclDataType data_type) {
static const std::map<aclDataType, enum DataType> data_type_map = {
{ACL_FLOAT16, DataType::kNumberTypeFloat16}, {ACL_FLOAT, DataType::kNumberTypeFloat32},
{ACL_DOUBLE, DataType::kNumberTypeFloat64}, {ACL_INT8, DataType::kNumberTypeInt8},
{ACL_INT16, DataType::kNumberTypeInt16}, {ACL_INT32, DataType::kNumberTypeInt32},
{ACL_INT64, DataType::kNumberTypeInt64}, {ACL_UINT8, DataType::kNumberTypeUInt8},
{ACL_UINT16, DataType::kNumberTypeUInt16}, {ACL_UINT32, DataType::kNumberTypeUInt32},
{ACL_UINT64, DataType::kNumberTypeUInt64}, {ACL_BOOL, DataType::kNumberTypeBool},
};
auto it = data_type_map.find(data_type);
if (it == data_type_map.end()) {
return DataType::kNumberTypeEnd;
} else {
return it->second;
}
}
template <class T>
inline static void ClearIfNotNull(T *vec) {
if (vec != nullptr) {
vec->clear();
}
}
template <class T, class U = std::vector<T>>
inline static void PushbackIfNotNull(U *vec, T &&item) {
if (vec != nullptr) {
vec->emplace_back(item);
}
}
static STATUS ConstructTensorDesc(const std::vector<AclTensorInfo> &acl_tensor_list, std::vector<std::string> *names,
std::vector<std::vector<int64_t>> *shapes, std::vector<enum DataType> *data_types,
std::vector<size_t> *mem_sizes) {
ClearIfNotNull(names);
ClearIfNotNull(shapes);
ClearIfNotNull(data_types);
ClearIfNotNull(mem_sizes);
for (size_t i = 0; i < acl_tensor_list.size(); ++i) {
const auto &info = acl_tensor_list[i];
PushbackIfNotNull(names, info.name);
PushbackIfNotNull(shapes, info.dims);
PushbackIfNotNull(data_types, TransToDataType(info.data_type));
PushbackIfNotNull(mem_sizes, info.buffer_size);
}
if (names->size() != acl_tensor_list.size() || shapes->size() != acl_tensor_list.size() ||
data_types->size() != acl_tensor_list.size() || mem_sizes->size() != acl_tensor_list.size()) {
MS_LOG(ERROR) << "Inner error, size do not match: names size " << names->size() << " shapes size " << shapes->size()
<< " data types size " << data_types->size() << " mem sizes size " << mem_sizes->size()
<< " acl_tensor_list size " << acl_tensor_list.size();
return lite::RET_ERROR;
}
return lite::RET_OK;
}
static std::string ShapeToString(const std::vector<int64_t> &shape) {
std::string result = "[";
for (size_t i = 0; i < shape.size(); ++i) {
result += std::to_string(shape[i]);
if (i + 1 < shape.size()) {
result += ", ";
}
}
result += "]";
return result;
}
STATUS ModelProcess::PreInitModelResource() {
model_desc_ = aclmdlCreateDesc();
aclError acl_ret = aclmdlGetDesc(model_desc_, model_id_);
if (acl_ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "Read model desc failed, ret = " << acl_ret;
return lite::RET_ERROR;
}
STATUS ret = InitInputsBuffer();
if (ret != lite::RET_OK) {
MS_LOG(ERROR) << "Create input buffer failed.";
return ret;
}
ret = InitOutputsBuffer();
if (ret != lite::RET_OK) {
MS_LOG(ERROR) << "Create output buffer failed.";
return ret;
}
return lite::RET_OK;
}
STATUS ModelProcess::InitInputsBuffer() {
aclError ret;
size_t input_size = aclmdlGetNumInputs(model_desc_);
MS_LOG(INFO) << "input_size = " << input_size;
for (size_t i = 0; i < input_size; ++i) {
auto buffer_size = aclmdlGetInputSizeByIndex(model_desc_, i);
void *data_mem_buffer = nullptr;
if (!is_run_on_device_) { // need to copy input/output to/from device
ret = aclrtMalloc(&data_mem_buffer, buffer_size, ACL_MEM_MALLOC_NORMAL_ONLY);
if (ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "Malloc device input buffer failed , input size " << buffer_size;
return lite::RET_ERROR;
}
}
aclmdlIODims dims;
ret = aclmdlGetInputDims(model_desc_, i, &dims);
if (ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "Get input shape failed, ret = " << ret;
if (!is_run_on_device_) {
aclrtFree(data_mem_buffer);
}
return lite::RET_ERROR;
}
aclDataType data_type = aclmdlGetInputDataType(model_desc_, i);
std::vector<int64_t> shape(dims.dims, dims.dims + dims.dimCount);
std::string input_name = aclmdlGetInputNameByIndex(model_desc_, i);
if (input_name.empty()) {
MS_LOG(WARNING) << "Get name of input " << i << " failed.";
}
MS_LOG(INFO) << "Name of input " << i << " is " << input_name;
input_infos_.emplace_back(
AclTensorInfo{data_mem_buffer, data_mem_buffer, buffer_size, data_type, shape, input_name});
}
MS_LOG(INFO) << "Create model inputs success";
return lite::RET_OK;
}
STATUS ModelProcess::CreateDataBuffer(void **data_mem_buffer, size_t buffer_size, aclmdlDataset *dataset) {
if (data_mem_buffer == nullptr) {
MS_LOG(ERROR) << "Data mem buffer is nullptr.";
return lite::RET_ERROR;
}
aclError ret;
auto free_data_buffer = [this](void *dataMemBuffer) {
if (!is_run_on_device_) {
aclrtFree(dataMemBuffer);
} else {
aclrtFreeHost(dataMemBuffer);
}
};
if (!is_run_on_device_) {
ret = aclrtMalloc(data_mem_buffer, buffer_size, ACL_MEM_MALLOC_NORMAL_ONLY);
if (ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "Malloc device buffer failed , buffer size " << buffer_size;
return lite::RET_ERROR;
}
} else {
ret = aclrtMallocHost(data_mem_buffer, buffer_size);
if (ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "Malloc host buffer failed , buffer size " << buffer_size;
return lite::RET_ERROR;
}
}
auto data_buffer = aclCreateDataBuffer(*data_mem_buffer, buffer_size);
if (data_buffer == nullptr) {
MS_LOG(ERROR) << "Create Data Buffer failed";
free_data_buffer(*data_mem_buffer);
return lite::RET_ERROR;
}
ret = aclmdlAddDatasetBuffer(dataset, data_buffer);
if (ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "add data buffer failed";
free_data_buffer(*data_mem_buffer);
aclDestroyDataBuffer(data_buffer);
return lite::RET_ERROR;
}
return lite::RET_OK;
}
STATUS ModelProcess::InitOutputsBuffer() {
aclError ret;
outputs_ = aclmdlCreateDataset();
if (outputs_ == nullptr) {
MS_LOG(ERROR) << "Create input dataset failed";
return lite::RET_ERROR;
}
size_t output_size = aclmdlGetNumOutputs(model_desc_);
MS_LOG(INFO) << "Output_size = " << output_size;
for (size_t i = 0; i < output_size; ++i) {
auto buffer_size = aclmdlGetOutputSizeByIndex(model_desc_, i);
void *data_mem_buffer = nullptr;
if (CreateDataBuffer(&data_mem_buffer, buffer_size, outputs_) != lite::RET_OK) {
MS_LOG(ERROR) << "Add output data buffer failed, buffer size " << buffer_size;
return lite::RET_ERROR;
}
aclmdlIODims dims;
ret = aclmdlGetOutputDims(model_desc_, i, &dims);
if (ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "Get input shape failed";
if (!is_run_on_device_) {
aclrtFree(data_mem_buffer);
} else {
aclrtFreeHost(data_mem_buffer);
}
return lite::RET_OK;
}
aclFormat format = aclmdlGetOutputFormat(model_desc_, i);
if (format != aclFormat::ACL_FORMAT_NCHW) {
MS_LOG(WARNING) << "The output format of om should be nchw, but now is " << format;
}
aclDataType data_type = aclmdlGetOutputDataType(model_desc_, i);
std::vector<int64_t> shape(dims.dims, dims.dims + dims.dimCount);
std::string output_name = aclmdlGetOutputNameByIndex(model_desc_, i);
if (output_name.empty()) {
MS_LOG(WARNING) << "Get name of output " << i << " failed.";
}
MS_LOG(INFO) << "Name of input " << i << " is " << output_name;
output_infos_.emplace_back(
AclTensorInfo{data_mem_buffer, data_mem_buffer, buffer_size, data_type, shape, output_name});
}
MS_LOG(INFO) << "Create model output success.";
return lite::RET_OK;
}
void ModelProcess::DestroyInputsDataset() {
if (inputs_ == nullptr) {
return;
}
for (size_t i = 0; i < aclmdlGetDatasetNumBuffers(inputs_); i++) {
auto dataBuffer = aclmdlGetDatasetBuffer(inputs_, i);
aclDestroyDataBuffer(dataBuffer);
}
aclmdlDestroyDataset(inputs_);
inputs_ = nullptr;
}
void ModelProcess::DestroyInputsDataMem() {
if (!is_run_on_device_) {
for (const auto &item : input_infos_) {
aclrtFree(item.device_data);
}
}
input_infos_.clear();
}
void ModelProcess::DestroyInputsBuffer() {
DestroyInputsDataMem();
DestroyInputsDataset();
}
void ModelProcess::DestroyOutputsBuffer() {
for (const auto &item : output_infos_) {
if (!is_run_on_device_) {
aclrtFree(item.device_data);
} else {
aclrtFreeHost(item.device_data);
}
}
output_infos_.clear();
if (outputs_ == nullptr) {
return;
}
for (size_t i = 0; i < aclmdlGetDatasetNumBuffers(outputs_); i++) {
auto dataBuffer = aclmdlGetDatasetBuffer(outputs_, i);
aclDestroyDataBuffer(dataBuffer);
}
aclmdlDestroyDataset(outputs_);
outputs_ = nullptr;
}
STATUS ModelProcess::UnLoad() {
auto ret = aclmdlUnload(model_id_);
if (ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "Unload model failed, ret = " << ret;
return lite::RET_ERROR;
}
if (model_desc_ != nullptr) {
ret = aclmdlDestroyDesc(model_desc_);
if (ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "Unload model failed, ret = " << ret;
return lite::RET_ERROR;
}
model_desc_ = nullptr;
}
DestroyInputsBuffer();
DestroyOutputsBuffer();
MS_LOG(INFO) << "End unload model " << model_id_;
return lite::RET_OK;
}
size_t ModelProcess::GetDynamicDims(const std::vector<AclTensorInfo> &inputs) {
size_t max_num = 0;
for (auto input : inputs) {
size_t cur_num = std::count(input.dims.begin(), input.dims.end(), -1);
if (cur_num > max_num) {
max_num = cur_num;
}
}
return max_num;
}
STATUS ModelProcess::SetBatchSize(const std::vector<mindspore::MSTensor> &inputs) {
size_t index;
aclError ret;
for (size_t i = 0; i < inputs.size(); i++) {
input_infos_[i].buffer_size = inputs[i].DataSize();
}
auto *p = reinterpret_cast<const float *>(inputs[inputs.size() - 1].Data().get());
if (p == nullptr) {
MS_LOG(ERROR) << "Pointer is nullptr.";
return lite::RET_OK;
}
auto dynamicBatchSize = p[0];
ret = aclmdlGetInputIndexByName(model_desc_, ACL_DYNAMIC_TENSOR_NAME, &index);
if (ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "Get index failed";
return lite::RET_ERROR;
}
ret = aclmdlSetDynamicBatchSize(model_id_, inputs_, index, dynamicBatchSize);
if (ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "Set dynamic batch size failed, model_id is " << model_id_;
return lite::RET_ERROR;
}
return lite::RET_OK;
}
STATUS ModelProcess::CheckTensorByTensorInfo(const std::vector<mindspore::MSTensor> &tensor,
const std::vector<AclTensorInfo> &tensor_info, size_t dynamic_nums) {
if (dynamic_nums == 0) {
for (size_t i = 0; i < tensor_info.size(); ++i) {
if (tensor[i].Shape() != tensor_info[i].dims) {
MS_LOG(ERROR) << "Note: input " << i << " shape not match, required " << ShapeToString(tensor_info[i].dims)
<< ", given " << ShapeToString(tensor[i].Shape());
return lite::RET_ERROR;
}
if (tensor[i].DataType() != TransToDataType(tensor_info[i].data_type)) {
MS_LOG(ERROR) << "Note: input " << i << " data type not match, required "
<< static_cast<int>(TransToDataType(tensor_info[i].data_type)) << ", given "
<< static_cast<int>(tensor[i].DataType());
return lite::RET_ERROR;
}
if (tensor[i].DataSize() != tensor_info[i].buffer_size) {
MS_LOG(ERROR) << "Input " << i << " data size not match, required size " << tensor_info[i].buffer_size
<< ", given count " << tensor[i].DataSize();
return lite::RET_ERROR;
}
}
}
return lite::RET_OK;
}
STATUS ModelProcess::ProcDynamicShape(const std::vector<mindspore::MSTensor> &inputs, size_t dynamic_nums) {
if (dynamic_nums == kDynamicBatchSize) {
if (SetBatchSize(inputs) != lite::RET_OK) {
MS_LOG(ERROR) << "Failed to convert dynamic batch size";
return lite::RET_ERROR;
}
if (ResetOutputSize() != lite::RET_OK) {
MS_LOG(ERROR) << "Reset output size failed";
return lite::RET_ERROR;
}
} else if (dynamic_nums == kDynamicImageSize) {
MS_LOG(ERROR) << "Only dynamic batch size is supported";
return lite::RET_ERROR;
}
return lite::RET_OK;
}
STATUS ModelProcess::CheckAndInitInput(const std::vector<mindspore::MSTensor> &inputs) {
aclError ret;
inputs_ = aclmdlCreateDataset();
size_t dynamic_nums = GetDynamicDims(input_infos_);
// check inputs
if (CheckTensorByTensorInfo(inputs, input_infos_, dynamic_nums) != lite::RET_OK) {
MS_LOG(ERROR) << "Check input tensor failed.";
return lite::RET_ERROR;
}
// copy inputs
for (size_t i = 0; i < input_infos_.size(); ++i) {
auto &info = input_infos_[i];
auto input = inputs[i];
void *data = input.MutableData();
void *input_buffer = nullptr;
if (!is_run_on_device_) {
info.cur_device_data = info.device_data;
ret = aclrtMemcpy(info.cur_device_data, info.buffer_size, data, input.DataSize(), ACL_MEMCPY_HOST_TO_DEVICE);
if (ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "Acl memcpy input " << i << " data to device failed, buffer size " << input.DataSize();
return lite::RET_ERROR;
}
input_buffer = info.cur_device_data;
} else {
input_buffer = data;
}
auto data_buffer = aclCreateDataBuffer(input_buffer, info.buffer_size);
if (data_buffer == nullptr) {
MS_LOG(ERROR) << "Create Data Buffer failed";
return lite::RET_ERROR;
}
ret = aclmdlAddDatasetBuffer(inputs_, data_buffer);
if (ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "Add data buffer failed";
aclDestroyDataBuffer(data_buffer);
return lite::RET_ERROR;
}
}
if (ProcDynamicShape(inputs, dynamic_nums) != lite::RET_OK) {
MS_LOG(ERROR) << "Proc input dynamic shape failed.";
return lite::RET_ERROR;
}
return lite::RET_OK;
}
STATUS ModelProcess::ResetOutputSize() {
aclDataType output_type;
aclError ret;
size_t output_size = aclmdlGetNumOutputs(model_desc_);
for (size_t index = 0; index < output_size; index++) {
size_t dims = 1;
struct aclmdlIODims output_dims;
ret = aclmdlGetCurOutputDims(model_desc_, index, &output_dims);
if (ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "get output dim error.";
return lite::RET_ERROR;
}
for (size_t i = 0; i < output_dims.dimCount; i++) {
dims *= output_dims.dims[i];
}
output_type = aclmdlGetOutputDataType(model_desc_, index);
output_infos_[index].buffer_size = dims * aclDataTypeSize(output_type);
}
return lite::RET_OK;
}
STATUS ModelProcess::SortTensorInfoByName(const std::vector<mindspore::MSTensor> &tensor,
std::vector<AclTensorInfo> *tensor_info) {
if (tensor_info == nullptr) {
MS_LOG(ERROR) << "Tensor info is nullptr.";
return lite::RET_ERROR;
}
if (tensor.size() != tensor_info->size()) {
MS_LOG(ERROR) << "Actual tensor count not match, required count " << tensor_info->size() << ", given count "
<< tensor.size();
return lite::RET_ERROR;
}
size_t size = tensor.size();
for (size_t i = 0; i < size; i++) {
std::string name = tensor[i].Name();
size_t j;
for (j = 0; j < size; j++) {
if (name.find((*tensor_info)[j].name) != std::string::npos) {
std::swap((*tensor_info)[i], (*tensor_info)[j]);
break;
}
}
if (j == size) {
MS_LOG(ERROR) << "Input[" << i << "] " << name << " can't be found in acl om.";
return lite::RET_ERROR;
}
}
return lite::RET_OK;
}
STATUS ModelProcess::PredictFromHost(const std::vector<mindspore::MSTensor> &inputs,
std::vector<mindspore::MSTensor> *outputs) {
if (SortTensorInfoByName(inputs, &input_infos_) != lite::RET_OK) {
MS_LOG(ERROR) << "Sort input tensor info failed.";
return lite::RET_ERROR;
}
STATUS ret = CheckAndInitInput(inputs);
if (ret != lite::RET_OK) {
MS_LOG(ERROR) << "Check or init input failed";
DestroyInputsDataset();
return ret; // forward status error
}
aclError acl_ret = aclmdlExecute(model_id_, inputs_, outputs_);
DestroyInputsDataset();
if (acl_ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "Execute Model Failed, ret = " << acl_ret;
return lite::RET_ERROR;
}
ret = GetOutputs(outputs);
if (ret != lite::RET_OK) {
MS_LOG(ERROR) << "Build outputs failed";
return ret;
}
MS_LOG(INFO) << "Execute model success";
return lite::RET_OK;
}
STATUS ModelProcess::GetOutputs(std::vector<mindspore::MSTensor> *outputs) {
if (outputs == nullptr) {
MS_LOG(ERROR) << "Ms tensor output is nullptr.";
return lite::RET_ERROR;
}
if (ConstructTensor(outputs) != lite::RET_OK) {
MS_LOG(ERROR) << "Construct ms tensor failed.";
return lite::RET_ERROR;
}
return lite::RET_OK;
}
STATUS ModelProcess::ConstructTensor(std::vector<mindspore::MSTensor> *outputs) {
if (outputs == nullptr) {
MS_LOG(ERROR) << "Ms tensor output is nullptr.";
return lite::RET_ERROR;
}
if (outputs->size() != output_infos_.size()) {
MS_LOG(ERROR) << "Actual tensor count not match, required count " << output_infos_.size() << ", given count "
<< outputs->size();
return lite::RET_ERROR;
}
std::vector<std::string> names;
std::vector<std::vector<int64_t>> shapes;
std::vector<enum DataType> data_types;
std::vector<size_t> mem_sizes;
if (ConstructTensorDesc(output_infos_, &names, &shapes, &data_types, &mem_sizes) != lite::RET_OK) {
MS_LOG(ERROR) << "Construct tensor desc failed.";
return lite::RET_ERROR;
}
// set output info and malloc data size
for (size_t i = 0; i < output_infos_.size(); ++i) {
std::string lite_output_name = (*outputs)[i].Name();
if (lite_output_name != names[i]) {
MS_LOG(INFO) << "Lite output name: " << lite_output_name << "; Om output name: " << names[i];
}
(*outputs)[i].SetFormat(Format::NCHW);
(*outputs)[i].SetDataType(data_types[i]);
(*outputs)[i].SetShape(shapes[i]);
(*outputs)[i].MutableData();
if ((*outputs)[i].DataSize() != mem_sizes[i]) {
MS_LOG(ERROR) << "Ms tensor size " << (*outputs)[i].DataSize() << " not match acl tensor size " << mem_sizes[i];
return lite::RET_ERROR;
}
}
aclrtMemcpyKind kind = is_run_on_device_ ? ACL_MEMCPY_HOST_TO_HOST : ACL_MEMCPY_DEVICE_TO_HOST;
for (size_t i = 0; i < output_infos_.size(); ++i) {
if (output_infos_[i].cur_device_data == nullptr) {
// when run on device, cur_device_data is nullptr before first execute
continue;
}
auto ret = aclrtMemcpy((*outputs)[i].MutableData(), (*outputs)[i].DataSize(), output_infos_[i].cur_device_data,
output_infos_[i].buffer_size, kind);
if (ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "Memcpy input " << i << " from " << (is_run_on_device_ ? "host" : "device")
<< " to host failed, memory size " << output_infos_[i].buffer_size;
return lite::RET_ERROR;
}
}
return lite::RET_OK;
}
} // namespace acl
} // namespace mindspore

View File

@ -0,0 +1,95 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_AGENT_ACL_MODEL_PROCESS_H_
#define MINDSPORE_LITE_SRC_RUNTIME_AGENT_ACL_MODEL_PROCESS_H_
#include <vector>
#include <string>
#include <map>
#include "acl/acl.h"
#include "acl/acl_mdl.h"
#include "acl/acl_rt.h"
#include "include/api/types.h"
#include "include/errorcode.h"
using mindspore::lite::STATUS;
namespace mindspore {
namespace acl {
struct AclTensorInfo {
void *cur_device_data;
void *device_data;
size_t buffer_size;
aclDataType data_type;
std::vector<int64_t> dims;
std::string name;
};
class ModelProcess {
public:
ModelProcess()
: model_id_(0xffffffff),
is_run_on_device_(false),
model_desc_(nullptr),
inputs_(nullptr),
outputs_(nullptr),
input_infos_(),
output_infos_() {}
~ModelProcess() {}
STATUS UnLoad();
STATUS PredictFromHost(const std::vector<mindspore::MSTensor> &inputs, std::vector<mindspore::MSTensor> *outputs);
STATUS PreInitModelResource();
// override this method to avoid request/reply data copy
void SetIsDevice(bool is_device) { is_run_on_device_ = is_device; }
void set_model_id(uint32_t model_id) { model_id_ = model_id; }
uint32_t model_id() const { return model_id_; }
private:
STATUS CreateDataBuffer(void **data_mem_buffer, size_t buffer_size, aclmdlDataset *dataset);
STATUS CheckAndInitInput(const std::vector<mindspore::MSTensor> &inputs);
STATUS SortTensorInfoByName(const std::vector<mindspore::MSTensor> &tensor, std::vector<AclTensorInfo> *tensor_info);
STATUS CheckTensorByTensorInfo(const std::vector<mindspore::MSTensor> &tensor,
const std::vector<AclTensorInfo> &tensor_info, size_t dynamic_nums);
STATUS GetOutputs(std::vector<mindspore::MSTensor> *outputs);
STATUS ConstructTensor(std::vector<mindspore::MSTensor> *outputs);
STATUS SetBatchSize(const std::vector<mindspore::MSTensor> &inputs);
STATUS InitInputsBuffer();
STATUS InitOutputsBuffer();
STATUS ResetOutputSize();
size_t GetDynamicDims(const std::vector<AclTensorInfo> &);
STATUS ProcDynamicShape(const std::vector<mindspore::MSTensor> &inputs, size_t dynamic_nums);
void DestroyInputsDataset();
void DestroyInputsDataMem();
void DestroyInputsBuffer();
void DestroyOutputsBuffer();
uint32_t model_id_;
// if run one device(AICPU), there is no need to alloc device memory and copy inputs to(/outputs from) device
bool is_run_on_device_;
aclmdlDesc *model_desc_;
aclmdlDataset *inputs_;
aclmdlDataset *outputs_;
std::vector<AclTensorInfo> input_infos_;
std::vector<AclTensorInfo> output_infos_;
};
} // namespace acl
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_RUNTIME_AGENT_ACL_MODEL_PROCESS_H_

View File

@ -6,7 +6,10 @@ set(CCSRC_SRC
${CCSRC_DIR}/backend/optimizer/common/visit.cc ${CCSRC_DIR}/backend/optimizer/common/visit.cc
${CCSRC_DIR}/backend/optimizer/common/optimizer.cc ${CCSRC_DIR}/backend/optimizer/common/optimizer.cc
) )
set(ENABLE_GLIBCXX ON) if(NOT MSLITE_ENABLE_ACL)
set(ENABLE_GLIBCXX ON)
endif()
include(${TOP_DIR}/cmake/external_libs/opencv.cmake) include(${TOP_DIR}/cmake/external_libs/opencv.cmake)
include(${TOP_DIR}/cmake/external_libs/glog.cmake) include(${TOP_DIR}/cmake/external_libs/glog.cmake)
include_directories(${TOP_DIR}/mindspore/ccsrc/backend/kernel_compiler/cpu) include_directories(${TOP_DIR}/mindspore/ccsrc/backend/kernel_compiler/cpu)
@ -136,6 +139,14 @@ add_subdirectory(registry)
add_subdirectory(preprocess) add_subdirectory(preprocess)
add_subdirectory(${CORE_DIR} mindspore_core) add_subdirectory(${CORE_DIR} mindspore_core)
if(MSLITE_ENABLE_ACL)
set(MODE_ASCEND_ACL ON)
include(${TOP_DIR}/cmake/dependency_graphengine.cmake)
add_subdirectory(acl)
link_directories(${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
endif()
set(SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../src) set(SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../src)
set(API_SRC ${SRC_DIR}/cxx_api/context.cc) set(API_SRC ${SRC_DIR}/cxx_api/context.cc)
set(LITE_SRC set(LITE_SRC
@ -222,6 +233,12 @@ target_link_libraries(converter_lite PRIVATE
preprocess_mid preprocess_mid
) )
if(MSLITE_ENABLE_ACL)
target_link_libraries(converter_lite PRIVATE
lite_acl_mid
mindspore_shared_lib)
endif()
if(NOT MSVC) if(NOT MSVC)
target_link_libraries(converter_lite PRIVATE pthread) target_link_libraries(converter_lite PRIVATE pthread)
endif() endif()

View File

@ -0,0 +1,24 @@
include_directories(${TOP_DIR}/graphengine/metadef/inc/external)
include_directories(${TOP_DIR}/graphengine/inc)
include_directories(${TOP_DIR}/graphengine/inc/external)
include_directories(${TOP_DIR}/graphengine/ge)
include_directories(${TOP_DIR}/graphengine/metadef/inc)
include_directories(${TOP_DIR}/graphengine/inc/framework)
include_directories(${TOP_DIR}/graphengine/third_party/fwkacllib/inc)
include_directories(${TOP_DIR}/graphengine/third_party/fwkacllib/inc/toolchain)
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
file(GLOB ACL_SRC
${CMAKE_CURRENT_SOURCE_DIR}/*.cc
${CMAKE_CURRENT_SOURCE_DIR}/common/*.cc
${CMAKE_CURRENT_SOURCE_DIR}/deparser/*.cc
${CMAKE_CURRENT_SOURCE_DIR}/infer/*.cc
)
add_subdirectory(${TOP_DIR}/mindspore/ccsrc/transform/graph_ir _mindspore_transform_graph_ir_obj)
add_subdirectory(${TOP_DIR}/mindspore/ccsrc/cxx_api mindspore_shared_lib)
set_property(SOURCE ${ACL_SRC} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_LITE)
add_library(lite_acl_mid OBJECT ${ACL_SRC})
add_dependencies(lite_acl_mid fbs_inner_src)

View File

@ -0,0 +1,405 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/acl/acl_pass.h"
#include <set>
#include "tools/converter/ops/ops_def.h"
#include "tools/common/graph_util.h"
#include "tools/common/tensor_util.h"
#include "tools/converter/acl/deparser/primitive_deparser_register.h"
#include "tools/converter/acl/deparser/spatial_node_adapter.h"
#include "tools/converter/parser/parser_utils.h"
#include "tools/converter/optimizer_manager.h"
#include "include/registry/pass_registry.h"
#include "common/utils.h"
#include "ops/custom.h"
#include "base/core_ops.h"
#include "cxx_api/model/acl/model_converter.h"
#include "backend/kernel_compiler/cpu/nnacl/op_base.h"
namespace mindspore {
namespace opt {
namespace {
constexpr auto kMakeTuple = "MakeTuple";
constexpr auto kOutputNames = "outputs_names";
constexpr auto kCustomPrimTypeACL = "ACL";
constexpr auto kCustomNodeName = "Custom";
} // namespace
ParameterPtr AclPass::CreateOmParameter(const FuncGraphPtr &func_graph, const Buffer &om_data) {
ParameterPtr om_parameter = func_graph->add_parameter();
om_parameter->set_name("ACL_om_data");
auto type_ptr = TypeIdToType(kNumberTypeUInt8);
ShapeVector shape_vector = {static_cast<int64_t>(om_data.DataSize())};
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector);
om_parameter->set_abstract(abstract_tensor);
auto param_value =
std::make_shared<tensor::Tensor>(kNumberTypeUInt8, ShapeVector({static_cast<int64_t>(om_data.DataSize())}));
auto tensor_data = param_value->data_c();
if (tensor_data == nullptr) {
MS_LOG(ERROR) << "New Tensor failed.";
return nullptr;
}
if (param_value->Size() < om_data.DataSize()) {
MS_LOG(ERROR) << "Dst buff size " << param_value->Size() << " should be greater than src buff size "
<< om_data.DataSize();
return nullptr;
}
if (memcpy_s(tensor_data, param_value->Size(), om_data.Data(), om_data.DataSize()) != EOK) {
MS_LOG(ERROR) << "Memcpy om data failed.";
return nullptr;
}
om_parameter->set_default_param(param_value);
return om_parameter;
}
// now build the whole graph, not split
STATUS AclPass::BuildGraph(const FuncGraphPtr &func_graph) {
Buffer om_data;
if (ConvertGraphToOm(func_graph, &om_data) != lite::RET_OK) {
MS_LOG(ERROR) << "Convert graph to om failed.";
return lite::RET_ERROR;
}
om_parameter_ = CreateOmParameter(func_graph, om_data);
if (om_parameter_ == nullptr) {
MS_LOG(ERROR) << "Convert graph to om failed.";
return lite::RET_ERROR;
}
return lite::RET_OK;
}
STATUS AclPass::RunPrimitiveDeparser(const FuncGraphPtr &func_graph) {
MS_LOG(INFO) << "Deparser graph start.";
MS_ASSERT(func_graph != nullptr);
std::set<FuncGraphPtr> all_func_graphs = {};
lite::GetAllFuncGraph(func_graph, &all_func_graphs);
for (auto graph : all_func_graphs) {
auto node_list = TopoSort(graph->get_return());
for (auto &node : node_list) {
if (!utils::isa<CNodePtr>(node)) {
continue;
}
auto cnode = node->cast<CNodePtr>();
auto prim = GetCNodePrimitive(cnode);
if (prim == nullptr) {
MS_LOG(ERROR) << "Prim is nullptr.";
return lite::RET_ERROR;
}
auto name = prim->name();
auto deparser = lite::PrimitiveDeparserRegister::GetInstance().GetPrimitiveDeparser(name);
if (deparser == nullptr) {
MS_LOG(DEBUG) << "Name: " << name << " not need to deparser.";
continue;
}
MS_LOG(INFO) << "Deparser cnode: " << name;
auto status = deparser->Deparser(cnode);
if (status != lite::RET_OK) {
MS_LOG(ERROR) << "Deparser primitive failed.";
return lite::RET_ERROR;
}
}
}
return lite::RET_OK;
}
STATUS AclPass::DeparseGraph(const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &manager) {
if (fmk_type_ == converter::kFmkTypeMs) {
MS_LOG(INFO) << "MindIr no need to deparser graph";
return lite::RET_OK;
}
if (RunPrimitiveDeparser(func_graph) != lite::RET_OK) {
MS_LOG(ERROR) << "Run deparser primitive failed.";
return lite::RET_ERROR;
}
if (lite::AdapteSpatialNode(func_graph, manager) != lite::RET_OK) {
MS_LOG(ERROR) << "Adapter spatial node failed.";
return lite::RET_ERROR;
}
return lite::RET_OK;
}
STATUS AclPass::PreProcGraph(const FuncGraphPtr &func_graph) {
if (fmk_type_ == converter::kFmkTypeMs) {
MS_LOG(INFO) << "MindIr no need to pre proc graph";
return lite::RET_OK;
}
// The format of nodes (cnode, parameter, val) must be nchw due to interface of convert om
if (!lite::RunOptimizerPass(func_graph, {"ToNCHWFormat", "DeleteRedundantTranspose"})) {
MS_LOG(ERROR) << "To nchw format success.";
return lite::RET_ERROR;
}
return lite::RET_OK;
}
STATUS AclPass::PostProcGraph(const FuncGraphPtr &func_graph) {
// The format must be nhwc due to ms model
if (!lite::RunOptimizerPass(func_graph, {"ToNHWCFormat"})) {
MS_LOG(ERROR) << "To NHWC Format failed.";
return lite::RET_ERROR;
}
return lite::RET_OK;
}
bool AclPass::Run(const FuncGraphPtr &func_graph) {
MS_LOG(INFO) << "Acl pass run start.";
if (func_graph == nullptr) {
MS_LOG(ERROR) << "Func_graph is nullptr.";
return false;
}
auto manager = Manage(func_graph, true);
if (manager == nullptr) {
MS_LOG(ERROR) << "Manager is nullptr.";
return false;
}
if (PreProcGraph(func_graph) != lite::RET_OK) {
MS_LOG(ERROR) << "Pre proc graph failed.";
return false;
}
if (DeparseGraph(func_graph, manager) != lite::RET_OK) {
MS_LOG(ERROR) << "Deparse graph failed.";
return false;
}
if (BuildGraph(func_graph) != lite::RET_OK) {
MS_LOG(ERROR) << "Build graph failed.";
return false;
}
custom_node_ = CreateCustomNode(func_graph);
if (custom_node_ == nullptr) {
MS_LOG(ERROR) << "Create custom node failed.";
return false;
}
// prepare graph for export create
if (ModifyGraphByCustomNode(func_graph, manager, custom_node_) != lite::RET_OK) {
MS_LOG(ERROR) << "Modify func graph by custom failed.";
return false;
}
if (PostProcGraph(func_graph) != lite::RET_OK) {
MS_LOG(ERROR) << "Post proc graph failed.";
return false;
}
MS_LOG(INFO) << "Acl pass run end.";
return true;
}
STATUS AclPass::ConvertGraphToOm(const FuncGraphPtr &func_graph, Buffer *om_data) {
if (om_data == nullptr) {
MS_LOG(ERROR) << "Om data is nullptr.";
return lite::RET_ERROR;
}
SetAclModelOptions(func_graph);
// call interface of cloud
ModelConverter model_converter;
model_converter.set_options(options_.get());
*om_data = model_converter.LoadMindIR(func_graph);
if (om_data->Data() == nullptr || om_data->DataSize() == 0) {
MS_LOG(ERROR) << "Model converter load mindir failed.";
return lite::RET_ERROR;
}
return lite::RET_OK;
}
void AclPass::SetAclModelOptions(const FuncGraphPtr &func_graph) {
MS_LOG(INFO) << "Set acl model options start.";
auto model_context = std::make_shared<mindspore::Context>();
auto ascend310_info = std::make_shared<Ascend310DeviceInfo>();
ascend310_info->SetDeviceID(0);
model_context->MutableDeviceInfo().emplace_back(ascend310_info);
// set options
options_ = std::make_unique<AclModelOptions>(model_context);
if (options_ == nullptr) {
MS_LOG(ERROR) << "Acl option make shared failed.";
return;
}
auto inputs = func_graph->get_inputs();
std::vector<std::string> input_names;
for (auto node : inputs) {
if (node == nullptr) {
MS_LOG(ERROR) << "Node is nullptr.";
return;
}
auto para = node->cast<ParameterPtr>();
if (para == nullptr) {
MS_LOG(ERROR) << "Parameter is nullptr.";
return;
}
std::string name = para->name();
for (auto pos = name.find(':'); pos != std::string::npos; pos = name.find(':')) {
name = name.substr(0, pos) + "_" + name.substr(pos + 1);
MS_LOG(INFO) << name;
}
para->set_name(name);
input_names.push_back(name);
}
options_->RenameInput(input_names);
MS_LOG(INFO) << "Set acl model options end.";
}
STATUS AclPass::GetFuncGraphOutputInfo(const FuncGraphPtr &func_graph, AnfNodePtrList *graph_outputs,
std::vector<std::string> *graph_output_names,
std::vector<std::vector<int64_t>> *graph_output_dims) {
CHECK_NULL_RETURN(graph_outputs);
CHECK_NULL_RETURN(graph_output_names);
CHECK_NULL_RETURN(graph_output_dims);
AnfNodePtr return_input = func_graph->output();
CHECK_NULL_RETURN(return_input);
auto input_cnode = return_input->cast<CNodePtr>();
CHECK_NULL_RETURN(input_cnode);
auto primitive = mindspore::GetValueNode<PrimitivePtr>(input_cnode->input(0));
if (primitive == nullptr) {
MS_LOG(ERROR) << "Primitive is nullptr, node: " << input_cnode->fullname_with_scope();
return lite::RET_ERROR;
}
// not consider custom op
std::string primitive_type = primitive->name();
if (primitive_type == kMakeTuple) {
for (size_t j = 1; j < input_cnode->inputs().size(); j++) {
auto item = input_cnode->input(j);
MS_ASSERT(item != nullptr);
graph_outputs->emplace_back(item);
graph_output_names->emplace_back(item->fullname_with_scope());
auto item_cnode = item->cast<CNodePtr>();
if (item_cnode == nullptr) {
MS_LOG(ERROR) << "Input of MakeTuple is not a cnode for input_id: " << j;
return lite::RET_ERROR;
}
std::vector<int64_t> dims;
if (lite::acl::GetShapeVectorFromCNode(item_cnode, &dims) != lite::RET_OK) {
MS_LOG(ERROR) << "Get node shape failed.";
return lite::RET_ERROR;
}
graph_output_dims->emplace_back(dims);
}
} else {
graph_outputs->emplace_back(input_cnode);
graph_output_names->emplace_back(input_cnode->fullname_with_scope());
std::vector<int64_t> dims;
if (lite::acl::GetShapeVectorFromCNode(input_cnode, &dims) != lite::RET_OK) {
MS_LOG(ERROR) << "Get node shape failed.";
return lite::RET_ERROR;
}
graph_output_dims->emplace_back(dims);
}
return lite::RET_OK;
}
STATUS AclPass::SetMultiOutputs(const CNodePtr &new_cnode, TypeId data_type) {
AbstractBasePtrList abstract_list;
for (size_t j = 0; j < graph_outputs_.size(); j++) {
auto abstract_tensor = lite::CreateTensorAbstract(graph_outputs_dims_[j], data_type);
if (abstract_tensor == nullptr) {
MS_LOG(ERROR) << "Abstract tensor is nullptr for output " << j;
return lite::RET_ERROR;
}
abstract_list.emplace_back(abstract_tensor);
}
new_cnode->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
return lite::RET_OK;
}
STATUS AclPass::SetCustomOutputs(const FuncGraphPtr &func_graph, const CNodePtr &custom_node) {
STATUS ret = GetFuncGraphOutputInfo(func_graph, &graph_outputs_, &graph_output_names_, &graph_outputs_dims_);
if (ret != lite::RET_OK) {
MS_LOG(ERROR) << "Get output info of graph failed.";
return lite::RET_ERROR;
}
if (graph_outputs_.empty() || graph_outputs_.size() != graph_outputs_dims_.size()) {
MS_LOG(ERROR) << "Graph output size is error, num size: " << graph_outputs_.size()
<< " dim size: " << graph_outputs_dims_.size();
return lite::RET_ERROR;
}
custom_node->AddAttr(kOutputNames, MakeValue(graph_output_names_));
TypeId type = lite::acl::GetTypeFromNode(graph_outputs_[0]);
if (graph_outputs_.size() == 1) {
auto abstract_tensor = lite::CreateTensorAbstract(graph_outputs_dims_[0], type);
if (abstract_tensor == nullptr) {
MS_LOG(ERROR) << "Abstract_tensor is nullptr.";
return lite::RET_ERROR;
}
custom_node->set_abstract(abstract_tensor);
return lite::RET_OK;
}
if (SetMultiOutputs(custom_node, type) != lite::RET_OK) {
MS_LOG(ERROR) << "Set multi graph output failed.";
return lite::RET_ERROR;
}
return lite::RET_OK;
}
CNodePtr AclPass::CreateCustomNode(const FuncGraphPtr &func_graph) {
auto prim = std::make_unique<mindspore::ops::Custom>();
if (prim == nullptr) {
MS_LOG(ERROR) << "New custom op failed.";
return nullptr;
}
prim->set_type(kCustomPrimTypeACL);
auto graph_input = func_graph->get_inputs();
CNodePtr custom_node = func_graph->NewCNode(std::shared_ptr<ops::PrimitiveC>(prim.release()), graph_input);
if (custom_node == nullptr) {
MS_LOG(ERROR) << "Custom cnode failed.";
return nullptr;
}
custom_node->set_fullname_with_scope(kCustomNodeName);
custom_node->add_input(om_parameter_);
if (SetCustomOutputs(func_graph, custom_node) != lite::RET_OK) {
MS_LOG(ERROR) << "Set custom outputs failed.";
return nullptr;
}
return custom_node;
}
STATUS AclPass::ModifyGraphByCustomNode(const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &manager,
const CNodePtr &custom_node) {
if (graph_outputs_.size() == 1) {
if (!manager->Replace(graph_outputs_[0], custom_node)) {
MS_LOG(ERROR) << "Replace node failed.";
return lite::RET_ERROR;
}
} else {
for (size_t j = 0; j < graph_outputs_.size(); ++j) {
auto tuple_get_item_prim_ptr = std::make_shared<lite::TupleGetItem>();
if (tuple_get_item_prim_ptr == nullptr) {
MS_LOG(ERROR) << "New TupleGetItem failed for output " << j;
return lite::RET_ERROR;
}
auto tuple_get_item_prim = NewValueNode(tuple_get_item_prim_ptr);
auto get_item_value = NewValueNode(MakeValue<int>(j));
AnfNodePtrList inputs{tuple_get_item_prim, custom_node, get_item_value};
CNodePtr get_item_cnode = func_graph->NewCNode(inputs);
if (get_item_cnode == nullptr) {
MS_LOG(ERROR) << "New get item cnode failed for output " << j;
return lite::RET_ERROR;
}
get_item_cnode->set_fullname_with_scope(custom_node->fullname_with_scope() + "_getitem_" + std::to_string(j));
if (!manager->Replace(graph_outputs_[j], get_item_cnode)) {
MS_LOG(ERROR) << "Replace node failed for output " << j;
return lite::RET_ERROR;
}
}
}
return lite::RET_OK;
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,69 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_ACL_ACL_PASS_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_ACL_ACL_PASS_H
#include <string>
#include <vector>
#include <memory>
#include "backend/optimizer/common/pass.h"
#include "include/errorcode.h"
#include "include/api/types.h"
#include "include/registry/parser_context.h"
#include "cxx_api/model/acl/acl_model_options.h"
using mindspore::converter::FmkType;
using mindspore::lite::STATUS;
namespace mindspore {
namespace opt {
class AclPass : public Pass {
public:
explicit AclPass(FmkType fmk_type) : Pass("Acl"), fmk_type_(fmk_type) {}
~AclPass() override = default;
bool Run(const FuncGraphPtr &func_graph) override;
private:
STATUS PreProcGraph(const FuncGraphPtr &func_graph);
STATUS PostProcGraph(const FuncGraphPtr &func_graph);
STATUS DeparseGraph(const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &manager);
STATUS RunPrimitiveDeparser(const FuncGraphPtr &func_graph);
STATUS BuildGraph(const FuncGraphPtr &func_graph);
STATUS ConvertGraphToOm(const FuncGraphPtr &func_graph, Buffer *om_data);
ParameterPtr CreateOmParameter(const FuncGraphPtr &func_graph, const Buffer &om);
CNodePtr CreateCustomNode(const FuncGraphPtr &func_graph);
STATUS SetCustomOutputs(const FuncGraphPtr &func_graph, const CNodePtr &custom_node);
STATUS SetMultiOutputs(const CNodePtr &new_cnode, TypeId data_type);
STATUS ModifyGraphByCustomNode(const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &manager,
const CNodePtr &custom_node);
void SetAclModelOptions(const FuncGraphPtr &func_graph);
STATUS GetFuncGraphOutputInfo(const FuncGraphPtr &func_graph, AnfNodePtrList *graph_outputs,
std::vector<std::string> *graph_output_names,
std::vector<std::vector<int64_t>> *graph_output_dims);
FmkType fmk_type_;
ParameterPtr om_parameter_ = nullptr;
CNodePtr custom_node_ = nullptr;
std::unique_ptr<AclModelOptions> options_;
AnfNodePtrList graph_outputs_;
std::vector<std::string> graph_output_names_;
std::vector<std::vector<int64_t>> graph_outputs_dims_;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_ACL_ACL_PASS_H

View File

@ -0,0 +1,125 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/acl/common/utils.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "base/base_ref.h"
#include "base/core_ops.h"
#include "abstract/dshape.h"
#include "abstract/abstract_value.h"
#include "utils/utils.h"
namespace mindspore {
namespace lite {
namespace acl {
namespace {
constexpr size_t kTupleGetItemInputSize = 3;
constexpr size_t kSecondIndex = 1;
constexpr size_t kInvalidSize = SIZE_MAX;
} // namespace
static size_t GetTupleGetItemOutIndex(const mindspore::CNodePtr &tuple_get_item) {
MS_ASSERT(tuple_get_item != nullptr);
if (tuple_get_item->size() != mindspore::kTupleGetItemInputSize) {
MS_LOG(ERROR) << "The node tuple_get_item must have 2 inputs!";
return kInvalidSize;
}
auto output_index_value_node = tuple_get_item->input(mindspore::kInputNodeOutputIndexInTupleGetItem);
MS_ASSERT(output_index_value_node != nullptr);
auto value_node = output_index_value_node->cast<mindspore::ValueNodePtr>();
MS_ASSERT(value_node != nullptr);
return IntToSize(opt::CastToInt(value_node->value()).front());
}
static bool CheckPrimitiveType(const mindspore::AnfNodePtr &node, const mindspore::PrimitivePtr &primitive_type) {
if (node == nullptr) {
return false;
}
if (node->isa<mindspore::CNode>()) {
auto cnode = node->cast<mindspore::CNodePtr>();
return IsPrimitive(cnode->input(0), primitive_type);
} else if (node->isa<mindspore::ValueNode>()) {
return IsPrimitive(node, primitive_type);
}
return false;
}
STATUS GetShapeVectorFromCNode(const mindspore::CNodePtr &cnode, std::vector<int64_t> *shape_vector) {
mindspore::AbstractBasePtr cnode_abstract;
if (CheckPrimitiveType(cnode, mindspore::prim::kPrimTupleGetItem)) {
auto tuple_inputs = cnode->inputs();
MS_ASSERT(tuple_inputs.size() == kTupleGetItemInputSize);
auto get_item_input_cnode = tuple_inputs.at(kSecondIndex);
MS_ASSERT(get_item_input_cnode != nullptr);
auto idx = GetTupleGetItemOutIndex(cnode);
if (!mindspore::utils::isa<mindspore::abstract::AbstractTuplePtr>(get_item_input_cnode->abstract())) {
MS_LOG(ERROR) << "TupleGetItem's abstract is not AbstractTuple";
return lite::RET_ERROR;
}
auto abstract_tuple =
mindspore::utils::cast<mindspore::abstract::AbstractTuplePtr>(get_item_input_cnode->abstract());
auto abstract_list = abstract_tuple->elements();
if (abstract_list.size() <= idx) {
MS_LOG(ERROR) << "AbstractTuple's size is smaller than expect";
return lite::RET_ERROR;
}
cnode_abstract = abstract_list[idx];
} else {
cnode_abstract = cnode->abstract();
}
if (cnode_abstract == nullptr) {
MS_LOG(ERROR) << "Abstract cnode is nullptr. " << cnode->fullname_with_scope();
return lite::RET_ERROR;
}
if (!mindspore::utils::isa<mindspore::abstract::AbstractTensorPtr>(cnode_abstract)) {
MS_LOG(ERROR) << "Abstract is not abstract tensor. " << cnode->fullname_with_scope();
return lite::RET_ERROR;
}
auto cnode_abstract_tensor = cnode_abstract->cast<mindspore::abstract::AbstractTensorPtr>();
if (!mindspore::utils::isa<mindspore::abstract::ShapePtr>(cnode_abstract_tensor->BuildShape())) {
MS_LOG(ERROR) << "Shape of abstract tensor should be ShapePtr. " << cnode->fullname_with_scope();
return lite::RET_ERROR;
}
auto shape_ptr = mindspore::utils::cast<mindspore::abstract::ShapePtr>(cnode_abstract_tensor->BuildShape());
if (shape_ptr->shape().empty()) {
MS_LOG(WARNING) << "Shape is empty " << cnode->fullname_with_scope();
}
*shape_vector = shape_ptr->shape();
return lite::RET_OK;
}
TypeId GetTypeFromNode(const AnfNodePtr &node) {
TypeId type = kNumberTypeFloat32;
if (utils::isa<CNodePtr>(node)) {
auto cnode = node->cast<CNodePtr>();
if (utils::isa<abstract::AbstractTensorPtr>(cnode->abstract())) {
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(cnode->abstract());
if (abstract_tensor == nullptr || abstract_tensor->element() == nullptr) {
MS_LOG(WARNING) << "Abstract_tensor or abstract_tensor->element() is nullptr.";
return type;
}
auto type_ptr = abstract_tensor->element()->GetTypeTrack();
type = type_ptr->type_id();
}
MS_LOG(INFO) << "node type id is " << type;
}
return type;
}
} // namespace acl
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,34 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_ACL_COMMON_UTILS_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_ACL_COMMON_UTILS_H
#include <vector>
#include "include/errorcode.h"
#include "ir/anf.h"
#include "ir/dtype/type_id.h"
namespace mindspore {
namespace lite {
namespace acl {
STATUS GetShapeVectorFromCNode(const mindspore::CNodePtr &cnode, std::vector<int64_t> *shape_vector);
TypeId GetTypeFromNode(const AnfNodePtr &node);
} // namespace acl
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_ACL_ACL_PASS_H

View File

@ -0,0 +1,67 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/acl/deparser/activation_deparser.h"
#include <map>
#include <memory>
#include "tools/converter/acl/deparser/primitive_deparser_register.h"
#include "ops/elu.h"
#include "ops/gelu.h"
#include "ops/leaky_relu.h"
#include "ops/relu.h"
#include "ops/relu6.h"
#include "ops/sigmoid.h"
#include "ops/tanh.h"
namespace mindspore {
namespace lite {
STATUS ActivationDeparser::Deparser(const CNodePtr &cnode) {
static std::map<ActivationType, PrimitivePtr> activation_type_map = {
{mindspore::ELU, std::make_shared<ops::Elu>()},
{mindspore::GELU, std::make_shared<ops::GeLU>()},
{mindspore::RELU, std::make_shared<ops::ReLU>()},
{mindspore::RELU6, std::make_shared<ops::ReLU6>()},
{mindspore::SIGMOID, std::make_shared<ops::Sigmoid>()},
{mindspore::TANH, std::make_shared<ops::Tanh>()},
{mindspore::LEAKY_RELU, std::make_shared<ops::LeakyRelu>()}};
ValueNodePtr value_node = nullptr;
PrimitivePtr src_prim = nullptr;
if (GetValueNodeAndPrimFromCnode(cnode, &value_node, &src_prim) != lite::RET_OK) {
MS_LOG(ERROR) << "Get primitive from cnode failed.";
return lite::RET_ERROR;
}
auto activate_prim = dynamic_cast<ops::Activation *>(src_prim.get());
if (activate_prim == nullptr) {
MS_LOG(ERROR) << "Dynamic cast activation failed.";
return lite::RET_ERROR;
}
PrimitivePtr dst_prim = nullptr;
ActivationType type = activate_prim->get_activation_type();
if (activation_type_map.find(type) != activation_type_map.end()) {
dst_prim = activation_type_map[type];
} else {
MS_LOG(ERROR) << "Type " << static_cast<int>(type) << " is unsupported.";
return lite::RET_ERROR;
}
MS_ASSERT(dst_prim != nullptr);
dst_prim->SetAttrs(src_prim->attrs());
value_node->set_value(dst_prim);
return lite::RET_OK;
}
REGISTER_PRIMITIVE_DEPARSER(kNameActivation, ActivationDeparser)
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,37 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef ACL_DEPARSER_PRIMITIVE_ACTIVATION_DEPARSER_H
#define ACL_DEPARSER_PRIMITIVE_ACTIVATION_DEPARSER_H
#include "tools/converter/acl/deparser/primitive_deparser.h"
#include "ops/fusion/activation.h"
using mindspore::ops::kNameActivation;
namespace mindspore {
namespace lite {
class ActivationDeparser : public PrimitiveDeparser {
public:
ActivationDeparser() : PrimitiveDeparser(kNameActivation) {}
~ActivationDeparser() override = default;
STATUS Deparser(const CNodePtr &cnode) override;
};
} // namespace lite
} // namespace mindspore
#endif // ACL_DEPARSER_PRIMITIVE_ACTIVATION_DEPARSER_H

View File

@ -0,0 +1,34 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/acl/deparser/add_fusion_deparser.h"
#include <memory>
#include "tools/converter/acl/deparser/primitive_deparser_register.h"
namespace mindspore {
namespace lite {
STATUS AddFusionDeparser::Deparser(const CNodePtr &cnode) {
auto dst_prim = std::make_shared<ops::Add>();
if (MoveAttrMap(cnode, dst_prim) != RET_OK) {
MS_LOG(ERROR) << "AddFusion deparser failed.";
return RET_ERROR;
}
return RET_OK;
}
REGISTER_PRIMITIVE_DEPARSER(kNameAddFusion, AddFusionDeparser)
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,37 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef ACL_DEPARSER_PRIMITIVE_ADDFUSION_DEPARSER_H
#define ACL_DEPARSER_PRIMITIVE_ADDFUSION_DEPARSER_H
#include "tools/converter/acl/deparser/primitive_deparser.h"
#include "ops/fusion/add_fusion.h"
using mindspore::ops::kNameAddFusion;
namespace mindspore {
namespace lite {
class AddFusionDeparser : public PrimitiveDeparser {
public:
AddFusionDeparser() : PrimitiveDeparser(kNameAddFusion) {}
~AddFusionDeparser() override = default;
STATUS Deparser(const CNodePtr &cnode) override;
};
} // namespace lite
} // namespace mindspore
#endif // ACL_DEPARSER_PRIMITIVE_ADDFUSION_DEPARSER_H

View File

@ -0,0 +1,63 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/acl/deparser/avgpool_fusion_deparser.h"
#include <memory>
#include "tools/converter/acl/deparser/primitive_deparser_register.h"
#include "tools/converter/acl/deparser/tbe_op_def.h"
#include "include/registry/parser_context.h"
namespace mindspore {
namespace lite {
STATUS AvgPoolFusionDeparser::Deparser(const CNodePtr &cnode) {
ValueNodePtr value_node = nullptr;
PrimitivePtr src_prim = nullptr;
if (GetValueNodeAndPrimFromCnode(cnode, &value_node, &src_prim) != lite::RET_OK) {
MS_LOG(ERROR) << "Get value node and primitive from cnode failed.";
return lite::RET_ERROR;
}
auto attr_val = src_prim->GetAttr(ops::kFmkType);
int fmk_type = attr_val != nullptr ? GetValue<int>(attr_val) : converter::kFmkTypeTf;
PrimitivePtr dst_prim = nullptr;
if (fmk_type == converter::kFmkTypeCaffe) {
dst_prim = std::make_shared<acl::Pooling>();
} else if (fmk_type == converter::kFmkTypeOnnx) {
ValuePtr val_ptr = src_prim->GetAttr(ops::kKernelSize);
if (val_ptr == nullptr) {
dst_prim = std::make_shared<acl::GlobalAveragePool>();
} else {
dst_prim = std::make_shared<acl::AvgPoolV2>();
}
} else {
dst_prim = std::make_shared<ops::AvgPool>();
}
if (dst_prim == nullptr) {
MS_LOG(ERROR) << "Get primitive by fmk type failed.";
return lite::RET_ERROR;
}
dst_prim->SetAttrs(src_prim->attrs());
if (AdjustPoolAttr(fmk_type, kNameAvgPoolFusion, dst_prim) != lite::RET_OK) {
MS_LOG(ERROR) << "Adjust pool attr failed.";
return lite::RET_ERROR;
}
value_node->set_value(dst_prim);
return lite::RET_OK;
}
REGISTER_PRIMITIVE_DEPARSER(kNameAvgPoolFusion, AvgPoolFusionDeparser)
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,36 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef ACL_DEPARSER_PRIMITIVE_AVGPOOLFUSION_DEPARSER_H
#define ACL_DEPARSER_PRIMITIVE_AVGPOOLFUSION_DEPARSER_H
#include "tools/converter/acl/deparser/primitive_deparser.h"
#include "ops/fusion/avg_pool_fusion.h"
using mindspore::ops::kNameAvgPoolFusion;
namespace mindspore {
namespace lite {
class AvgPoolFusionDeparser : public PrimitiveDeparser {
public:
AvgPoolFusionDeparser() : PrimitiveDeparser(kNameAvgPoolFusion) {}
~AvgPoolFusionDeparser() override = default;
STATUS Deparser(const CNodePtr &cnode) override;
};
} // namespace lite
} // namespace mindspore
#endif // ACL_DEPARSER_PRIMITIVE_AVGPOOLFUSION_DEPARSER_H

View File

@ -0,0 +1,47 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/acl/deparser/batchnorm_deparser.h"
#include <memory>
#include "tools/converter/acl/deparser/primitive_deparser_register.h"
#include "tools/converter/acl/deparser/tbe_op_def.h"
#include "include/registry/parser_context.h"
namespace mindspore {
namespace lite {
STATUS BatchNormDeparser::Deparser(const CNodePtr &cnode) {
ValueNodePtr value_node = nullptr;
PrimitivePtr src_prim = nullptr;
if (GetValueNodeAndPrimFromCnode(cnode, &value_node, &src_prim) != lite::RET_OK) {
MS_LOG(ERROR) << "Get value node and primitive from cnode failed.";
return lite::RET_ERROR;
}
auto attr_val = src_prim->GetAttr(ops::kFmkType);
int fmk_type = attr_val != nullptr ? GetValue<int>(attr_val) : converter::kFmkTypeTf;
if (fmk_type == converter::kFmkTypeCaffe) {
auto dst_prim = std::make_shared<acl::BNInference>();
if (MoveAttrMap(cnode, dst_prim) != RET_OK) {
MS_LOG(ERROR) << "BatchNorm deparser failed.";
return RET_ERROR;
}
}
return RET_OK;
}
REGISTER_PRIMITIVE_DEPARSER(kNameBatchNorm, BatchNormDeparser)
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,37 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef ACL_DEPARSER_PRIMITIVE_BATCHNORM_DEPARSER_H
#define ACL_DEPARSER_PRIMITIVE_BATCHNORM_DEPARSER_H
#include "tools/converter/acl/deparser/primitive_deparser.h"
#include "ops/batch_norm.h"
using mindspore::ops::kNameBatchNorm;
namespace mindspore {
namespace lite {
class BatchNormDeparser : public PrimitiveDeparser {
public:
BatchNormDeparser() : PrimitiveDeparser(kNameBatchNorm) {}
~BatchNormDeparser() override = default;
STATUS Deparser(const CNodePtr &cnode) override;
};
} // namespace lite
} // namespace mindspore
#endif // ACL_DEPARSER_PRIMITIVE_BATCHNORM_DEPARSER_H

View File

@ -0,0 +1,61 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/acl/deparser/cast_deparser.h"
#include "tools/converter/acl/deparser/primitive_deparser_register.h"
#include "tools/optimizer/common/gllo_utils.h"
namespace mindspore {
namespace lite {
namespace {
constexpr size_t kNameCastInputNum = 3;
} // namespace
STATUS CastDeparser::Deparser(const CNodePtr &cnode) {
if (cnode == nullptr) {
MS_LOG(ERROR) << "Cnode is nullptr.";
return lite::RET_ERROR;
}
if (cnode->size() != kNameCastInputNum) {
MS_LOG(ERROR) << "Input size of gather must be three.";
return lite::RET_ERROR;
}
// convert last parameter to const value node
auto to_input = cnode->input(kNameCastInputNum - 1);
if (!utils::isa<ParameterPtr>(to_input)) {
MS_LOG(ERROR) << "The to node is not parameter.";
return lite::RET_ERROR;
}
ParameterPtr to_param = to_input->cast<ParameterPtr>();
auto data = opt::GetIntParameterData(to_param);
int dst_type = data.empty() ? kNumberTypeInt32 : data.front();
TypePtr type_ptr = TypeIdToType(TypeId(dst_type));
if (type_ptr == nullptr) {
MS_LOG(ERROR) << "New type ptr failed.";
return lite::RET_ERROR;
}
ValueNodePtr value_node = NewValueNode(type_ptr);
if (value_node == nullptr) {
MS_LOG(ERROR) << "New value node failed.";
return lite::RET_ERROR;
}
cnode->set_input(kNameCastInputNum - 1, value_node);
return lite::RET_OK;
}
REGISTER_PRIMITIVE_DEPARSER(kNameCast, CastDeparser)
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,37 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef ACL_DEPARSER_PRIMITIVE_CAST_DEPARSER_H
#define ACL_DEPARSER_PRIMITIVE_CAST_DEPARSER_H
#include "tools/converter/acl/deparser/primitive_deparser.h"
#include "ops/cast.h"
using mindspore::ops::kNameCast;
namespace mindspore {
namespace lite {
class CastDeparser : public PrimitiveDeparser {
public:
CastDeparser() : PrimitiveDeparser(kNameCast) {}
~CastDeparser() override = default;
STATUS Deparser(const CNodePtr &cnode) override;
};
} // namespace lite
} // namespace mindspore
#endif // ACL_DEPARSER_PRIMITIVE_CAST_DEPARSER_H

View File

@ -0,0 +1,53 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/acl/deparser/concat_deparser.h"
#include "tools/converter/acl/deparser/primitive_deparser_register.h"
namespace mindspore {
namespace lite {
namespace {
constexpr auto kNameInputNums = "inputNums";
}
STATUS ConcatDeparser::Deparser(const CNodePtr &cnode) {
if (AddAttrForDynInputPrimitive(cnode) != RET_OK) {
MS_LOG(ERROR) << "Concat deparser failed.";
return RET_ERROR;
}
return RET_OK;
}
STATUS ConcatDeparser::AddAttrForDynInputPrimitive(const CNodePtr &cnode) {
MS_ASSERT(cnode != nullptr);
auto value_node = cnode->input(0)->cast<ValueNodePtr>();
MS_ASSERT(value_node != nullptr);
auto prim = GetValueNode<PrimitivePtr>(value_node);
if (prim == nullptr) {
MS_LOG(ERROR) << "Value node is invalid.";
return lite::RET_ERROR;
}
// add attr input num for dynamic input op
int64_t num = static_cast<int64_t>(cnode->size());
if (num > 1) {
prim->AddAttr(kNameInputNums, MakeValue(num - 1));
}
return lite::RET_OK;
}
REGISTER_PRIMITIVE_DEPARSER(kNameConcat, ConcatDeparser)
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,40 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef ACL_DEPARSER_PRIMITIVE_CONCAT_DEPARSER_H
#define ACL_DEPARSER_PRIMITIVE_CONCAT_DEPARSER_H
#include "tools/converter/acl/deparser/primitive_deparser.h"
#include "ops/concat.h"
using mindspore::ops::kNameConcat;
namespace mindspore {
namespace lite {
class ConcatDeparser : public PrimitiveDeparser {
public:
ConcatDeparser() : PrimitiveDeparser(kNameConcat) {}
~ConcatDeparser() override = default;
STATUS Deparser(const CNodePtr &cnode) override;
private:
STATUS AddAttrForDynInputPrimitive(const CNodePtr &cnode);
};
} // namespace lite
} // namespace mindspore
#endif // ACL_DEPARSER_PRIMITIVE_CONCAT_DEPARSER_H

View File

@ -0,0 +1,50 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/acl/deparser/conv2d_fusion_deparser.h"
#include "memory"
#include "tools/converter/acl/deparser/primitive_deparser_register.h"
namespace mindspore {
namespace lite {
STATUS Conv2DFusionDeparser::Deparser(const CNodePtr &cnode) {
ValueNodePtr value_node = nullptr;
PrimitivePtr src_prim = nullptr;
if (GetValueNodeAndPrimFromCnode(cnode, &value_node, &src_prim) != lite::RET_OK) {
MS_LOG(ERROR) << "Get primitive from cnode failed.";
return lite::RET_ERROR;
}
auto dst_prim = std::make_shared<ops::Conv2D>();
MS_ASSERT(dst_prim != nullptr);
dst_prim->SetAttrs(src_prim->attrs());
auto status = AttrAdjust(dst_prim, ops::kStride);
if (status != lite::RET_OK) {
MS_LOG(ERROR) << "adjust stride failed.";
return status;
}
status = AttrAdjust(dst_prim, ops::kDilation);
if (status != lite::RET_OK) {
MS_LOG(ERROR) << "adjust dilation failed.";
return status;
}
value_node->set_value(dst_prim);
return lite::RET_OK;
}
REGISTER_PRIMITIVE_DEPARSER(kNameConv2DFusion, Conv2DFusionDeparser)
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,36 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef ACL_DEPARSER_PRIMITIVE_CONV2DFUSION_DEPARSER_H
#define ACL_DEPARSER_PRIMITIVE_CONV2DFUSION_DEPARSER_H
#include "tools/converter/acl/deparser/primitive_deparser.h"
#include "ops/fusion/conv2d_fusion.h"
using mindspore::ops::kNameConv2DFusion;
namespace mindspore {
namespace lite {
class Conv2DFusionDeparser : public PrimitiveDeparser {
public:
Conv2DFusionDeparser() : PrimitiveDeparser(kNameConv2DFusion) {}
~Conv2DFusionDeparser() override = default;
STATUS Deparser(const CNodePtr &cnode) override;
};
} // namespace lite
} // namespace mindspore
#endif // ACL_DEPARSER_PRIMITIVE_CONV2DFUSION_DEPARSER_H

View File

@ -0,0 +1,54 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/acl/deparser/conv2d_transpose_fusion_deparser.h"
#include <memory>
#include "tools/converter/acl/deparser/primitive_deparser_register.h"
#include "include/registry/parser_context.h"
#include "tools/converter/acl/deparser/tbe_op_def.h"
namespace mindspore {
namespace lite {
STATUS Conv2dTransposeFusionDeparser::Deparser(const CNodePtr &cnode) {
ValueNodePtr value_node = nullptr;
PrimitivePtr src_prim = nullptr;
if (GetValueNodeAndPrimFromCnode(cnode, &value_node, &src_prim) != lite::RET_OK) {
MS_LOG(ERROR) << "Get primitive from cnode failed.";
return lite::RET_ERROR;
}
auto attr_val = src_prim->GetAttr(ops::kFmkType);
int fmk_type = attr_val != nullptr ? GetValue<int>(attr_val) : converter::kFmkTypeTf;
PrimitivePtr dst_prim = nullptr;
if (fmk_type == converter::kFmkTypeCaffe) {
dst_prim = std::make_shared<acl::Deconvolution>();
} else {
dst_prim = std::make_shared<ops::Conv2DTranspose>();
}
MS_ASSERT(dst_prim != nullptr);
dst_prim->SetAttrs(src_prim->attrs());
auto status = AttrAdjust(dst_prim, ops::kDilation);
if (status != lite::RET_OK) {
MS_LOG(ERROR) << "adjust failed.";
return status;
}
value_node->set_value(dst_prim);
return lite::RET_OK;
}
REGISTER_PRIMITIVE_DEPARSER(kNameConv2dTransposeFusion, Conv2dTransposeFusionDeparser)
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,36 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef ACL_DEPARSER_PRIMITIVE_CONV2DTRANSPOSEFUSION_DEPARSER_H
#define ACL_DEPARSER_PRIMITIVE_CONV2DTRANSPOSEFUSION_DEPARSER_H
#include "tools/converter/acl/deparser/primitive_deparser.h"
#include "ops/fusion/conv2d_transpose_fusion.h"
using mindspore::ops::kNameConv2dTransposeFusion;
namespace mindspore {
namespace lite {
class Conv2dTransposeFusionDeparser : public PrimitiveDeparser {
public:
Conv2dTransposeFusionDeparser() : PrimitiveDeparser(kNameConv2dTransposeFusion) {}
~Conv2dTransposeFusionDeparser() override = default;
STATUS Deparser(const CNodePtr &cnode) override;
};
} // namespace lite
} // namespace mindspore
#endif // ACL_DEPARSER_PRIMITIVE_CONV2DTRANSPOSEFUSION_DEPARSER_H

View File

@ -0,0 +1,50 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/acl/deparser/eltwise_deparser.h"
#include "tools/converter/acl/deparser/primitive_deparser_register.h"
#include "ops/op_utils.h"
namespace mindspore {
namespace lite {
STATUS EltWiseDeparser::Deparser(const CNodePtr &cnode) {
if (AddAttrForDynInputPrimitive(cnode) != RET_OK) {
MS_LOG(ERROR) << "EltWise deparser failed.";
return RET_ERROR;
}
return RET_OK;
}
STATUS EltWiseDeparser::AddAttrForDynInputPrimitive(const CNodePtr &cnode) {
MS_ASSERT(cnode != nullptr);
auto value_node = cnode->input(0)->cast<ValueNodePtr>();
MS_ASSERT(value_node != nullptr);
auto prim = GetValueNode<PrimitivePtr>(value_node);
if (prim == nullptr) {
MS_LOG(ERROR) << "Value node is invalid.";
return lite::RET_ERROR;
}
// add attr input num for dynamic input op
int64_t num = static_cast<int64_t>(cnode->size());
if (num > 1) {
prim->AddAttr(ops::kN, MakeValue(num - 1));
}
return lite::RET_OK;
}
REGISTER_PRIMITIVE_DEPARSER(kNameEltwise, EltWiseDeparser)
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,40 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef ACL_DEPARSER_PRIMITIVE_ELTWISE_DEPARSER_H
#define ACL_DEPARSER_PRIMITIVE_ELTWISE_DEPARSER_H
#include "tools/converter/acl/deparser/primitive_deparser.h"
#include "ops/eltwise.h"
using mindspore::ops::kNameEltwise;
namespace mindspore {
namespace lite {
class EltWiseDeparser : public PrimitiveDeparser {
public:
EltWiseDeparser() : PrimitiveDeparser(kNameEltwise) {}
~EltWiseDeparser() override = default;
STATUS Deparser(const CNodePtr &cnode) override;
private:
STATUS AddAttrForDynInputPrimitive(const CNodePtr &cnode);
};
} // namespace lite
} // namespace mindspore
#endif // ACL_DEPARSER_PRIMITIVE_ELTWISE_DEPARSER_H

View File

@ -0,0 +1,36 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/acl/deparser/fused_batchnorm_deparser.h"
#include "tools/converter/acl/deparser/primitive_deparser_register.h"
#include "ops/op_utils.h"
namespace mindspore {
namespace lite {
STATUS FusedBatchNormDeparser::Deparser(const CNodePtr &cnode) {
ValueNodePtr value_node = nullptr;
PrimitivePtr src_prim = nullptr;
if (GetValueNodeAndPrimFromCnode(cnode, &value_node, &src_prim) != lite::RET_OK) {
MS_LOG(ERROR) << "Get primitive from cnode failed.";
return lite::RET_ERROR;
}
src_prim->AddAttr(ops::kIsTraining, MakeValue(false));
return lite::RET_OK;
}
REGISTER_PRIMITIVE_DEPARSER(kNameFusedBatchNorm, FusedBatchNormDeparser)
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,36 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef ACL_DEPARSER_PRIMITIVE_FUSEDBATCHNORM_DEPARSER_H
#define ACL_DEPARSER_PRIMITIVE_FUSEDBATCHNORM_DEPARSER_H
#include "tools/converter/acl/deparser/primitive_deparser.h"
#include "ops/fused_batch_norm.h"
using mindspore::ops::kNameFusedBatchNorm;
namespace mindspore {
namespace lite {
class FusedBatchNormDeparser : public PrimitiveDeparser {
public:
FusedBatchNormDeparser() : PrimitiveDeparser(kNameFusedBatchNorm) {}
~FusedBatchNormDeparser() override = default;
STATUS Deparser(const CNodePtr &cnode) override;
};
} // namespace lite
} // namespace mindspore
#endif // ACL_DEPARSER_PRIMITIVE_FUSEDBATCHNORM_DEPARSER_H

View File

@ -0,0 +1,56 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/acl/deparser/gather_fusion_deparser.h"
#include "tools/converter/acl/deparser/primitive_deparser_register.h"
#include "tools/optimizer/common/gllo_utils.h"
namespace mindspore {
namespace lite {
namespace {
constexpr size_t kNameGatherInputNum = 4;
}
STATUS GatherDeparser::Deparser(const CNodePtr &cnode) {
if (cnode == nullptr) {
MS_LOG(ERROR) << "Cnode is nullptr.";
return lite::RET_ERROR;
}
if (cnode->size() != kNameGatherInputNum) {
MS_LOG(ERROR) << "Input size of gather must be four.";
return lite::RET_ERROR;
}
// convert last parameter to const value node
auto axis_input = cnode->input(kNameGatherInputNum - 1);
if (!utils::isa<ParameterPtr>(axis_input)) {
MS_LOG(ERROR) << "The axis node is not parameter.";
return lite::RET_ERROR;
}
ParameterPtr axis_param = axis_input->cast<ParameterPtr>();
auto data = opt::GetIntParameterData(axis_param);
int64_t axis = data.empty() ? 0 : static_cast<int64_t>(data.front());
ValueNodePtr value_node = NewValueNode<int64_t>(axis);
if (value_node == nullptr) {
MS_LOG(ERROR) << "New value node failed.";
return lite::RET_ERROR;
}
cnode->set_input(kNameGatherInputNum - 1, value_node);
return lite::RET_OK;
}
REGISTER_PRIMITIVE_DEPARSER(kNameGather, GatherDeparser)
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,36 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef ACL_DEPARSER_PRIMITIVE_GATHER_DEPARSER_H
#define ACL_DEPARSER_PRIMITIVE_GATHER_DEPARSER_H
#include "tools/converter/acl/deparser/primitive_deparser.h"
#include "ops/gather.h"
using mindspore::ops::kNameGather;
namespace mindspore {
namespace lite {
class GatherDeparser : public PrimitiveDeparser {
public:
GatherDeparser() : PrimitiveDeparser(kNameGather) {}
~GatherDeparser() override = default;
STATUS Deparser(const CNodePtr &cnode) override;
};
} // namespace lite
} // namespace mindspore
#endif // ACL_DEPARSER_PRIMITIVE_GATHER_DEPARSER_H

View File

@ -0,0 +1,58 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/acl/deparser/maxpool_fusion_deparser.h"
#include <memory>
#include "tools/converter/acl/deparser/primitive_deparser_register.h"
#include "tools/converter/acl/deparser/tbe_op_def.h"
#include "include/registry/parser_context.h"
namespace mindspore {
namespace lite {
STATUS MaxPoolFusionDeparser::Deparser(const CNodePtr &cnode) {
ValueNodePtr value_node = nullptr;
PrimitivePtr src_prim = nullptr;
if (GetValueNodeAndPrimFromCnode(cnode, &value_node, &src_prim) != lite::RET_OK) {
MS_LOG(ERROR) << "Get value node and primitive from cnode failed.";
return lite::RET_ERROR;
}
auto attr_val = src_prim->GetAttr(ops::kFmkType);
int fmk_type = attr_val != nullptr ? GetValue<int>(attr_val) : converter::kFmkTypeTf;
PrimitivePtr dst_prim = nullptr;
if (fmk_type == converter::kFmkTypeCaffe) {
dst_prim = std::make_shared<acl::Pooling>();
} else if (fmk_type == converter::kFmkTypeOnnx) {
dst_prim = std::make_shared<acl::MaxPoolV3>();
} else {
dst_prim = std::make_shared<ops::MaxPool>();
}
if (dst_prim == nullptr) {
MS_LOG(ERROR) << "Get primitive by fmk type failed.";
return lite::RET_ERROR;
}
dst_prim->SetAttrs(src_prim->attrs());
if (AdjustPoolAttr(fmk_type, kNameMaxPoolFusion, dst_prim) != lite::RET_OK) {
MS_LOG(ERROR) << "Adjust pool attr failed.";
return lite::RET_ERROR;
}
value_node->set_value(dst_prim);
return lite::RET_OK;
}
REGISTER_PRIMITIVE_DEPARSER(kNameMaxPoolFusion, MaxPoolFusionDeparser)
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,36 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef ACL_DEPARSER_PRIMITIVE_MAXPOOLFUSION_DEPARSER_H
#define ACL_DEPARSER_PRIMITIVE_MAXPOOLFUSION_DEPARSER_H
#include "tools/converter/acl/deparser/primitive_deparser.h"
#include "ops/fusion/max_pool_fusion.h"
using mindspore::ops::kNameMaxPoolFusion;
namespace mindspore {
namespace lite {
class MaxPoolFusionDeparser : public PrimitiveDeparser {
public:
MaxPoolFusionDeparser() : PrimitiveDeparser(kNameMaxPoolFusion) {}
~MaxPoolFusionDeparser() override = default;
STATUS Deparser(const CNodePtr &cnode) override;
};
} // namespace lite
} // namespace mindspore
#endif // ACL_DEPARSER_PRIMITIVE_MAXPOOLFUSION_DEPARSER_H

View File

@ -0,0 +1,34 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/acl/deparser/mul_fusion_deparser.h"
#include <memory>
#include "tools/converter/acl/deparser/primitive_deparser_register.h"
namespace mindspore {
namespace lite {
STATUS MulFusionDeparser::Deparser(const CNodePtr &cnode) {
auto dst_prim = std::make_shared<ops::Mul>();
if (MoveAttrMap(cnode, dst_prim) != RET_OK) {
MS_LOG(ERROR) << "MulFusion deparser failed.";
return RET_ERROR;
}
return RET_OK;
}
REGISTER_PRIMITIVE_DEPARSER(kNameMulFusion, MulFusionDeparser)
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,37 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef ACL_DEPARSER_PRIMITIVE_MULFUSION_DEPARSER_H
#define ACL_DEPARSER_PRIMITIVE_MULFUSION_DEPARSER_H
#include "tools/converter/acl/deparser/primitive_deparser.h"
#include "ops/fusion/mul_fusion.h"
using mindspore::ops::kNameMulFusion;
namespace mindspore {
namespace lite {
class MulFusionDeparser : public PrimitiveDeparser {
public:
MulFusionDeparser() : PrimitiveDeparser(kNameMulFusion) {}
~MulFusionDeparser() override = default;
STATUS Deparser(const CNodePtr &cnode) override;
};
} // namespace lite
} // namespace mindspore
#endif // ACL_DEPARSER_PRIMITIVE_MULFUSION_DEPARSER_H

View File

@ -0,0 +1,83 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/acl/deparser/pad_fusion_deparser.h"
#include <memory>
#include <map>
#include <string>
#include "tools/converter/acl/deparser/primitive_deparser_register.h"
#include "tools/converter/acl/deparser/tbe_op_def.h"
#include "ops/op_utils.h"
namespace mindspore {
namespace lite {
namespace {
constexpr size_t kNamePadInputNum = 3;
constexpr auto kNamePadContiguous = "pad_contiguous";
} // namespace
STATUS PadFusionDeparser::Deparser(const CNodePtr &cnode) {
ValueNodePtr value_node = nullptr;
PrimitivePtr src_prim = nullptr;
if (GetValueNodeAndPrimFromCnode(cnode, &value_node, &src_prim) != lite::RET_OK) {
MS_LOG(ERROR) << "Get primitive from cnode failed.";
return lite::RET_ERROR;
}
auto dst_prim = std::make_shared<acl::PadV3>();
MS_ASSERT(dst_prim != nullptr);
dst_prim->SetAttrs(src_prim->attrs());
AdjustPadAttr(dst_prim);
if (cnode->size() != kNamePadInputNum) {
MS_LOG(INFO) << "No need to add attr to input, input num: " << cnode->size();
value_node->set_value(dst_prim);
return lite::RET_OK;
}
auto func_graph = cnode->func_graph();
if (func_graph == nullptr) {
MS_LOG(ERROR) << "Func graph is nullptr.";
return lite::RET_ERROR;
}
int status = AddAttrToInput(func_graph, cnode, dst_prim, ops::kConstantValue, 2);
if (status != lite::RET_OK) {
MS_LOG(ERROR) << "Add constant value to input failed.";
return lite::RET_ERROR;
}
value_node->set_value(dst_prim);
return lite::RET_OK;
}
void PadFusionDeparser::AdjustPadAttr(const PrimitivePtr &dst_prim) {
static std::map<int64_t, std::string> kPadModeToStrMap = {
{PaddingMode::CONSTANT, "constant"},
{PaddingMode::REFLECT, "reflect"},
{PaddingMode::SYMMETRIC, "edge"},
};
auto pad_mode_value = dst_prim->GetAttr(ops::kPaddingMode);
if (pad_mode_value != nullptr) {
auto pad_mode = GetValue<int64_t>(pad_mode_value);
if (kPadModeToStrMap.find(pad_mode) != kPadModeToStrMap.end()) {
dst_prim->AddAttr(ops::kMode, MakeValue(kPadModeToStrMap[pad_mode]));
dst_prim->DelAttr(ops::kPaddingMode);
}
}
dst_prim->AddAttr(kNamePadContiguous, MakeValue(true));
}
REGISTER_PRIMITIVE_DEPARSER(kNamePadFusion, PadFusionDeparser)
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,39 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef ACL_DEPARSER_PRIMITIVE_PADFUSION_DEPARSER_H
#define ACL_DEPARSER_PRIMITIVE_PADFUSION_DEPARSER_H
#include "tools/converter/acl/deparser/primitive_deparser.h"
#include "ops/fusion/pad_fusion.h"
using mindspore::ops::kNamePadFusion;
namespace mindspore {
namespace lite {
class PadFusionDeparser : public PrimitiveDeparser {
public:
PadFusionDeparser() : PrimitiveDeparser(kNamePadFusion) {}
~PadFusionDeparser() override = default;
STATUS Deparser(const CNodePtr &cnode) override;
private:
void AdjustPadAttr(const PrimitivePtr &dst_prim);
};
} // namespace lite
} // namespace mindspore
#endif // ACL_DEPARSER_PRIMITIVE_PADFUSION_DEPARSER_H

View File

@ -0,0 +1,34 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/acl/deparser/prelu_fusion_deparser.h"
#include <memory>
#include "tools/converter/acl/deparser/primitive_deparser_register.h"
namespace mindspore {
namespace lite {
STATUS PReluFusionDeparser::Deparser(const CNodePtr &cnode) {
auto dst_prim = std::make_shared<ops::PReLU>();
if (MoveAttrMap(cnode, dst_prim) != RET_OK) {
MS_LOG(ERROR) << "PReluFusion deparser failed.";
return RET_ERROR;
}
return RET_OK;
}
REGISTER_PRIMITIVE_DEPARSER(kNamePReLUFusion, PReluFusionDeparser)
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,36 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef ACL_DEPARSER_PRIMITIVE_PRELUFUSION_DEPARSER_H
#define ACL_DEPARSER_PRIMITIVE_PRELUFUSION_DEPARSER_H
#include "tools/converter/acl/deparser/primitive_deparser.h"
#include "ops/fusion/prelu_fusion.h"
using mindspore::ops::kNamePReLUFusion;
namespace mindspore {
namespace lite {
class PReluFusionDeparser : public PrimitiveDeparser {
public:
PReluFusionDeparser() : PrimitiveDeparser(kNamePReLUFusion) {}
~PReluFusionDeparser() override = default;
STATUS Deparser(const CNodePtr &cnode) override;
};
} // namespace lite
} // namespace mindspore
#endif // ACL_DEPARSER_PRIMITIVE_PRELUFUSION_DEPARSER_H

View File

@ -0,0 +1,198 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/acl/deparser/primitive_deparser.h"
#include <map>
#include <vector>
#include "tools/converter/acl/common/utils.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "ir/graph_utils.h"
#include "utils/log_adapter.h"
#include "include/errorcode.h"
#include "include/registry/parser_context.h"
#include "ops/op_utils.h"
#include "ops/fusion/avg_pool_fusion.h"
#include "backend/kernel_compiler/cpu/nnacl/op_base.h"
namespace mindspore {
namespace lite {
namespace {
constexpr auto kCommonAttrValueNum = 2;
constexpr auto kNamePaddingMode = "padding_mode";
constexpr auto kNameCeilMode = "ceil_mode";
} // namespace
STATUS PrimitiveDeparser::Deparser(const CNodePtr &cnode) { return lite::RET_OK; }
STATUS PrimitiveDeparser::GetValueNodeAndPrimFromCnode(const CNodePtr &cnode, ValueNodePtr *value_node,
PrimitivePtr *prim_ptr) {
CHECK_NULL_RETURN(cnode);
CHECK_NULL_RETURN(value_node);
CHECK_NULL_RETURN(prim_ptr);
*value_node = cnode->input(0)->cast<ValueNodePtr>();
if (*value_node == nullptr) {
MS_LOG(ERROR) << "Value node[" << cnode->fullname_with_scope() << "] is nullptr.";
return lite::RET_ERROR;
}
*prim_ptr = GetValueNode<PrimitivePtr>(*value_node);
if (*prim_ptr == nullptr) {
MS_LOG(ERROR) << "Value node[" << cnode->fullname_with_scope() << "] cast to primitive failed.";
return lite::RET_ERROR;
}
return lite::RET_OK;
}
STATUS PrimitiveDeparser::AttrAdjust(const PrimitivePtr &prim, const std::string &name) {
auto value_ptr = prim->GetAttr(name);
if (value_ptr == nullptr) {
MS_LOG(WARNING) << prim->name() << " has no attr " << name;
return lite::RET_OK;
}
if (utils::isa<ValueSequeuePtr>(value_ptr)) {
auto val_seq_ptr = value_ptr->cast<ValueSequeuePtr>();
CHECK_NULL_RETURN(val_seq_ptr);
ValuePtr first_val = nullptr;
if (!val_seq_ptr->value().empty()) {
first_val = val_seq_ptr->value().front();
}
CHECK_NULL_RETURN(first_val);
CHECK_NULL_RETURN(first_val->type());
if (first_val->type()->number_type() != kNumberTypeInt64) {
MS_LOG(ERROR) << "Value number type of name: " << prim->name() << " ,please check the attr name: " << name;
return lite::RET_ERROR;
}
} else {
CHECK_NULL_RETURN(value_ptr->type());
if (value_ptr->type()->number_type() != kNumberTypeInt64) {
MS_LOG(ERROR) << "Value number type of name: " << prim->name() << " ,please check the attr name: " << name;
return lite::RET_ERROR;
}
}
auto origin_value = opt::CastToInt(value_ptr);
if (origin_value.size() != kCommonAttrValueNum) {
MS_LOG(ERROR) << name << " Value num must be two.";
return lite::RET_ERROR;
}
std::vector<int64_t> new_value;
new_value.push_back(1);
new_value.push_back(1);
new_value.push_back(static_cast<int64_t>(origin_value[0]));
new_value.push_back(static_cast<int64_t>(origin_value[1]));
prim->AddAttr(name, MakeValue(new_value));
return lite::RET_OK;
}
void PrimitiveDeparser::AdjustCaffePoolAttr(const std::string &src_prim_name, const PrimitivePtr &dst_prim) {
int64_t mode = src_prim_name == ops::kNameAvgPoolFusion ? 1 : 0;
dst_prim->AddAttr(ops::kMode, MakeValue(mode));
auto run_mode_val = dst_prim->GetAttr(ops::kRoundMode);
auto run_mode = GetValue<int64_t>(run_mode_val);
int64_t run_mode_ge = run_mode == RoundMode::FLOOR ? 1 : 0;
dst_prim->set_attr(ops::kRoundMode, MakeValue(run_mode_ge));
}
void PrimitiveDeparser::AdjustOnnxPoolAttr(const PrimitivePtr &dst_prim) {
static std::map<int64_t, std::string> kPadModToStrMap = {
{PadMode::PAD, "CALCULATED"},
{PadMode::SAME, "SAME"},
{PadMode::VALID, "VALID"},
};
auto pad_mode_val = dst_prim->GetAttr(ops::kPadMode);
auto pad_mode = GetValue<int64_t>(pad_mode_val);
std::string padding_mode = "CALCULATED";
if (kPadModToStrMap.find(pad_mode) != kPadModToStrMap.end()) {
padding_mode = kPadModToStrMap[pad_mode];
}
dst_prim->AddAttr(kNamePaddingMode, MakeValue(padding_mode));
auto run_mode_val = dst_prim->GetAttr(ops::kRoundMode);
int64_t run_mode = GetValue<int64_t>(run_mode_val);
bool ceil_mode = run_mode == RoundMode::CEIL;
dst_prim->AddAttr(kNameCeilMode, MakeValue(ceil_mode));
}
STATUS PrimitiveDeparser::AdjustPoolAttr(int fmk_type, const std::string &src_prim_name, const PrimitivePtr &dst_prim) {
if (fmk_type == converter::kFmkTypeCaffe) {
AdjustCaffePoolAttr(src_prim_name, dst_prim);
return lite::RET_OK;
} else if (fmk_type == converter::kFmkTypeOnnx) {
AdjustOnnxPoolAttr(dst_prim);
}
// adjust common attr
auto status = AttrAdjust(dst_prim, ops::kKernelSize);
if (status != lite::RET_OK) {
MS_LOG(ERROR) << "Adjust kernel size failed.";
return status;
}
status = AttrAdjust(dst_prim, ops::kStrides);
if (status != lite::RET_OK) {
MS_LOG(ERROR) << "adjust strides failed.";
return status;
}
return lite::RET_OK;
}
STATUS PrimitiveDeparser::MoveAttrMap(const CNodePtr &cnode, const PrimitivePtr &dst_prim) {
ValueNodePtr value_node = nullptr;
PrimitivePtr src_prim = nullptr;
if (GetValueNodeAndPrimFromCnode(cnode, &value_node, &src_prim) != lite::RET_OK) {
MS_LOG(ERROR) << "Get primitive from cnode failed.";
return lite::RET_ERROR;
}
if (dst_prim == nullptr) {
MS_LOG(ERROR) << "Primitive is nullptr.";
return lite::RET_ERROR;
}
dst_prim->SetAttrs(src_prim->attrs());
value_node->set_value(dst_prim);
return lite::RET_OK;
}
STATUS PrimitiveDeparser::AddAttrToInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
const PrimitivePtr &dst_prim, const std::string &attr_name, int flag) {
auto attr_val = dst_prim->GetAttr(attr_name);
if (attr_val == nullptr) {
MS_LOG(INFO) << "There is no attr: " << attr_name;
return lite::RET_OK;
}
auto inputs = cnode->inputs();
switch (flag) {
case (1): {
auto value_data = opt::CastToVec2DInt(attr_val);
auto param_node =
opt::BuildIntVec2DParameterNode(func_graph, value_data, cnode->fullname_with_scope() + "_" + attr_name);
inputs.push_back(param_node);
break;
}
case (2): {
auto value_data = GetValue<float>(attr_val);
auto param_node =
opt::BuildFloatValueParameterNode(func_graph, value_data, cnode->fullname_with_scope() + "_" + attr_name);
inputs.push_back(param_node);
break;
}
default:
MS_LOG(ERROR) << "Invalid flag for attr: " << flag;
return lite::RET_ERROR;
}
cnode->set_inputs(inputs);
return lite::RET_OK;
}
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,59 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef ACL_DEPARSER_PRIMITIVE_DEPARSER_H
#define ACL_DEPARSER_PRIMITIVE_DEPARSER_H
#include <string>
#include <memory>
#include "base/base.h"
#include "include/errorcode.h"
#include "ir/anf.h"
namespace mindspore {
namespace lite {
class PrimitiveDeparser {
public:
explicit PrimitiveDeparser(const std::string &name) : name_(name) {}
virtual ~PrimitiveDeparser() = default;
virtual STATUS Deparser(const CNodePtr &cnode);
protected:
STATUS AttrAdjust(const PrimitivePtr &prim, const std::string &name);
STATUS MoveAttrMap(const CNodePtr &cnode, const PrimitivePtr &dst_prim);
STATUS GetValueNodeAndPrimFromCnode(const CNodePtr &cnode, ValueNodePtr *value_node, PrimitivePtr *prim_ptr);
STATUS AdjustPoolAttr(int fmk_type, const std::string &src_prim_name, const PrimitivePtr &dst_prim);
STATUS AddAttrToInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const PrimitivePtr &dst_prim,
const std::string &attr_name, int flag);
private:
void AdjustCaffePoolAttr(const std::string &src_prim_name, const PrimitivePtr &dst_prim);
void AdjustOnnxPoolAttr(const PrimitivePtr &dst_prim);
std::string name_;
};
using PrimitiveDeparserPtr = std::shared_ptr<PrimitiveDeparser>;
} // namespace lite
} // namespace mindspore
#endif // ACL_DEPARSER_PRIMITIVE_DEPARSER_H

View File

@ -0,0 +1,44 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/acl/deparser/primitive_deparser_register.h"
#include "utils/log_adapter.h"
namespace mindspore {
namespace lite {
PrimitiveDeparserRegister &PrimitiveDeparserRegister::GetInstance() {
static PrimitiveDeparserRegister instance;
return instance;
}
void PrimitiveDeparserRegister::InsertPrimitiveDeparser(const std::string &name, const PrimitiveDeparserPtr &deparser) {
deparser_[name] = deparser;
}
PrimitiveDeparserPtr PrimitiveDeparserRegister::GetPrimitiveDeparser(const std::string &name) {
if (deparser_.find(name) != deparser_.end()) {
return deparser_[name];
} else {
MS_LOG(DEBUG) << "Unsupported primitive name : " << name;
return nullptr;
}
}
RegisterPrimitiveDeparser::RegisterPrimitiveDeparser(const std::string &name, const PrimitiveDeparserPtr &deparser) {
PrimitiveDeparserRegister::GetInstance().InsertPrimitiveDeparser(name, deparser);
}
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,54 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef ACL_DEPARSER_PRIMITIVE_DEPARSER_REGISTER_H
#define ACL_DEPARSER_PRIMITIVE_DEPARSER_REGISTER_H
#include <map>
#include <memory>
#include <string>
#include "ir/anf.h"
#include "tools/converter/acl/deparser/primitive_deparser.h"
namespace mindspore {
namespace lite {
class PrimitiveDeparserRegister {
public:
static PrimitiveDeparserRegister &GetInstance();
void InsertPrimitiveDeparser(const std::string &name, const PrimitiveDeparserPtr &deparser);
PrimitiveDeparserPtr GetPrimitiveDeparser(const std::string &name);
private:
PrimitiveDeparserRegister() = default;
~PrimitiveDeparserRegister() = default;
std::map<std::string, PrimitiveDeparserPtr> deparser_;
};
class RegisterPrimitiveDeparser {
public:
RegisterPrimitiveDeparser(const std::string &name, const PrimitiveDeparserPtr &deparser);
~RegisterPrimitiveDeparser() = default;
};
#define REGISTER_PRIMITIVE_DEPARSER(name, deparser) \
static RegisterPrimitiveDeparser g_##name##PrimDeparser(name, std::make_shared<deparser>());
} // namespace lite
} // namespace mindspore
#endif // ACL_DEPARSER_PRIMITIVE_DEPARSER_REGISTER_H

View File

@ -0,0 +1,34 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/acl/deparser/scale_fusion_deparser.h"
#include <memory>
#include "tools/converter/acl/deparser/primitive_deparser_register.h"
namespace mindspore {
namespace lite {
STATUS ScaleFusionDeparser::Deparser(const CNodePtr &cnode) {
auto dst_prim = std::make_shared<ops::Scale>();
if (MoveAttrMap(cnode, dst_prim) != RET_OK) {
MS_LOG(ERROR) << "ScaleFusion deparser failed.";
return RET_ERROR;
}
return RET_OK;
}
REGISTER_PRIMITIVE_DEPARSER(kNameScaleFusion, ScaleFusionDeparser)
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,36 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef ACL_DEPARSER_PRIMITIVE_SCALEFUSION_DEPARSER_H
#define ACL_DEPARSER_PRIMITIVE_SCALEFUSION_DEPARSER_H
#include "tools/converter/acl/deparser/primitive_deparser.h"
#include "ops/fusion/scale_fusion.h"
using mindspore::ops::kNameScaleFusion;
namespace mindspore {
namespace lite {
class ScaleFusionDeparser : public PrimitiveDeparser {
public:
ScaleFusionDeparser() : PrimitiveDeparser(kNameScaleFusion) {}
~ScaleFusionDeparser() override = default;
STATUS Deparser(const CNodePtr &cnode) override;
};
} // namespace lite
} // namespace mindspore
#endif // ACL_DEPARSER_PRIMITIVE_SCALEFUSION_DEPARSER_H

View File

@ -0,0 +1,152 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/acl/deparser/spatial_node_adapter.h"
#include <vector>
#include <set>
#include <memory>
#include <string>
#include "tools/converter/acl/common/utils.h"
#include "tools/converter/ops/ops_def.h"
#include "tools/common/tensor_util.h"
#include "include/errorcode.h"
#include "base/base.h"
#include "base/core_ops.h"
#include "utils/log_adapter.h"
#include "ops/concat.h"
#include "ops/batch_norm.h"
#include "ops/fused_batch_norm.h"
#include "ops/stack.h"
namespace mindspore {
namespace lite {
namespace {
constexpr size_t kCnodeInputMinNum = 2;
constexpr auto kAnfPrimitiveIndex = 0;
constexpr auto kNamewiEltwise = "Eltwise";
const std::set<std::string> kCNodeWithMultiOutputs = {ops::kNameBatchNorm, ops::kNameFusedBatchNorm};
const std::set<std::string> kCNodeWithDynamicInput = {kNamewiEltwise, ops::kNameConcat, ops::kNameStack};
} // namespace
CNodePtr CreateTupleGetItemNode(const FuncGraphPtr &func_graph, const CNodePtr &input_cnode) {
CNodePtr get_item_cnode = nullptr;
auto tuple_get_item_prim_ptr = std::make_shared<lite::TupleGetItem>();
if (tuple_get_item_prim_ptr == nullptr) {
MS_LOG(ERROR) << "New TupleGetItem failed";
return nullptr;
}
auto tuple_get_item_prim = NewValueNode(tuple_get_item_prim_ptr);
auto get_item_value = NewValueNode(MakeValue<int64_t>(0));
AnfNodePtrList inputs{tuple_get_item_prim, input_cnode, get_item_value};
get_item_cnode = func_graph->NewCNode(inputs);
if (get_item_cnode == nullptr) {
MS_LOG(ERROR) << "New get item cnode failed.";
return nullptr;
}
std::vector<int64_t> shape;
if (acl::GetShapeVectorFromCNode(input_cnode, &shape) != lite::RET_OK) {
MS_LOG(ERROR) << "Get shape failed.";
return nullptr;
}
TypeId type = acl::GetTypeFromNode(input_cnode);
auto get_item_abstract = CreateTensorAbstract(shape, type);
if (get_item_abstract == nullptr) {
MS_LOG(ERROR) << "Create tensor abstract failed.";
return nullptr;
}
get_item_cnode->set_abstract(get_item_abstract);
get_item_cnode->set_fullname_with_scope(input_cnode->fullname_with_scope() + "_getitem");
return get_item_cnode;
}
static STATUS AdapteNodeWithMultiOutputs(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
const FuncGraphManagerPtr &manager) {
std::string cnode_func_name = GetCNodeFuncName(cnode);
if (cnode_func_name == prim::kTupleGetItem || cnode_func_name == kNameReturn) {
return lite::RET_OK;
}
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
auto input = cnode->input(i);
if (!utils::isa<CNode>(input)) {
continue;
}
auto input_cnode = input->cast<CNodePtr>();
std::string input_func_name = GetCNodeFuncName(input_cnode);
if (kCNodeWithMultiOutputs.find(input_func_name) != kCNodeWithMultiOutputs.end()) {
MS_LOG(INFO) << "Adapter cnode with multioutputs: " << cnode_func_name;
CNodePtr get_item_cnode = CreateTupleGetItemNode(func_graph, input_cnode);
if (get_item_cnode == nullptr) {
MS_LOG(ERROR) << "Create tuple item for " << cnode_func_name << " failed.";
return lite::RET_ERROR;
}
if (!manager->Replace(input_cnode, get_item_cnode)) {
MS_LOG(ERROR) << "Replace " << cnode_func_name << " failed.";
return lite::RET_ERROR;
}
}
}
return lite::RET_OK;
}
static STATUS AdapteNodeWithDynamicInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
std::string cnode_func_name = GetCNodeFuncName(cnode);
if (kCNodeWithDynamicInput.find(cnode_func_name) == kCNodeWithDynamicInput.end()) {
return lite::RET_OK;
}
MS_LOG(INFO) << "Adapter cnode with dynamic input: " << cnode_func_name;
auto make_tuple_val_node = NewValueNode(prim::kPrimMakeTuple);
if (make_tuple_val_node == nullptr) {
MS_LOG(ERROR) << "New make tuple val node failed.";
return lite::RET_ERROR;
}
AnfNodePtrList new_inputs = {make_tuple_val_node};
auto cnode_inputs = cnode->inputs();
if (cnode_inputs.size() >= kCnodeInputMinNum) {
new_inputs.insert(new_inputs.end(), cnode_inputs.begin() + 1, cnode_inputs.end());
}
auto make_tuple_cnode = func_graph->NewCNode(new_inputs);
if (make_tuple_cnode == nullptr) {
MS_LOG(ERROR) << "New make tuple cnode failed.";
return lite::RET_ERROR;
}
const std::vector<AnfNodePtr> replace_node = {cnode_inputs[0], make_tuple_cnode};
cnode->set_inputs(replace_node);
return lite::RET_OK;
}
STATUS AdapteSpatialNode(const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &manager) {
auto cnodes = func_graph->GetOrderedCnodes();
for (const auto &cnode : cnodes) {
if (cnode == nullptr) {
MS_LOG(ERROR) << "Cnode is nullptr.";
return lite::RET_ERROR;
}
if (AdapteNodeWithMultiOutputs(func_graph, cnode, manager) != lite::RET_OK) {
MS_LOG(ERROR) << "Adapter node with multioutput failed.";
return lite::RET_ERROR;
}
if (AdapteNodeWithDynamicInput(func_graph, cnode) != lite::RET_OK) {
MS_LOG(ERROR) << "Adapter node with dynamic input failed.";
return lite::RET_ERROR;
}
}
return lite::RET_OK;
}
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,29 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef ACL_DEPARSER_SPATIAL_NODE_ADAPTER_PASS_H
#define ACL_DEPARSER_SPATIAL_NODE_ADAPTER_PASS_H
#include "ir/func_graph.h"
#include "include/errorcode.h"
namespace mindspore {
namespace lite {
STATUS AdapteSpatialNode(const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &manager);
} // namespace lite
} // namespace mindspore
#endif // ACL_DEPARSER_SPATIAL_NODE_ADAPTER_PASS_H

View File

@ -0,0 +1,53 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/acl/deparser/stack_deparser.h"
#include "tools/converter/acl/deparser/primitive_deparser_register.h"
namespace mindspore {
namespace lite {
namespace {
constexpr auto kNameNum = "num";
}
STATUS StackDeparser::Deparser(const CNodePtr &cnode) {
if (AddAttrForDynInputPrimitive(cnode) != RET_OK) {
MS_LOG(ERROR) << "Stack deparser failed.";
return RET_ERROR;
}
return RET_OK;
}
STATUS StackDeparser::AddAttrForDynInputPrimitive(const CNodePtr &cnode) {
MS_ASSERT(cnode != nullptr);
auto value_node = cnode->input(0)->cast<ValueNodePtr>();
MS_ASSERT(value_node != nullptr);
auto prim = GetValueNode<PrimitivePtr>(value_node);
if (prim == nullptr) {
MS_LOG(ERROR) << "Value node is invalid.";
return lite::RET_ERROR;
}
// add attr input num for dynamic input op
int64_t num = static_cast<int64_t>(cnode->size());
if (num > 1) {
prim->AddAttr(kNameNum, MakeValue(num - 1));
}
return lite::RET_OK;
}
REGISTER_PRIMITIVE_DEPARSER(kNameStack, StackDeparser)
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,40 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef ACL_DEPARSER_PRIMITIVE_STACK_DEPARSER_H
#define ACL_DEPARSER_PRIMITIVE_STACK_DEPARSER_H
#include "tools/converter/acl/deparser/primitive_deparser.h"
#include "ops/stack.h"
using mindspore::ops::kNameStack;
namespace mindspore {
namespace lite {
class StackDeparser : public PrimitiveDeparser {
public:
StackDeparser() : PrimitiveDeparser(kNameStack) {}
~StackDeparser() override = default;
STATUS Deparser(const CNodePtr &cnode) override;
private:
STATUS AddAttrForDynInputPrimitive(const CNodePtr &cnode);
};
} // namespace lite
} // namespace mindspore
#endif // ACL_DEPARSER_PRIMITIVE_STACK_DEPARSER_H

View File

@ -0,0 +1,47 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/acl/deparser/stridedslice_deparser.h"
#include <memory>
#include "tools/converter/acl/deparser/primitive_deparser_register.h"
#include "tools/converter/acl/deparser/tbe_op_def.h"
#include "include/registry/parser_context.h"
#include "ops/op_utils.h"
namespace mindspore {
namespace lite {
STATUS StridedSliceDeparser::Deparser(const CNodePtr &cnode) {
ValueNodePtr value_node = nullptr;
PrimitivePtr src_prim = nullptr;
if (GetValueNodeAndPrimFromCnode(cnode, &value_node, &src_prim) != lite::RET_OK) {
MS_LOG(ERROR) << "Get value node and primitive from cnode failed.";
return lite::RET_ERROR;
}
auto attr_val = src_prim->GetAttr(ops::kFmkType);
int fmk_type = attr_val != nullptr ? GetValue<int>(attr_val) : converter::kFmkTypeTf;
if (fmk_type == converter::kFmkTypeOnnx) {
auto dst_prim = std::make_shared<acl::StridedSliceV2>();
MS_ASSERT(dst_prim != nullptr);
dst_prim->SetAttrs(src_prim->attrs());
value_node->set_value(dst_prim);
}
return lite::RET_OK;
}
REGISTER_PRIMITIVE_DEPARSER(kNameStridedSlice, StridedSliceDeparser)
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,36 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef ACL_DEPARSER_PRIMITIVE_STRIDEDSLICE_DEPARSER_H
#define ACL_DEPARSER_PRIMITIVE_STRIDEDSLICE_DEPARSER_H
#include "tools/converter/acl/deparser/primitive_deparser.h"
#include "ops/strided_slice.h"
using mindspore::ops::kNameStridedSlice;
namespace mindspore {
namespace lite {
class StridedSliceDeparser : public PrimitiveDeparser {
public:
StridedSliceDeparser() : PrimitiveDeparser(kNameStridedSlice) {}
~StridedSliceDeparser() override = default;
STATUS Deparser(const CNodePtr &cnode) override;
};
} // namespace lite
} // namespace mindspore
#endif // ACL_DEPARSER_PRIMITIVE_STRIDEDSLICE_DEPARSER_H

Some files were not shown because too many files have changed in this diff Show More