add fp16 cropping
This commit is contained in:
parent
9394686885
commit
05c873f06a
|
@ -116,6 +116,6 @@ if(ENABLE_CPU)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
########################### arm fp16 build optimize library ########################
|
########################### arm fp16 build optimize library ########################
|
||||||
if(ENABLE_FP16)
|
if(PLATFORM_ARM)
|
||||||
add_subdirectory(${NNACL_DIR}/optimize)
|
add_subdirectory(${NNACL_DIR}/optimize)
|
||||||
endif()
|
endif()
|
||||||
|
|
|
@ -35,15 +35,14 @@ if(NOT PLATFORM_ARM32)
|
||||||
list(APPEND SDOT_FILES ${SDOT_SRC})
|
list(APPEND SDOT_FILES ${SDOT_SRC})
|
||||||
add_library(nnacl_optimize_mid OBJECT ${SDOT_FILES})
|
add_library(nnacl_optimize_mid OBJECT ${SDOT_FILES})
|
||||||
add_dependencies(nnacl_optimize_mid fbs_src)
|
add_dependencies(nnacl_optimize_mid fbs_src)
|
||||||
endif()
|
|
||||||
|
|
||||||
if(ENABLE_FP16)
|
|
||||||
add_library(nnacl_fp16_mid OBJECT ${FP16_FILES})
|
|
||||||
if(PLATFORM_ARM32)
|
|
||||||
target_compile_options(nnacl_fp16_mid PRIVATE -march=armv8.2-a+fp16 -mfpu=neon-fp-armv8 -mfloat-abi=softfp)
|
|
||||||
else()
|
|
||||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=armv8.2-a+dotprod+fp16")
|
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=armv8.2-a+dotprod+fp16")
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=armv8.2-a+dotprod+fp16")
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=armv8.2-a+dotprod+fp16")
|
||||||
endif()
|
endif()
|
||||||
add_dependencies(nnacl_fp16_mid fbs_src)
|
|
||||||
|
if(MSLITE_ENABLE_FP16)
|
||||||
|
add_library(nnacl_fp16_mid OBJECT ${FP16_FILES})
|
||||||
|
add_dependencies(nnacl_fp16_mid fbs_src)
|
||||||
|
if(PLATFORM_ARM32)
|
||||||
|
target_compile_options(nnacl_fp16_mid PRIVATE -march=armv8.2-a+fp16 -mfpu=neon-fp-armv8 -mfloat-abi=softfp)
|
||||||
|
endif()
|
||||||
endif()
|
endif()
|
|
@ -18,12 +18,6 @@ if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_L
|
||||||
message(FATAL_ERROR "GCC version ${CMAKE_CXX_COMPILER_VERSION} must not be less than 7.3.0")
|
message(FATAL_ERROR "GCC version ${CMAKE_CXX_COMPILER_VERSION} must not be less than 7.3.0")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(PLATFORM_ARM32 AND CMAKE_CXX_COMPILER_ID STREQUAL "Clang" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 9.0)
|
|
||||||
set(ENABLE_FP16 "off")
|
|
||||||
message(STATUS "If you want to build fp16 in arm82_a32, \
|
|
||||||
your Clang version:[${CMAKE_CXX_COMPILER_VERSION}] must not be less than 9.0 and please use android nkd r21e!")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
#Options that can be configured through environment variables or manually
|
#Options that can be configured through environment variables or manually
|
||||||
set(MSLITE_GPU_BACKEND "" CACHE STRING "enable gpu backend, \
|
set(MSLITE_GPU_BACKEND "" CACHE STRING "enable gpu backend, \
|
||||||
only arm64 support opencl, only x86_64 support tensorrt, opencl/cuda/tensorrt/off")
|
only arm64 support opencl, only x86_64 support tensorrt, opencl/cuda/tensorrt/off")
|
||||||
|
@ -46,6 +40,7 @@ option(MSLITE_CUSTOM_KERNEL_REGISTRY "enable extend kernel registry" on)
|
||||||
option(MSLITE_ENABLE_MINDRT "enable mindrt use" on)
|
option(MSLITE_ENABLE_MINDRT "enable mindrt use" on)
|
||||||
option(MSLITE_DELEGATE_USE "enable delegate use" on)
|
option(MSLITE_DELEGATE_USE "enable delegate use" on)
|
||||||
option(MSLITE_ENABLE_V0 "support v0 schema" on)
|
option(MSLITE_ENABLE_V0 "support v0 schema" on)
|
||||||
|
option(MSLITE_ENABLE_FP16 "Whether to compile Fp16 operator" off)
|
||||||
|
|
||||||
#Option that can be configured through manually
|
#Option that can be configured through manually
|
||||||
option(ENABLE_VERBOSE "" off)
|
option(ENABLE_VERBOSE "" off)
|
||||||
|
@ -120,6 +115,9 @@ endif()
|
||||||
if(DEFINED ENV{MSLITE_ENABLE_V0})
|
if(DEFINED ENV{MSLITE_ENABLE_V0})
|
||||||
set(MSLITE_ENABLE_V0 $ENV{MSLITE_ENABLE_V0})
|
set(MSLITE_ENABLE_V0 $ENV{MSLITE_ENABLE_V0})
|
||||||
endif()
|
endif()
|
||||||
|
if(DEFINED ENV{MSLITE_ENABLE_FP16})
|
||||||
|
set(MSLITE_ENABLE_FP16 $ENV{MSLITE_ENABLE_FP16})
|
||||||
|
endif()
|
||||||
|
|
||||||
if(PLATFORM_ARM64)
|
if(PLATFORM_ARM64)
|
||||||
if(MSLITE_GPU_BACKEND STREQUAL "")
|
if(MSLITE_GPU_BACKEND STREQUAL "")
|
||||||
|
@ -185,7 +183,7 @@ if(MSVC)
|
||||||
set(MSLITE_ENABLE_CONVERTER off)
|
set(MSLITE_ENABLE_CONVERTER off)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(MSLITE_ENABLE_CONVERTER AND (
|
if((MSLITE_ENABLE_CONVERTER OR MSLITE_ENABLE_TESTCASES) AND (
|
||||||
NOT MSLITE_ENABLE_MINDRT
|
NOT MSLITE_ENABLE_MINDRT
|
||||||
OR NOT MSLITE_STRING_KERNEL
|
OR NOT MSLITE_STRING_KERNEL
|
||||||
OR NOT MSLITE_CONTROLFLOW_TENSORLIST
|
OR NOT MSLITE_CONTROLFLOW_TENSORLIST
|
||||||
|
@ -193,7 +191,13 @@ if(MSLITE_ENABLE_CONVERTER AND (
|
||||||
OR NOT MSLITE_CUSTOM_KERNEL_REGISTRY))
|
OR NOT MSLITE_CUSTOM_KERNEL_REGISTRY))
|
||||||
message(FATAL_ERROR "If one of 'MSLITE_ENABLE_MINDRT MSLITE_STRING_KERNEL "
|
message(FATAL_ERROR "If one of 'MSLITE_ENABLE_MINDRT MSLITE_STRING_KERNEL "
|
||||||
"MSLITE_CONTROLFLOW_TENSORLIST MSLITE_WEIGHT_DECODE MSLITE_CUSTOM_KERNEL_REGISTRY'"
|
"MSLITE_CONTROLFLOW_TENSORLIST MSLITE_WEIGHT_DECODE MSLITE_CUSTOM_KERNEL_REGISTRY'"
|
||||||
"is configured as off, MSLITE_ENABLE_CONVERTER must also be configured as off")
|
"is configured as off, MSLITE_ENABLE_CONVERTER and MSLITE_ENABLE_TESTCASES must also be configured as off")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if(MSLITE_ENABLE_FP16 AND PLATFORM_ARM32 AND CMAKE_CXX_COMPILER_ID STREQUAL "Clang"
|
||||||
|
AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 9.0)
|
||||||
|
message(FATAL_ERROR "If you want to build fp16 in arm82_a32, \
|
||||||
|
your Clang version:[${CMAKE_CXX_COMPILER_VERSION}] must not be less than 9.0 and please use android nkd r21e!")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
message(STATUS "************MindSpore Lite Build Option:************")
|
message(STATUS "************MindSpore Lite Build Option:************")
|
||||||
|
@ -216,6 +220,7 @@ message(STATUS "\tENABLE_MINDRT = \t${ENABLE_MINDRT}")
|
||||||
message(STATUS "\tENABLE_V0 = \t${ENABLE_V0}")
|
message(STATUS "\tENABLE_V0 = \t${ENABLE_V0}")
|
||||||
message(STATUS "\tBUILD_MINDDATA = \t${BUILD_MINDDATA}")
|
message(STATUS "\tBUILD_MINDDATA = \t${BUILD_MINDDATA}")
|
||||||
message(STATUS "\tMSLITE_DELEGATE_USE = \t${MSLITE_DELEGATE_USE}")
|
message(STATUS "\tMSLITE_DELEGATE_USE = \t${MSLITE_DELEGATE_USE}")
|
||||||
|
message(STATUS "\tMSLITE_ENABLE_FP16 = \t${MSLITE_ENABLE_FP16}")
|
||||||
|
|
||||||
if(MSLITE_ENABLE_HIGH_PERFORMANCE)
|
if(MSLITE_ENABLE_HIGH_PERFORMANCE)
|
||||||
add_compile_definitions(ENABLE_HIGH_PERFORMANCE)
|
add_compile_definitions(ENABLE_HIGH_PERFORMANCE)
|
||||||
|
@ -320,7 +325,7 @@ endif()
|
||||||
if(ENABLE_NEON)
|
if(ENABLE_NEON)
|
||||||
add_compile_definitions(ENABLE_NEON)
|
add_compile_definitions(ENABLE_NEON)
|
||||||
endif()
|
endif()
|
||||||
if(ENABLE_FP16)
|
if(MSLITE_ENABLE_FP16)
|
||||||
add_compile_definitions(ENABLE_FP16)
|
add_compile_definitions(ENABLE_FP16)
|
||||||
if(PLATFORM_ARM32)
|
if(PLATFORM_ARM32)
|
||||||
add_compile_definitions(ENABLE_ARM82_A32)
|
add_compile_definitions(ENABLE_ARM82_A32)
|
||||||
|
|
|
@ -150,7 +150,7 @@ build_lite() {
|
||||||
CMAKE_ANDROID_ABI="armeabi-v7a"
|
CMAKE_ANDROID_ABI="armeabi-v7a"
|
||||||
CMAKE_ANDROID_TOOLCHAIN_NAME="clang"
|
CMAKE_ANDROID_TOOLCHAIN_NAME="clang"
|
||||||
CMAKE_ANDROID_STL=${MSLITE_ANDROID_STL}
|
CMAKE_ANDROID_STL=${MSLITE_ANDROID_STL}
|
||||||
ENABLE_FP16="on"
|
MSLITE_ENABLE_FP16="on"
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
@ -158,7 +158,7 @@ build_lite() {
|
||||||
if [ "$(uname)" == "Darwin" ]; then
|
if [ "$(uname)" == "Darwin" ]; then
|
||||||
pkg_name=mindspore-lite-${VERSION_STR}-ios-aarch64
|
pkg_name=mindspore-lite-${VERSION_STR}-ios-aarch64
|
||||||
cmake -DCMAKE_TOOLCHAIN_FILE=${BASEPATH}/cmake/lite_ios.cmake -DARCHS="arm64" -DENABLE_BITCODE=0 \
|
cmake -DCMAKE_TOOLCHAIN_FILE=${BASEPATH}/cmake/lite_ios.cmake -DARCHS="arm64" -DENABLE_BITCODE=0 \
|
||||||
-DCMAKE_BUILD_TYPE="Release" -DBUILD_MINDDATA="" -DPLATFORM_ARM64="on" -DENABLE_NEON="on" -DENABLE_FP16="on" \
|
-DCMAKE_BUILD_TYPE="Release" -DBUILD_MINDDATA="" -DPLATFORM_ARM64="on" -DENABLE_NEON="on" -DMSLITE_ENABLE_FP16="on" \
|
||||||
-DMSLITE_ENABLE_TRAIN="off" -DMSLITE_GPU_BACKEND="off" -DMSLITE_ENABLE_NPU="off" \
|
-DMSLITE_ENABLE_TRAIN="off" -DMSLITE_GPU_BACKEND="off" -DMSLITE_ENABLE_NPU="off" \
|
||||||
-DENABLE_ASAN=${ENABLE_ASAN} -DCMAKE_INSTALL_PREFIX=${BUILD_PATH}/output/tmp -G Xcode ..
|
-DENABLE_ASAN=${ENABLE_ASAN} -DCMAKE_INSTALL_PREFIX=${BUILD_PATH}/output/tmp -G Xcode ..
|
||||||
else
|
else
|
||||||
|
@ -167,7 +167,7 @@ build_lite() {
|
||||||
cmake -DCMAKE_TOOLCHAIN_FILE="${ANDROID_NDK}/build/cmake/android.toolchain.cmake" -DANDROID_NATIVE_API_LEVEL="19" \
|
cmake -DCMAKE_TOOLCHAIN_FILE="${ANDROID_NDK}/build/cmake/android.toolchain.cmake" -DANDROID_NATIVE_API_LEVEL="19" \
|
||||||
-DANDROID_NDK="${ANDROID_NDK}" -DANDROID_ABI="arm64-v8a" -DANDROID_TOOLCHAIN_NAME="aarch64-linux-android-clang" \
|
-DANDROID_NDK="${ANDROID_NDK}" -DANDROID_ABI="arm64-v8a" -DANDROID_TOOLCHAIN_NAME="aarch64-linux-android-clang" \
|
||||||
-DANDROID_STL=${MSLITE_ANDROID_STL} -DCMAKE_BUILD_TYPE=${LITE_BUILD_TYPE} -DBUILD_MINDDATA=${COMPILE_MINDDATA_LITE} \
|
-DANDROID_STL=${MSLITE_ANDROID_STL} -DCMAKE_BUILD_TYPE=${LITE_BUILD_TYPE} -DBUILD_MINDDATA=${COMPILE_MINDDATA_LITE} \
|
||||||
-DPLATFORM_ARM64="on" -DENABLE_NEON="on" -DENABLE_FP16="on" -DCMAKE_INSTALL_PREFIX=${BASEPATH}/output/tmp \
|
-DPLATFORM_ARM64="on" -DENABLE_NEON="on" -DMSLITE_ENABLE_FP16="on" -DCMAKE_INSTALL_PREFIX=${BASEPATH}/output/tmp \
|
||||||
-DMS_VERSION_MAJOR=${VERSION_MAJOR} -DMS_VERSION_MINOR=${VERSION_MINOR} -DMS_VERSION_REVISION=${VERSION_REVISION} \
|
-DMS_VERSION_MAJOR=${VERSION_MAJOR} -DMS_VERSION_MINOR=${VERSION_MINOR} -DMS_VERSION_REVISION=${VERSION_REVISION} \
|
||||||
-DENABLE_ASAN=${ENABLE_ASAN} -DENABLE_VERBOSE=${ENABLE_VERBOSE} "${BASEPATH}/mindspore/lite"
|
-DENABLE_ASAN=${ENABLE_ASAN} -DENABLE_VERBOSE=${ENABLE_VERBOSE} "${BASEPATH}/mindspore/lite"
|
||||||
fi
|
fi
|
||||||
|
@ -184,7 +184,7 @@ build_lite() {
|
||||||
cmake -DCMAKE_TOOLCHAIN_FILE=${CMAKE_TOOLCHAIN_FILE} -DTOOLCHAIN_NAME=${CMAKE_TOOLCHAIN_NAME} -DANDROID_NATIVE_API_LEVEL=${ANDROID_NATIVE_API_LEVEL} \
|
cmake -DCMAKE_TOOLCHAIN_FILE=${CMAKE_TOOLCHAIN_FILE} -DTOOLCHAIN_NAME=${CMAKE_TOOLCHAIN_NAME} -DANDROID_NATIVE_API_LEVEL=${ANDROID_NATIVE_API_LEVEL} \
|
||||||
-DANDROID_NDK=${CMAKE_ANDROID_NDK} -DANDROID_ABI=${CMAKE_ANDROID_ABI} -DANDROID_TOOLCHAIN_NAME=${CMAKE_ANDROID_TOOLCHAIN_NAME} \
|
-DANDROID_NDK=${CMAKE_ANDROID_NDK} -DANDROID_ABI=${CMAKE_ANDROID_ABI} -DANDROID_TOOLCHAIN_NAME=${CMAKE_ANDROID_TOOLCHAIN_NAME} \
|
||||||
-DANDROID_STL=${CMAKE_ANDROID_STL} -DCMAKE_BUILD_TYPE=${LITE_BUILD_TYPE} -DBUILD_MINDDATA=${COMPILE_MINDDATA_LITE} \
|
-DANDROID_STL=${CMAKE_ANDROID_STL} -DCMAKE_BUILD_TYPE=${LITE_BUILD_TYPE} -DBUILD_MINDDATA=${COMPILE_MINDDATA_LITE} \
|
||||||
-DPLATFORM_ARM32="on" -DENABLE_NEON="on" -DENABLE_FP16=${ENABLE_FP16} -DCMAKE_INSTALL_PREFIX=${BASEPATH}/output/tmp \
|
-DPLATFORM_ARM32="on" -DENABLE_NEON="on" -DMSLITE_ENABLE_FP16=${MSLITE_ENABLE_FP16} -DCMAKE_INSTALL_PREFIX=${BASEPATH}/output/tmp \
|
||||||
-DMS_VERSION_MAJOR=${VERSION_MAJOR} -DMS_VERSION_MINOR=${VERSION_MINOR} -DMS_VERSION_REVISION=${VERSION_REVISION} \
|
-DMS_VERSION_MAJOR=${VERSION_MAJOR} -DMS_VERSION_MINOR=${VERSION_MINOR} -DMS_VERSION_REVISION=${VERSION_REVISION} \
|
||||||
-DENABLE_ASAN=${ENABLE_ASAN} -DENABLE_VERBOSE=${ENABLE_VERBOSE} "${BASEPATH}/mindspore/lite"
|
-DENABLE_ASAN=${ENABLE_ASAN} -DENABLE_VERBOSE=${ENABLE_VERBOSE} "${BASEPATH}/mindspore/lite"
|
||||||
fi
|
fi
|
||||||
|
|
|
@ -358,7 +358,7 @@ if(PLATFORM_ARM)
|
||||||
target_link_libraries(mindspore-lite cpu_opt_kernel_mid nnacl_optimize_mid)
|
target_link_libraries(mindspore-lite cpu_opt_kernel_mid nnacl_optimize_mid)
|
||||||
target_link_libraries(mindspore-lite_static cpu_opt_kernel_mid nnacl_optimize_mid)
|
target_link_libraries(mindspore-lite_static cpu_opt_kernel_mid nnacl_optimize_mid)
|
||||||
endif()
|
endif()
|
||||||
if(ENABLE_FP16)
|
if(MSLITE_ENABLE_FP16)
|
||||||
target_link_libraries(mindspore-lite cpu_fp16_kernel_mid nnacl_fp16_mid)
|
target_link_libraries(mindspore-lite cpu_fp16_kernel_mid nnacl_fp16_mid)
|
||||||
target_link_libraries(mindspore-lite_static cpu_fp16_kernel_mid nnacl_fp16_mid)
|
target_link_libraries(mindspore-lite_static cpu_fp16_kernel_mid nnacl_fp16_mid)
|
||||||
endif()
|
endif()
|
||||||
|
|
|
@ -37,6 +37,10 @@ const char *const unsupport_delegate_log =
|
||||||
"The mindspore-lite library does not support delegate. Set environment variable "
|
"The mindspore-lite library does not support delegate. Set environment variable "
|
||||||
"MSLITE_DELEGATE_USE to on to "
|
"MSLITE_DELEGATE_USE to on to "
|
||||||
"recompile it.";
|
"recompile it.";
|
||||||
|
const char *const unsupport_fp16_log =
|
||||||
|
"The mindspore-lite library does not support fp16. Set environment variable "
|
||||||
|
"MSLITE_ENABLE_FP16 to on to "
|
||||||
|
"recompile it.";
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#ifdef USE_GLOG
|
#ifdef USE_GLOG
|
||||||
#include "utils/log_adapter.h"
|
#include "utils/log_adapter.h"
|
||||||
|
|
|
@ -508,6 +508,11 @@ int LiteSession::CompileGraph(Model *model) {
|
||||||
}
|
}
|
||||||
InitGraphInputTensors(model);
|
InitGraphInputTensors(model);
|
||||||
InitGraphOutputTensors(model);
|
InitGraphOutputTensors(model);
|
||||||
|
#ifndef ENABLE_FP16
|
||||||
|
if (context_->GetCpuInfo().enable_float16_) {
|
||||||
|
MS_LOG(WARNING) << unsupport_fp16_log;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
// scheduler kernels
|
// scheduler kernels
|
||||||
Scheduler scheduler(context_, ms_context_, model, &tensors_, inputs_, outputs_, is_train_session_, delegate_);
|
Scheduler scheduler(context_, ms_context_, model, &tensors_, inputs_, outputs_, is_train_session_, delegate_);
|
||||||
scheduler.SetupSchedulerCb(std::move(sched_cb_));
|
scheduler.SetupSchedulerCb(std::move(sched_cb_));
|
||||||
|
|
|
@ -34,7 +34,7 @@ endif()
|
||||||
add_library(cpu_kernel_mid OBJECT ${KERNEL_SRC})
|
add_library(cpu_kernel_mid OBJECT ${KERNEL_SRC})
|
||||||
add_dependencies(cpu_kernel_mid fbs_src)
|
add_dependencies(cpu_kernel_mid fbs_src)
|
||||||
if(PLATFORM_ARM)
|
if(PLATFORM_ARM)
|
||||||
if(ENABLE_FP16)
|
if(MSLITE_ENABLE_FP16)
|
||||||
file(GLOB FP16_KERNEL_SRC ${CMAKE_CURRENT_SOURCE_DIR}/fp16/*.cc)
|
file(GLOB FP16_KERNEL_SRC ${CMAKE_CURRENT_SOURCE_DIR}/fp16/*.cc)
|
||||||
if(SUPPORT_TRAIN)
|
if(SUPPORT_TRAIN)
|
||||||
file(GLOB FP16_KERNEL_TRAIN_SRC ${CMAKE_CURRENT_SOURCE_DIR}/fp16_grad/*.cc)
|
file(GLOB FP16_KERNEL_TRAIN_SRC ${CMAKE_CURRENT_SOURCE_DIR}/fp16_grad/*.cc)
|
||||||
|
|
|
@ -81,7 +81,7 @@ int ArgMinMaxCPUKernel::Run() {
|
||||||
if (input->data_type() == kNumberTypeFloat32) {
|
if (input->data_type() == kNumberTypeFloat32) {
|
||||||
ArgMinMaxFp32(reinterpret_cast<float *>(input_data), reinterpret_cast<void *>(output_data),
|
ArgMinMaxFp32(reinterpret_cast<float *>(input_data), reinterpret_cast<void *>(output_data),
|
||||||
reinterpret_cast<float *>(output_value), shape.data(), arg_param_);
|
reinterpret_cast<float *>(output_value), shape.data(), arg_param_);
|
||||||
#ifdef ENABLE_ARM64
|
#ifdef ENABLE_FP16
|
||||||
} else if (input->data_type() == kNumberTypeFloat16) {
|
} else if (input->data_type() == kNumberTypeFloat16) {
|
||||||
ArgMinMaxFp16(reinterpret_cast<float16_t *>(input_data), reinterpret_cast<void *>(output_data),
|
ArgMinMaxFp16(reinterpret_cast<float16_t *>(input_data), reinterpret_cast<void *>(output_data),
|
||||||
reinterpret_cast<float16_t *>(output_value), shape.data(), arg_param_);
|
reinterpret_cast<float16_t *>(output_value), shape.data(), arg_param_);
|
||||||
|
|
|
@ -19,7 +19,7 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "include/errorcode.h"
|
#include "include/errorcode.h"
|
||||||
#include "nnacl/fp32/arg_min_max_fp32.h"
|
#include "nnacl/fp32/arg_min_max_fp32.h"
|
||||||
#ifdef ENABLE_ARM64
|
#ifdef ENABLE_FP16
|
||||||
#include "nnacl/fp16/arg_min_max_fp16.h"
|
#include "nnacl/fp16/arg_min_max_fp16.h"
|
||||||
#endif
|
#endif
|
||||||
#include "nnacl/common_func.h"
|
#include "nnacl/common_func.h"
|
||||||
|
|
|
@ -892,7 +892,9 @@ kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector<Tensor *> &in
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#if (defined GPU_OPENCL) || (defined ENABLE_FP16)
|
||||||
int kernel_thread_count = op_parameter->thread_num_;
|
int kernel_thread_count = op_parameter->thread_num_;
|
||||||
|
#endif
|
||||||
op_parameter->is_train_session_ = is_train_session_;
|
op_parameter->is_train_session_ = is_train_session_;
|
||||||
kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, data_type, static_cast<schema::PrimitiveType>(op_parameter->type_)};
|
kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, data_type, static_cast<schema::PrimitiveType>(op_parameter->type_)};
|
||||||
|
|
||||||
|
@ -920,6 +922,7 @@ kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector<Tensor *> &in
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
#ifdef ENABLE_FP16
|
||||||
if ((prefer_data_type == kNumberTypeFloat16 || prefer_data_type == kTypeUnknown) &&
|
if ((prefer_data_type == kNumberTypeFloat16 || prefer_data_type == kTypeUnknown) &&
|
||||||
((is_train_session_ == false) || (sched_cb_ && sched_cb_->SchedFp16Kernel(node)))) {
|
((is_train_session_ == false) || (sched_cb_ && sched_cb_->SchedFp16Kernel(node)))) {
|
||||||
status = FindCpuKernel(in_tensors, out_tensors, op_parameter, desc, kNumberTypeFloat16, &kernel);
|
status = FindCpuKernel(in_tensors, out_tensors, op_parameter, desc, kNumberTypeFloat16, &kernel);
|
||||||
|
@ -941,6 +944,7 @@ kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector<Tensor *> &in
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
if (data_type == kNumberTypeFloat16) {
|
if (data_type == kNumberTypeFloat16) {
|
||||||
MS_LOG(DEBUG) << "Get fp16 op failed, back to fp32 op.";
|
MS_LOG(DEBUG) << "Get fp16 op failed, back to fp32 op.";
|
||||||
desc.data_type = kNumberTypeFloat32;
|
desc.data_type = kNumberTypeFloat32;
|
||||||
|
@ -1090,6 +1094,7 @@ kernel::LiteKernel *Scheduler::SchedulePartialToKernel(const lite::Model::Node *
|
||||||
return subgraph_kernel;
|
return subgraph_kernel;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef ENABLE_FP16
|
||||||
int Scheduler::SubGraphPreferDataType(const int &subgraph_index, TypeId *prefer_data_type) {
|
int Scheduler::SubGraphPreferDataType(const int &subgraph_index, TypeId *prefer_data_type) {
|
||||||
if (!context_->IsCpuFloat16Enabled()) {
|
if (!context_->IsCpuFloat16Enabled()) {
|
||||||
*prefer_data_type = kNumberTypeFloat32;
|
*prefer_data_type = kNumberTypeFloat32;
|
||||||
|
@ -1131,6 +1136,7 @@ int Scheduler::SubGraphPreferDataType(const int &subgraph_index, TypeId *prefer_
|
||||||
*prefer_data_type = kNumberTypeFloat16;
|
*prefer_data_type = kNumberTypeFloat16;
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
std::vector<kernel::LiteKernel *> Scheduler::ScheduleMainSubGraphToKernels() {
|
std::vector<kernel::LiteKernel *> Scheduler::ScheduleMainSubGraphToKernels() {
|
||||||
std::vector<kernel::LiteKernel *> kernels;
|
std::vector<kernel::LiteKernel *> kernels;
|
||||||
|
@ -1146,10 +1152,12 @@ std::vector<kernel::LiteKernel *> Scheduler::ScheduleMainSubGraphToKernels() {
|
||||||
|
|
||||||
kernel::LiteKernel *Scheduler::SchedulePartialToSubGraphKernel(const int &subgraph_index) {
|
kernel::LiteKernel *Scheduler::SchedulePartialToSubGraphKernel(const int &subgraph_index) {
|
||||||
TypeId prefer_data_type = kTypeUnknown;
|
TypeId prefer_data_type = kTypeUnknown;
|
||||||
|
#ifdef ENABLE_FP16
|
||||||
if (SubGraphPreferDataType(subgraph_index, &prefer_data_type) != RET_OK) {
|
if (SubGraphPreferDataType(subgraph_index, &prefer_data_type) != RET_OK) {
|
||||||
MS_LOG(ERROR) << "SubGraphPreferDataType failed, subgraph index: " << subgraph_index;
|
MS_LOG(ERROR) << "SubGraphPreferDataType failed, subgraph index: " << subgraph_index;
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
std::vector<kernel::LiteKernel *> kernels;
|
std::vector<kernel::LiteKernel *> kernels;
|
||||||
std::vector<lite::Tensor *> in_tensors;
|
std::vector<lite::Tensor *> in_tensors;
|
||||||
std::vector<lite::Tensor *> out_tensors;
|
std::vector<lite::Tensor *> out_tensors;
|
||||||
|
|
|
@ -108,7 +108,9 @@ class Scheduler {
|
||||||
int RestoreSubGraphInput(const lite::Model::Node *partial_node);
|
int RestoreSubGraphInput(const lite::Model::Node *partial_node);
|
||||||
|
|
||||||
bool IsControlFlowPattern(const lite::Model::Node &partial_node);
|
bool IsControlFlowPattern(const lite::Model::Node &partial_node);
|
||||||
|
#ifdef ENABLE_FP16
|
||||||
int SubGraphPreferDataType(const int &subgraph_index, TypeId *prefer_data_type);
|
int SubGraphPreferDataType(const int &subgraph_index, TypeId *prefer_data_type);
|
||||||
|
#endif
|
||||||
#ifndef CONTROLFLOW_TENSORLIST_CLIP
|
#ifndef CONTROLFLOW_TENSORLIST_CLIP
|
||||||
int InferSwitchShape(const Model::Node *node);
|
int InferSwitchShape(const Model::Node *node);
|
||||||
Model::Node *NodeInputIsSwitch(const Model::Node *node);
|
Model::Node *NodeInputIsSwitch(const Model::Node *node);
|
||||||
|
|
|
@ -50,7 +50,7 @@ if(SUPPORT_TRAIN)
|
||||||
list(APPEND KERNEL_OP_SRC ${KERNEL_OP_TRAIN_SRC})
|
list(APPEND KERNEL_OP_SRC ${KERNEL_OP_TRAIN_SRC})
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(ENABLE_FP16)
|
if(MSLITE_ENABLE_FP16)
|
||||||
list(APPEND KERNEL_OP_SRC ${FP16_KERNEL_OP_SRC})
|
list(APPEND KERNEL_OP_SRC ${FP16_KERNEL_OP_SRC})
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
@ -378,7 +378,7 @@ if(MSLITE_GPU_BACKEND STREQUAL opencl)
|
||||||
)
|
)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(ENABLE_FP16)
|
if(MSLITE_ENABLE_FP16)
|
||||||
file(GLOB_RECURSE TEST_CASE_KERNEL_FP16_SRC
|
file(GLOB_RECURSE TEST_CASE_KERNEL_FP16_SRC
|
||||||
${TEST_DIR}/ut/src/runtime/kernel/arm/fp16/*.cc
|
${TEST_DIR}/ut/src/runtime/kernel/arm/fp16/*.cc
|
||||||
)
|
)
|
||||||
|
@ -393,7 +393,7 @@ if(ENABLE_FP16)
|
||||||
)
|
)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(ENABLE_FP16 AND SUPPORT_TRAIN)
|
if(MSLITE_ENABLE_FP16 AND SUPPORT_TRAIN)
|
||||||
file(GLOB_RECURSE TEST_CASE_KERNEL_FP16_SRC_GRAD
|
file(GLOB_RECURSE TEST_CASE_KERNEL_FP16_SRC_GRAD
|
||||||
${TEST_DIR}/ut/src/runtime/kernel/arm/fp16_grad/*.cc)
|
${TEST_DIR}/ut/src/runtime/kernel/arm/fp16_grad/*.cc)
|
||||||
list(APPEND TEST_SRC ${TEST_CASE_KERNEL_FP16_SRC_GRAD})
|
list(APPEND TEST_SRC ${TEST_CASE_KERNEL_FP16_SRC_GRAD})
|
||||||
|
@ -415,7 +415,7 @@ target_link_libraries(lite-test
|
||||||
mindspore::gtest
|
mindspore::gtest
|
||||||
)
|
)
|
||||||
|
|
||||||
if(PLATFORM_ARM AND ENABLE_FP16)
|
if(PLATFORM_ARM AND MSLITE_ENABLE_FP16)
|
||||||
target_link_libraries(lite-test nnacl_fp16_mid)
|
target_link_libraries(lite-test nnacl_fp16_mid)
|
||||||
if(PLATFORM_ARM64)
|
if(PLATFORM_ARM64)
|
||||||
target_link_libraries(lite-test nnacl_optimize_mid)
|
target_link_libraries(lite-test nnacl_optimize_mid)
|
||||||
|
|
Loading…
Reference in New Issue