add fp16 cropping

This commit is contained in:
gongdaguo 2021-08-23 15:56:48 +08:00
parent 9394686885
commit 05c873f06a
13 changed files with 50 additions and 27 deletions

View File

@ -116,6 +116,6 @@ if(ENABLE_CPU)
endif() endif()
########################### arm fp16 build optimize library ######################## ########################### arm fp16 build optimize library ########################
if(ENABLE_FP16) if(PLATFORM_ARM)
add_subdirectory(${NNACL_DIR}/optimize) add_subdirectory(${NNACL_DIR}/optimize)
endif() endif()

View File

@ -35,15 +35,14 @@ if(NOT PLATFORM_ARM32)
list(APPEND SDOT_FILES ${SDOT_SRC}) list(APPEND SDOT_FILES ${SDOT_SRC})
add_library(nnacl_optimize_mid OBJECT ${SDOT_FILES}) add_library(nnacl_optimize_mid OBJECT ${SDOT_FILES})
add_dependencies(nnacl_optimize_mid fbs_src) add_dependencies(nnacl_optimize_mid fbs_src)
endif()
if(ENABLE_FP16)
add_library(nnacl_fp16_mid OBJECT ${FP16_FILES})
if(PLATFORM_ARM32)
target_compile_options(nnacl_fp16_mid PRIVATE -march=armv8.2-a+fp16 -mfpu=neon-fp-armv8 -mfloat-abi=softfp)
else()
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=armv8.2-a+dotprod+fp16") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=armv8.2-a+dotprod+fp16")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=armv8.2-a+dotprod+fp16") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=armv8.2-a+dotprod+fp16")
endif() endif()
add_dependencies(nnacl_fp16_mid fbs_src)
if(MSLITE_ENABLE_FP16)
add_library(nnacl_fp16_mid OBJECT ${FP16_FILES})
add_dependencies(nnacl_fp16_mid fbs_src)
if(PLATFORM_ARM32)
target_compile_options(nnacl_fp16_mid PRIVATE -march=armv8.2-a+fp16 -mfpu=neon-fp-armv8 -mfloat-abi=softfp)
endif()
endif() endif()

View File

@ -18,12 +18,6 @@ if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_L
message(FATAL_ERROR "GCC version ${CMAKE_CXX_COMPILER_VERSION} must not be less than 7.3.0") message(FATAL_ERROR "GCC version ${CMAKE_CXX_COMPILER_VERSION} must not be less than 7.3.0")
endif() endif()
if(PLATFORM_ARM32 AND CMAKE_CXX_COMPILER_ID STREQUAL "Clang" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 9.0)
set(ENABLE_FP16 "off")
message(STATUS "If you want to build fp16 in arm82_a32, \
your Clang version:[${CMAKE_CXX_COMPILER_VERSION}] must not be less than 9.0 and please use android nkd r21e!")
endif()
#Options that can be configured through environment variables or manually #Options that can be configured through environment variables or manually
set(MSLITE_GPU_BACKEND "" CACHE STRING "enable gpu backend, \ set(MSLITE_GPU_BACKEND "" CACHE STRING "enable gpu backend, \
only arm64 support opencl, only x86_64 support tensorrt, opencl/cuda/tensorrt/off") only arm64 support opencl, only x86_64 support tensorrt, opencl/cuda/tensorrt/off")
@ -46,6 +40,7 @@ option(MSLITE_CUSTOM_KERNEL_REGISTRY "enable extend kernel registry" on)
option(MSLITE_ENABLE_MINDRT "enable mindrt use" on) option(MSLITE_ENABLE_MINDRT "enable mindrt use" on)
option(MSLITE_DELEGATE_USE "enable delegate use" on) option(MSLITE_DELEGATE_USE "enable delegate use" on)
option(MSLITE_ENABLE_V0 "support v0 schema" on) option(MSLITE_ENABLE_V0 "support v0 schema" on)
option(MSLITE_ENABLE_FP16 "Whether to compile Fp16 operator" off)
#Option that can be configured through manually #Option that can be configured through manually
option(ENABLE_VERBOSE "" off) option(ENABLE_VERBOSE "" off)
@ -120,6 +115,9 @@ endif()
if(DEFINED ENV{MSLITE_ENABLE_V0}) if(DEFINED ENV{MSLITE_ENABLE_V0})
set(MSLITE_ENABLE_V0 $ENV{MSLITE_ENABLE_V0}) set(MSLITE_ENABLE_V0 $ENV{MSLITE_ENABLE_V0})
endif() endif()
if(DEFINED ENV{MSLITE_ENABLE_FP16})
set(MSLITE_ENABLE_FP16 $ENV{MSLITE_ENABLE_FP16})
endif()
if(PLATFORM_ARM64) if(PLATFORM_ARM64)
if(MSLITE_GPU_BACKEND STREQUAL "") if(MSLITE_GPU_BACKEND STREQUAL "")
@ -185,7 +183,7 @@ if(MSVC)
set(MSLITE_ENABLE_CONVERTER off) set(MSLITE_ENABLE_CONVERTER off)
endif() endif()
if(MSLITE_ENABLE_CONVERTER AND ( if((MSLITE_ENABLE_CONVERTER OR MSLITE_ENABLE_TESTCASES) AND (
NOT MSLITE_ENABLE_MINDRT NOT MSLITE_ENABLE_MINDRT
OR NOT MSLITE_STRING_KERNEL OR NOT MSLITE_STRING_KERNEL
OR NOT MSLITE_CONTROLFLOW_TENSORLIST OR NOT MSLITE_CONTROLFLOW_TENSORLIST
@ -193,7 +191,13 @@ if(MSLITE_ENABLE_CONVERTER AND (
OR NOT MSLITE_CUSTOM_KERNEL_REGISTRY)) OR NOT MSLITE_CUSTOM_KERNEL_REGISTRY))
message(FATAL_ERROR "If one of 'MSLITE_ENABLE_MINDRT MSLITE_STRING_KERNEL " message(FATAL_ERROR "If one of 'MSLITE_ENABLE_MINDRT MSLITE_STRING_KERNEL "
"MSLITE_CONTROLFLOW_TENSORLIST MSLITE_WEIGHT_DECODE MSLITE_CUSTOM_KERNEL_REGISTRY'" "MSLITE_CONTROLFLOW_TENSORLIST MSLITE_WEIGHT_DECODE MSLITE_CUSTOM_KERNEL_REGISTRY'"
"is configured as off, MSLITE_ENABLE_CONVERTER must also be configured as off") "is configured as off, MSLITE_ENABLE_CONVERTER and MSLITE_ENABLE_TESTCASES must also be configured as off")
endif()
if(MSLITE_ENABLE_FP16 AND PLATFORM_ARM32 AND CMAKE_CXX_COMPILER_ID STREQUAL "Clang"
AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 9.0)
message(FATAL_ERROR "If you want to build fp16 in arm82_a32, \
your Clang version:[${CMAKE_CXX_COMPILER_VERSION}] must not be less than 9.0 and please use android nkd r21e!")
endif() endif()
message(STATUS "************MindSpore Lite Build Option:************") message(STATUS "************MindSpore Lite Build Option:************")
@ -216,6 +220,7 @@ message(STATUS "\tENABLE_MINDRT = \t${ENABLE_MINDRT}")
message(STATUS "\tENABLE_V0 = \t${ENABLE_V0}") message(STATUS "\tENABLE_V0 = \t${ENABLE_V0}")
message(STATUS "\tBUILD_MINDDATA = \t${BUILD_MINDDATA}") message(STATUS "\tBUILD_MINDDATA = \t${BUILD_MINDDATA}")
message(STATUS "\tMSLITE_DELEGATE_USE = \t${MSLITE_DELEGATE_USE}") message(STATUS "\tMSLITE_DELEGATE_USE = \t${MSLITE_DELEGATE_USE}")
message(STATUS "\tMSLITE_ENABLE_FP16 = \t${MSLITE_ENABLE_FP16}")
if(MSLITE_ENABLE_HIGH_PERFORMANCE) if(MSLITE_ENABLE_HIGH_PERFORMANCE)
add_compile_definitions(ENABLE_HIGH_PERFORMANCE) add_compile_definitions(ENABLE_HIGH_PERFORMANCE)
@ -320,7 +325,7 @@ endif()
if(ENABLE_NEON) if(ENABLE_NEON)
add_compile_definitions(ENABLE_NEON) add_compile_definitions(ENABLE_NEON)
endif() endif()
if(ENABLE_FP16) if(MSLITE_ENABLE_FP16)
add_compile_definitions(ENABLE_FP16) add_compile_definitions(ENABLE_FP16)
if(PLATFORM_ARM32) if(PLATFORM_ARM32)
add_compile_definitions(ENABLE_ARM82_A32) add_compile_definitions(ENABLE_ARM82_A32)

View File

@ -150,7 +150,7 @@ build_lite() {
CMAKE_ANDROID_ABI="armeabi-v7a" CMAKE_ANDROID_ABI="armeabi-v7a"
CMAKE_ANDROID_TOOLCHAIN_NAME="clang" CMAKE_ANDROID_TOOLCHAIN_NAME="clang"
CMAKE_ANDROID_STL=${MSLITE_ANDROID_STL} CMAKE_ANDROID_STL=${MSLITE_ANDROID_STL}
ENABLE_FP16="on" MSLITE_ENABLE_FP16="on"
fi fi
fi fi
@ -158,7 +158,7 @@ build_lite() {
if [ "$(uname)" == "Darwin" ]; then if [ "$(uname)" == "Darwin" ]; then
pkg_name=mindspore-lite-${VERSION_STR}-ios-aarch64 pkg_name=mindspore-lite-${VERSION_STR}-ios-aarch64
cmake -DCMAKE_TOOLCHAIN_FILE=${BASEPATH}/cmake/lite_ios.cmake -DARCHS="arm64" -DENABLE_BITCODE=0 \ cmake -DCMAKE_TOOLCHAIN_FILE=${BASEPATH}/cmake/lite_ios.cmake -DARCHS="arm64" -DENABLE_BITCODE=0 \
-DCMAKE_BUILD_TYPE="Release" -DBUILD_MINDDATA="" -DPLATFORM_ARM64="on" -DENABLE_NEON="on" -DENABLE_FP16="on" \ -DCMAKE_BUILD_TYPE="Release" -DBUILD_MINDDATA="" -DPLATFORM_ARM64="on" -DENABLE_NEON="on" -DMSLITE_ENABLE_FP16="on" \
-DMSLITE_ENABLE_TRAIN="off" -DMSLITE_GPU_BACKEND="off" -DMSLITE_ENABLE_NPU="off" \ -DMSLITE_ENABLE_TRAIN="off" -DMSLITE_GPU_BACKEND="off" -DMSLITE_ENABLE_NPU="off" \
-DENABLE_ASAN=${ENABLE_ASAN} -DCMAKE_INSTALL_PREFIX=${BUILD_PATH}/output/tmp -G Xcode .. -DENABLE_ASAN=${ENABLE_ASAN} -DCMAKE_INSTALL_PREFIX=${BUILD_PATH}/output/tmp -G Xcode ..
else else
@ -167,7 +167,7 @@ build_lite() {
cmake -DCMAKE_TOOLCHAIN_FILE="${ANDROID_NDK}/build/cmake/android.toolchain.cmake" -DANDROID_NATIVE_API_LEVEL="19" \ cmake -DCMAKE_TOOLCHAIN_FILE="${ANDROID_NDK}/build/cmake/android.toolchain.cmake" -DANDROID_NATIVE_API_LEVEL="19" \
-DANDROID_NDK="${ANDROID_NDK}" -DANDROID_ABI="arm64-v8a" -DANDROID_TOOLCHAIN_NAME="aarch64-linux-android-clang" \ -DANDROID_NDK="${ANDROID_NDK}" -DANDROID_ABI="arm64-v8a" -DANDROID_TOOLCHAIN_NAME="aarch64-linux-android-clang" \
-DANDROID_STL=${MSLITE_ANDROID_STL} -DCMAKE_BUILD_TYPE=${LITE_BUILD_TYPE} -DBUILD_MINDDATA=${COMPILE_MINDDATA_LITE} \ -DANDROID_STL=${MSLITE_ANDROID_STL} -DCMAKE_BUILD_TYPE=${LITE_BUILD_TYPE} -DBUILD_MINDDATA=${COMPILE_MINDDATA_LITE} \
-DPLATFORM_ARM64="on" -DENABLE_NEON="on" -DENABLE_FP16="on" -DCMAKE_INSTALL_PREFIX=${BASEPATH}/output/tmp \ -DPLATFORM_ARM64="on" -DENABLE_NEON="on" -DMSLITE_ENABLE_FP16="on" -DCMAKE_INSTALL_PREFIX=${BASEPATH}/output/tmp \
-DMS_VERSION_MAJOR=${VERSION_MAJOR} -DMS_VERSION_MINOR=${VERSION_MINOR} -DMS_VERSION_REVISION=${VERSION_REVISION} \ -DMS_VERSION_MAJOR=${VERSION_MAJOR} -DMS_VERSION_MINOR=${VERSION_MINOR} -DMS_VERSION_REVISION=${VERSION_REVISION} \
-DENABLE_ASAN=${ENABLE_ASAN} -DENABLE_VERBOSE=${ENABLE_VERBOSE} "${BASEPATH}/mindspore/lite" -DENABLE_ASAN=${ENABLE_ASAN} -DENABLE_VERBOSE=${ENABLE_VERBOSE} "${BASEPATH}/mindspore/lite"
fi fi
@ -184,7 +184,7 @@ build_lite() {
cmake -DCMAKE_TOOLCHAIN_FILE=${CMAKE_TOOLCHAIN_FILE} -DTOOLCHAIN_NAME=${CMAKE_TOOLCHAIN_NAME} -DANDROID_NATIVE_API_LEVEL=${ANDROID_NATIVE_API_LEVEL} \ cmake -DCMAKE_TOOLCHAIN_FILE=${CMAKE_TOOLCHAIN_FILE} -DTOOLCHAIN_NAME=${CMAKE_TOOLCHAIN_NAME} -DANDROID_NATIVE_API_LEVEL=${ANDROID_NATIVE_API_LEVEL} \
-DANDROID_NDK=${CMAKE_ANDROID_NDK} -DANDROID_ABI=${CMAKE_ANDROID_ABI} -DANDROID_TOOLCHAIN_NAME=${CMAKE_ANDROID_TOOLCHAIN_NAME} \ -DANDROID_NDK=${CMAKE_ANDROID_NDK} -DANDROID_ABI=${CMAKE_ANDROID_ABI} -DANDROID_TOOLCHAIN_NAME=${CMAKE_ANDROID_TOOLCHAIN_NAME} \
-DANDROID_STL=${CMAKE_ANDROID_STL} -DCMAKE_BUILD_TYPE=${LITE_BUILD_TYPE} -DBUILD_MINDDATA=${COMPILE_MINDDATA_LITE} \ -DANDROID_STL=${CMAKE_ANDROID_STL} -DCMAKE_BUILD_TYPE=${LITE_BUILD_TYPE} -DBUILD_MINDDATA=${COMPILE_MINDDATA_LITE} \
-DPLATFORM_ARM32="on" -DENABLE_NEON="on" -DENABLE_FP16=${ENABLE_FP16} -DCMAKE_INSTALL_PREFIX=${BASEPATH}/output/tmp \ -DPLATFORM_ARM32="on" -DENABLE_NEON="on" -DMSLITE_ENABLE_FP16=${MSLITE_ENABLE_FP16} -DCMAKE_INSTALL_PREFIX=${BASEPATH}/output/tmp \
-DMS_VERSION_MAJOR=${VERSION_MAJOR} -DMS_VERSION_MINOR=${VERSION_MINOR} -DMS_VERSION_REVISION=${VERSION_REVISION} \ -DMS_VERSION_MAJOR=${VERSION_MAJOR} -DMS_VERSION_MINOR=${VERSION_MINOR} -DMS_VERSION_REVISION=${VERSION_REVISION} \
-DENABLE_ASAN=${ENABLE_ASAN} -DENABLE_VERBOSE=${ENABLE_VERBOSE} "${BASEPATH}/mindspore/lite" -DENABLE_ASAN=${ENABLE_ASAN} -DENABLE_VERBOSE=${ENABLE_VERBOSE} "${BASEPATH}/mindspore/lite"
fi fi

View File

@ -358,7 +358,7 @@ if(PLATFORM_ARM)
target_link_libraries(mindspore-lite cpu_opt_kernel_mid nnacl_optimize_mid) target_link_libraries(mindspore-lite cpu_opt_kernel_mid nnacl_optimize_mid)
target_link_libraries(mindspore-lite_static cpu_opt_kernel_mid nnacl_optimize_mid) target_link_libraries(mindspore-lite_static cpu_opt_kernel_mid nnacl_optimize_mid)
endif() endif()
if(ENABLE_FP16) if(MSLITE_ENABLE_FP16)
target_link_libraries(mindspore-lite cpu_fp16_kernel_mid nnacl_fp16_mid) target_link_libraries(mindspore-lite cpu_fp16_kernel_mid nnacl_fp16_mid)
target_link_libraries(mindspore-lite_static cpu_fp16_kernel_mid nnacl_fp16_mid) target_link_libraries(mindspore-lite_static cpu_fp16_kernel_mid nnacl_fp16_mid)
endif() endif()

View File

@ -37,6 +37,10 @@ const char *const unsupport_delegate_log =
"The mindspore-lite library does not support delegate. Set environment variable " "The mindspore-lite library does not support delegate. Set environment variable "
"MSLITE_DELEGATE_USE to on to " "MSLITE_DELEGATE_USE to on to "
"recompile it."; "recompile it.";
const char *const unsupport_fp16_log =
"The mindspore-lite library does not support fp16. Set environment variable "
"MSLITE_ENABLE_FP16 to on to "
"recompile it.";
} // namespace mindspore } // namespace mindspore
#ifdef USE_GLOG #ifdef USE_GLOG
#include "utils/log_adapter.h" #include "utils/log_adapter.h"

View File

@ -508,6 +508,11 @@ int LiteSession::CompileGraph(Model *model) {
} }
InitGraphInputTensors(model); InitGraphInputTensors(model);
InitGraphOutputTensors(model); InitGraphOutputTensors(model);
#ifndef ENABLE_FP16
if (context_->GetCpuInfo().enable_float16_) {
MS_LOG(WARNING) << unsupport_fp16_log;
}
#endif
// scheduler kernels // scheduler kernels
Scheduler scheduler(context_, ms_context_, model, &tensors_, inputs_, outputs_, is_train_session_, delegate_); Scheduler scheduler(context_, ms_context_, model, &tensors_, inputs_, outputs_, is_train_session_, delegate_);
scheduler.SetupSchedulerCb(std::move(sched_cb_)); scheduler.SetupSchedulerCb(std::move(sched_cb_));

View File

@ -34,7 +34,7 @@ endif()
add_library(cpu_kernel_mid OBJECT ${KERNEL_SRC}) add_library(cpu_kernel_mid OBJECT ${KERNEL_SRC})
add_dependencies(cpu_kernel_mid fbs_src) add_dependencies(cpu_kernel_mid fbs_src)
if(PLATFORM_ARM) if(PLATFORM_ARM)
if(ENABLE_FP16) if(MSLITE_ENABLE_FP16)
file(GLOB FP16_KERNEL_SRC ${CMAKE_CURRENT_SOURCE_DIR}/fp16/*.cc) file(GLOB FP16_KERNEL_SRC ${CMAKE_CURRENT_SOURCE_DIR}/fp16/*.cc)
if(SUPPORT_TRAIN) if(SUPPORT_TRAIN)
file(GLOB FP16_KERNEL_TRAIN_SRC ${CMAKE_CURRENT_SOURCE_DIR}/fp16_grad/*.cc) file(GLOB FP16_KERNEL_TRAIN_SRC ${CMAKE_CURRENT_SOURCE_DIR}/fp16_grad/*.cc)

View File

@ -81,7 +81,7 @@ int ArgMinMaxCPUKernel::Run() {
if (input->data_type() == kNumberTypeFloat32) { if (input->data_type() == kNumberTypeFloat32) {
ArgMinMaxFp32(reinterpret_cast<float *>(input_data), reinterpret_cast<void *>(output_data), ArgMinMaxFp32(reinterpret_cast<float *>(input_data), reinterpret_cast<void *>(output_data),
reinterpret_cast<float *>(output_value), shape.data(), arg_param_); reinterpret_cast<float *>(output_value), shape.data(), arg_param_);
#ifdef ENABLE_ARM64 #ifdef ENABLE_FP16
} else if (input->data_type() == kNumberTypeFloat16) { } else if (input->data_type() == kNumberTypeFloat16) {
ArgMinMaxFp16(reinterpret_cast<float16_t *>(input_data), reinterpret_cast<void *>(output_data), ArgMinMaxFp16(reinterpret_cast<float16_t *>(input_data), reinterpret_cast<void *>(output_data),
reinterpret_cast<float16_t *>(output_value), shape.data(), arg_param_); reinterpret_cast<float16_t *>(output_value), shape.data(), arg_param_);

View File

@ -19,7 +19,7 @@
#include <vector> #include <vector>
#include "include/errorcode.h" #include "include/errorcode.h"
#include "nnacl/fp32/arg_min_max_fp32.h" #include "nnacl/fp32/arg_min_max_fp32.h"
#ifdef ENABLE_ARM64 #ifdef ENABLE_FP16
#include "nnacl/fp16/arg_min_max_fp16.h" #include "nnacl/fp16/arg_min_max_fp16.h"
#endif #endif
#include "nnacl/common_func.h" #include "nnacl/common_func.h"

View File

@ -892,7 +892,9 @@ kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector<Tensor *> &in
} }
#endif #endif
#if (defined GPU_OPENCL) || (defined ENABLE_FP16)
int kernel_thread_count = op_parameter->thread_num_; int kernel_thread_count = op_parameter->thread_num_;
#endif
op_parameter->is_train_session_ = is_train_session_; op_parameter->is_train_session_ = is_train_session_;
kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, data_type, static_cast<schema::PrimitiveType>(op_parameter->type_)}; kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, data_type, static_cast<schema::PrimitiveType>(op_parameter->type_)};
@ -920,6 +922,7 @@ kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector<Tensor *> &in
} }
} }
#endif #endif
#ifdef ENABLE_FP16
if ((prefer_data_type == kNumberTypeFloat16 || prefer_data_type == kTypeUnknown) && if ((prefer_data_type == kNumberTypeFloat16 || prefer_data_type == kTypeUnknown) &&
((is_train_session_ == false) || (sched_cb_ && sched_cb_->SchedFp16Kernel(node)))) { ((is_train_session_ == false) || (sched_cb_ && sched_cb_->SchedFp16Kernel(node)))) {
status = FindCpuKernel(in_tensors, out_tensors, op_parameter, desc, kNumberTypeFloat16, &kernel); status = FindCpuKernel(in_tensors, out_tensors, op_parameter, desc, kNumberTypeFloat16, &kernel);
@ -941,6 +944,7 @@ kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector<Tensor *> &in
} }
} }
} }
#endif
if (data_type == kNumberTypeFloat16) { if (data_type == kNumberTypeFloat16) {
MS_LOG(DEBUG) << "Get fp16 op failed, back to fp32 op."; MS_LOG(DEBUG) << "Get fp16 op failed, back to fp32 op.";
desc.data_type = kNumberTypeFloat32; desc.data_type = kNumberTypeFloat32;
@ -1090,6 +1094,7 @@ kernel::LiteKernel *Scheduler::SchedulePartialToKernel(const lite::Model::Node *
return subgraph_kernel; return subgraph_kernel;
} }
#ifdef ENABLE_FP16
int Scheduler::SubGraphPreferDataType(const int &subgraph_index, TypeId *prefer_data_type) { int Scheduler::SubGraphPreferDataType(const int &subgraph_index, TypeId *prefer_data_type) {
if (!context_->IsCpuFloat16Enabled()) { if (!context_->IsCpuFloat16Enabled()) {
*prefer_data_type = kNumberTypeFloat32; *prefer_data_type = kNumberTypeFloat32;
@ -1131,6 +1136,7 @@ int Scheduler::SubGraphPreferDataType(const int &subgraph_index, TypeId *prefer_
*prefer_data_type = kNumberTypeFloat16; *prefer_data_type = kNumberTypeFloat16;
return RET_OK; return RET_OK;
} }
#endif
std::vector<kernel::LiteKernel *> Scheduler::ScheduleMainSubGraphToKernels() { std::vector<kernel::LiteKernel *> Scheduler::ScheduleMainSubGraphToKernels() {
std::vector<kernel::LiteKernel *> kernels; std::vector<kernel::LiteKernel *> kernels;
@ -1146,10 +1152,12 @@ std::vector<kernel::LiteKernel *> Scheduler::ScheduleMainSubGraphToKernels() {
kernel::LiteKernel *Scheduler::SchedulePartialToSubGraphKernel(const int &subgraph_index) { kernel::LiteKernel *Scheduler::SchedulePartialToSubGraphKernel(const int &subgraph_index) {
TypeId prefer_data_type = kTypeUnknown; TypeId prefer_data_type = kTypeUnknown;
#ifdef ENABLE_FP16
if (SubGraphPreferDataType(subgraph_index, &prefer_data_type) != RET_OK) { if (SubGraphPreferDataType(subgraph_index, &prefer_data_type) != RET_OK) {
MS_LOG(ERROR) << "SubGraphPreferDataType failed, subgraph index: " << subgraph_index; MS_LOG(ERROR) << "SubGraphPreferDataType failed, subgraph index: " << subgraph_index;
return nullptr; return nullptr;
} }
#endif
std::vector<kernel::LiteKernel *> kernels; std::vector<kernel::LiteKernel *> kernels;
std::vector<lite::Tensor *> in_tensors; std::vector<lite::Tensor *> in_tensors;
std::vector<lite::Tensor *> out_tensors; std::vector<lite::Tensor *> out_tensors;

View File

@ -108,7 +108,9 @@ class Scheduler {
int RestoreSubGraphInput(const lite::Model::Node *partial_node); int RestoreSubGraphInput(const lite::Model::Node *partial_node);
bool IsControlFlowPattern(const lite::Model::Node &partial_node); bool IsControlFlowPattern(const lite::Model::Node &partial_node);
#ifdef ENABLE_FP16
int SubGraphPreferDataType(const int &subgraph_index, TypeId *prefer_data_type); int SubGraphPreferDataType(const int &subgraph_index, TypeId *prefer_data_type);
#endif
#ifndef CONTROLFLOW_TENSORLIST_CLIP #ifndef CONTROLFLOW_TENSORLIST_CLIP
int InferSwitchShape(const Model::Node *node); int InferSwitchShape(const Model::Node *node);
Model::Node *NodeInputIsSwitch(const Model::Node *node); Model::Node *NodeInputIsSwitch(const Model::Node *node);

View File

@ -50,7 +50,7 @@ if(SUPPORT_TRAIN)
list(APPEND KERNEL_OP_SRC ${KERNEL_OP_TRAIN_SRC}) list(APPEND KERNEL_OP_SRC ${KERNEL_OP_TRAIN_SRC})
endif() endif()
if(ENABLE_FP16) if(MSLITE_ENABLE_FP16)
list(APPEND KERNEL_OP_SRC ${FP16_KERNEL_OP_SRC}) list(APPEND KERNEL_OP_SRC ${FP16_KERNEL_OP_SRC})
endif() endif()
@ -378,7 +378,7 @@ if(MSLITE_GPU_BACKEND STREQUAL opencl)
) )
endif() endif()
if(ENABLE_FP16) if(MSLITE_ENABLE_FP16)
file(GLOB_RECURSE TEST_CASE_KERNEL_FP16_SRC file(GLOB_RECURSE TEST_CASE_KERNEL_FP16_SRC
${TEST_DIR}/ut/src/runtime/kernel/arm/fp16/*.cc ${TEST_DIR}/ut/src/runtime/kernel/arm/fp16/*.cc
) )
@ -393,7 +393,7 @@ if(ENABLE_FP16)
) )
endif() endif()
if(ENABLE_FP16 AND SUPPORT_TRAIN) if(MSLITE_ENABLE_FP16 AND SUPPORT_TRAIN)
file(GLOB_RECURSE TEST_CASE_KERNEL_FP16_SRC_GRAD file(GLOB_RECURSE TEST_CASE_KERNEL_FP16_SRC_GRAD
${TEST_DIR}/ut/src/runtime/kernel/arm/fp16_grad/*.cc) ${TEST_DIR}/ut/src/runtime/kernel/arm/fp16_grad/*.cc)
list(APPEND TEST_SRC ${TEST_CASE_KERNEL_FP16_SRC_GRAD}) list(APPEND TEST_SRC ${TEST_CASE_KERNEL_FP16_SRC_GRAD})
@ -415,7 +415,7 @@ target_link_libraries(lite-test
mindspore::gtest mindspore::gtest
) )
if(PLATFORM_ARM AND ENABLE_FP16) if(PLATFORM_ARM AND MSLITE_ENABLE_FP16)
target_link_libraries(lite-test nnacl_fp16_mid) target_link_libraries(lite-test nnacl_fp16_mid)
if(PLATFORM_ARM64) if(PLATFORM_ARM64)
target_link_libraries(lite-test nnacl_optimize_mid) target_link_libraries(lite-test nnacl_optimize_mid)