!17277 [MSLITE] Support nnie.

Merge pull request !17277 from wangshaocong/nnie_to_master
This commit is contained in:
i-robot 2021-06-09 11:02:47 +08:00 committed by Gitee
commit 28072324e9
43 changed files with 330 additions and 65 deletions

View File

@ -568,6 +568,24 @@ build_lite()
cd ${BASEPATH}/mindspore/lite/build cd ${BASEPATH}/mindspore/lite/build
write_commit_file write_commit_file
if [[ "${local_lite_platform}" == "arm32" ]]; then
if [[ "${TOOLCHAIN_FILE}" && "${TOOLCHAIN_NAME}" ]]; then
RUN_TESTCASES="off"
COMPILE_MINDDATA_LITE="off"
CMAKE_TOOLCHAIN_FILE=${TOOLCHAIN_FILE}
CMAKE_TOOLCHAIN_NAME=${TOOLCHAIN_NAME}
else
CMAKE_TOOLCHAIN_FILE=${ANDROID_NDK}/build/cmake/android.toolchain.cmake
ANDROID_NATIVE_API_LEVEL="19"
CMAKE_ANDROID_NDK=${ANDROID_NDK}
CMAKE_ANDROID_ABI="armeabi-v7a"
CMAKE_ANDROID_TOOLCHAIN_NAME="clang"
CMAKE_ANDROID_STL=${MSLITE_ANDROID_STL}
CMAKE_BUILD_TYPE=${LITE_BUILD_TYPE}
ENABLE_FP16="on"
fi
fi
if [[ "${local_lite_platform}" == "arm64" ]]; then if [[ "${local_lite_platform}" == "arm64" ]]; then
if [ "$(uname)" == "Darwin" ]; then if [ "$(uname)" == "Darwin" ]; then
cmake -DCMAKE_TOOLCHAIN_FILE=${BASEPATH}/cmake/lite_ios.cmake -DARCHS="arm64" -DENABLE_BITCODE=0 \ cmake -DCMAKE_TOOLCHAIN_FILE=${BASEPATH}/cmake/lite_ios.cmake -DARCHS="arm64" -DENABLE_BITCODE=0 \
@ -593,10 +611,10 @@ build_lite()
else else
checkndk checkndk
echo "default link libc++_static.a, export MSLITE_ANDROID_STL=c++_shared to link libc++_shared.so" echo "default link libc++_static.a, export MSLITE_ANDROID_STL=c++_shared to link libc++_shared.so"
cmake -DCMAKE_TOOLCHAIN_FILE="${ANDROID_NDK}/build/cmake/android.toolchain.cmake" -DANDROID_NATIVE_API_LEVEL="19" \ cmake -DCMAKE_TOOLCHAIN_FILE=${CMAKE_TOOLCHAIN_FILE} -DTOOLCHAIN_NAME=${CMAKE_TOOLCHAIN_NAME} -DANDROID_NATIVE_API_LEVEL=${ANDROID_NATIVE_API_LEVEL} \
-DANDROID_NDK="${ANDROID_NDK}" -DANDROID_ABI="armeabi-v7a" -DANDROID_TOOLCHAIN_NAME="clang" \ -DANDROID_NDK=${CMAKE_ANDROID_NDK} -DANDROID_ABI=${CMAKE_ANDROID_ABI} -DANDROID_TOOLCHAIN_NAME=${CMAKE_ANDROID_TOOLCHAIN_NAME} \
-DANDROID_STL=${MSLITE_ANDROID_STL} -DCMAKE_BUILD_TYPE=${LITE_BUILD_TYPE} -DBUILD_MINDDATA=${COMPILE_MINDDATA_LITE} \ -DANDROID_STL=${CMAKE_ANDROID_STL} -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} -DBUILD_MINDDATA=${COMPILE_MINDDATA_LITE} \
-DPLATFORM_ARM32="on" -DENABLE_NEON="on" -DENABLE_FP16="on" -DCMAKE_INSTALL_PREFIX=${BASEPATH}/output/tmp \ -DPLATFORM_ARM32="on" -DENABLE_NEON="on" -DENABLE_FP16=${ENABLE_FP16} -DCMAKE_INSTALL_PREFIX=${BASEPATH}/output/tmp \
-DMS_VERSION_MAJOR=${VERSION_MAJOR} -DMS_VERSION_MINOR=${VERSION_MINOR} -DMS_VERSION_REVISION=${VERSION_REVISION} \ -DMS_VERSION_MAJOR=${VERSION_MAJOR} -DMS_VERSION_MINOR=${VERSION_MINOR} -DMS_VERSION_REVISION=${VERSION_REVISION} \
-DENABLE_ASAN=${ENABLE_ASAN} -DENABLE_VERBOSE=${ENABLE_VERBOSE} "${BASEPATH}/mindspore/lite" -DENABLE_ASAN=${ENABLE_ASAN} -DENABLE_VERBOSE=${ENABLE_VERBOSE} "${BASEPATH}/mindspore/lite"
fi fi

View File

@ -18,7 +18,8 @@ mindspore_add_pkg(flatbuffers
EXE flatc EXE flatc
URL ${REQ_URL} URL ${REQ_URL}
MD5 ${MD5} MD5 ${MD5}
CMAKE_OPTION -DFLATBUFFERS_BUILD_TESTS=OFF -DCMAKE_INSTALL_LIBDIR=lib) CMAKE_OPTION -DCMAKE_C_COMPILER=${FLATC_GCC_COMPILER} -DCMAKE_CXX_COMPILER=${FLATC_GXX_COMPILER}
-DFLATBUFFERS_BUILD_TESTS=OFF -DCMAKE_INSTALL_LIBDIR=lib)
include_directories(${flatbuffers_INC}) include_directories(${flatbuffers_INC})
add_library(mindspore::flatbuffers ALIAS flatbuffers::flatbuffers) add_library(mindspore::flatbuffers ALIAS flatbuffers::flatbuffers)

View File

@ -15,6 +15,7 @@ set(MIND_DATA_INC_DIR ${RUNTIME_PKG_NAME}/inference/include/dataset)
set(TURBO_DIR ${RUNTIME_PKG_NAME}/inference/third_party/libjpeg-turbo) set(TURBO_DIR ${RUNTIME_PKG_NAME}/inference/third_party/libjpeg-turbo)
set(SECUREC_DIR ${RUNTIME_PKG_NAME}/inference/third_party/securec) set(SECUREC_DIR ${RUNTIME_PKG_NAME}/inference/third_party/securec)
set(MINDSPORE_LITE_LIB_NAME libmindspore-lite) set(MINDSPORE_LITE_LIB_NAME libmindspore-lite)
set(MINDSPORE_CORE_LIB_NAME libmindspore_core)
set(BENCHMARK_NAME benchmark) set(BENCHMARK_NAME benchmark)
set(BENCHMARK_ROOT_DIR ${RUNTIME_PKG_NAME}/tools/benchmark) set(BENCHMARK_ROOT_DIR ${RUNTIME_PKG_NAME}/tools/benchmark)
@ -253,6 +254,10 @@ elseif(PLATFORM_ARM32)
${RUNTIME_COMPONENT_NAME}) ${RUNTIME_COMPONENT_NAME})
endif() endif()
endif() endif()
if(MSLITE_ENABLE_NNIE AND TARGET_HIMIX200)
install(FILES ${TOP_DIR}/mindspore/lite/tools/providers/NNIE/3516D/libmslite_nnie.so
DESTINATION ${RUNTIME_PKG_NAME}/providers/3516D COMPONENT ${RUNTIME_COMPONENT_NAME})
endif()
elseif(WIN32) elseif(WIN32)
get_filename_component(CXX_DIR ${CMAKE_CXX_COMPILER} PATH) get_filename_component(CXX_DIR ${CMAKE_CXX_COMPILER} PATH)
file(GLOB LIB_LIST ${CXX_DIR}/libstdc++-6.dll ${CXX_DIR}/libwinpthread-1.dll file(GLOB LIB_LIST ${CXX_DIR}/libstdc++-6.dll ${CXX_DIR}/libwinpthread-1.dll
@ -393,6 +398,39 @@ else()
install(FILES ${TOP_DIR}/mindspore/lite/build/tools/cropper/cropper_mapping_npu.cfg install(FILES ${TOP_DIR}/mindspore/lite/build/tools/cropper/cropper_mapping_npu.cfg
DESTINATION ${CROPPER_ROOT_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) DESTINATION ${CROPPER_ROOT_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
endif() endif()
if(NOT SUPPORT_TRAIN AND MSLITE_ENABLE_NNIE)
install(DIRECTORY ${TOP_DIR}/mindspore/lite/build/schema/ DESTINATION ${CONVERTER_ROOT_DIR}/include/schema
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h")
install(DIRECTORY ${TOP_DIR}/mindspore/core/abstract/ DESTINATION ${CONVERTER_ROOT_DIR}/include/core/abstract
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h")
install(DIRECTORY ${TOP_DIR}/mindspore/core/base/ DESTINATION ${CONVERTER_ROOT_DIR}/include/core/base
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h")
install(DIRECTORY ${TOP_DIR}/mindspore/core/ir/ DESTINATION ${CONVERTER_ROOT_DIR}/include/core/ir
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h")
install(DIRECTORY ${TOP_DIR}/mindspore/core/ops/ DESTINATION ${CONVERTER_ROOT_DIR}/include/core/ops
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h")
install(DIRECTORY ${TOP_DIR}/mindspore/core/utils/ DESTINATION ${CONVERTER_ROOT_DIR}/include/core/utils
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h")
install(FILES ${TOP_DIR}/mindspore/lite/build/tools/converter/mindspore_core/${MINDSPORE_CORE_LIB_NAME}.a
DESTINATION ${CONVERTER_ROOT_DIR}/lib COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${TOP_DIR}/mindspore/lite/build/securec/src/libsecurec.a DESTINATION ${CONVERTER_ROOT_DIR}/lib
COMPONENT ${RUNTIME_COMPONENT_NAME})
file(GLOB PROTOBUF_LIB_PATH ${TOP_DIR}/mindspore/lite/build/.mslib/protobuf_*/lib/libprotobuf.a)
install(FILES ${PROTOBUF_LIB_PATH} DESTINATION ${CONVERTER_ROOT_DIR}/lib
COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${TOP_DIR}/mindspore/lite/tools/providers/NNIE/3516D/libmslite_nnie_converter.so
DESTINATION ${CONVERTER_ROOT_DIR}/providers/3516D/ COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${TOP_DIR}/mindspore/lite/tools/providers/NNIE/3516D/libmslite_nnie_data_process.so
DESTINATION ${CONVERTER_ROOT_DIR}/providers/3516D COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${TOP_DIR}/mindspore/lite/tools/providers/NNIE/3516D/libnnie_mapper.so
DESTINATION ${CONVERTER_ROOT_DIR}/providers/3516D COMPONENT ${RUNTIME_COMPONENT_NAME})
install(DIRECTORY ${TOP_DIR}/mindspore/lite/tools/providers/NNIE/3516D/opencv-4.2.0/lib/
DESTINATION ${CONVERTER_ROOT_DIR}/providers/3516D/third_party/opencv-4.2.0
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.so*")
install(DIRECTORY ${TOP_DIR}/mindspore/lite/tools/providers/NNIE/3516D/protobuf-3.9.0/lib/
DESTINATION ${CONVERTER_ROOT_DIR}/providers/3516D/third_party/protobuf-3.9.0
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.so*")
endif()
endif() endif()
if(CMAKE_SYSTEM_NAME MATCHES "Windows") if(CMAKE_SYSTEM_NAME MATCHES "Windows")

View File

@ -31,7 +31,9 @@ typedef struct ArgElement {
int32_t i_data_; int32_t i_data_;
float f_data_; float f_data_;
#ifdef ENABLE_ARM #ifdef ENABLE_ARM
#ifndef SUPPORT_NNIE
float16_t f16_data_; float16_t f16_data_;
#endif
#endif #endif
} data_; } data_;
} ArgElement; } ArgElement;

View File

@ -97,8 +97,13 @@ int ElementLogicalAnd(const float *in0, const float *in1, float *out, int size)
uint32x4_t mask = vmovq_n_u32(((uint32_t)(1u << 31) - 1)); uint32x4_t mask = vmovq_n_u32(((uint32_t)(1u << 31) - 1));
uint32x4_t zeros = vdupq_n_u32(0); uint32x4_t zeros = vdupq_n_u32(0);
for (; index <= size - 4; index += C4NUM) { for (; index <= size - 4; index += C4NUM) {
#ifndef SUPPORT_NNIE
uint32x4_t vin0 = vandq_u32(vreinterpretq_s32_f32(vld1q_f32(in0 + index)), mask); uint32x4_t vin0 = vandq_u32(vreinterpretq_s32_f32(vld1q_f32(in0 + index)), mask);
uint32x4_t vin1 = vandq_u32(vreinterpretq_s32_f32(vld1q_f32(in1 + index)), mask); uint32x4_t vin1 = vandq_u32(vreinterpretq_s32_f32(vld1q_f32(in1 + index)), mask);
#else
uint32x4_t vin0 = vandq_u32(vreinterpretq_u32_f32(vld1q_f32(in0 + index)), mask);
uint32x4_t vin1 = vandq_u32(vreinterpretq_u32_f32(vld1q_f32(in1 + index)), mask);
#endif
float32x4_t vout = vbslq_f32(vceqq_u32(vandq_u32(vin0, vin1), zeros), vfalse, vtrue); float32x4_t vout = vbslq_f32(vceqq_u32(vandq_u32(vin0, vin1), zeros), vfalse, vtrue);
vst1q_f32(out + index, vout); vst1q_f32(out + index, vout);
} }
@ -133,8 +138,13 @@ int ElementLogicalOr(const float *in0, const float *in1, float *out, int size) {
uint32x4_t mask = vmovq_n_u32(((uint32_t)(1u << 31) - 1)); uint32x4_t mask = vmovq_n_u32(((uint32_t)(1u << 31) - 1));
uint32x4_t zeros = vdupq_n_u32(0); uint32x4_t zeros = vdupq_n_u32(0);
for (; index <= size - 4; index += C4NUM) { for (; index <= size - 4; index += C4NUM) {
#ifndef SUPPORT_NNIE
uint32x4_t vin0 = vandq_u32(vreinterpretq_s32_f32(vld1q_f32(in0 + index)), mask); uint32x4_t vin0 = vandq_u32(vreinterpretq_s32_f32(vld1q_f32(in0 + index)), mask);
uint32x4_t vin1 = vandq_u32(vreinterpretq_s32_f32(vld1q_f32(in1 + index)), mask); uint32x4_t vin1 = vandq_u32(vreinterpretq_s32_f32(vld1q_f32(in1 + index)), mask);
#else
uint32x4_t vin0 = vandq_u32(vreinterpretq_u32_f32(vld1q_f32(in0 + index)), mask);
uint32x4_t vin1 = vandq_u32(vreinterpretq_u32_f32(vld1q_f32(in1 + index)), mask);
#endif
float32x4_t vout = vbslq_f32(vceqq_u32(vorrq_u32(vin0, vin1), zeros), vfalse, vtrue); float32x4_t vout = vbslq_f32(vceqq_u32(vorrq_u32(vin0, vin1), zeros), vfalse, vtrue);
vst1q_f32(out + index, vout); vst1q_f32(out + index, vout);
} }

View File

@ -219,6 +219,7 @@ void TiledC4MatmulFp32(float *dst, const float *src, const float *weight, size_t
#endif #endif
#ifdef ENABLE_ARM32 #ifdef ENABLE_ARM32
#ifndef SUPPORT_NNIE
void DeConvWgMergeArm32(const float *src_ptr, float *dst_ptr, size_t src_step, size_t dst_step) { void DeConvWgMergeArm32(const float *src_ptr, float *dst_ptr, size_t src_step, size_t dst_step) {
asm volatile( asm volatile(
"mov r11, %[src_ptr]\n" "mov r11, %[src_ptr]\n"
@ -277,6 +278,66 @@ void DeConvWgMergeArm32(const float *src_ptr, float *dst_ptr, size_t src_step, s
: "r8", "r10", "r11", "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11"); : "r8", "r10", "r11", "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11");
return; return;
} }
#else
void DeConvWgMergeArm32(const float *src_ptr, float *dst_ptr, size_t src_step, size_t dst_step) {
asm volatile(
"mov r7, %[src_ptr]\n"
"mov r8, %[dst_ptr]\n"
"mov r10, r8\n"
"vld1.32 {q0}, [r7], %[src_step]\n"
"vld1.32 {q1}, [r8], %[dst_step]\n"
"vld1.32 {q2}, [r7], %[src_step]\n"
"vld1.32 {q3}, [r8], %[dst_step]\n"
"vadd.f32 q0, q0, q1\n"
"vld1.32 {q8}, [r7], %[src_step]\n"
"vadd.f32 q2, q2, q3\n"
"vst1.32 {q0}, [r10], %[dst_step]\n"
"vst1.32 {q2}, [r10], %[dst_step]\n"
"vld1.32 {q9}, [r8], %[dst_step]\n"
"vld1.32 {q10}, [r7], %[src_step]\n"
"vadd.f32 q8, q8, q9\n"
"vld1.32 {q11}, [r8], %[dst_step]\n"
"vadd.f32 q10, q10, q11\n"
"vld1.32 {q0}, [r7], %[src_step]\n"
"vst1.32 {q8}, [r10], %[dst_step]\n"
"vst1.32 {q10}, [r10], %[dst_step]\n"
"vld1.32 {q1}, [r8], %[dst_step]\n"
"vld1.32 {q2}, [r7], %[src_step]\n"
"vld1.32 {q3}, [r8], %[dst_step]\n"
"vadd.f32 q0, q0, q1\n"
"vadd.f32 q2, q2, q3\n"
"vst1.32 {q0}, [r10], %[dst_step]\n"
"vst1.32 {q2}, [r10], %[dst_step]\n"
"vld1.32 {q8}, [r7], %[src_step]\n"
"vld1.32 {q9}, [r8], %[dst_step]\n"
"vld1.32 {q10}, [r7], %[src_step]\n"
"vld1.32 {q11}, [r8], %[dst_step]\n"
"vadd.f32 q8, q8, q9\n"
"vadd.f32 q10, q10, q11\n"
"vst1.32 {q8}, [r10], %[dst_step]\n"
"vst1.32 {q10}, [r10], %[dst_step]\n"
:
: [ src_ptr ] "r"(src_ptr), [ dst_ptr ] "r"(dst_ptr), [ src_step ] "r"(src_step), [ dst_step ] "r"(dst_step)
: "r8", "r10", "r7", "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11");
return;
}
#endif
#endif #endif
void DeConvWgMerge(const float *src, float *dst, size_t src_stride, size_t dst_stride, size_t count) { void DeConvWgMerge(const float *src, float *dst, size_t src_stride, size_t dst_stride, size_t count) {

View File

@ -420,6 +420,7 @@ void RowMajor2Col8Major_arm64(const float *src_c, float *dst_c, size_t col) {
} }
#endif #endif
#ifdef ENABLE_ARM32 #ifdef ENABLE_ARM32
#ifndef SUPPORT_NNIE
void RowMajor2Col8Major_arm32(const float *src_c, float *dst_c, size_t col) { void RowMajor2Col8Major_arm32(const float *src_c, float *dst_c, size_t col) {
size_t stride = col * sizeof(float); size_t stride = col * sizeof(float);
asm volatile( asm volatile(
@ -462,6 +463,50 @@ void RowMajor2Col8Major_arm32(const float *src_c, float *dst_c, size_t col) {
: "r10", "r11", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7"); : "r10", "r11", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7");
return; return;
} }
#else
void RowMajor2Col8Major_arm32(const float *src_c, float *dst_c, size_t col) {
size_t stride = col * sizeof(float);
asm volatile(
"mov r10, %[src_c]\n"
"mov r7, %[dst_c]\n"
"vld1.32 {q0}, [r10], %[stride]\n"
"vld1.32 {q2}, [r10], %[stride]\n"
"vld1.32 {q4}, [r10], %[stride]\n"
"vld1.32 {q6}, [r10], %[stride]\n"
"vtrn.32 d0, d4\n"
"vtrn.32 d1, d5\n"
"vtrn.32 d8, d12\n"
"vtrn.32 d9, d13\n"
"vld1.32 {q1}, [r10], %[stride]\n"
"vld1.32 {q3}, [r10], %[stride]\n"
"vld1.32 {q5}, [r10], %[stride]\n"
"vld1.32 {q7}, [r10], %[stride]\n"
"vswp d1, d8\n"
"vswp d5, d12\n"
"vtrn.32 d2, d6\n"
"vtrn.32 d3, d7\n"
"vtrn.32 d10, d14\n"
"vtrn.32 d11, d15\n"
"vswp d3, d10\n"
"vswp d7, d14\n"
"vst1.32 {q0, q1}, [r7]!\n"
"vst1.32 {q2, q3}, [r7]!\n"
"vst1.32 {q4, q5}, [r7]!\n"
"vst1.32 {q6, q7}, [r7]!\n"
:
: [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride)
: "r10", "r7", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7");
return;
}
#endif
#endif #endif
void RowMajor2Col8Major(const float *src_ptr, float *dst_ptr, size_t row, size_t col) { void RowMajor2Col8Major(const float *src_ptr, float *dst_ptr, size_t row, size_t col) {
size_t row8 = row / C8NUM * C8NUM; size_t row8 = row / C8NUM * C8NUM;

View File

@ -353,6 +353,7 @@ static void RowMajor2Col8MajorStride(const float *src_ptr, float *dst_ptr, size_
"v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29",
"v30", "v31"); "v30", "v31");
#elif ENABLE_ARM32 #elif ENABLE_ARM32
#ifndef SUPPORT_NNIE
/* 8x4 row-major to col-major */ /* 8x4 row-major to col-major */
size_t stride = col * sizeof(float); size_t stride = col * sizeof(float);
asm volatile( asm volatile(
@ -393,6 +394,48 @@ static void RowMajor2Col8MajorStride(const float *src_ptr, float *dst_ptr, size_
: :
: [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride) : [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride)
: "r10", "r11", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7"); : "r10", "r11", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7");
#else
/* 8x4 row-major to col-major */
size_t stride = col * sizeof(float);
asm volatile(
"mov r10, %[src_c]\n"
"mov r7, %[dst_c]\n"
"vld1.32 {q0}, [r10], %[stride]\n"
"vld1.32 {q2}, [r10], %[stride]\n"
"vld1.32 {q4}, [r10], %[stride]\n"
"vld1.32 {q6}, [r10], %[stride]\n"
"vtrn.32 d0, d4\n"
"vtrn.32 d1, d5\n"
"vtrn.32 d8, d12\n"
"vtrn.32 d9, d13\n"
"vld1.32 {q1}, [r10], %[stride]\n"
"vld1.32 {q3}, [r10], %[stride]\n"
"vld1.32 {q5}, [r10], %[stride]\n"
"vld1.32 {q7}, [r10], %[stride]\n"
"vswp d1, d8\n"
"vswp d5, d12\n"
"vtrn.32 d2, d6\n"
"vtrn.32 d3, d7\n"
"vtrn.32 d10, d14\n"
"vtrn.32 d11, d15\n"
"vswp d3, d10\n"
"vswp d7, d14\n"
"vst1.32 {q0, q1}, [r7]!\n"
"vst1.32 {q2, q3}, [r7]!\n"
"vst1.32 {q4, q5}, [r7]!\n"
"vst1.32 {q6, q7}, [r7]!\n"
:
: [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride)
: "r10", "r7", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7");
#endif
#else #else
for (int tr = 0; tr < 8; tr++) { for (int tr = 0; tr < 8; tr++) {
for (int tc = 0; tc < 4; tc++) { for (int tc = 0; tc < 4; tc++) {

View File

@ -146,7 +146,8 @@ void MaxPoolingGrad(const float *input_ptr, const float *dy_ptr, float *output_p
#ifdef ENABLE_ARM #ifdef ENABLE_ARM
uint32x4_t index = {val_idx, val_idx + 1, val_idx + 2, val_idx + 3}; uint32x4_t index = {val_idx, val_idx + 1, val_idx + 2, val_idx + 3};
float32x4_t in = vld1q_f32(inPtr + val_idx); float32x4_t in = vld1q_f32(inPtr + val_idx);
max_idx = MaxIndex(in, &max_val, index, max_idx); max_idx = vreinterpretq_u32_s32(
MaxIndex(in, &max_val, vreinterpretq_s32_u32(index), vreinterpretq_s32_u32(max_idx)));
#else #else
float val[4] = {inPtr[val_idx], inPtr[val_idx + 1], inPtr[val_idx + 2], inPtr[val_idx + 3]}; float val[4] = {inPtr[val_idx], inPtr[val_idx + 1], inPtr[val_idx + 2], inPtr[val_idx + 3]};
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {

View File

@ -137,10 +137,10 @@ void RowMajor2Row16x4MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int co
"vld1.8 {q2}, [r0], r2 \n" "vld1.8 {q2}, [r0], r2 \n"
"vld1.8 {q3}, [r0], r2 \n" "vld1.8 {q3}, [r0], r2 \n"
"vst1.32 q0, [r1], r3 \n" "vst1.32 {d0, d1}, [r1], r3 \n"
"vst1.32 q1, [r1], r3 \n" "vst1.32 {d2, d3}, [r1], r3 \n"
"vst1.32 q2, [r1], r3 \n" "vst1.32 {d4, d5}, [r1], r3 \n"
"vst1.32 q3, [r1], r3 \n" "vst1.32 {d6, d7}, [r1], r3 \n"
: :
: [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ col_offset ] "r"(col_offset) : [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ col_offset ] "r"(col_offset)

View File

@ -17,6 +17,7 @@
#include "base/base.h" #include "base/base.h"
#include <atomic> #include <atomic>
#include <mutex> #include <mutex>
#include <string>
#include <unordered_map> #include <unordered_map>
namespace mindspore { namespace mindspore {
@ -32,7 +33,7 @@ uint32_t Base::GetTypeId(const char *const type_name) {
if (it != t->map.end()) { if (it != t->map.end()) {
return it->second; return it->second;
} }
uint32_t tid = ++(t->type_counter); uint32_t tid = std::hash<std::string>()(type_name);
t->map[type_name] = tid; t->map[type_name] = tid;
return tid; return tid;
} }

View File

@ -589,6 +589,8 @@ inline const PrimitivePtr kPrimIf = std::make_shared<Primitive>("If");
inline const PrimitivePtr kPrimAvgPoolFusion = std::make_shared<Primitive>("AvgPoolFusion"); inline const PrimitivePtr kPrimAvgPoolFusion = std::make_shared<Primitive>("AvgPoolFusion");
inline const PrimitivePtr kPrimMaxPoolFusion = std::make_shared<Primitive>("MaxPoolFusion"); inline const PrimitivePtr kPrimMaxPoolFusion = std::make_shared<Primitive>("MaxPoolFusion");
inline const PrimitivePtr kPrimActivation = std::make_shared<Primitive>("Activation"); inline const PrimitivePtr kPrimActivation = std::make_shared<Primitive>("Activation");
inline const PrimitivePtr kPrimPReLUFusion = std::make_shared<Primitive>("PReLUFusion");
inline const PrimitivePtr kPrimCustom = std::make_shared<Primitive>("Custom");
inline const PrimitivePtr kPrimTopKFusion = std::make_shared<Primitive>("TopKFusion"); inline const PrimitivePtr kPrimTopKFusion = std::make_shared<Primitive>("TopKFusion");
inline const PrimitivePtr kPrimTileFusion = std::make_shared<Primitive>("TileFusion"); inline const PrimitivePtr kPrimTileFusion = std::make_shared<Primitive>("TileFusion");
inline const PrimitivePtr kPrimReduceFusion = std::make_shared<Primitive>("ReduceFusion"); inline const PrimitivePtr kPrimReduceFusion = std::make_shared<Primitive>("ReduceFusion");

View File

@ -42,6 +42,7 @@ void InterThreadPool::ThreadAsyncRun(Worker *worker) {
} }
void InterThreadPool::ActorThreadRun() { void InterThreadPool::ActorThreadRun() {
#ifndef SUPPORT_NNIE
ActorReference actor; ActorReference actor;
{ {
std::unique_lock<std::mutex> _l(actor_mutex_); std::unique_lock<std::mutex> _l(actor_mutex_);
@ -54,6 +55,7 @@ void InterThreadPool::ActorThreadRun() {
} }
actor->Run(); actor->Run();
finish_cond_var_.notify_one(); finish_cond_var_.notify_one();
#endif
} }
void InterThreadPool::EnqueReadyActor(const ActorReference &actor) { void InterThreadPool::EnqueReadyActor(const ActorReference &actor) {

View File

@ -100,7 +100,7 @@ void ThreadPool::KernelThreadRun(Worker *worker) {
int ThreadPool::ParallelLaunch(const Func &func, Content content, int task_num) { int ThreadPool::ParallelLaunch(const Func &func, Content content, int task_num) {
// distribute task to the KernelThread and the free ActorThread, // distribute task to the KernelThread and the free ActorThread,
// if the task num is greater than the KernelThread num // if the task num is greater than the KernelThread num
Task task = Task(func, content); Task task = {func, content};
Worker *curr = CurrentWorker(); Worker *curr = CurrentWorker();
if (inter_thread_num_ == thread_num_ || curr == nullptr) { if (inter_thread_num_ == thread_num_ || curr == nullptr) {
SyncRunTask(&task, task_num); SyncRunTask(&task, task_num);

View File

@ -3,7 +3,12 @@ project(Lite)
set(BUILD_LITE "on") set(BUILD_LITE "on")
if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7.3.0) if(TOOLCHAIN_NAME STREQUAL "himix200")
set(TARGET_HIMIX200 on)
add_compile_definitions(SUPPORT_NNIE)
endif()
if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7.3.0 AND NOT TARGET_HIMIX200)
message(FATAL_ERROR "GCC version ${CMAKE_CXX_COMPILER_VERSION} must not be less than 7.3.0") message(FATAL_ERROR "GCC version ${CMAKE_CXX_COMPILER_VERSION} must not be less than 7.3.0")
endif() endif()
@ -22,6 +27,7 @@ option(MSLITE_ENABLE_AVX "enable AVX instruction set, only x86_64 support" off)
option(MSLITE_ENABLE_CONVERTER "enable converter, only x86_64 support" on) option(MSLITE_ENABLE_CONVERTER "enable converter, only x86_64 support" on)
option(MSLITE_ENABLE_TOOLS "enable tools" on) option(MSLITE_ENABLE_TOOLS "enable tools" on)
option(MSLITE_ENABLE_TESTCASES "enable testcase" off) option(MSLITE_ENABLE_TESTCASES "enable testcase" off)
option(MSLITE_ENABLE_NNIE "enable NNIE" off)
# Option that can be configured through manually # Option that can be configured through manually
option(ENABLE_ARM82_A32 "if build fp16 on platform_arm32" off) option(ENABLE_ARM82_A32 "if build fp16 on platform_arm32" off)
@ -55,6 +61,9 @@ endif()
if(DEFINED ENV{MSLITE_ENABLE_TESTCASES}) if(DEFINED ENV{MSLITE_ENABLE_TESTCASES})
set(MSLITE_ENABLE_TESTCASES $ENV{MSLITE_ENABLE_TESTCASES}) set(MSLITE_ENABLE_TESTCASES $ENV{MSLITE_ENABLE_TESTCASES})
endif() endif()
if(DEFINED ENV{MSLITE_ENABLE_NNIE})
set(MSLITE_ENABLE_NNIE $ENV{MSLITE_ENABLE_NNIE})
endif()
if(PLATFORM_ARM64 OR PLATFORM_ARM32) if(PLATFORM_ARM64 OR PLATFORM_ARM32)
set(PLATFORM_ARM "on") set(PLATFORM_ARM "on")
@ -197,6 +206,9 @@ else()
-Wno-deprecated-declarations -Wno-missing-braces ${CMAKE_C_FLAGS}") -Wno-deprecated-declarations -Wno-missing-braces ${CMAKE_C_FLAGS}")
set(CMAKE_CXX_FLAGS "-fPIC -fPIE -D_FORTIFY_SOURCE=2 -O2 -Wall -Werror -fstack-protector-strong -Wno-attributes \ set(CMAKE_CXX_FLAGS "-fPIC -fPIE -D_FORTIFY_SOURCE=2 -O2 -Wall -Werror -fstack-protector-strong -Wno-attributes \
-Wno-deprecated-declarations -Wno-missing-braces -Wno-overloaded-virtual ${CMAKE_CXX_FLAGS}") -Wno-deprecated-declarations -Wno-missing-braces -Wno-overloaded-virtual ${CMAKE_CXX_FLAGS}")
if(TARGET_HIMIX200)
set(CMAKE_CXX_FLAGS "-Wno-error=maybe-uninitialized ${CMAKE_CXX_FLAGS}")
endif()
if(NOT WIN32) if(NOT WIN32)
set(CMAKE_SHARED_LINKER_FLAGS "-Wl,-z,relro,-z,now -Wl,-z,noexecstack ${CMAKE_SHARED_LINKER_FLAGS}") set(CMAKE_SHARED_LINKER_FLAGS "-Wl,-z,relro,-z,now -Wl,-z,noexecstack ${CMAKE_SHARED_LINKER_FLAGS}")
set(CMAKE_EXE_LINKER_FLAGS "-Wl,-z,relro,-z,now -Wl,-z,noexecstack ${CMAKE_EXE_LINKER_FLAGS}") set(CMAKE_EXE_LINKER_FLAGS "-Wl,-z,relro,-z,now -Wl,-z,noexecstack ${CMAKE_EXE_LINKER_FLAGS}")

View File

@ -14,6 +14,9 @@ if(PLATFORM_ARM32 OR PLATFORM_ARM64)
-fdata-sections -ffast-math -fno-rtti -fno-exceptions") -fdata-sections -ffast-math -fno-rtti -fno-exceptions")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fomit-frame-pointer -fstrict-aliasing -ffunction-sections \ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fomit-frame-pointer -fstrict-aliasing -ffunction-sections \
-fdata-sections -ffast-math -fno-rtti -fno-exceptions") -fdata-sections -ffast-math -fno-rtti -fno-exceptions")
if(TARGET_HIMIX200)
string(REPLACE "-fno-rtti " "" CMAKE_C_FLAGS ${CMAKE_C_FLAGS})
endif()
endif() endif()
if("${CMAKE_BUILD_TYPE}" STREQUAL "Release" AND APPLE) if("${CMAKE_BUILD_TYPE}" STREQUAL "Release" AND APPLE)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fstrict-aliasing -ffunction-sections \ set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fstrict-aliasing -ffunction-sections \
@ -25,7 +28,10 @@ if(PLATFORM_ARM32 OR PLATFORM_ARM64)
endif() endif()
endif() endif()
set(API_SRC if(TARGET_HIMIX200)
set(API_SRC)
else()
set(API_SRC
${CORE_DIR}/utils/status.cc ${CORE_DIR}/utils/status.cc
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/cell.cc ${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/cell.cc
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/serialization.cc ${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/serialization.cc
@ -35,7 +41,8 @@ set(API_SRC
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/model/model_impl.cc ${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/model/model_impl.cc
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/graph/graph.cc ${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/graph/graph.cc
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/tensor/tensor_impl.cc ${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/tensor/tensor_impl.cc
) )
endif()
if(SUPPORT_NPU) if(SUPPORT_NPU)
include_directories(${DDK_PATH}) include_directories(${DDK_PATH})
@ -138,6 +145,16 @@ if(ENABLE_MINDRT)
${CMAKE_CURRENT_SOURCE_DIR}/lite_mindrt.cc ${CMAKE_CURRENT_SOURCE_DIR}/lite_mindrt.cc
${CMAKE_CURRENT_SOURCE_DIR}/mindrt_executor.cc ${CMAKE_CURRENT_SOURCE_DIR}/mindrt_executor.cc
) )
elseif(TARGET_HIMIX200)
include_directories(${CORE_DIR}/mindrt)
include_directories(${CORE_DIR}/mindrt/include)
include_directories(${CORE_DIR}/mindrt/src)
set(LITE_SRC
${LITE_SRC}
${CORE_DIR}/mindrt/src/thread/core_affinity.cc
${CORE_DIR}/mindrt/src/thread/inter_threadpool.cc
${CORE_DIR}/mindrt/src/thread/threadpool.cc
)
endif() endif()
add_subdirectory(ops) add_subdirectory(ops)
@ -199,7 +216,7 @@ if(SUPPORT_NPU)
target_link_libraries(mindspore-lite npu_kernel_mid) target_link_libraries(mindspore-lite npu_kernel_mid)
target_link_libraries(mindspore-lite_static npu_kernel_mid) target_link_libraries(mindspore-lite_static npu_kernel_mid)
endif() endif()
if(PLATFORM_ARM32 OR PLATFORM_ARM64) if(PLATFORM_ARM32 OR PLATFORM_ARM64 AND NOT TARGET_HIMIX200)
target_link_libraries(mindspore-lite log) target_link_libraries(mindspore-lite log)
target_link_libraries(mindspore-lite_static log) target_link_libraries(mindspore-lite_static log)
endif() endif()
@ -222,12 +239,15 @@ if(SUPPORT_TRAIN)
target_link_libraries(mindspore-lite-train_static minddata-lite mindspore-lite) target_link_libraries(mindspore-lite-train_static minddata-lite mindspore-lite)
endif() endif()
if("${CMAKE_BUILD_TYPE}" STREQUAL "Release") if(NOT APPLE AND PLATFORM_ARM AND NOT TARGET_HIMIX200)
if(NOT APPLE AND PLATFORM_ARM) set(NDK_STRIP
set(NDK_STRIP "${ANDROID_NDK}/toolchains/aarch64-linux-android-4.9/prebuilt/linux-x86_64/aarch64-linux-android/bin/strip")
"${ANDROID_NDK}/toolchains/aarch64-linux-android-4.9/prebuilt/linux-x86_64/aarch64-linux-android/bin/strip") endif()
if(NOT APPLE AND "${CMAKE_BUILD_TYPE}" STREQUAL "Release")
if(PLATFORM_ARM)
add_custom_command(TARGET mindspore-lite POST_BUILD COMMAND ${NDK_STRIP} add_custom_command(TARGET mindspore-lite POST_BUILD COMMAND ${NDK_STRIP}
${CMAKE_BINARY_DIR}/src/libmindspore-lite*.so) ${CMAKE_BINARY_DIR}/src/libmindspore-lite*.so)
elseif(NOT WIN32) elseif(NOT WIN32)
add_custom_command(TARGET mindspore-lite POST_BUILD COMMAND strip ${CMAKE_BINARY_DIR}/src/libmindspore-lite*.so) add_custom_command(TARGET mindspore-lite POST_BUILD COMMAND strip ${CMAKE_BINARY_DIR}/src/libmindspore-lite*.so)
endif() endif()

View File

@ -93,11 +93,9 @@ const char *EnumStrForMsLogLevel(MsLogLevel level) {
void LogWriter::OutputLog(const std::ostringstream &msg) const { void LogWriter::OutputLog(const std::ostringstream &msg) const {
if (IsPrint(log_level_)) { if (IsPrint(log_level_)) {
#ifdef ENABLE_ARM #if defined(ENABLE_ARM) && (defined(__ANDROID__) || defined(ANDROID))
#if defined(__ANDROID__) || defined(ANDROID)
__android_log_print(GetAndroidLogLevel(log_level_), ANDROID_LOG_TAG, "[%s:%d] %s] %s", location_.file_, __android_log_print(GetAndroidLogLevel(log_level_), ANDROID_LOG_TAG, "[%s:%d] %s] %s", location_.file_,
location_.line_, location_.func_, msg.str().c_str()); location_.line_, location_.func_, msg.str().c_str());
#endif
#else #else
printf("%s [%s:%d] %s] %s\n", EnumStrForMsLogLevel(log_level_), location_.file_, location_.line_, location_.func_, printf("%s [%s:%d] %s] %s\n", EnumStrForMsLogLevel(log_level_), location_.file_, location_.line_, location_.func_,
msg.str().c_str()); msg.str().c_str());

View File

@ -139,7 +139,7 @@ class LiteSession : public session::LiteSession {
std::unordered_map<Tensor *, Tensor *> graph_output_map_; /* <calculate-tensor, graph-output-tensor> */ std::unordered_map<Tensor *, Tensor *> graph_output_map_; /* <calculate-tensor, graph-output-tensor> */
Executor *executor_ = nullptr; Executor *executor_ = nullptr;
Model *model_ = nullptr; Model *model_ = nullptr;
std::atomic<bool> is_running_ = false; std::atomic<bool> is_running_ = {false};
bool is_train_session_ = false; bool is_train_session_ = false;
friend class TransferSession; friend class TransferSession;
#if SUPPORT_NPU #if SUPPORT_NPU

View File

@ -52,7 +52,7 @@ class DefaultAllocator : public Allocator {
void UnLock(); void UnLock();
bool ReuseMemory(size_t free_size, size_t size); bool ReuseMemory(size_t free_size, size_t size);
struct MemBuf { struct MemBuf {
std::atomic_int ref_count_ = 0; std::atomic_int ref_count_ = {0};
size_t size; size_t size;
void *buf; void *buf;
}; };

View File

@ -47,7 +47,7 @@ int NonZeroCPUKernel::Run() {
} }
auto non_zero_nums = out_tensor->shape()[1]; auto non_zero_nums = out_tensor->shape()[1];
int non_zero_count = 0; int non_zero_count = 0;
std::vector coordiate_values(in_tensor->shape().size(), 0); std::vector<int> coordiate_values(in_tensor->shape().size(), 0);
for (int i = 0; i < in_tensor->ElementsNum(); i += 1) { for (int i = 0; i < in_tensor->ElementsNum(); i += 1) {
if (input_data[i]) { if (input_data[i]) {
for (size_t j = 0; j < input_dim_size; j++) { for (size_t j = 0; j < input_dim_size; j++) {

View File

@ -145,7 +145,7 @@ class CpuFp32SubGraph : public CpuSubGraph {
std::vector<LiteKernel *> nodes, Kernel *kernel) std::vector<LiteKernel *> nodes, Kernel *kernel)
: CpuSubGraph(std::move(in_kernels), std::move(out_kernels), std::move(nodes), kernel) { : CpuSubGraph(std::move(in_kernels), std::move(out_kernels), std::move(nodes), kernel) {
subgraph_type_ = kCpuFP32SubGraph; subgraph_type_ = kCpuFP32SubGraph;
static std::atomic_int index = 0; static std::atomic_int index = {0};
this->set_name("CpuFP32SubGraph" + std::to_string(index++)); this->set_name("CpuFP32SubGraph" + std::to_string(index++));
desc_.data_type = kNumberTypeFloat32; desc_.data_type = kNumberTypeFloat32;
} }

View File

@ -217,8 +217,8 @@ class Tensor : public mindspore::tensor::MSTensor {
std::vector<int> shape_; std::vector<int> shape_;
schema::Format format_; schema::Format format_;
Category category_; Category category_;
std::atomic_int ref_count_ = 0; std::atomic_int ref_count_ = {0};
int init_ref_count_ = 0; size_t init_ref_count_ = 0;
std::vector<QuantArg> quant_params_; std::vector<QuantArg> quant_params_;
std::vector<float> quant_clusters_; std::vector<float> quant_clusters_;
AllocatorPtr allocator_ = nullptr; AllocatorPtr allocator_ = nullptr;

View File

@ -12,7 +12,7 @@ add_executable(benchmark
add_dependencies(benchmark fbs_src) add_dependencies(benchmark fbs_src)
if(PLATFORM_ARM32 OR PLATFORM_ARM64) if(PLATFORM_ARM32 OR PLATFORM_ARM64 AND NOT TARGET_HIMIX200)
if(SUPPORT_NPU AND ANDROID_STL STREQUAL "c++_static") if(SUPPORT_NPU AND ANDROID_STL STREQUAL "c++_static")
target_link_libraries(benchmark mindspore-lite mindspore::json c++_shared) target_link_libraries(benchmark mindspore-lite mindspore::json c++_shared)
else() else()

View File

@ -125,6 +125,8 @@ set(LITE_SRC
${SRC_DIR}/tensor.cc ${SRC_DIR}/tensor.cc
${SRC_DIR}/ms_tensor.cc ${SRC_DIR}/ms_tensor.cc
${SRC_DIR}/tensorlist.cc ${SRC_DIR}/tensorlist.cc
${SRC_DIR}/registry/kernel_interface_registry.cc
${SRC_DIR}/registry/kernel_interface.cc
${SRC_DIR}/kernel_registry.cc ${SRC_DIR}/kernel_registry.cc
${SRC_DIR}/inner_kernel.cc ${SRC_DIR}/inner_kernel.cc
${SRC_DIR}/lite_kernel.cc ${SRC_DIR}/lite_kernel.cc

View File

@ -45,7 +45,10 @@ Flags::Flags() {
AddFlag(&Flags::bitNumIn, "bitNum", "Weight quantization bitNum", "8"); AddFlag(&Flags::bitNumIn, "bitNum", "Weight quantization bitNum", "8");
AddFlag(&Flags::quantWeightSizeStr, "quantWeightSize", "Weight quantization size threshold", "0"); AddFlag(&Flags::quantWeightSizeStr, "quantWeightSize", "Weight quantization size threshold", "0");
AddFlag(&Flags::quantWeightChannelStr, "quantWeightChannel", "Channel threshold for weight quantization", "16"); AddFlag(&Flags::quantWeightChannelStr, "quantWeightChannel", "Channel threshold for weight quantization", "16");
AddFlag(&Flags::configFile, "configFile", "Configuration for post-training, offline split op to parallel", ""); AddFlag(&Flags::configFile, "configFile",
"Configuration for post-training, offline split op to parallel,"
"and converter for nnie",
"");
AddFlag(&Flags::trainModelIn, "trainModel", AddFlag(&Flags::trainModelIn, "trainModel",
"whether the model is going to be trained on device. " "whether the model is going to be trained on device. "
"true | false", "true | false",

View File

@ -56,6 +56,7 @@
#include "ops/op_utils.h" #include "ops/op_utils.h"
#include "ops/quant_dtype_cast.h" #include "ops/quant_dtype_cast.h"
#include "ops/resize.h" #include "ops/resize.h"
#include "ops/roi_pooling.h"
#include "ops/sgd.h" #include "ops/sgd.h"
#include "ops/space_to_batch.h" #include "ops/space_to_batch.h"
#include "ops/space_to_batch_nd.h" #include "ops/space_to_batch_nd.h"
@ -88,6 +89,7 @@ static const std::unordered_map<std::string, std::vector<size_t>> NHWCOpMap = {
{ops::kNamePReLUFusion, {1}}, {ops::kNamePReLUFusion, {1}},
{ops::kNameResize, {1}}, {ops::kNameResize, {1}},
{ops::kNameResizeGrad, {}}, {ops::kNameResizeGrad, {}},
{ops::kNameROIPooling, {1}},
{ops::kNameSGD, {2}}, {ops::kNameSGD, {2}},
{ops::kNameSpaceToBatch, {1}}, {ops::kNameSpaceToBatch, {1}},
{ops::kNameSpaceToBatchND, {1}}, {ops::kNameSpaceToBatchND, {1}},

View File

@ -132,45 +132,49 @@ STATUS NodeInferShape::InferShape(const CNodePtr &cnode) {
fbb.Clear(); fbb.Clear();
return lite::RET_ERROR; return lite::RET_ERROR;
} }
auto parameter_gen = lite::PopulateRegistry::GetInstance()->GetParameterCreator(prim->value_type(), lite::SCHEMA_CUR); auto ret = KernelInferShape(inputs, outputs, prim, {});
if (parameter_gen == nullptr) { if (ret == lite::RET_NOT_SUPPORT) {
MS_LOG(ERROR) << "PopulateParameter return nullptr, type: " << schema::EnumNamePrimitiveType(prim->value_type()); auto parameter_gen =
FreeTensors(&inputs); lite::PopulateRegistry::GetInstance()->GetParameterCreator(prim->value_type(), lite::SCHEMA_CUR);
FreeTensors(&outputs); if (parameter_gen == nullptr) {
fbb.Clear(); MS_LOG(ERROR) << "PopulateParameter return nullptr, type: " << schema::EnumNamePrimitiveType(prim->value_type());
return lite::RET_ERROR;
}
auto parameter = parameter_gen(prim);
if (parameter == nullptr) {
MS_LOG(ERROR) << "parameter is nullptr.";
FreeTensors(&inputs);
FreeTensors(&outputs);
fbb.Clear();
return lite::RET_ERROR;
}
RectifyFormat(cnode, inputs, fmk_type_);
auto status = KernelInferShape(inputs, outputs, parameter);
if (status == lite::RET_OK) {
anf_prim->AddAttr(kInferDone, MakeValue<bool>(true));
}
if (status == lite::RET_OK || status == lite::RET_INFER_INVALID) {
auto set_status = SetCNodeAbstract(cnode, outputs, status);
if (set_status != lite::RET_OK) {
MS_LOG(ERROR) << "set CNode abstract failed: " << cnode->fullname_with_scope();
FreeTensors(&inputs); FreeTensors(&inputs);
FreeTensors(&outputs); FreeTensors(&outputs);
free(parameter);
fbb.Clear(); fbb.Clear();
return set_status; return lite::RET_ERROR;
} }
} else { auto parameter = parameter_gen(prim);
MS_LOG(ERROR) << "infer shape failed."; if (parameter == nullptr) {
MS_LOG(ERROR) << "parameter is nullptr.";
FreeTensors(&inputs);
FreeTensors(&outputs);
fbb.Clear();
return lite::RET_ERROR;
}
RectifyFormat(cnode, inputs, fmk_type_);
ret = KernelInferShape(inputs, outputs, parameter);
if (ret == lite::RET_OK) {
anf_prim->AddAttr(kInferDone, MakeValue<bool>(true));
}
if (ret == lite::RET_OK || ret == lite::RET_INFER_INVALID) {
auto set_status = SetCNodeAbstract(cnode, outputs, ret);
if (set_status != lite::RET_OK) {
MS_LOG(ERROR) << "set CNode abstract failed: " << cnode->fullname_with_scope();
FreeTensors(&inputs);
FreeTensors(&outputs);
free(parameter);
fbb.Clear();
return set_status;
}
} else {
MS_LOG(ERROR) << "infer shape failed.";
}
FreeTensors(&inputs);
FreeTensors(&outputs);
free(parameter);
} }
FreeTensors(&inputs);
FreeTensors(&outputs);
free(parameter);
fbb.Clear(); fbb.Clear();
return status; return ret;
} }
std::vector<int> NodeInferShape::GetInputShape(const CNodePtr &cnode, size_t index) { std::vector<int> NodeInferShape::GetInputShape(const CNodePtr &cnode, size_t index) {

Binary file not shown.