support cloud infer package

This commit is contained in:
zhengyuanhua 2022-07-11 11:17:12 +08:00
parent 1cbaa6e2d8
commit d0ff3a0137
43 changed files with 1623 additions and 324 deletions

View File

@ -21,6 +21,7 @@ set(TURBO_DIR ${RUNTIME_PKG_NAME}/runtime/third_party/libjpeg-turbo)
set(GLOG_DIR ${RUNTIME_PKG_NAME}/runtime/third_party/glog)
set(SECUREC_DIR ${RUNTIME_PKG_NAME}/runtime/third_party/securec)
set(MINDSPORE_LITE_LIB_NAME libmindspore-lite)
set(MINDSPORE_LITE_EXTENDRT_LIB_NAME libmindspore-extendrt)
set(MINDSPORE_CORE_LIB_NAME libmindspore_core)
set(BENCHMARK_NAME benchmark)
set(MSLITE_NNIE_LIB_NAME libmslite_nnie)
@ -403,10 +404,22 @@ if(PLATFORM_ARM64)
COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${TOP_DIR}/mindspore/lite/build/schema/ops_types_generated.h DESTINATION ${RUNTIME_INC_DIR}/schema
COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${TOP_DIR}/mindspore/lite/build/src/${MINDSPORE_LITE_LIB_NAME}.so DESTINATION ${RUNTIME_LIB_DIR}
COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${TOP_DIR}/mindspore/lite/build/src/${MINDSPORE_LITE_LIB_NAME}.a DESTINATION ${RUNTIME_LIB_DIR}
COMPONENT ${RUNTIME_COMPONENT_NAME})
if(MSLITE_ENABLE_CLOUD_FUSION_INFERENCE)
install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/${MINDSPORE_LITE_EXTENDRT_LIB_NAME}.so
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${glog_LIBPATH}/libmindspore_glog.so.0.4.0 DESTINATION ${RUNTIME_LIB_DIR}
RENAME libmindspore_glog.so.0 COMPONENT ${RUNTIME_COMPONENT_NAME})
install(TARGETS mindspore_core DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
if(MSLITE_ENABLE_ACL)
install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/kernel/ascend/libascend_kernel_plugin.so
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
endif()
else()
install(FILES ${TOP_DIR}/mindspore/lite/build/src/${MINDSPORE_LITE_LIB_NAME}.so DESTINATION ${RUNTIME_LIB_DIR}
COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${TOP_DIR}/mindspore/lite/build/src/${MINDSPORE_LITE_LIB_NAME}.a DESTINATION ${RUNTIME_LIB_DIR}
COMPONENT ${RUNTIME_COMPONENT_NAME})
endif()
if(ENABLE_MODEL_OBF)
install(FILES ${TOP_DIR}/mindspore/lite/tools/obfuscator/lib/android-aarch64/libmsdeobfuscator-lite.so
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
@ -530,46 +543,9 @@ if(PLATFORM_ARM64)
install(TARGETS mindspore_core DESTINATION ${CONVERTER_ROOT_DIR}/lib
COMPONENT ${RUNTIME_COMPONENT_NAME})
endif()
if(MSLITE_MINDDATA_IMPLEMENT STREQUAL "cloud" AND MSLITE_ENABLE_RUNTIME_CONVERT)
file(GLOB DATA_ENGINE_LIB_LIST ${LITE_ACL_DIR}/_c_dataengine/*.so)
file(GLOB DATA_RECORD_LIB_LIST ${LITE_ACL_DIR}/_c_mindrecord/*.so)
install(FILES ${DATA_ENGINE_LIB_LIST}
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${DATA_RECORD_LIB_LIST}
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${jpeg_turbo_LIBPATH}/libjpeg.so.62.3.0
DESTINATION ${RUNTIME_LIB_DIR} RENAME libjpeg.so.62 COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${jpeg_turbo_LIBPATH}/libturbojpeg.so.0.2.0
DESTINATION ${RUNTIME_LIB_DIR} RENAME libturbojpeg.so.0 COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${tinyxml2_LIBPATH}/libtinyxml2.so.8.0.0
DESTINATION ${RUNTIME_LIB_DIR} RENAME libtinyxml2.so.8 COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${icu4c_LIBPATH}/libicuuc.so.69.1
DESTINATION ${RUNTIME_LIB_DIR} RENAME libicuuc.so.69 COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${icu4c_LIBPATH}/libicudata.so.69.1
DESTINATION ${RUNTIME_LIB_DIR} RENAME libicudata.so.69 COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${icu4c_LIBPATH}/libicui18n.so.69.1
DESTINATION ${RUNTIME_LIB_DIR} RENAME libicui18n.so.69 COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${grpc_LIBPATH}/libmindspore_grpc++.so.1.36.1 DESTINATION ${RUNTIME_LIB_DIR}
RENAME libmindspore_grpc++.so.1 COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${grpc_LIBPATH}/libmindspore_grpc.so.15.0.0 DESTINATION
${RUNTIME_LIB_DIR} RENAME libmindspore_grpc.so.15 COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${grpc_LIBPATH}/libmindspore_gpr.so.15.0.0 DESTINATION
${RUNTIME_LIB_DIR} RENAME libmindspore_gpr.so.15 COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${grpc_LIBPATH}/libmindspore_upb.so.15.0.0 DESTINATION
${RUNTIME_LIB_DIR} RENAME libmindspore_upb.so.15 COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${grpc_LIBPATH}/libmindspore_address_sorting.so.15.0.0 DESTINATION ${RUNTIME_LIB_DIR}
RENAME libmindspore_address_sorting.so.15 COMPONENT ${RUNTIME_COMPONENT_NAME})
## Public header files for minddata
install(
FILES ${TOP_DIR}/mindspore/ccsrc/minddata/dataset/include/dataset/config.h
${TOP_DIR}/mindspore/ccsrc/minddata/dataset/include/dataset/constants.h
${TOP_DIR}/mindspore/ccsrc/minddata/dataset/include/dataset/execute.h
${TOP_DIR}/mindspore/ccsrc/minddata/dataset/include/dataset/text.h
${TOP_DIR}/mindspore/ccsrc/minddata/dataset/include/dataset/transforms.h
${TOP_DIR}/mindspore/ccsrc/minddata/dataset/include/dataset/vision.h
${TOP_DIR}/mindspore/ccsrc/minddata/dataset/include/dataset/vision_lite.h
${TOP_DIR}/mindspore/ccsrc/minddata/dataset/include/dataset/vision_ascend.h
DESTINATION ${RUNTIME_INC_DIR}/dataset COMPONENT ${RUNTIME_COMPONENT_NAME})
if(MSLITE_ENABLE_CLOUD_FUSION_INFERENCE)
install(FILES ${LITE_ACL_DIR}/libascend_pass_plugin.so DESTINATION ${CONVERTER_ROOT_DIR}/lib
COMPONENT ${RUNTIME_COMPONENT_NAME})
endif()
endif()
@ -669,10 +645,22 @@ elseif(PLATFORM_ARM32)
COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${TOP_DIR}/mindspore/lite/build/schema/ops_types_generated.h DESTINATION ${RUNTIME_INC_DIR}/schema
COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${TOP_DIR}/mindspore/lite/build/src/${MINDSPORE_LITE_LIB_NAME}.so DESTINATION ${RUNTIME_LIB_DIR}
COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${TOP_DIR}/mindspore/lite/build/src/${MINDSPORE_LITE_LIB_NAME}.a DESTINATION ${RUNTIME_LIB_DIR}
COMPONENT ${RUNTIME_COMPONENT_NAME})
if(MSLITE_ENABLE_CLOUD_FUSION_INFERENCE)
install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/${MINDSPORE_LITE_EXTENDRT_LIB_NAME}.so
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${glog_LIBPATH}/libmindspore_glog.so.0.4.0 DESTINATION ${RUNTIME_LIB_DIR}
RENAME libmindspore_glog.so.0 COMPONENT ${RUNTIME_COMPONENT_NAME})
install(TARGETS mindspore_core DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
if(MSLITE_ENABLE_ACL)
install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/kernel/ascend/libascend_kernel_plugin.so
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
endif()
else()
install(FILES ${TOP_DIR}/mindspore/lite/build/src/${MINDSPORE_LITE_LIB_NAME}.so DESTINATION ${RUNTIME_LIB_DIR}
COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${TOP_DIR}/mindspore/lite/build/src/${MINDSPORE_LITE_LIB_NAME}.a DESTINATION ${RUNTIME_LIB_DIR}
COMPONENT ${RUNTIME_COMPONENT_NAME})
endif()
if(ENABLE_MODEL_OBF)
install(FILES ${TOP_DIR}/mindspore/lite/tools/obfuscator/lib/android-aarch32/libmsdeobfuscator-lite.so
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
@ -835,14 +823,21 @@ else()
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "ops*" EXCLUDE)
install(DIRECTORY ${TOP_DIR}/include/c_api/ DESTINATION ${RUNTIME_INC_DIR}/c_api
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h")
install(FILES ${TOP_DIR}/mindspore/lite/build/src/${MINDSPORE_LITE_LIB_NAME}.so DESTINATION ${RUNTIME_LIB_DIR}
COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${TOP_DIR}/mindspore/lite/build/src/${MINDSPORE_LITE_LIB_NAME}.a DESTINATION ${RUNTIME_LIB_DIR}
COMPONENT ${RUNTIME_COMPONENT_NAME})
if(MSLITE_ENABLE_CLOUD_FUSION_INFERENCE)
install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/${MINDSPORE_LITE_EXTENDRT_LIB_NAME}.so
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${glog_LIBPATH}/libmindspore_glog.so.0.4.0 DESTINATION ${RUNTIME_LIB_DIR}
RENAME libmindspore_glog.so.0 COMPONENT ${RUNTIME_COMPONENT_NAME})
install(TARGETS mindspore_core DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
if(MSLITE_ENABLE_ACL)
install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/kernel/ascend/libascend_kernel_plugin.so
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
endif()
else()
install(FILES ${TOP_DIR}/mindspore/lite/build/src/${MINDSPORE_LITE_LIB_NAME}.so DESTINATION ${RUNTIME_LIB_DIR}
COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${TOP_DIR}/mindspore/lite/build/src/${MINDSPORE_LITE_LIB_NAME}.a DESTINATION ${RUNTIME_LIB_DIR}
COMPONENT ${RUNTIME_COMPONENT_NAME})
endif()
if(ENABLE_MODEL_OBF)
install(FILES ${TOP_DIR}/mindspore/lite/tools/obfuscator/bin/linux-x64/msobfuscator
@ -918,46 +913,9 @@ else()
RENAME libmindspore_glog.so.0 COMPONENT ${RUNTIME_COMPONENT_NAME})
install(TARGETS mindspore_core DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
endif()
if(MSLITE_MINDDATA_IMPLEMENT STREQUAL "cloud" AND MSLITE_ENABLE_RUNTIME_CONVERT)
file(GLOB DATA_ENGINE_LIB_LIST ${LITE_ACL_DIR}/_c_dataengine/*.so)
file(GLOB DATA_RECORD_LIB_LIST ${LITE_ACL_DIR}/_c_mindrecord/*.so)
install(FILES ${DATA_ENGINE_LIB_LIST}
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${DATA_RECORD_LIB_LIST}
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${jpeg_turbo_LIBPATH}/libjpeg.so.62.3.0
DESTINATION ${RUNTIME_LIB_DIR} RENAME libjpeg.so.62 COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${jpeg_turbo_LIBPATH}/libturbojpeg.so.0.2.0
DESTINATION ${RUNTIME_LIB_DIR} RENAME libturbojpeg.so.0 COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${tinyxml2_LIBPATH}/libtinyxml2.so.8.0.0
DESTINATION ${RUNTIME_LIB_DIR} RENAME libtinyxml2.so.8 COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${icu4c_LIBPATH}/libicuuc.so.69.1
DESTINATION ${RUNTIME_LIB_DIR} RENAME libicuuc.so.69 COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${icu4c_LIBPATH}/libicudata.so.69.1
DESTINATION ${RUNTIME_LIB_DIR} RENAME libicudata.so.69 COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${icu4c_LIBPATH}/libicui18n.so.69.1
DESTINATION ${RUNTIME_LIB_DIR} RENAME libicui18n.so.69 COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${grpc_LIBPATH}/libmindspore_grpc++.so.1.36.1 DESTINATION ${RUNTIME_LIB_DIR}
RENAME libmindspore_grpc++.so.1 COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${grpc_LIBPATH}/libmindspore_grpc.so.15.0.0 DESTINATION
${RUNTIME_LIB_DIR} RENAME libmindspore_grpc.so.15 COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${grpc_LIBPATH}/libmindspore_gpr.so.15.0.0 DESTINATION
${RUNTIME_LIB_DIR} RENAME libmindspore_gpr.so.15 COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${grpc_LIBPATH}/libmindspore_upb.so.15.0.0 DESTINATION
${RUNTIME_LIB_DIR} RENAME libmindspore_upb.so.15 COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${grpc_LIBPATH}/libmindspore_address_sorting.so.15.0.0 DESTINATION ${RUNTIME_LIB_DIR}
RENAME libmindspore_address_sorting.so.15 COMPONENT ${RUNTIME_COMPONENT_NAME})
## Public header files for minddata
install(
FILES ${TOP_DIR}/mindspore/ccsrc/minddata/dataset/include/dataset/config.h
${TOP_DIR}/mindspore/ccsrc/minddata/dataset/include/dataset/constants.h
${TOP_DIR}/mindspore/ccsrc/minddata/dataset/include/dataset/execute.h
${TOP_DIR}/mindspore/ccsrc/minddata/dataset/include/dataset/text.h
${TOP_DIR}/mindspore/ccsrc/minddata/dataset/include/dataset/transforms.h
${TOP_DIR}/mindspore/ccsrc/minddata/dataset/include/dataset/vision.h
${TOP_DIR}/mindspore/ccsrc/minddata/dataset/include/dataset/vision_lite.h
${TOP_DIR}/mindspore/ccsrc/minddata/dataset/include/dataset/vision_ascend.h
DESTINATION ${RUNTIME_INC_DIR}/dataset COMPONENT ${RUNTIME_COMPONENT_NAME})
if(MSLITE_ENABLE_CLOUD_FUSION_INFERENCE)
install(FILES ${LITE_ACL_DIR}/libascend_pass_plugin.so DESTINATION ${CONVERTER_ROOT_DIR}/lib
COMPONENT ${RUNTIME_COMPONENT_NAME})
endif()
endif()

View File

@ -170,6 +170,8 @@ if(DEFINED ENV{MSLITE_ENABLE_MODEL_ENCRYPTION})
if((${CMAKE_SYSTEM_NAME} MATCHES "Linux" AND PLATFORM_X86_64)
OR((PLATFORM_ARM64 OR PLATFORM_ARM32) AND ANDROID_NDK_TOOLCHAIN_INCLUDED))
set(MSLITE_ENABLE_MODEL_ENCRYPTION $ENV{MSLITE_ENABLE_MODEL_ENCRYPTION})
elseif(MSLITE_ENABLE_ACL)
set(MSLITE_ENABLE_MODEL_ENCRYPTION $ENV{MSLITE_ENABLE_MODEL_ENCRYPTION})
else()
set(MSLITE_ENABLE_MODEL_ENCRYPTION OFF)
endif()
@ -204,6 +206,7 @@ if(DEFINED ENV{MSLITE_ENABLE_CLOUD_FUSION_INFERENCE})
endif()
if(MSLITE_ENABLE_ACL AND MSLITE_ENABLE_CLOUD_FUSION_INFERENCE)
add_compile_definitions(ENABLE_CLOUD_FUSION_INFERENCE)
set(PLATFORM_ARM64 off)
set(PLATFORM_ARM32 off)
set(MSLITE_ENABLE_FP16 off)

View File

@ -77,6 +77,9 @@ class MS_API Converter {
void SetNoFusion(bool no_fusion);
bool GetNoFusion();
void SetDevice(const std::string &device);
std::string GetDevice();
Status Convert();
void *Convert(size_t *data_size);

View File

@ -679,7 +679,7 @@ if(ENABLE_MODEL_OBF)
target_link_libraries(mindspore-lite_static ${OBF_LIB_DIR}/libmsdeobfuscator-lite.so)
endif()
if(MSLITE_ENABLE_ACL)
if(MSLITE_ENABLE_ACL AND (NOT MSLITE_ENABLE_CLOUD_FUSION_INFERENCE))
target_link_libraries(mindspore-lite ascend_kernel_mid)
target_link_libraries(mindspore-lite_static ascend_kernel_mid)
endif()

View File

@ -13,8 +13,6 @@ if(MSLITE_ENABLE_CLOUD_FUSION_INFERENCE)
add_compile_definitions(USE_GLOG)
string(REPLACE "-fno-rtti" "" CMAKE_C_FLAGS ${CMAKE_C_FLAGS})
string(REPLACE "-fno-rtti" "" CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS})
string(REPLACE "-fno-exceptions" "" CMAKE_C_FLAGS ${CMAKE_C_FLAGS})
string(REPLACE "-fno-exceptions" "" CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS})
add_compile_definitions(ENABLE_CLOUD_FUSION_INFERENCE)
remove_definitions(-DBUILD_LITE_INFERENCE)
set(MINDIR_MODEL_SRC
@ -30,10 +28,16 @@ if(MSLITE_ENABLE_CLOUD_FUSION_INFERENCE)
${MINDIR_KERNEL_SRC}
${CMAKE_CURRENT_SOURCE_DIR}/mindir_loader/mindir_model/inner_kernel.cc)
set(MSLITE_KERNEL_PLUGIN
${MSLITE_KERNEL_PLUGIN}
${CMAKE_CURRENT_SOURCE_DIR}/kernel/ascend/plugin/ascend_kernel_plugin.cc)
set(MSLITE_EXTEND_RUNTIME_SRC ${MSLITE_EXTEND_RUNTIME_SRC}
# ${MINDIR_MODEL_SRC}
# ${MINDIR_KERNEL_SRC}
${MSLITE_KERNEL_PLUGIN}
${CMAKE_CURRENT_SOURCE_DIR}/kernel/cpu/less_test_kernel_mod.cc
${CMAKE_CURRENT_SOURCE_DIR}/kernel/cpu/transpose_kernel_mod.cc
${CMAKE_CURRENT_SOURCE_DIR}/infer_session.cc
${CMAKE_CURRENT_SOURCE_DIR}/single_op_session.cc
${CMAKE_CURRENT_SOURCE_DIR}/infer_device_address.cc
@ -92,6 +96,7 @@ if(MSLITE_ENABLE_CLOUD_FUSION_INFERENCE)
# ${CCSRC_DIR}/backend/common/optimizer/pattern_engine.cc
# ${CCSRC_DIR}/backend/common/optimizer/visit.cc
# ${CCSRC_DIR}/backend/common/optimizer/common_backend_optimization.cc
${CCSRC_DIR}/runtime/device/auto_mem_offload.cc
${CCSRC_DIR}/runtime/device/ms_device_shape_transfer.cc
${CCSRC_DIR}/runtime/device/kernel_info.cc
${CCSRC_DIR}/runtime/device/convert_tensor_utils.cc
@ -168,8 +173,6 @@ if(MSLITE_ENABLE_CLOUD_FUSION_INFERENCE)
if(MSLITE_ENABLE_ACL)
include_directories(${TOP_DIR}/graphengine/inc/external)
add_subdirectory(kernel/ascend)
link_directories(${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
target_link_libraries(mindspore-extendrt ascend_kernel_mid)
endif()
if(SUPPORT_CUDA)

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -13,52 +13,47 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "include/api/context.h"
#include <any>
#include <map>
#include <type_traits>
#include "extendrt/factory.h"
#include "utils/log_adapter.h"
constexpr auto kModelOptionCpuEnableFP16 = "mindspore.option.cpu.enable_fp16";
constexpr auto kModelOptionGPUEnableFP16 = "mindspore.option.gpu.enable_fp16";
constexpr auto kModelOptionKirinNpuFrequency = "mindspore.option.kirin_npu.frequency";
constexpr auto kModelOptionDeviceID = "mindspore.option.device_id";
constexpr auto kModelOptionGPUDeviceID = kModelOptionDeviceID;
constexpr auto kModelOptionGPUPrecisionMode = "mindspore.option.gpu.precision_mode";
constexpr auto kModelOptionAscend910DeviceID = kModelOptionDeviceID;
constexpr auto kModelOptionAscend310DeviceID = kModelOptionDeviceID;
constexpr auto kModelOptionAscend310InsertOpCfgPath = "mindspore.option.ascend310.insert_op_config_file_path";
constexpr auto kModelOptionAscend310InputFormat = "mindspore.option.ascend310.input_format";
constexpr auto kModelOptionAscend310InputShapeMap = "mindspore.option.ascend310.input_shape_map";
constexpr auto kModelOptionAscend310InputShape = "mindspore.option.ascend310.input_shape";
constexpr auto kModelOptionAscend310OutputType = "mindspore.option.ascend310.output_type";
constexpr auto kModelOptionAscend310PrecisionMode = "mindspore.option.ascend310.precision_mode";
constexpr auto kModelOptionAscend310OpSelectImplMode = "mindspore.option.ascend310.op_select_impl_mode";
constexpr auto KModelOptionAscend310FusionSwitchCfgPath = "mindspore.option.ascend310.fusion_switch_config_file_path";
constexpr auto kModelOptionAscend310DynamicBatchSize = "mindspore.option.ascend310.dynamic_batch_size";
constexpr auto kModelOptionAscend310BufferOptimize = "mindspore.option.ascend310.buffer_optimize";
#include "src/runtime/cxx_api/context.h"
#include <string>
#include <memory>
#include "include/api/types.h"
#include "include/api/data_type.h"
#include "include/lite_types.h"
#include "src/runtime/inner_allocator.h"
#include "src/common/log_adapter.h"
#include "src/extendrt/delegate/tensorrt/distribution/distribution_base.h"
namespace mindspore {
class Allocator {};
struct Context::Data {
std::vector<std::shared_ptr<DeviceInfoContext>> device_info_list;
int32_t thread_num;
bool enable_parallel_ = false;
std::vector<int32_t> affinity_core_list_;
int affinity_mode_ = 2;
};
struct DeviceInfoContext::Data {
std::map<std::string, std::any> params;
};
constexpr auto kModelOptionCpuEnableFP16 = "mindspore.option.cpu.enable_fp16";
constexpr auto kModelOptionGPUEnableFP16 = "mindspore.option.gpu.enable_fp16";
constexpr auto kModelOptionGPUEnableGLTexture = "mindspore.option.gpu.enable_gl_texture_";
constexpr auto kModelOptionGPUGLContext = "mindspore.option.gpu.gl_context_";
constexpr auto kModelOptionGPUGLDisplay = "mindspore.option.gpu.gl_display_";
constexpr auto kModelOptionGPUDeviceID = "mindspore.option.gpu.device_id";
constexpr auto kModelOptionGPURankID = "mindspore.option.gpu.rank_id";
constexpr auto kModelOptionGPUGroupSize = "mindspore.option.gpu.group_size";
constexpr auto kModelOptionKirinNpuFrequency = "mindspore.option.kirin_npu.frequency";
constexpr auto kModelOptionProvider = "mindspore.option.provider";
constexpr auto kModelOptionProviderDevice = "mindspore.option.provider.device";
constexpr auto kModelOptionDeviceID = "mindspore.option.device_id";
constexpr auto kModelOptionAscendDeviceID = kModelOptionDeviceID;
constexpr auto kModelOptionAscendInsertOpCfgPath = "mindspore.option.ascend.insert_op_config_file_path";
constexpr auto kModelOptionAscendInputFormat = "mindspore.option.ascend.input_format";
constexpr auto kModelOptionAscendInputShapeMap = "mindspore.option.ascend.input_shape_map";
constexpr auto kModelOptionAscendInputShape = "mindspore.option.ascend.input_shape";
constexpr auto kModelOptionAscendOutputType = "mindspore.option.ascend.output_type";
constexpr auto kModelOptionAscendPrecisionMode = "mindspore.option.ascend.precision_mode";
constexpr auto kModelOptionAscendOpSelectImplMode = "mindspore.option.ascend.op_select_impl_mode";
constexpr auto KModelOptionAscendFusionSwitchCfgPath = "mindspore.option.ascend.fusion_switch_config_file_path";
constexpr auto kModelOptionAscendDynamicBatchSize = "mindspore.option.ascend.dynamic_batch_size";
constexpr auto kModelOptionAscendDynamicImageSize = "mindspore.option.ascend.dynamic_image_size";
constexpr auto kModelOptionAscendBufferOptimize = "mindspore.option.ascend.buffer_optimize";
Context::Context() : data_(std::make_shared<Data>()) {}
template <class T, typename U = std::remove_cv_t<std::remove_reference_t<T>>>
static const U &GetValue(const std::shared_ptr<DeviceInfoContext::Data> &data, const std::string &key) {
static const U empty_result{};
static U empty_result;
if (data == nullptr) {
return empty_result;
}
@ -66,230 +61,529 @@ static const U &GetValue(const std::shared_ptr<DeviceInfoContext::Data> &data, c
if (iter == data->params.end()) {
return empty_result;
}
#ifndef SUPPORT_NNIE
const std::any &value = iter->second;
if (value.type() != typeid(U)) {
return empty_result;
}
return std::any_cast<const U &>(value);
#else
const std::experimental::any &value = iter->second;
return std::experimental::any_cast<const U &>(value);
#endif
}
void Context::SetThreadNum(int32_t thread_num) {
MS_EXCEPTION_IF_NULL(data_);
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return;
}
data_->thread_num = thread_num;
}
void Context::SetInterOpParallelNum(int32_t parallel_num) {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return;
}
data_->inter_op_parallel_num_ = parallel_num;
}
int32_t Context::GetInterOpParallelNum() const {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return 0;
}
return data_->inter_op_parallel_num_;
}
int32_t Context::GetThreadNum() const {
MS_EXCEPTION_IF_NULL(data_);
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return 0;
}
return data_->thread_num;
}
void Context::SetEnableParallel(bool is_parallel) {
MS_EXCEPTION_IF_NULL(data_);
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return;
}
data_->enable_parallel_ = is_parallel;
}
bool Context::GetEnableParallel() const {
MS_EXCEPTION_IF_NULL(data_);
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return false;
}
return data_->enable_parallel_;
}
void Context::SetThreadAffinity(int mode) {
MS_EXCEPTION_IF_NULL(data_);
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return;
}
if (mode < lite::NO_BIND || mode > lite::MID_CPU) {
MS_LOG(WARNING) << "Invalid thread affinity mode: " << mode << ", change to NO_BIND mode.";
data_->affinity_mode_ = lite::NO_BIND;
return;
}
data_->affinity_mode_ = mode;
return;
}
int Context::GetThreadAffinityMode() const {
MS_EXCEPTION_IF_NULL(data_);
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return -1;
}
return data_->affinity_mode_;
}
void Context::SetThreadAffinity(const std::vector<int> &core_list) {
MS_EXCEPTION_IF_NULL(data_);
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return;
}
data_->affinity_core_list_ = core_list;
return;
}
std::vector<int32_t> Context::GetThreadAffinityCoreList() const {
MS_EXCEPTION_IF_NULL(data_);
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return {};
}
return data_->affinity_core_list_;
}
void Context::SetDelegate(const std::shared_ptr<Delegate> &delegate) {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return;
}
data_->delegate = delegate;
}
std::shared_ptr<Delegate> Context::GetDelegate() const {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return nullptr;
}
return data_->delegate;
}
void Context::SetMultiModalHW(bool float_mode) {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return;
}
data_->float_mode = float_mode;
}
bool Context::GetMultiModalHW() const {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return false;
}
return data_->float_mode;
}
std::vector<std::shared_ptr<DeviceInfoContext>> &Context::MutableDeviceInfo() {
MS_EXCEPTION_IF_NULL(data_);
static std::vector<std::shared_ptr<DeviceInfoContext>> empty{};
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return empty;
}
return data_->device_info_list;
}
DeviceInfoContext::DeviceInfoContext() : data_(std::make_shared<Data>()) {}
std::vector<char> DeviceInfoContext::GetProviderChar() const {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return std::vector<char>();
}
const std::string &ref = GetValue<std::string>(data_, kModelOptionProvider);
return StringToChar(ref);
}
void DeviceInfoContext::SetProvider(const std::vector<char> &provider) {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return;
}
data_->params[kModelOptionProvider] = CharToString(provider);
}
std::vector<char> DeviceInfoContext::GetProviderDeviceChar() const {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return std::vector<char>();
}
const std::string &ref = GetValue<std::string>(data_, kModelOptionProviderDevice);
return StringToChar(ref);
}
void DeviceInfoContext::SetProviderDevice(const std::vector<char> &device) {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return;
}
data_->params[kModelOptionProviderDevice] = CharToString(device);
}
void DeviceInfoContext::SetAllocator(const std::shared_ptr<Allocator> &allocator) {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return;
}
data_->allocator = allocator;
}
std::shared_ptr<Allocator> DeviceInfoContext::GetAllocator() const {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return nullptr;
}
return data_->allocator;
}
void CPUDeviceInfo::SetEnableFP16(bool is_fp16) {
MS_EXCEPTION_IF_NULL(data_);
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return;
}
data_->params[kModelOptionCpuEnableFP16] = is_fp16;
}
bool CPUDeviceInfo::GetEnableFP16() const {
MS_EXCEPTION_IF_NULL(data_);
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return false;
}
return GetValue<bool>(data_, kModelOptionCpuEnableFP16);
}
void GPUDeviceInfo::SetEnableFP16(bool is_fp16) {
MS_EXCEPTION_IF_NULL(data_);
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return;
}
data_->params[kModelOptionGPUEnableFP16] = is_fp16;
}
bool GPUDeviceInfo::GetEnableFP16() const {
MS_EXCEPTION_IF_NULL(data_);
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return false;
}
return GetValue<bool>(data_, kModelOptionGPUEnableFP16);
}
void GPUDeviceInfo::SetEnableGLTexture(bool is_enable_gl_texture) {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return;
}
data_->params[kModelOptionGPUEnableGLTexture] = is_enable_gl_texture;
}
bool GPUDeviceInfo::GetEnableGLTexture() const {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return false;
}
return GetValue<bool>(data_, kModelOptionGPUEnableGLTexture);
}
void GPUDeviceInfo::SetGLContext(void *gl_context) {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return;
}
data_->params[kModelOptionGPUGLContext] = gl_context;
}
void *GPUDeviceInfo::GetGLContext() const {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return nullptr;
}
return GetValue<void *>(data_, kModelOptionGPUGLContext);
}
void GPUDeviceInfo::SetGLDisplay(void *gl_display) {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return;
}
data_->params[kModelOptionGPUGLDisplay] = gl_display;
}
void *GPUDeviceInfo::GetGLDisplay() const {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return nullptr;
}
return GetValue<void *>(data_, kModelOptionGPUGLDisplay);
}
void KirinNPUDeviceInfo::SetFrequency(int frequency) {
MS_EXCEPTION_IF_NULL(data_);
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return;
}
data_->params[kModelOptionKirinNpuFrequency] = frequency;
}
int KirinNPUDeviceInfo::GetFrequency() const {
MS_EXCEPTION_IF_NULL(data_);
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return 0;
}
return GetValue<int>(data_, kModelOptionKirinNpuFrequency);
}
void GPUDeviceInfo::SetDeviceID(uint32_t device_id) {
MS_EXCEPTION_IF_NULL(data_);
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return;
}
data_->params[kModelOptionGPUDeviceID] = device_id;
}
uint32_t GPUDeviceInfo::GetDeviceID() const {
MS_EXCEPTION_IF_NULL(data_);
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return 0;
}
return GetValue<uint32_t>(data_, kModelOptionGPUDeviceID);
}
int GPUDeviceInfo::GetRankID() const {
MS_LOG(ERROR) << "Unsupported Feature.";
return 0;
data_->params[kModelOptionGPURankID] = lite::GetRankID();
return GetValue<int>(data_, kModelOptionGPURankID);
}
int GPUDeviceInfo::GetGroupSize() const {
MS_LOG(ERROR) << "Unsupported Feature.";
return 0;
data_->params[kModelOptionGPUGroupSize] = lite::GetGPUGroupSize();
return GetValue<int>(data_, kModelOptionGPUGroupSize);
}
void GPUDeviceInfo::SetPrecisionMode(const std::vector<char> &precision_mode) {
MS_EXCEPTION_IF_NULL(data_);
data_->params[kModelOptionGPUPrecisionMode] = CharToString(precision_mode);
MS_LOG(ERROR) << "Unsupported Feature.";
}
std::vector<char> GPUDeviceInfo::GetPrecisionModeChar() const {
MS_EXCEPTION_IF_NULL(data_);
const std::string &ref = GetValue<std::string>(data_, kModelOptionGPUPrecisionMode);
return StringToChar(ref);
MS_LOG(ERROR) << "Unsupported Feature.";
std::vector<char> ret;
return ret;
}
void AscendDeviceInfo::SetDeviceID(uint32_t device_id) {
MS_EXCEPTION_IF_NULL(data_);
data_->params[kModelOptionAscend310DeviceID] = device_id;
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return;
}
data_->params[kModelOptionAscendDeviceID] = device_id;
}
uint32_t AscendDeviceInfo::GetDeviceID() const {
MS_EXCEPTION_IF_NULL(data_);
return GetValue<uint32_t>(data_, kModelOptionAscend310DeviceID);
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return 0;
}
return GetValue<uint32_t>(data_, kModelOptionAscendDeviceID);
}
void AscendDeviceInfo::SetInsertOpConfigPath(const std::vector<char> &cfg_path) {
MS_EXCEPTION_IF_NULL(data_);
data_->params[kModelOptionAscend310InsertOpCfgPath] = CharToString(cfg_path);
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return;
}
data_->params[kModelOptionAscendInsertOpCfgPath] = CharToString(cfg_path);
}
std::vector<char> AscendDeviceInfo::GetInsertOpConfigPathChar() const {
MS_EXCEPTION_IF_NULL(data_);
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310InsertOpCfgPath);
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return std::vector<char>();
}
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscendInsertOpCfgPath);
return StringToChar(ref);
}
void AscendDeviceInfo::SetInputFormat(const std::vector<char> &format) {
MS_EXCEPTION_IF_NULL(data_);
data_->params[kModelOptionAscend310InputFormat] = CharToString(format);
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return;
}
data_->params[kModelOptionAscendInputFormat] = CharToString(format);
}
std::vector<char> AscendDeviceInfo::GetInputFormatChar() const {
MS_EXCEPTION_IF_NULL(data_);
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310InputFormat);
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return std::vector<char>();
}
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscendInputFormat);
return StringToChar(ref);
}
void AscendDeviceInfo::SetInputShape(const std::vector<char> &shape) {
MS_EXCEPTION_IF_NULL(data_);
data_->params[kModelOptionAscend310InputShape] = CharToString(shape);
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return;
}
data_->params[kModelOptionAscendInputShape] = CharToString(shape);
}
std::vector<char> AscendDeviceInfo::GetInputShapeChar() const {
MS_EXCEPTION_IF_NULL(data_);
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310InputShape);
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return std::vector<char>();
}
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscendInputShape);
return StringToChar(ref);
}
void AscendDeviceInfo::SetDynamicBatchSize(const std::vector<size_t> &dynamic_batch_size) {
MS_EXCEPTION_IF_NULL(data_);
std::string batchs = "";
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return;
}
std::string batchs;
for (size_t i = 0; i < dynamic_batch_size.size(); ++i) {
if (i != 0) {
batchs.push_back(',');
}
batchs += std::to_string(dynamic_batch_size[i]);
}
data_->params[kModelOptionAscend310DynamicBatchSize] = batchs;
data_->params[kModelOptionAscendDynamicBatchSize] = batchs;
}
std::vector<char> AscendDeviceInfo::GetDynamicBatchSizeChar() const {
MS_EXCEPTION_IF_NULL(data_);
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310DynamicBatchSize);
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return std::vector<char>();
}
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscendDynamicBatchSize);
return StringToChar(ref);
}
void AscendDeviceInfo::SetDynamicImageSize(const std::vector<char> &) { return; }
void AscendDeviceInfo::SetDynamicImageSize(const std::vector<char> &dynamic_image_size) {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return;
}
data_->params[kModelOptionAscendDynamicImageSize] = CharToString(dynamic_image_size);
}
std::vector<char> AscendDeviceInfo::GetDynamicImageSizeChar() const { return std::vector<char>(); }
std::vector<char> AscendDeviceInfo::GetDynamicImageSizeChar() const {
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return std::vector<char>();
}
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscendDynamicImageSize);
return StringToChar(ref);
}
void AscendDeviceInfo::SetPrecisionMode(const std::vector<char> &precision_mode) {
MS_EXCEPTION_IF_NULL(data_);
data_->params[kModelOptionAscend310PrecisionMode] = CharToString(precision_mode);
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return;
}
data_->params[kModelOptionAscendPrecisionMode] = CharToString(precision_mode);
}
std::vector<char> AscendDeviceInfo::GetPrecisionModeChar() const {
MS_EXCEPTION_IF_NULL(data_);
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310PrecisionMode);
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return std::vector<char>();
}
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscendPrecisionMode);
return StringToChar(ref);
}
void AscendDeviceInfo::SetOpSelectImplMode(const std::vector<char> &op_select_impl_mode) {
MS_EXCEPTION_IF_NULL(data_);
data_->params[kModelOptionAscend310OpSelectImplMode] = CharToString(op_select_impl_mode);
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return;
}
data_->params[kModelOptionAscendOpSelectImplMode] = CharToString(op_select_impl_mode);
}
std::vector<char> AscendDeviceInfo::GetOpSelectImplModeChar() const {
MS_EXCEPTION_IF_NULL(data_);
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310OpSelectImplMode);
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return std::vector<char>();
}
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscendOpSelectImplMode);
return StringToChar(ref);
}
void AscendDeviceInfo::SetFusionSwitchConfigPath(const std::vector<char> &cfg_path) {
MS_EXCEPTION_IF_NULL(data_);
data_->params[KModelOptionAscend310FusionSwitchCfgPath] = CharToString(cfg_path);
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return;
}
data_->params[KModelOptionAscendFusionSwitchCfgPath] = CharToString(cfg_path);
}
std::vector<char> AscendDeviceInfo::GetFusionSwitchConfigPathChar() const {
MS_EXCEPTION_IF_NULL(data_);
const std::string &ref = GetValue<std::string>(data_, KModelOptionAscend310FusionSwitchCfgPath);
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return std::vector<char>();
}
const std::string &ref = GetValue<std::string>(data_, KModelOptionAscendFusionSwitchCfgPath);
return StringToChar(ref);
}
void AscendDeviceInfo::SetInputShapeMap(const std::map<int, std::vector<int>> &shape) {
MS_EXCEPTION_IF_NULL(data_);
data_->params[kModelOptionAscend310InputShapeMap] = shape;
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return;
}
data_->params[kModelOptionAscendInputShapeMap] = shape;
}
std::map<int, std::vector<int>> AscendDeviceInfo::GetInputShapeMap() const {
MS_EXCEPTION_IF_NULL(data_);
return GetValue<std::map<int, std::vector<int>>>(data_, kModelOptionAscend310InputShapeMap);
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return std::map<int, std::vector<int>>();
}
return GetValue<std::map<int, std::vector<int>>>(data_, kModelOptionAscendInputShapeMap);
}
void AscendDeviceInfo::SetOutputType(enum DataType output_type) {
MS_EXCEPTION_IF_NULL(data_);
data_->params[kModelOptionAscend310OutputType] = output_type;
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return;
}
data_->params[kModelOptionAscendOutputType] = output_type;
}
enum DataType AscendDeviceInfo::GetOutputType() const {
MS_EXCEPTION_IF_NULL(data_);
return GetValue<enum DataType>(data_, kModelOptionAscend310OutputType);
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return DataType::kTypeUnknown;
}
return GetValue<enum DataType>(data_, kModelOptionAscendOutputType);
}
void AscendDeviceInfo::SetBufferOptimizeMode(const std::vector<char> &buffer_optimize_mode) {
MS_EXCEPTION_IF_NULL(data_);
data_->params[kModelOptionAscend310BufferOptimize] = CharToString(buffer_optimize_mode);
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return;
}
data_->params[kModelOptionAscendBufferOptimize] = CharToString(buffer_optimize_mode);
}
std::vector<char> AscendDeviceInfo::GetBufferOptimizeModeChar() const {
MS_EXCEPTION_IF_NULL(data_);
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310BufferOptimize);
if (data_ == nullptr) {
MS_LOG(ERROR) << "Invalid context.";
return std::vector<char>();
}
const std::string &ref = GetValue<std::string>(data_, kModelOptionAscendBufferOptimize);
return StringToChar(ref);
}
} // namespace mindspore

View File

@ -22,31 +22,31 @@
#include <string>
#include <fstream>
#include "utils/file_utils.h"
#include "include/api/status.h"
namespace mindspore {
inline Status DLSoPath(std::string *so_path) {
if (so_path == nullptr) {
inline Status DLSoPath(const std::string &benchmark_so, const std::string &target_so, std::string *target_so_path) {
if (target_so_path == nullptr) {
return Status(kMEFailed, "Input so_path can not be nullptr.");
}
Dl_info dl_info;
dladdr(reinterpret_cast<void *>(DLSoPath), &dl_info);
std::string libmindspore_so = dl_info.dli_fname;
std::string cur_so_path = dl_info.dli_fname;
auto pos = libmindspore_so.find("libmindspore.so");
auto pos = cur_so_path.find(benchmark_so);
if (pos == std::string::npos) {
return Status(kMEFailed, "Could not find libmindspore.so, check path.");
return Status(kMEFailed, "Could not find benchmark so " + benchmark_so + " check path.");
}
std::string parent_dir = libmindspore_so.substr(0, pos) + "../";
std::string c_dataengine_so;
std::string parent_dir = cur_so_path.substr(0, pos);
std::string found_target_so;
DIR *dir = opendir(parent_dir.c_str());
if (dir != nullptr) {
// access all the files and directories within directory
dirent *ent = readdir(dir);
while (ent != nullptr) {
if (std::string(ent->d_name).find("_c_dataengine") != std::string::npos) {
c_dataengine_so = std::string(ent->d_name);
if (std::string(ent->d_name).find(target_so) != std::string::npos) {
found_target_so = std::string(ent->d_name);
break;
}
ent = readdir(dir);
@ -55,14 +55,17 @@ inline Status DLSoPath(std::string *so_path) {
} else {
return Status(kMEFailed, "Could not open directory: " + parent_dir);
}
std::string unreal_path = parent_dir + c_dataengine_so;
if (found_target_so.empty()) {
MS_LOG(WARNING) << target_so << "is not existed in dir " << parent_dir;
return kSuccess;
}
std::string unreal_path = parent_dir + found_target_so;
auto realpath = FileUtils::GetRealPath(unreal_path.c_str());
if (!realpath.has_value()) {
return Status(kMEFailed, "Get c_dataengine_so real path failed, path: " + unreal_path);
return Status(kMEFailed, "Get target so " + target_so + " real path failed, path: " + unreal_path);
}
*so_path = realpath.value();
*target_so_path = realpath.value();
return kSuccess;
}

View File

@ -60,6 +60,13 @@ Status Model::Build(const std::vector<char> &model_path, ModelType model_type,
return kSuccess;
}
// to do, now just to adapter benchmark
Status Model::Build(const std::vector<char> &model_path, ModelType model_type,
const std::shared_ptr<Context> &model_context, const Key &dec_key, const std::string &dec_mode,
const std::vector<char> &cropto_lib_path) {
return Build(model_path, model_type, model_context);
}
Status Model::Build(GraphCell graph, const std::shared_ptr<Context> &model_context,
const std::shared_ptr<TrainCfg> &train_cfg) {
MS_LOG(ERROR) << "Unsupported Feature.";
@ -152,4 +159,16 @@ MSTensor Model::GetOutputByTensorName(const std::vector<char> &name) {
std::vector<MSTensor> Model::GetOutputsByNodeName(const std::vector<char> &node_name) {
return std::vector<MSTensor>();
}
Status Model::BindGLTexture2DMemory(const std::map<std::string, unsigned int> &inputGLTexture,
std::map<std::string, unsigned int> *outputGLTexture) {
return kSuccess;
}
Status Model::LoadConfig(const std::vector<char> &config_path) { return kSuccess; }
Status Model::UpdateConfig(const std::vector<char> &section,
const std::pair<std::vector<char>, std::vector<char>> &config) {
return kSuccess;
}
} // namespace mindspore

View File

@ -86,8 +86,33 @@ std::vector<MSTensor> ModelImpl::GetOutputs() {
}
MSTensor ModelImpl::GetInputByTensorName(const std::string &name) { return MSTensor(); }
std::vector<std::string> ModelImpl::GetOutputTensorNames() { return std::vector<std::string>(); }
MSTensor ModelImpl::GetOutputByTensorName(const std::string &name) { return MSTensor(); }
std::vector<std::string> ModelImpl::GetOutputTensorNames() {
if (session_ == nullptr) {
MS_LOG(ERROR) << "Session is null.";
std::vector<std::string> empty;
return empty;
}
return session_->GetOutputNames();
}
MSTensor ModelImpl::GetOutputByTensorName(const std::string &name) {
if (session_ == nullptr) {
MS_LOG(ERROR) << "Session is null.";
return MSTensor(nullptr);
}
auto tensor_ptr = session_->GetOutputByTensorName(name);
if (tensor_ptr == nullptr) {
MS_LOG(ERROR) << "Model does not contains tensor " << name << " .";
return MSTensor(nullptr);
}
auto ms_outputs = TensorPtrToMSTensor({tensor_ptr}, {name});
if (ms_outputs.empty()) {
MS_LOG(ERROR) << "Tensor to ms tensor failed." << name << " .";
return MSTensor(nullptr);
}
return ms_outputs[0];
}
Status ModelImpl::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) {
MS_EXCEPTION_IF_NULL(session_);
@ -144,7 +169,7 @@ Status ModelImpl::Preprocess(const std::vector<std::vector<MSTensor>> &inputs, s
#if !defined(_WIN32) && !defined(_WIN64)
// Config preprocessor, temporary way to let mindspore.so depends on _c_dataengine
std::string dataengine_so_path;
Status dlret = DLSoPath(&dataengine_so_path);
Status dlret = DLSoPath("libmindspore.so", "_c_dataengine", &dataengine_so_path);
CHECK_FAIL_AND_RELEASE(dlret, nullptr, "Parse dataengine_so failed: " + dlret.GetErrDescription());
// Run preprocess

View File

@ -195,7 +195,7 @@ Status Serialization::Load(const std::vector<char> &file, ModelType model_type,
std::vector<std::string> preprocessor = mindir_loader.LoadPreprocess(file_path);
if (!preprocessor.empty()) {
std::string dataengine_so_path;
Status dlret = DLSoPath(&dataengine_so_path);
Status dlret = DLSoPath("libmindspore.so", "_c_dataengine", &dataengine_so_path);
CHECK_FAIL_AND_RELEASE(dlret, nullptr, "Parse dataengine_so failed: " + dlret.GetErrDescription());
void *handle = nullptr;
@ -276,7 +276,7 @@ Status Serialization::Load(const std::vector<std::vector<char>> &files, ModelTyp
#if !defined(_WIN32) && !defined(_WIN64)
// Dataset so loading
std::string dataengine_so_path;
Status dlret = DLSoPath(&dataengine_so_path);
Status dlret = DLSoPath("libmindspore.so", "_c_dataengine", &dataengine_so_path);
CHECK_FAIL_AND_RELEASE(dlret, nullptr, "Parse dataengine_so failed: " + dlret.GetErrDescription());
void *handle = nullptr;

View File

@ -5,22 +5,17 @@ include_directories(${TOP_DIR}/mindspore/lite/src)
find_library(ge_graph libgraph.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
file(GLOB_RECURSE ASCEND_SRC ${CMAKE_CURRENT_SOURCE_DIR}
"custom_ascend_kernel.cc"
"src/*.cc"
"api/*.cc"
"model/*.cc"
)
add_library(ascend_kernel_mid OBJECT ${ASCEND_SRC})
add_library(ascend_kernel_plugin SHARED ${ASCEND_SRC})
add_dependencies(ascend_kernel_plugin fbs_inner_src)
add_dependencies(ascend_kernel_mid fbs_inner_src)
if("${MSLITE_REGISTRY_DEVICE}" STREQUAL "SD3403" AND PLATFORM_ARM64)
find_library(ge_graph libgraph.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
find_library(acl libascendcl.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
find_library(acl_retr libacl_retr.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
find_library(acl_cblas libacl_cblas.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
find_library(acl_runtime libruntime.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
target_link_libraries(ascend_kernel_mid ${ge_graph} ${acl} ${acl_retr} ${acl_cblas} ${acl_runtime})
else()
target_link_libraries(ascend_kernel_mid ${ge_graph} ${ge_compiler}
${acl_retr} ${acl_cblas} ${acl_dvpp} ${acl_runtime} ${libplatform}
${libcompress} ${libopskernel} ${libaicore_utils} ${libaicpu_engine_common} ${acl})
endif()
find_library(ge_graph libgraph.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
find_library(acl libascendcl.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
find_library(acl_retr libacl_retr.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
find_library(acl_cblas libacl_cblas.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
find_library(acl_runtime libruntime.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
target_link_libraries(ascend_kernel_plugin ${ge_graph} ${acl} ${acl_retr} ${acl_cblas} ${acl_runtime})

View File

@ -0,0 +1,38 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "extendrt/kernel/ascend/api/ascend_kernel_api.h"
constexpr auto kNameCustomAscend = "CustomAscend";
std::map<std::string, CreatorFunc> *CreateCustomAscendKernel() {
CreatorFunc creator_func = []() { return std::make_shared<mindspore::kernel::acl::CustomAscendKernelMod>(); };
std::map<std::string, CreatorFunc> *func_map = new (std::nothrow) std::map<std::string, CreatorFunc>();
if (func_map == nullptr) {
MS_LOG(ERROR) << "New custom ascend kernel failed.";
return {};
}
(*func_map)[kNameCustomAscend] = creator_func;
return func_map;
}
void DestroyCustomAscendKernel(std::map<std::string, CreatorFunc> *creator_func) {
if (creator_func == nullptr) {
MS_LOG(ERROR) << "Param creator func is nullptr.";
return;
}
delete creator_func;
}

View File

@ -0,0 +1,37 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_ASCEND_KERNEL_API_H_
#define MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_ASCEND_KERNEL_API_H_
#include <map>
#include <memory>
#include <string>
#include "extendrt/kernel/ascend/src/custom_ascend_kernel.h"
#ifdef __cplusplus
extern "C" {
#endif
using CreatorFunc = std::function<std::shared_ptr<mindspore::kernel::KernelMod>()>;
std::map<std::string, CreatorFunc> *CreateCustomAscendKernel();
void DestroyCustomAscendKernel(std::map<std::string, CreatorFunc> *creator_func);
#ifdef __cplusplus
}
#endif
#endif // MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_ASCEND_KERNEL_API_H_

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2021-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2021-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2021-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2021-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2021-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
* Copyright 2021-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2021-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -0,0 +1,102 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "extendrt/kernel/ascend/plugin/ascend_kernel_plugin.h"
#include <map>
#include "utils/log_adapter.h"
#include "include/errorcode.h"
#include "plugin/factory/ms_factory.h"
#if !defined(_WIN32) && !defined(_WIN64)
#include <dlfcn.h>
#include "extendrt/cxx_api/dlutils.h"
#endif
namespace mindspore::kernel {
AscendKernelPlugin &AscendKernelPlugin::GetInstance() {
static AscendKernelPlugin instance;
return instance;
}
AscendKernelPlugin::AscendKernelPlugin() : handle_(nullptr), create_kernel_map_(nullptr), is_registered_(false) {}
void AscendKernelPlugin::Register() {
#if !defined(_WIN32) && !defined(_WIN64)
if (is_registered_) {
return;
}
std::string ascend_kernel_plugin_path;
auto ret = DLSoPath("libmindspore-extendrt.so", "libascend_kernel_plugin.so", &ascend_kernel_plugin_path);
if (ret != kSuccess) {
MS_LOG(ERROR) << "Get real path of libascend_kernel_plugin.so failed.";
return;
}
if (ret == kSuccess && ascend_kernel_plugin_path.empty()) {
return;
}
MS_LOG(INFO) << "Find ascend kernel plugin so success, path = " << ascend_kernel_plugin_path;
void *function = nullptr;
ret = DLSoOpen(ascend_kernel_plugin_path, "CreateCustomAscendKernel", &handle_, &function);
if (ret != kSuccess) {
MS_LOG(ERROR) << "DLSoOpen failed, so path: " << ascend_kernel_plugin_path;
return;
}
auto create_kernel_func = reinterpret_cast<std::map<std::string, KernelModFunc> *(*)(void)>(function);
if (create_kernel_func == nullptr) {
MS_LOG(ERROR) << "Cast CreateCustomAscendKernel failed.";
return;
}
create_kernel_map_ = create_kernel_func();
if (create_kernel_map_ == nullptr) {
MS_LOG(ERROR) << "Create custom ascend kernel failed.";
return;
}
// register
for (auto &kernel : *create_kernel_map_) {
static KernelRegistrar<kernel::KernelMod> ascend_kernel_reg(kernel.first, kernel.second);
}
is_registered_ = true;
MS_LOG(INFO) << "Register ascend kernel plugin success.";
#endif
}
void AscendKernelPlugin::DestroyAscendKernelMap() {
#if !defined(_WIN32) && !defined(_WIN64)
if (handle_ == nullptr) {
MS_LOG(DEBUG) << "Handle is nullptr.";
return;
}
auto destroy_map_func =
reinterpret_cast<void (*)(std::map<std::string, KernelModFunc> *)>(dlsym(handle_, "DestroyCustomAscendKernel"));
if (destroy_map_func == nullptr) {
MS_LOG(ERROR) << "Undefined symbol DestroyCustomAscendKernel in ['libascend_kernel_plugin.so'].";
return;
}
destroy_map_func(create_kernel_map_);
#endif
}
AscendKernelPlugin::~AscendKernelPlugin() {
#if !defined(_WIN32) && !defined(_WIN64)
MS_LOG(DEBUG) << "~AscendKernelPlugin() begin.";
DestroyAscendKernelMap();
if (handle_ != nullptr) {
(void)dlclose(handle_);
handle_ = nullptr;
}
MS_LOG(DEBUG) << "~AscendKernelPlugin() end.";
#endif
}
} // namespace mindspore::kernel

View File

@ -0,0 +1,44 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_ASCEND_KERNEL_PLUGIN_H_
#define MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_ASCEND_KERNEL_PLUGIN_H_
#include <map>
#include <string>
#include <memory>
#include "kernel/kernel.h"
namespace mindspore::kernel {
using KernelModFunc = std::function<std::shared_ptr<kernel::KernelMod>()>;
class AscendKernelPlugin {
public:
static AscendKernelPlugin &GetInstance();
void Register();
void DestroyAscendKernelMap();
private:
AscendKernelPlugin();
~AscendKernelPlugin();
void *handle_;
std::map<std::string, KernelModFunc> *create_kernel_map_;
bool is_registered_;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_ASCEND_KERNEL_PLUGIN_H_

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -14,7 +14,8 @@
* limitations under the License.
*/
#include "extendrt/kernel/ascend/custom_ascend_kernel.h"
#include "extendrt/kernel/ascend/src/custom_ascend_kernel.h"
#include <utility>
#include "include/registry/register_kernel.h"
#include "include/api/types.h"
@ -198,45 +199,6 @@ bool CustomAscendKernelMod::Launch(const std::vector<AddressPtr> &inputs, const
return true;
}
// std::shared_ptr<kernel::Kernel> CustomCreateKernel(const std::vector<mindspore::MSTensor> &inputs,
// const std::vector<mindspore::MSTensor> &outputs,
// const schema::Primitive *primitive, const mindspore::Context *ctx)
// {
// if (primitive == nullptr) {
// MS_LOG(ERROR) << "Primitive is nullptr.";
// return nullptr;
// }
// if (primitive->value_type() != schema::PrimitiveType_Custom) {
// MS_LOG(ERROR) << "Primitive type is not PrimitiveType_Custom";
// return nullptr;
// }
//
// auto kernel = std::make_shared<CustomAscendKernel>(inputs, outputs, primitive, ctx);
// if (kernel == nullptr) {
// MS_LOG(ERROR) << "New custom kernel is nullptr";
// return nullptr;
// }
// return kernel;
// }
MS_KERNEL_FACTORY_REG(KernelMod, CustomAscend, CustomAscendKernelMod);
} // namespace acl
} // namespace mindspore::kernel
namespace mindspore {
namespace registry {
namespace {
const auto kFloat32 = DataType::kNumberTypeFloat32;
const auto kFloat16 = DataType::kNumberTypeFloat16;
const auto kInt32 = DataType::kNumberTypeInt32;
const auto kInt8 = DataType::kNumberTypeInt8;
const auto kUInt8 = DataType::kNumberTypeUInt8;
const auto kBool = DataType::kNumberTypeBool;
} // namespace
// REGISTER_CUSTOM_KERNEL(ASCEND, ACL, kFloat32, ACL, kernel::acl::CustomCreateKernel)
// REGISTER_CUSTOM_KERNEL(ASCEND, ACL, kFloat16, ACL, kernel::acl::CustomCreateKernel)
// REGISTER_CUSTOM_KERNEL(ASCEND, ACL, kInt32, ACL, kernel::acl::CustomCreateKernel)
// REGISTER_CUSTOM_KERNEL(ASCEND, ACL, kInt8, ACL, kernel::acl::CustomCreateKernel)
// REGISTER_CUSTOM_KERNEL(ASCEND, ACL, kUInt8, ACL, kernel::acl::CustomCreateKernel)
// REGISTER_CUSTOM_KERNEL(ASCEND, ACL, kBool, ACL, kernel::acl::CustomCreateKernel)
} // namespace registry
} // namespace mindspore

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -0,0 +1,424 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "extendrt/kernel/cpu/transpose_kernel_mod.h"
#include <vector>
#include <memory>
#include "plugin/factory/ms_factory.h"
#include "include/api/status.h"
#include "plugin/device/cpu/kernel/nnacl/errorcode.h"
namespace mindspore::kernel {
namespace {
constexpr size_t kTransposeInputsNum = 2;
constexpr size_t kTransposeOutputsNum = 1;
constexpr size_t kIndex0 = 0;
constexpr size_t kIndex1 = 1;
constexpr size_t kIndex2 = 2;
constexpr size_t kIndex3 = 3;
constexpr size_t kIndex4 = 4;
constexpr size_t kIndex5 = 5;
constexpr size_t kIndex6 = 6;
constexpr size_t kIndex7 = 7;
// kMaxTransposeSerialSize = 64 * 3 * 512 * 512
constexpr size_t kMaxTransposeSerialSize = 50331648;
} // namespace
bool TransposeKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kTransposeInputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kTransposeOutputsNum, kernel_name_);
launch_func_(this, inputs, outputs);
return true;
}
int TransposeKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &) {
return kSuccess;
}
bool TransposeKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
kernel_name_ = base_operator->name();
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kTransposeInputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kTransposeOutputsNum, kernel_name_);
input_shape_ = inputs[kIndex0]->GetShapeVector();
output_shape_ = outputs[kIndex0]->GetShapeVector();
auto address_ptr = inputs[kIndex1]->GetData();
int *addr = static_cast<int *>(address_ptr->addr);
std::vector<int64_t> perm;
for (size_t i = 0; i < (address_ptr->size) / sizeof(int); ++i) {
perm.emplace_back(static_cast<int64_t>(addr[i]));
}
for (auto p : perm) {
p = (p >= 0) ? p : (perm.size() + p);
if (p < 0) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the perm value must be in [-" << perm.size() << ", "
<< (perm.size() - 1) << "], but got " << perm;
}
axes_.emplace_back(p);
}
dtype_ = inputs[kIndex0]->GetDtype();
if (axes_.size() > MAX_TRANSPOSE_DIM_SIZE) {
MS_LOG(EXCEPTION) << "Transpose support max dimension is " << MAX_TRANSPOSE_DIM_SIZE << "D, but got "
<< axes_.size() << "D.";
}
for (size_t i = 0; i < axes_.size(); ++i) {
transpose_param_.perm_[i] = SizeToInt(axes_[i]);
}
size_t num_axes = input_shape_.size();
transpose_param_.perm_size_ = axes_.size();
transpose_param_.num_axes_ = SizeToInt(num_axes);
transpose_param_.strides_[num_axes - 1] = 1;
transpose_param_.out_strides_[num_axes - 1] = 1;
for (size_t i = num_axes - 1; i >= 1; i--) {
transpose_param_.strides_[i - 1] = input_shape_[i] * transpose_param_.strides_[i];
transpose_param_.out_strides_[i - 1] = output_shape_[i] * transpose_param_.out_strides_[i];
}
launch_map_[kNumberTypeBool] = &TransposeKernelMod::LaunchKernel<bool>;
launch_map_[kNumberTypeInt8] = &TransposeKernelMod::LaunchKernel<int8_t>;
launch_map_[kNumberTypeInt16] = &TransposeKernelMod::LaunchKernel<int16_t>;
launch_map_[kNumberTypeInt32] = &TransposeKernelMod::LaunchKernel<int32_t>;
launch_map_[kNumberTypeInt64] = &TransposeKernelMod::LaunchKernel<int64_t>;
launch_map_[kNumberTypeUInt8] = &TransposeKernelMod::LaunchKernel<uint8_t>;
launch_map_[kNumberTypeUInt16] = &TransposeKernelMod::LaunchKernel<uint16_t>;
launch_map_[kNumberTypeUInt32] = &TransposeKernelMod::LaunchKernel<uint32_t>;
launch_map_[kNumberTypeUInt64] = &TransposeKernelMod::LaunchKernel<uint64_t>;
launch_map_[kNumberTypeFloat16] = &TransposeKernelMod::LaunchKernel<float16>;
launch_map_[kNumberTypeFloat32] = &TransposeKernelMod::LaunchKernel<float>;
launch_map_[kNumberTypeFloat64] = &TransposeKernelMod::LaunchKernel<double>;
auto iter = launch_map_.find(dtype_);
if (iter != launch_map_.end()) {
launch_func_ = iter->second;
} else {
MS_LOG(EXCEPTION) << "Unsupported input data type: " << dtype_;
}
return true;
}
template <typename T>
void TransposeKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs) {
const auto *input_addr = reinterpret_cast<T *>(inputs[0]->addr);
auto *output_addr = reinterpret_cast<T *>(outputs[0]->addr);
transpose_param_.data_num_ = SizeToInt(inputs[0]->size / sizeof(T));
int output_shape[SizeToInt(output_shape_.size())];
for (size_t i = 0; i < output_shape_.size(); ++i) {
output_shape[i] = output_shape_[i];
}
bool res{static_cast<bool>(NNACL_OK)};
res = DoTranspose(input_addr, output_addr, output_shape, &transpose_param_);
if (res != static_cast<bool>(NNACL_OK)) {
MS_LOG(EXCEPTION) << "Transpose run failed.";
}
}
template <typename T>
int TransposeKernelMod::DoTranspose(const T *in_data, T *out_data, const int *output_shape,
const TransposeParameter *transpose_param) {
NNACL_CHECK_NULL_RETURN_ERR(in_data);
NNACL_CHECK_NULL_RETURN_ERR(out_data);
NNACL_CHECK_NULL_RETURN_ERR(output_shape);
NNACL_CHECK_NULL_RETURN_ERR(transpose_param);
const int *perm = transpose_param->perm_;
const int *strides = transpose_param->strides_;
const int *out_strides = transpose_param->out_strides_;
int data_size = transpose_param->data_num_ * sizeof(T);
int num_axes = transpose_param->num_axes_;
bool needTranspose = false;
for (size_t i = 1; i < (unsigned int)num_axes; ++i) {
if (perm[i] - perm[i - 1] != 1) {
needTranspose = true;
break;
}
}
if (!needTranspose) {
(void)memcpy(out_data, in_data, data_size);
return NNACL_OK;
}
for (size_t i = 0; i < (unsigned int)num_axes; ++i) {
if (perm[i] < 0) {
return NNACL_PARAM_INVALID;
}
}
if (num_axes == kIndex2) {
TransposeDim2(in_data, out_data, strides, out_strides, perm, output_shape);
} else if (num_axes == kIndex3) {
TransposeDim3(in_data, out_data, strides, out_strides, perm, output_shape);
} else if (num_axes == kIndex4) {
TransposeDim4(in_data, out_data, strides, out_strides, perm, output_shape);
} else if (num_axes == kIndex5) {
TransposeDim5(in_data, out_data, strides, out_strides, perm, output_shape);
} else if (num_axes == kIndex6) {
TransposeDim6(in_data, out_data, strides, out_strides, perm, output_shape);
} else if (num_axes == kIndex7) {
TransposeDim7(in_data, out_data, strides, out_strides, perm, output_shape);
} else {
return NNACL_ERR;
}
return NNACL_OK;
}
template <typename T>
void TransposeKernelMod::TransposeDim2(const T *in_data, T *out_data, const int *strides, const int *out_strides,
const int *perm, const int *output_shape) {
const int stride0 = strides[perm[kIndex0]];
const int stride1 = strides[perm[kIndex1]];
const int output0 = output_shape[kIndex0];
const int output1 = output_shape[kIndex1];
for (size_t i = 0; i < (unsigned int)output0; ++i) {
size_t out_stride0_i = i * output1;
size_t stride0_i = i * 1 * stride0;
for (size_t j = 0; j < (unsigned int)output1; ++j) {
out_data[out_stride0_i + j] = in_data[stride0_i + j * stride1];
}
}
}
template <typename T>
void TransposeKernelMod::TransposeDim3(const T *in_data, T *out_data, const int *strides, const int *out_strides,
const int *perm, const int *output_shape) {
const int stride0 = strides[perm[kIndex0]];
const int stride1 = strides[perm[kIndex1]];
const int stride2 = strides[perm[kIndex2]];
const int out_stride0 = out_strides[kIndex0];
const int out_stride1 = out_strides[kIndex1];
const int output0 = output_shape[kIndex0];
const int output1 = output_shape[kIndex1];
const int output2 = output_shape[kIndex2];
for (size_t i = 0; i < (unsigned int)output0; ++i) {
size_t out_stride0_i = i * out_stride0;
size_t stride0_i = i * stride0;
for (size_t j = 0; j < (unsigned int)output1; ++j) {
size_t out_stride1_j = j * out_stride1;
size_t stride1_j = j * stride1;
for (size_t k = 0; k < (unsigned int)output2; ++k) {
out_data[out_stride0_i + out_stride1_j + k] = in_data[stride0_i + stride1_j + k * stride2];
}
}
}
}
template <typename T>
void TransposeKernelMod::TransposeDim4(const T *in_data, T *out_data, const int *strides, const int *out_strides,
const int *perm, const int *output_shape) {
const int stride0 = strides[perm[kIndex0]];
const int stride1 = strides[perm[kIndex1]];
const int stride2 = strides[perm[kIndex2]];
const int stride3 = strides[perm[kIndex3]];
const int out_stride0 = out_strides[kIndex0];
const int out_stride1 = out_strides[kIndex1];
const int out_stride2 = out_strides[kIndex2];
const int output0 = output_shape[kIndex0];
const int output1 = output_shape[kIndex1];
const int output2 = output_shape[kIndex2];
const int output3 = output_shape[kIndex3];
for (size_t i = 0; i < (unsigned int)output0; ++i) {
size_t out_stride0_i = i * out_stride0;
size_t stride0_i = i * stride0;
for (size_t j = 0; j < (unsigned int)output1; ++j) {
size_t out_stride1_j = j * out_stride1;
size_t stride1_j = j * stride1;
for (size_t k = 0; k < (unsigned int)output2; ++k) {
size_t out_stride2_k = k * out_stride2;
size_t stride2_k = k * stride2;
for (size_t m = 0; m < (unsigned int)output3; ++m) {
out_data[out_stride0_i + out_stride1_j + out_stride2_k + m] =
in_data[stride0_i + stride1_j + stride2_k + m * stride3];
}
}
}
}
}
template <typename T>
void TransposeKernelMod::TransposeDim5(const T *in_data, T *out_data, const int *strides, const int *out_strides,
const int *perm, const int *output_shape) {
const int stride0 = strides[perm[kIndex0]];
const int stride1 = strides[perm[kIndex1]];
const int stride2 = strides[perm[kIndex2]];
const int stride3 = strides[perm[kIndex3]];
const int stride4 = strides[perm[kIndex4]];
const int out_stride0 = out_strides[kIndex0];
const int out_stride1 = out_strides[kIndex1];
const int out_stride2 = out_strides[kIndex2];
const int out_stride3 = out_strides[kIndex3];
const int output0 = output_shape[kIndex0];
const int output1 = output_shape[kIndex1];
const int output2 = output_shape[kIndex2];
const int output3 = output_shape[kIndex3];
const int output4 = output_shape[kIndex4];
for (size_t i = 0; i < (unsigned int)output0; ++i) {
size_t out_stride0_i = i * out_stride0;
size_t stride0_i = i * stride0;
for (size_t j = 0; j < (unsigned int)output1; ++j) {
size_t out_stride1_j = j * out_stride1;
size_t stride1_j = j * stride1;
for (size_t k = 0; k < (unsigned int)output2; ++k) {
size_t out_stride2_k = k * out_stride2;
size_t stride2_k = k * stride2;
for (size_t m = 0; m < (unsigned int)output3; ++m) {
size_t out_stride3_m = m * out_stride3;
size_t stride3_m = m * stride3;
for (size_t n = 0; n < (unsigned int)output4; ++n) {
out_data[out_stride0_i + out_stride1_j + out_stride2_k + out_stride3_m + n] =
in_data[stride0_i + stride1_j + stride2_k + stride3_m + n * stride4];
}
}
}
}
}
}
template <typename T>
void TransposeKernelMod::TransposeDim6(const T *in_data, T *out_data, const int *strides, const int *out_strides,
const int *perm, const int *output_shape) {
const int stride0 = strides[perm[kIndex0]];
const int stride1 = strides[perm[kIndex1]];
const int stride2 = strides[perm[kIndex2]];
const int stride3 = strides[perm[kIndex3]];
const int stride4 = strides[perm[kIndex4]];
const int stride5 = strides[perm[kIndex5]];
const int out_stride0 = out_strides[kIndex0];
const int out_stride1 = out_strides[kIndex1];
const int out_stride2 = out_strides[kIndex2];
const int out_stride3 = out_strides[kIndex3];
const int out_stride4 = out_strides[kIndex4];
const int output0 = output_shape[kIndex0];
const int output1 = output_shape[kIndex1];
const int output2 = output_shape[kIndex2];
const int output3 = output_shape[kIndex3];
const int output4 = output_shape[kIndex4];
const int output5 = output_shape[kIndex5];
for (size_t i = 0; i < (unsigned int)output0; ++i) {
size_t out_stride0_i = i * out_stride0;
size_t stride0_i = i * stride0;
for (size_t j = 0; j < (unsigned int)output1; ++j) {
size_t out_stride1_j = j * out_stride1;
size_t stride1_j = j * stride1;
for (size_t k = 0; k < (unsigned int)output2; ++k) {
size_t out_stride2_k = k * out_stride2;
size_t stride2_k = k * stride2;
for (size_t m = 0; m < (unsigned int)output3; ++m) {
size_t out_stride3_m = m * out_stride3;
size_t stride3_m = m * stride3;
for (size_t n = 0; n < (unsigned int)output4; ++n) {
size_t out_stride4_n = n * out_stride4;
size_t stride4_n = n * stride4;
for (size_t g = 0; g < (unsigned int)output5; ++g) {
out_data[out_stride0_i + out_stride1_j + out_stride2_k + out_stride3_m + out_stride4_n + g] =
in_data[stride0_i + stride1_j + stride2_k + stride3_m + stride4_n + g * stride5];
}
}
}
}
}
}
}
template <typename T>
void TransposeKernelMod::TransposeDim7(const T *in_data, T *out_data, const int *strides, const int *out_strides,
const int *perm, const int *output_shape) {
const int stride0 = strides[perm[kIndex0]];
const int stride1 = strides[perm[kIndex1]];
const int stride2 = strides[perm[kIndex2]];
const int stride3 = strides[perm[kIndex3]];
const int stride4 = strides[perm[kIndex4]];
const int stride5 = strides[perm[kIndex5]];
const int stride6 = strides[perm[kIndex6]];
const int out_stride0 = out_strides[kIndex0];
const int out_stride1 = out_strides[kIndex1];
const int out_stride2 = out_strides[kIndex2];
const int out_stride3 = out_strides[kIndex3];
const int out_stride4 = out_strides[kIndex4];
const int out_stride5 = out_strides[kIndex5];
const int output0 = output_shape[kIndex0];
const int output1 = output_shape[kIndex1];
const int output2 = output_shape[kIndex2];
const int output3 = output_shape[kIndex3];
const int output4 = output_shape[kIndex4];
const int output5 = output_shape[kIndex5];
const int output6 = output_shape[kIndex6];
for (size_t i = 0; i < (unsigned int)output0; ++i) {
size_t out_stride0_i = i * out_stride0;
size_t stride0_i = i * stride0;
for (size_t j = 0; j < (unsigned int)output1; ++j) {
size_t out_stride1_j = j * out_stride1;
size_t stride1_j = j * stride1;
for (size_t k = 0; k < (unsigned int)output2; ++k) {
size_t out_stride2_k = k * out_stride2;
size_t stride2_k = k * stride2;
for (size_t m = 0; m < (unsigned int)output3; ++m) {
size_t out_stride3_m = m * out_stride3;
size_t stride3_m = m * stride3;
for (size_t n = 0; n < (unsigned int)output4; ++n) {
size_t out_stride4_n = n * out_stride4;
size_t stride4_n = n * stride4;
for (size_t g = 0; g < (unsigned int)output5; ++g) {
size_t out_stride5_g = g * out_stride5;
size_t stride5_g = g * stride5;
for (size_t s = 0; s < (unsigned int)output6; ++s) {
out_data[out_stride0_i + out_stride1_j + out_stride2_k + out_stride3_m + out_stride4_n + out_stride5_g +
s] =
in_data[stride0_i + stride1_j + stride2_k + stride3_m + stride4_n + stride5_g + s * stride6];
}
}
}
}
}
}
}
}
template <typename T>
void TransposeKernelMod::TransposeDims(const T *in_data, T *out_data, const int *output_shape,
const TransposeParameter *transpose_param, int task_id, int thread_num) {
NNACL_CHECK_NULL_RETURN_VOID(in_data);
NNACL_CHECK_NULL_RETURN_VOID(out_data);
NNACL_CHECK_NULL_RETURN_VOID(output_shape);
NNACL_CHECK_NULL_RETURN_VOID(transpose_param);
NNACL_CHECK_ZERO_RETURN(thread_num);
const int *perm = transpose_param->perm_;
const int *strides = transpose_param->strides_;
const int *out_strides = transpose_param->out_strides_;
int num_axes = transpose_param->num_axes_;
size_t data_size = (*out_strides) * output_shape[0];
size_t offset_size = UP_DIV(data_size, thread_num);
size_t task_offset = offset_size * task_id;
int count = data_size - task_offset;
if (count <= 0) {
return;
}
count = MSMIN(offset_size, (unsigned int)count);
for (int idx = task_offset; (unsigned int)idx < task_offset + count; ++idx) {
int pos = idx;
int output_idx = 0;
int input_idx = 0;
for (int i = 0; i < num_axes; ++i) {
NNACL_CHECK_ZERO_RETURN(*(out_strides + i));
int position = pos / *(out_strides + i);
int out_stride = i < num_axes - 1 ? out_strides[i] : 1;
output_idx += (position * out_stride);
input_idx += (position * strides[perm[i]]);
pos -= position * (*(out_strides + i));
}
out_data[output_idx] = in_data[input_idx];
}
}
MS_KERNEL_FACTORY_REG_BY_CREATOR(KernelMod, Transpose,
[]() { return std::make_shared<TransposeKernelMod>(prim::kPrimTranspose->name()); });
} // namespace mindspore::kernel

View File

@ -0,0 +1,84 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_CPU_TRANSPOSE_KERNEL_MOD_H_
#define MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_CPU_TRANSPOSE_KERNEL_MOD_H_
#include <vector>
#include <string>
#include <map>
#include <unordered_map>
#include "plugin/device/cpu/kernel/cpu_kernel_mod.h"
#include "plugin/device/cpu/kernel/nnacl/transpose.h"
#include "kernel/common_utils.h"
namespace mindspore::kernel {
class TransposeKernelMod : public CpuKernelMod {
public:
TransposeKernelMod() = default;
~TransposeKernelMod() override = default;
explicit TransposeKernelMod(const std::string name) { kernel_name_ = name; }
virtual bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr);
virtual bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs);
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
private:
template <typename T>
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
template <typename T>
int DoTranspose(const T *in_data, T *out_data, const int *output_shape, const TransposeParameter *transpose_param);
template <typename T>
void TransposeDim2(const T *in_data, T *out_data, const int *strides, const int *out_strides, const int *perm,
const int *output_shape);
template <typename T>
void TransposeDim3(const T *in_data, T *out_data, const int *strides, const int *out_strides, const int *perm,
const int *output_shape);
template <typename T>
void TransposeDim4(const T *in_data, T *out_data, const int *strides, const int *out_strides, const int *perm,
const int *output_shape);
template <typename T>
void TransposeDim5(const T *in_data, T *out_data, const int *strides, const int *out_strides, const int *perm,
const int *output_shape);
template <typename T>
void TransposeDim6(const T *in_data, T *out_data, const int *strides, const int *out_strides, const int *perm,
const int *output_shape);
template <typename T>
void TransposeDim7(const T *in_data, T *out_data, const int *strides, const int *out_strides, const int *perm,
const int *output_shape);
template <typename T>
void TransposeDims(const T *in_data, T *out_data, const int *output_shape, const TransposeParameter *transpose_param,
int task_id, int thread_num);
TransposeParameter transpose_param_;
std::vector<int64_t> input_shape_;
std::vector<int64_t> output_shape_;
std::vector<size_t> axes_;
TypeId dtype_{kTypeUnknown};
using TypeKernel =
std::function<void(TransposeKernelMod *, const std::vector<AddressPtr> &, const std::vector<AddressPtr> &)>;
std::unordered_map<TypeId, TypeKernel> launch_map_;
TypeKernel launch_func_;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_CPU_TRANSPOSE_KERNEL_MOD_H_

View File

@ -26,6 +26,7 @@
#include "kernel/common_utils.h"
#include "plugin/device/cpu/kernel/cpu_kernel_mod.h"
#include "src/extendrt/utils/kernel_build_utils.h"
#include "src/extendrt/kernel/ascend/plugin/ascend_kernel_plugin.h"
namespace mindspore {
const size_t tensor_max_size = 0x1000000;
@ -33,6 +34,7 @@ const size_t tensor_max_size = 0x1000000;
Status SingleOpInferSession::Init(const std::shared_ptr<Context> context) {
MS_LOG(INFO) << "SingleOpInferSession::Init";
session_basic_ = std::make_shared<session::SessionBasic>();
kernel::AscendKernelPlugin::GetInstance().Register();
return kSuccess;
}
@ -53,11 +55,11 @@ Status SingleOpInferSession::CompileGraph(FuncGraphPtr graph) {
mindspore::infer::SetKernelInfo(kernel_node);
std::string kernel_name = common::AnfAlgo::GetCNodeName(kernel_node);
std::shared_ptr<kernel::KernelMod> kernel_mod = kernel::Factory<kernel::KernelMod>::Instance().Create(kernel_name);
MS_LOG(INFO) << "SingleOpInferSession::Kernels " << kernel_name;
auto args = kernel::AbstractArgsFromCNode(kernel_node);
if (kernel_mod == nullptr) {
MS_LOG(EXCEPTION) << "Kernel mod is nullptr, kernel name: " << kernel_name;
}
MS_LOG(INFO) << "SingleOpInferSession::Kernels " << kernel_name;
auto args = kernel::AbstractArgsFromCNode(kernel_node);
mindspore::infer::CopyInputWeights(kernel_node, args.inputs);
auto ret = kernel_mod->Init(args.op, args.inputs, args.outputs);
MS_LOG(INFO) << "SingleOpInferSession::Kernels ret " << ret;
@ -177,7 +179,17 @@ std::vector<tensor::TensorPtr> SingleOpInferSession::GetOutputs() { return outpu
std::vector<tensor::TensorPtr> SingleOpInferSession::GetInputs() { return inputs_; }
std::vector<std::string> SingleOpInferSession::GetOutputNames() { return output_names_; }
std::vector<std::string> SingleOpInferSession::GetInputNames() { return input_names_; }
tensor::TensorPtr SingleOpInferSession::GetOutputByTensorName(const std::string &tensorName) { return nullptr; }
tensor::TensorPtr SingleOpInferSession::GetOutputByTensorName(const std::string &tensorName) {
for (size_t idx = 0; idx < output_names_.size(); ++idx) {
if (output_names_[idx] == tensorName) {
if (idx < outputs_.size()) {
return outputs_[idx];
}
}
}
MS_LOG(ERROR) << "Can't found tensor name " << tensorName;
return nullptr;
}
tensor::TensorPtr SingleOpInferSession::GetInputByTensorName(const std::string &name) { return nullptr; }
void SingleOpInferSession::AssignKernelGraphAddress(KernelGraphPtr kernel_graph) {
@ -336,5 +348,6 @@ void SingleOpInferSession::CopyOutputs(std::vector<tensor::TensorPtr> *outputs)
auto tensor_ptr = std::make_shared<mindspore::tensor::Tensor>(type_id, shape, data, data_size);
outputs->push_back(tensor_ptr);
}
outputs_ = *outputs;
}
} // namespace mindspore

View File

@ -35,8 +35,9 @@ using AnfAlgo = mindspore::session::AnfRuntimeAlgorithm;
using mindspore::kernel::KernelBuildInfo;
namespace {
constexpr auto kParamDynamic = "dynamic";
constexpr auto kCustomAscendInputNum = 3;
constexpr auto kInputNum = 3;
constexpr auto kNameCustomAscend = "CustomAscend";
constexpr auto kNameTranspose = "Transpose";
constexpr auto kCustomTypeAscend = "acl_build";
bool IsInputNotCNode(const CNodePtr &kernel_node, size_t input_index) {
@ -71,8 +72,8 @@ void GetInputDtypes(const CNodePtr &kernel_node, std::vector<TypeId> *input_type
if (IsInputNotCNode(kernel_node, input_index)) {
input_no_cnode_indexes->emplace_back(input_index);
dtype = common::AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_index);
// } else {
// dtype = AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, input_index);
} else {
dtype = AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, input_index);
}
input_types->emplace_back(dtype);
}
@ -546,11 +547,11 @@ void SetKernelInfo(const CNodePtr &kernel_node) {
void CopyInputWeights(const CNodePtr &kernel_node, const std::vector<kernel::KernelTensorPtr> &inputs) {
std::string kernel_name = common::AnfAlgo::GetCNodeName(kernel_node);
if (kernel_name == kNameCustomAscend) {
if (kernel_name == kNameCustomAscend || kernel_name == kNameTranspose) {
auto node_input_size = kernel_node->inputs().size();
if (node_input_size < kCustomAscendInputNum) {
MS_LOG(ERROR) << "Input num of custom ascend kernel should larger than " << (kCustomAscendInputNum - 1)
<< ", real num is " << node_input_size;
if (node_input_size < kInputNum) {
MS_LOG(ERROR) << "Input num of custom ascend kernel should larger than " << (kInputNum - 1) << ", real num is "
<< node_input_size;
return;
}
if (node_input_size != inputs.size() + 1) {
@ -558,17 +559,17 @@ void CopyInputWeights(const CNodePtr &kernel_node, const std::vector<kernel::Ker
<< " is not equal to kernel tensor size[" << (inputs.size() + 1) << "].";
return;
}
auto om_input = kernel_node->input(node_input_size - 1);
if (!om_input->isa<Parameter>()) {
auto weight_input = kernel_node->input(node_input_size - 1);
if (!weight_input->isa<Parameter>()) {
MS_LOG(ERROR) << "Om input is not parameter.";
return;
}
ParameterPtr om_param = om_input->cast<ParameterPtr>();
if (om_param == nullptr || !om_param->has_default()) {
MS_LOG(ERROR) << "Om param is invalid, val= " << om_param;
ParameterPtr weight_param = weight_input->cast<ParameterPtr>();
if (weight_param == nullptr || !weight_param->has_default()) {
MS_LOG(ERROR) << "Om param is invalid, val= " << weight_param;
return;
}
auto tensor = std::static_pointer_cast<tensor::Tensor>(om_param->default_param());
auto tensor = std::static_pointer_cast<tensor::Tensor>(weight_param->default_param());
if (tensor == nullptr) {
MS_LOG(ERROR) << "Tensor is nullptr.";
return;

View File

@ -64,6 +64,7 @@ char *RuntimeConvert(const char *model_buf, const size_t &buf_size, size_t *size
param->aclModelOptionCfgParam.buffer_optimize = ascend_info->GetBufferOptimizeMode();
param->aclModelOptionCfgParam.insert_op_config_file_path = ascend_info->GetInsertOpConfigPath();
param->aclModelOptionCfgParam.dynamic_image_size = ascend_info->GetDynamicImageSize();
param->device = "Ascend";
} else {
continue;
}

View File

@ -1,7 +1,13 @@
cmake_minimum_required(VERSION 3.12)
project(Lite_benchmark)
set(BENCHMARK_LINK_LIB mindspore-lite)
if(NOT MSLITE_ENABLE_CLOUD_FUSION_INFERENCE)
set(BENCHMARK_LINK_LIB mindspore-lite)
else()
add_definitions(-DUSE_GLOG)
set(BENCHMARK_LINK_LIB mindspore-extendrt)
endif()
set(PROVIDERS_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../providers)
if(TARGET_HIMIX)
add_subdirectory(${PROVIDERS_DIR}/nnie nnie)
@ -64,12 +70,16 @@ if(MSLITE_GPU_BACKEND STREQUAL opencl)
endif()
endif()
if(NOT MSLITE_ENABLE_CLOUD_FUSION_INFERENCE)
set(C_SRC ${CMAKE_CURRENT_SOURCE_DIR}/benchmark_c_api.cc)
endif()
add_executable(benchmark
${CMAKE_CURRENT_SOURCE_DIR}/main.cc
${CMAKE_CURRENT_SOURCE_DIR}/run_benchmark.cc
${CMAKE_CURRENT_SOURCE_DIR}/benchmark_base.cc
${CMAKE_CURRENT_SOURCE_DIR}/benchmark_unified_api.cc
${CMAKE_CURRENT_SOURCE_DIR}/benchmark_c_api.cc
${C_SRC}
${COMMON_SRC})
add_dependencies(benchmark fbs_src)

View File

@ -51,7 +51,9 @@ int RunBenchmark(int argc, const char **argv) {
if (api_type == nullptr || std::string(api_type) == "NEW") {
benchmark = new (std::nothrow) BenchmarkUnifiedApi(&flags);
} else if (std::string(api_type) == "C") {
#ifndef ENABLE_CLOUD_FUSION_INFERENCE
benchmark = new (std::nothrow) tools::BenchmarkCApi(&flags);
#endif
} else {
BENCHMARK_LOG_ERROR("Invalid MSLITE_API_TYPE, (OLD/NEW/C, default:OLD)");
return RET_ERROR;

View File

@ -61,8 +61,8 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/parser/conv1d_inout_adjust.cc
${CMAKE_CURRENT_SOURCE_DIR}/parser/inputs_adjust.cc
${CMAKE_CURRENT_SOURCE_DIR}/parser/unify_format.cc
${CMAKE_CURRENT_SOURCE_DIR}/adapter/acl/plugin/acl_pass_plugin.cc
${CMAKE_CURRENT_SOURCE_DIR}/import/*.cc
${CMAKE_CURRENT_SOURCE_DIR}/adapter/acl/acl_pass.cc
${SRC_DIR}/common/quant_utils.cc
${SRC_DIR}/common/dynamic_library_loader.cc
${SRC_DIR}/train/train_populate_parameter.cc
@ -228,6 +228,10 @@ if(MSLITE_ENABLE_CONTROLFLOW)
set(LITE_SRC ${LITE_SRC} ${CONTROL_FLOW_KERNEL_SRC})
endif()
if(NOT MSLITE_ENABLE_CLOUD_FUSION_INFERENCE)
set(LITE_SRC ${LITE_SRC} ${CMAKE_CURRENT_SOURCE_DIR}/adapter/acl/acl_pass.cc)
endif()
if(MSLITE_ENABLE_MINDRT)
add_compile_definitions(ENABLE_MINDRT)
include_directories(${CORE_DIR}/mindrt)
@ -372,10 +376,9 @@ target_link_libraries(mindspore_converter
mindir_proto_mid
_mindspore_transform_express_ir_obj)
if(MSLITE_ENABLE_ACL)
if(MSLITE_ENABLE_ACL AND NOT MSLITE_ENABLE_CLOUD_FUSION_INFERENCE)
target_link_libraries(mindspore_converter
lite_acl_mid
ascend_kernel_mid)
lite_acl_mid)
endif()
if(NOT MSVC)

View File

@ -9,13 +9,21 @@ include_directories(${TOP_DIR}/graphengine/third_party/fwkacllib/inc/aicpu)
include_directories(${TOP_DIR}/graphengine/third_party/fwkacllib/inc/toolchain)
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
link_directories(${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
file(GLOB ACL_SRC
${CMAKE_CURRENT_SOURCE_DIR}/api/*.cc
${CMAKE_CURRENT_SOURCE_DIR}/src/*.cc
${CMAKE_CURRENT_SOURCE_DIR}/common/*.cc
${CMAKE_CURRENT_SOURCE_DIR}/mapper/*.cc
${CMAKE_CURRENT_SOURCE_DIR}/infer/*.cc
${TOP_DIR}/mindspore/lite/src/runtime/kernel/ascend/src/acl_env_guard.cc
)
if(MSLITE_ENABLE_CLOUD_FUSION_INFERENCE)
set(ACL_SRC ${ACL_SRC}
${CMAKE_CURRENT_SOURCE_DIR}/acl_pass.cc)
endif()
set(ENABLE_ACL on)
set(MODE_ASCEND_ACL off)
add_subdirectory(${TOP_DIR}/mindspore/ccsrc/transform/graph_ir _mindspore_transform_graph_ir_obj)
@ -23,7 +31,15 @@ add_subdirectory(${TOP_DIR}/mindspore/ccsrc/cxx_api mindspore_shared_lib)
set_property(SOURCE ${ACL_SRC} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_LITE)
add_library(lite_acl_mid OBJECT ${ACL_SRC})
target_link_libraries(lite_acl_mid mindspore_shared_lib)
add_dependencies(lite_acl_mid mindspore_shared_lib)
add_dependencies(lite_acl_mid fbs_inner_src)
if(MSLITE_ENABLE_CLOUD_FUSION_INFERENCE)
add_library(ascend_pass_plugin SHARED ${ACL_SRC})
target_link_libraries(ascend_pass_plugin mindspore_shared_lib)
add_dependencies(ascend_pass_plugin mindspore_shared_lib)
add_dependencies(ascend_pass_plugin fbs_inner_src)
else()
add_library(lite_acl_mid OBJECT ${ACL_SRC})
target_link_libraries(lite_acl_mid mindspore_shared_lib)
add_dependencies(lite_acl_mid mindspore_shared_lib)
add_dependencies(lite_acl_mid fbs_inner_src)
endif()

View File

@ -0,0 +1,34 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/adapter/acl/api/acl_pass_api.h"
mindspore::opt::Pass *CreateAclPass(const std::shared_ptr<mindspore::ConverterPara> &param) {
auto acl_pass_ptr = new (std::nothrow) mindspore::opt::AclPass(param);
if (acl_pass_ptr == nullptr) {
MS_LOG(ERROR) << "New acl pass failed.";
return nullptr;
}
return acl_pass_ptr;
}
void DestroyAclPass(mindspore::opt::Pass *acl_pass) {
if (acl_pass == nullptr) {
MS_LOG(ERROR) << "Param acl pass is nullptr.";
return;
}
delete acl_pass;
}

View File

@ -0,0 +1,34 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_ADAPTER_ACL_API_ACL_PASS_API_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_ADAPTER_ACL_API_ACL_PASS_API_H_
#include <memory>
#include "tools/converter/adapter/acl/acl_pass.h"
#ifdef __cplusplus
extern "C" {
#endif
mindspore::opt::Pass *CreateAclPass(const std::shared_ptr<mindspore::ConverterPara> &param);
void DestroyAclPass(mindspore::opt::Pass *acl_pass);
#ifdef __cplusplus
}
#endif
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_ADAPTER_ACL_API_ACL_PASS_API_H_

View File

@ -0,0 +1,109 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/adapter/acl/plugin/acl_pass_plugin.h"
#include "utils/ms_utils.h"
#include "utils/log_adapter.h"
#include "include/errorcode.h"
#if !defined(_WIN32) && !defined(_WIN64)
#include <dlfcn.h>
#include "extendrt/cxx_api/dlutils.h"
#endif
namespace mindspore {
namespace opt {
AclPassPlugin &AclPassPlugin::GetInstance() {
static AclPassPlugin instance;
return instance;
}
AclPassPlugin::AclPassPlugin() : handle_(nullptr), pass_ptr_(nullptr) {}
bool AclPassPlugin::HasPluginSo() {
#if !defined(_WIN32) && !defined(_WIN64)
std::string ascend_pass_plugin_path;
auto ret = DLSoPath("libmindspore_converter.so", "libascend_pass_plugin.so", &ascend_pass_plugin_path);
if (ret != kSuccess) {
MS_LOG(ERROR) << "Get real path of libascend_pass_plugin.so failed.";
return false;
}
if (ret == kSuccess && !ascend_pass_plugin_path.empty()) {
real_path_ = ascend_pass_plugin_path;
MS_LOG(INFO) << "Find ascend pass plugin so success, path = " << real_path_;
return true;
}
#endif
return false;
}
Pass *AclPassPlugin::CreateAclPass(const std::shared_ptr<ConverterPara> &param) {
#if !defined(_WIN32) && !defined(_WIN64)
if (pass_ptr_ != nullptr) {
MS_LOG(INFO) << "Acl pass has been created.";
return pass_ptr_;
}
void *function = nullptr;
auto ret = DLSoOpen(real_path_, "CreateAclPass", &handle_, &function);
if (ret != kSuccess) {
MS_LOG(ERROR) << "DLSoOpen failed, so path: " << real_path_;
return nullptr;
}
auto create_func = reinterpret_cast<mindspore::opt::Pass *(*)(const std::shared_ptr<ConverterPara> &)>(function);
if (create_func == nullptr) {
MS_LOG(ERROR) << "Cast symbol CreateAclPass failed.";
return nullptr;
}
pass_ptr_ = create_func(param);
if (pass_ptr_ == nullptr) {
MS_LOG(ERROR) << "New acl pass failed.";
return nullptr;
}
#endif
return pass_ptr_;
}
void AclPassPlugin::DestroyAclPass(Pass *acl_pass) {
#if !defined(_WIN32) && !defined(_WIN64)
if (handle_ == nullptr) {
MS_LOG(ERROR) << "Handle is nullptr .";
return;
}
if (acl_pass != pass_ptr_) {
MS_LOG(ERROR) << "Out pass ptr is not same as inner pass ptr.";
return;
}
auto destroy_func = reinterpret_cast<void (*)(mindspore::opt::Pass *)>(dlsym(handle_, "DestroyAclPass"));
if (destroy_func == nullptr) {
MS_LOG(ERROR) << "Undefined symbol DestroyAclPass in ['libascend_pass_plugin.so']";
return;
}
destroy_func(acl_pass);
pass_ptr_ = nullptr;
#endif
}
AclPassPlugin::~AclPassPlugin() {
#if !defined(_WIN32) && !defined(_WIN64)
MS_LOG(DEBUG) << "~AclPassPlugin() begin.";
if (handle_ != nullptr) {
(void)dlclose(handle_);
handle_ = nullptr;
}
MS_LOG(DEBUG) << "~AclPassPlugin() end.";
#endif
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,47 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_ADAPTER_ACL_ACL_PASS_PLUGIN_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_ADAPTER_ACL_ACL_PASS_PLUGIN_H_
#include <memory>
#include <string>
#include "backend/common/optimizer/pass.h"
#include "tools/converter/cxx_api/converter_para.h"
namespace mindspore {
namespace opt {
class AclPassPlugin {
public:
static AclPassPlugin &GetInstance();
bool HasPluginSo();
Pass *CreateAclPass(const std::shared_ptr<ConverterPara> &param);
void DestroyAclPass(Pass *acl_pass);
private:
AclPassPlugin();
~AclPassPlugin();
void *handle_;
Pass *pass_ptr_;
std::string real_path_;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_ADAPTER_ACL_ACL_PASS_PLUGIN_H_

View File

@ -89,12 +89,15 @@
#include "tools/optimizer/fusion/transpose_fusion.h"
#include "tools/optimizer/format/to_nchw_format.h"
#include "tools/optimizer/format/to_nhwc_format.h"
#ifndef ENABLE_CLOUD_FUSION_INFERENCE
#include "tools/converter/adapter/acl/acl_pass.h"
#endif
#include "src/common/log_util.h"
#include "tools/optimizer/fusion/groupnorm_fusion.h"
#include "tools/optimizer/fusion/mul_reduce_fusion.h"
#include "tools/converter/import/cast_op_adjust.h"
#include "tools/converter/quantizer/quant_helper/remove_unused_quant_param.h"
#include "tools/converter/adapter/acl/plugin/acl_pass_plugin.h"
using std::string;
namespace mindspore::lite {
@ -329,13 +332,27 @@ int AnfTransform::RunGraphPass(const FuncGraphPtr &old_graph, const std::shared_
}
int AnfTransform::RunConvertPass(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> &param) {
#ifndef ENABLE_CLOUD_FUSION_INFERENCE
auto acl_pass = std::make_shared<opt::AclPass>(param);
CHECK_NULL_RETURN(acl_pass);
if (!acl_pass->Run(old_graph)) {
MS_LOG(ERROR) << "Acl pass failed.";
return RET_ERROR;
}
#endif
if (opt::AclPassPlugin::GetInstance().HasPluginSo()) {
auto acl_pass_ptr = opt::AclPassPlugin::GetInstance().CreateAclPass(param);
if (acl_pass_ptr == nullptr) {
MS_LOG(ERROR) << "Acl pass ptr is nullptr.";
return RET_ERROR;
}
if (!acl_pass_ptr->Run(old_graph)) {
MS_LOG(ERROR) << "Acl pass failed.";
opt::AclPassPlugin::GetInstance().DestroyAclPass(acl_pass_ptr);
return RET_ERROR;
}
opt::AclPassPlugin::GetInstance().DestroyAclPass(acl_pass_ptr);
}
auto optimizer = std::make_shared<opt::GraphOptimizer>();
CHECK_NULL_RETURN(optimizer);
auto convert_pm = std::make_shared<opt::PassManager>("anf graph convert pass manager", true);

View File

@ -84,6 +84,7 @@ Flags::Flags() {
"false");
AddFlag(&Flags::exportMindIR, "exportMindIR", "MINDIR | MINDIR_LITE", "MINDIR_LITE");
AddFlag(&Flags::noFusionStr, "NoFusion", "Avoid fusion optimization true|false", "false");
AddFlag(&Flags::device, "device", "Set the target device Ascend", "");
}
int Flags::InitInputOutputDataType() {

View File

@ -79,6 +79,7 @@ class Flags : public virtual mindspore::lite::FlagParser {
std::string encryptionStr = "false";
bool encryption = false;
#endif
std::string device;
};
} // namespace converter
} // namespace mindspore

View File

@ -67,6 +67,7 @@ int main(int argc, const char **argv) {
converter.SetInfer(flags.infer);
converter.SetTrainModel(flags.trainModel);
converter.SetNoFusion(flags.disableFusion);
converter.SetDevice(flags.device);
auto status = converter.Convert();
if (status != mindspore::kSuccess) {

View File

@ -263,6 +263,20 @@ bool Converter::GetNoFusion() {
}
}
void Converter::SetDevice(const std::string &device) {
if (data_ != nullptr) {
data_->device = device;
}
}
std::string Converter::GetDevice() {
if (data_ != nullptr) {
return data_->device;
} else {
return "";
}
}
Status Converter::Convert() {
if (data_ != nullptr) {
auto ret = lite::RunConverter(data_, nullptr, nullptr, false);

View File

@ -68,6 +68,7 @@ struct ConverterPara {
lite::acl::AclModelOptionCfg aclModelOptionCfgParam;
lite::micro::MicroParam microParam;
ParallelSplitConfig parallel_split_config;
std::string device;
};
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_CXX_API_CONVERTER_PARA_H_