forked from mindspore-Ecosystem/mindspore
sync lite to r0.7
This commit is contained in:
parent
b5393e6628
commit
0ce8708dee
45
build.bat
45
build.bat
|
@ -16,20 +16,20 @@
|
|||
@title mindspore_build
|
||||
|
||||
SET BASEPATH=%CD%
|
||||
IF NOT EXIST %BASEPATH%/build (
|
||||
IF NOT EXIST "%BASEPATH%/build" (
|
||||
md "build"
|
||||
)
|
||||
|
||||
cd %BASEPATH%/build
|
||||
cd "%BASEPATH%/build"
|
||||
set BUILD_PATH=%CD%
|
||||
|
||||
IF NOT EXIST %BUILD_PATH%/mindspore (
|
||||
IF NOT EXIST "%BUILD_PATH%/mindspore" (
|
||||
md "mindspore"
|
||||
)
|
||||
|
||||
cd %CD%/mindspore
|
||||
cd "%CD%/mindspore"
|
||||
|
||||
IF "%2%" == "lite" (
|
||||
IF "%1%" == "lite" (
|
||||
call :gene_gtest
|
||||
call :run_cmake
|
||||
IF errorlevel 1 (
|
||||
|
@ -47,14 +47,17 @@ IF "%2%" == "lite" (
|
|||
)
|
||||
|
||||
cd %BUILD_PATH%/mindspore
|
||||
IF "%1%" == "" (
|
||||
cmake --build . -- -j6
|
||||
IF "%2%" == "" (
|
||||
cmake --build . --target package -- -j6
|
||||
) ELSE (
|
||||
cmake --build . -- -j%1%
|
||||
cmake --build . --target package -- -j%2%
|
||||
)
|
||||
IF errorlevel 1 (
|
||||
echo "build fail."
|
||||
goto run_fail
|
||||
) ELSE (
|
||||
cd "%BASEPATH%/output"
|
||||
rd /s /q _CPack_Packages
|
||||
)
|
||||
) ELSE (
|
||||
cmake -DCMAKE_BUILD_TYPE=Release -DENABLE_CPU=ON -DENABLE_MINDDATA=ON -DUSE_GLOG=ON ^
|
||||
|
@ -75,40 +78,40 @@ IF "%2%" == "lite" (
|
|||
)
|
||||
)
|
||||
|
||||
cd %BASEPATH%
|
||||
cd "%BASEPATH%"
|
||||
|
||||
goto run_eof
|
||||
|
||||
:run_cmake
|
||||
cd %BUILD_PATH%/mindspore
|
||||
cd "%BUILD_PATH%/mindspore"
|
||||
cmake -DBUILD_DEVICE=on -DBUILD_CONVERTER=on -DPLATFORM_ARM64=off -DSUPPORT_TRAIN=off ^
|
||||
-DCMAKE_BUILD_TYPE=Release -DSUPPORT_GPU=off -DBUILD_MINDDATA=off -DOFFLINE_COMPILE=off ^
|
||||
-G "CodeBlocks - MinGW Makefiles" %BASEPATH%/mindspore/lite
|
||||
-G "CodeBlocks - MinGW Makefiles" "%BASEPATH%/mindspore/lite"
|
||||
GOTO:EOF
|
||||
|
||||
:gene_gtest
|
||||
cd %BASEPATH%/third_party
|
||||
cd "%BASEPATH%/third_party"
|
||||
IF EXIST googletest rd /s /q googletest
|
||||
git submodule update --init --recursive googletest
|
||||
cd %BUILD_PATH%/mindspore
|
||||
cd "%BUILD_PATH%/mindspore"
|
||||
GOTO:EOF
|
||||
|
||||
:gene_protobuf
|
||||
SET PROTOC=%BASEPATH%/build/mindspore/_deps/protobuf-src/_build/protoc
|
||||
SET PROTOC="%BASEPATH%/build/mindspore/_deps/protobuf-src/_build/protoc"
|
||||
|
||||
SET PROTO_SRC_DIR=%BASEPATH%/mindspore/lite/tools/converter/parser/caffe
|
||||
SET PROTO_SRC_DIR="%BASEPATH%/mindspore/lite/tools/converter/parser/caffe"
|
||||
cd %PROTO_SRC_DIR%
|
||||
%PROTOC% *.proto --proto_path=%PROTO_SRC_DIR% --cpp_out=%PROTO_SRC_DIR%
|
||||
|
||||
SET PROTO_SRC_DIR=%BASEPATH%/mindspore/lite/tools/converter/parser/onnx
|
||||
SET PROTO_SRC_DIR="%BASEPATH%/mindspore/lite/tools/converter/parser/onnx"
|
||||
cd %PROTO_SRC_DIR%
|
||||
%PROTOC% *.proto --proto_path=%PROTO_SRC_DIR% --cpp_out=%PROTO_SRC_DIR%
|
||||
cd %BUILD_PATH%/mindspore
|
||||
GOTO:EOF
|
||||
|
||||
:gene_flatbuffer
|
||||
SET FLATC=%BASEPATH%/build/mindspore/_deps/flatbuffers-src/_build/flatc
|
||||
SET FLAT_DIR=%BASEPATH%/mindspore/lite/schema
|
||||
SET FLATC="%BASEPATH%/build/mindspore/_deps/flatbuffers-src/_build/flatc"
|
||||
SET FLAT_DIR="%BASEPATH%/mindspore/lite/schema"
|
||||
cd %FLAT_DIR%
|
||||
IF EXIST inner rd /s /q inner
|
||||
md inner
|
||||
|
@ -116,14 +119,14 @@ GOTO:EOF
|
|||
%FLATC% -c -b *.fbs
|
||||
%FLATC% -c -b --reflect-types --gen-mutable --reflect-names --gen-object-api -o %FLAT_DIR%/inner *.fbs
|
||||
|
||||
SET FLAT_DIR=%BASEPATH%/mindspore/lite/tools/converter/parser/tflite
|
||||
SET FLAT_DIR="%BASEPATH%/mindspore/lite/tools/converter/parser/tflite"
|
||||
cd %FLAT_DIR%
|
||||
%FLATC% -c -b --reflect-types --gen-mutable --reflect-names --gen-object-api -o %FLAT_DIR% *.fbs
|
||||
cd %BUILD_PATH%/mindspore
|
||||
cd "%BUILD_PATH%/mindspore"
|
||||
GOTO:EOF
|
||||
|
||||
:run_fail
|
||||
cd %BASEPATH%
|
||||
cd "%BASEPATH%"
|
||||
set errorlevel=1
|
||||
|
||||
:run_eof
|
||||
|
|
4
build.sh
4
build.sh
|
@ -393,7 +393,7 @@ build_mindspore()
|
|||
CMAKE_VERBOSE="--verbose"
|
||||
fi
|
||||
cmake --build . --target package ${CMAKE_VERBOSE} -j$THREAD_NUM
|
||||
echo "success to build mindspore project!"
|
||||
echo "success building mindspore project!"
|
||||
}
|
||||
|
||||
checkndk() {
|
||||
|
@ -618,10 +618,12 @@ build_lite()
|
|||
|
||||
if [[ "${COMPILE_RET}" -ne 0 ]]; then
|
||||
echo "---------------- mindspore lite: build failed ----------------"
|
||||
exit 1
|
||||
else
|
||||
mv ${BASEPATH}/output/tmp/*.tar.gz* ${BASEPATH}/output/
|
||||
rm -rf ${BASEPATH}/output/tmp/
|
||||
echo "---------------- mindspore lite: build success ----------------"
|
||||
exit 0
|
||||
fi
|
||||
}
|
||||
|
||||
|
|
|
@ -1,12 +1,18 @@
|
|||
include(CMakePackageConfigHelpers)
|
||||
|
||||
set(LIB_DIR ${MAIN_DIR}/lib)
|
||||
set(INC_DIR ${MAIN_DIR}/include)
|
||||
set(TURBO_DIR ${MAIN_DIR}/third_party/libjpeg-turbo)
|
||||
set(OPENCV_DIR ${MAIN_DIR}/third_party/opencv)
|
||||
set(PROTOBF_DIR ${MAIN_DIR}/third_party/protobuf)
|
||||
set(FLATBF_DIR ${MAIN_DIR}/third_party/flatbuffers)
|
||||
set(LIB_DIR ${MAIN_DIR}-${COMPONENT_NAME}/lib)
|
||||
set(INC_DIR ${MAIN_DIR}-${COMPONENT_NAME}/include)
|
||||
set(TURBO_DIR ${MAIN_DIR}-${COMPONENT_NAME}/third_party/libjpeg-turbo)
|
||||
set(OPENCV_DIR ${MAIN_DIR}-${COMPONENT_NAME}/third_party/opencv)
|
||||
set(PROTOBF_DIR ${MAIN_DIR}-${COMPONENT_NAME}/third_party/protobuf)
|
||||
set(FLATBF_DIR ${MAIN_DIR}-${COMPONENT_NAME}/third_party/flatbuffers)
|
||||
|
||||
set(LIB_DIR_RUN_X86 ${MAIN_DIR}-${RUN_X86_COMPONENT_NAME}/lib)
|
||||
set(INC_DIR_RUN_X86 ${MAIN_DIR}-${RUN_X86_COMPONENT_NAME}/include)
|
||||
set(TURBO_DIR_RUN_X86 ${MAIN_DIR}-${RUN_X86_COMPONENT_NAME}/third_party/libjpeg-turbo)
|
||||
set(OPENCV_DIR_RUN_X86 ${MAIN_DIR}-${RUN_X86_COMPONENT_NAME}/third_party/opencv)
|
||||
set(PROTOBF_DIR_RUN_X86 ${MAIN_DIR}-${RUN_X86_COMPONENT_NAME}/third_party/protobuf)
|
||||
set(FLATBF_DIR_RUN_X86 ${MAIN_DIR}-${RUN_X86_COMPONENT_NAME}/third_party/flatbuffers)
|
||||
if (BUILD_MINDDATA)
|
||||
install(DIRECTORY ${TOP_DIR}/mindspore/ccsrc/minddata/dataset/include/ DESTINATION ${INC_DIR} COMPONENT ${COMPONENT_NAME} FILES_MATCHING PATTERN "*.h")
|
||||
install(FILES ${TOP_DIR}/mindspore/lite/build/minddata/libminddata-lite.so DESTINATION ${LIB_DIR} COMPONENT ${COMPONENT_NAME})
|
||||
|
@ -41,19 +47,40 @@ elseif (PLATFORM_ARM32)
|
|||
install(DIRECTORY ${TOP_DIR}/mindspore/lite/include/ DESTINATION ${INC_DIR} COMPONENT ${COMPONENT_NAME} FILES_MATCHING PATTERN "*.h")
|
||||
install(DIRECTORY ${TOP_DIR}/mindspore/lite/schema/ DESTINATION ${INC_DIR}/schema COMPONENT ${COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "inner" EXCLUDE)
|
||||
install(DIRECTORY ${TOP_DIR}/third_party/flatbuffers/include DESTINATION ${FLATBF_DIR} COMPONENT ${COMPONENT_NAME})
|
||||
elseif (CMAKE_SYSTEM_NAME MATCHES "Windows")
|
||||
get_filename_component(CXX_DIR ${CMAKE_CXX_COMPILER} PATH)
|
||||
file(GLOB LIB_LIST ${CXX_DIR}/libstdc++-6.dll ${CXX_DIR}/libwinpthread-1.dll ${CXX_DIR}/libssp-0.dll ${CXX_DIR}/libgcc_s_seh-1.dll)
|
||||
install(FILES ${TOP_DIR}/build/mindspore/tools/converter/converter_lite.exe DESTINATION ${TOP_DIR}/build/mindspore/package COMPONENT ${COMPONENT_NAME})
|
||||
install(FILES ${LIB_LIST} DESTINATION ${TOP_DIR}/build/mindspore/package COMPONENT ${COMPONENT_NAME})
|
||||
install(FILES ${TOP_DIR}/build/mindspore/tools/converter/libconverter_parser.a DESTINATION ${TOP_DIR}/build/mindspore/package COMPONENT ${PARSER_NAME})
|
||||
else ()
|
||||
install(FILES ${TOP_DIR}/mindspore/lite/build/src/libmindspore-lite.so DESTINATION ${LIB_DIR} COMPONENT ${RUN_X86_COMPONENT_NAME})
|
||||
install(DIRECTORY ${TOP_DIR}/mindspore/lite/include/ DESTINATION ${INC_DIR_RUN_X86} COMPONENT ${RUN_X86_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h")
|
||||
install(DIRECTORY ${TOP_DIR}/mindspore/lite/schema/ DESTINATION ${INC_DIR_RUN_X86}/schema COMPONENT ${RUN_X86_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "inner" EXCLUDE)
|
||||
install(FILES ${TOP_DIR}/mindspore/core/ir/dtype/type_id.h DESTINATION ${INC_DIR_RUN_X86}/ir/dtype COMPONENT ${RUN_X86_COMPONENT_NAME})
|
||||
install(DIRECTORY ${TOP_DIR}/third_party/flatbuffers/include DESTINATION ${FLATBF_DIR_RUN_X86} COMPONENT ${RUN_X86_COMPONENT_NAME})
|
||||
install(FILES ${TOP_DIR}/mindspore/lite/build/src/libmindspore-lite.so DESTINATION ${LIB_DIR_RUN_X86} COMPONENT ${RUN_X86_COMPONENT_NAME})
|
||||
|
||||
install(FILES ${TOP_DIR}/third_party/protobuf/build/lib/libprotobuf.so.19.0.0 DESTINATION ${PROTOBF_DIR}/lib RENAME libprotobuf.so.19 COMPONENT ${COMPONENT_NAME})
|
||||
endif ()
|
||||
|
||||
set(CPACK_GENERATOR TGZ)
|
||||
if (CMAKE_SYSTEM_NAME MATCHES "Windows")
|
||||
set(CPACK_GENERATOR ZIP)
|
||||
else ()
|
||||
set(CPACK_GENERATOR TGZ)
|
||||
endif ()
|
||||
set(CPACK_ARCHIVE_COMPONENT_INSTALL ON)
|
||||
if (PLATFORM_ARM64 OR PLATFORM_ARM32)
|
||||
set(CPACK_COMPONENTS_ALL ${COMPONENT_NAME})
|
||||
elseif (WIN32)
|
||||
set(CPACK_COMPONENTS_ALL ${COMPONENT_NAME})
|
||||
else ()
|
||||
set(CPACK_COMPONENTS_ALL ${COMPONENT_NAME} ${RUN_X86_COMPONENT_NAME})
|
||||
endif ()
|
||||
set(CPACK_PACKAGE_FILE_NAME ${MAIN_DIR})
|
||||
set(CPACK_PACKAGE_DIRECTORY ${TOP_DIR}/output/tmp)
|
||||
if (WIN32)
|
||||
set(CPACK_PACKAGE_DIRECTORY ${TOP_DIR}/output)
|
||||
else ()
|
||||
set(CPACK_PACKAGE_DIRECTORY ${TOP_DIR}/output/tmp)
|
||||
endif()
|
||||
set(CPACK_PACKAGE_CHECKSUM SHA256)
|
||||
include(CPack)
|
|
@ -5,15 +5,15 @@ if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_
|
|||
message(FATAL_ERROR "GCC vesion ${CMAKE_CXX_COMPILER_VERSION} must not be less than 7.3.0")
|
||||
endif ()
|
||||
|
||||
set(MS_VERSION_MAJOY 0)
|
||||
set(MS_VERSION_MAJOR 0)
|
||||
set(MS_VERSION_MINOR 7)
|
||||
set(MS_VERSION_REVISION 0)
|
||||
|
||||
set(DIR_PREFIX mindspore-lite)
|
||||
set(MS_VERSION ${MS_VERSION_MAJOY}.${MS_VERSION_MINOR}.${MS_VERSION_REVISION})
|
||||
set(MS_VERSION ${MS_VERSION_MAJOR}.${MS_VERSION_MINOR}.${MS_VERSION_REVISION})
|
||||
set(MAIN_DIR ${DIR_PREFIX}-${MS_VERSION})
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DMS_VERSION_MAJOY=${MS_VERSION_MAJOY} -DMS_VERSION_MINOR=${MS_VERSION_MINOR} -DMS_VERSION_REVISION=${MS_VERSION_REVISION}")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DMS_VERSION_MAJOY=${MS_VERSION_MAJOY} -DMS_VERSION_MINOR=${MS_VERSION_MINOR} -DMS_VERSION_REVISION=${MS_VERSION_REVISION}")
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DMS_VERSION_MAJOR=${MS_VERSION_MAJOR} -DMS_VERSION_MINOR=${MS_VERSION_MINOR} -DMS_VERSION_REVISION=${MS_VERSION_REVISION}")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DMS_VERSION_MAJOR=${MS_VERSION_MAJOR} -DMS_VERSION_MINOR=${MS_VERSION_MINOR} -DMS_VERSION_REVISION=${MS_VERSION_REVISION}")
|
||||
|
||||
if (SUPPORT_GPU)
|
||||
set(PROCESS_UNIT gpu)
|
||||
|
@ -25,13 +25,16 @@ if (PLATFORM_ARM64)
|
|||
set(COMPONENT_NAME runtime-arm64-${PROCESS_UNIT})
|
||||
elseif (PLATFORM_ARM32)
|
||||
set(COMPONENT_NAME runtime-arm32-${PROCESS_UNIT})
|
||||
elseif (WIN32)
|
||||
set(PARSER_NAME libconverter-parser-win-${PROCESS_UNIT})
|
||||
set(COMPONENT_NAME converter-win-${PROCESS_UNIT})
|
||||
else ()
|
||||
set(COMPONENT_NAME convert-ubuntu)
|
||||
endif()
|
||||
set(RUN_X86_COMPONENT_NAME runtime-x86-${PROCESS_UNIT})
|
||||
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17")
|
||||
set(TOP_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../..)
|
||||
string(REPLACE "/mindspore/lite" "" TOP_DIR ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
set(CORE_DIR ${TOP_DIR}/mindspore/core)
|
||||
set(CCSRC_DIR ${TOP_DIR}/mindspore/ccsrc)
|
||||
include_directories(${TOP_DIR})
|
||||
|
@ -65,20 +68,20 @@ set(CMAKE_VERBOSE_MAKEFILE on)
|
|||
add_compile_definitions(USE_ANDROID_LOG)
|
||||
add_compile_definitions(NO_DLIB)
|
||||
add_compile_options(-fPIC)
|
||||
if (NOT PLATFORM_ARM64 AND NOT PLATFORM_ARM32)
|
||||
if ("${CMAKE_BUILD_TYPE}" STREQUAL "Debug")
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DDebug -g")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DDebug -g")
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fvisibility=default")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fvisibility=default")
|
||||
else ()
|
||||
## enable for binscope for release
|
||||
set(CMAKE_C_FLAGS "-fPIC -fPIE -D_FORTIFY_SOURCE=2 -O2 -Wall -Werror -fstack-protector-strong -Wno-attributes -Wno-deprecated-declarations ${CMAKE_C_FLAGS}")
|
||||
set(CMAKE_CXX_FLAGS "-fPIC -fPIE -D_FORTIFY_SOURCE=2 -O2 -Wall -Werror -fstack-protector-strong -Wno-attributes -Wno-deprecated-declarations ${CMAKE_CXX_FLAGS}")
|
||||
if ("${CMAKE_BUILD_TYPE}" STREQUAL "Debug")
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DDebug -g")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DDebug -g")
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fvisibility=default")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fvisibility=default")
|
||||
else ()
|
||||
## enable for binscope for release
|
||||
set(CMAKE_C_FLAGS "-fPIC -fPIE -D_FORTIFY_SOURCE=2 -O2 -Wall -Werror -fstack-protector-strong -Wno-attributes -Wno-deprecated-declarations -Wno-missing-braces ${CMAKE_C_FLAGS}")
|
||||
set(CMAKE_CXX_FLAGS "-fPIC -fPIE -D_FORTIFY_SOURCE=2 -O2 -Wall -Werror -fstack-protector-strong -Wno-attributes -Wno-deprecated-declarations -Wno-missing-braces -Wno-overloaded-virtual ${CMAKE_CXX_FLAGS}")
|
||||
if (NOT WIN32)
|
||||
set(CMAKE_SHARED_LINKER_FLAGS "-Wl,-z,relro,-z,now -Wl,-z,noexecstack ${CMAKE_SHARED_LINKER_FLAGS}")
|
||||
set(CMAKE_EXE_LINKER_FLAGS "-Wl,-z,relro,-z,now -Wl,-z,noexecstack ${CMAKE_EXE_LINKER_FLAGS}")
|
||||
string(REPLACE " -g " " " CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
|
||||
endif ()
|
||||
endif()
|
||||
string(REPLACE " -g " " " CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
|
||||
endif ()
|
||||
|
||||
if (BUILD_DEVICE)
|
||||
|
@ -110,12 +113,11 @@ if (WIN32)
|
|||
add_compile_definitions(BUILDING_DLL)
|
||||
endif ()
|
||||
|
||||
set(ANF_SRC
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../core/ir/meta_tensor.cc
|
||||
set(CORE_SRC
|
||||
${CORE_DIR}/ir/meta_tensor.cc
|
||||
${CORE_DIR}/gvar/logging_level.cc
|
||||
${CORE_DIR}/gvar/typeid_manager.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../core/base/base.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/src/common/log_adapter.cc
|
||||
${CORE_DIR}/base/base.cc
|
||||
)
|
||||
if (BUILD_CONVERTER)
|
||||
if (PLATFORM_ARM64 OR PLATFORM_ARM32)
|
||||
|
@ -163,7 +165,6 @@ if (BUILD_DEVICE)
|
|||
add_compile_definitions(ENABLE_ARM32)
|
||||
endif ()
|
||||
if (PLATFORM_ARM64)
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=armv8.2-a+dotprod+fp16")
|
||||
add_compile_definitions(ENABLE_ARM64)
|
||||
if (ENABLE_FP16)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=armv8.2-a+dotprod+fp16")
|
||||
|
@ -207,4 +208,4 @@ if (BUILD_DEVICE)
|
|||
endif ()
|
||||
endif ()
|
||||
|
||||
include(${TOP_DIR}/cmake/package_lite.cmake)
|
||||
include(${TOP_DIR}/cmake/package_lite.cmake)
|
||||
|
|
|
@ -0,0 +1,56 @@
|
|||
[查看中文](./README_CN.md)
|
||||
|
||||
## What Is MindSpore Lite
|
||||
|
||||
MindSpore lite is a high-performance, lightweight open source reasoning framework that can be used to meet the needs of AI applications on mobile devices. MindSpore Lite focuses on how to deploy AI technology more effectively on devices. It has been integrated into HMS (Huawei Mobile Services) to provide inferences for applications such as image classification, object detection and OCR. MindSpore Lite will promote the development and enrichment of the AI software/hardware application ecosystem.
|
||||
|
||||
<img src="../../docs/MindSpore-Lite-architecture.png" alt="MindSpore Lite Architecture" width="600"/>
|
||||
|
||||
For more details please check out our [MindSpore Lite Architecture Guide](https://www.mindspore.cn/lite/docs/en/master/architecture.html).
|
||||
|
||||
### MindSpore Lite features
|
||||
|
||||
1. Cooperative work with MindSpore training
|
||||
- Provides training, optimization, and deployment.
|
||||
- The unified IR realizes the device-cloud AI application integration.
|
||||
|
||||
2. Lightweight
|
||||
- Provides model compress, which could help to improve performance as well.
|
||||
- Provides the ultra-lightweight reasoning solution MindSpore Micro to meet the deployment requirements in extreme environments such as smart watches and headphones.
|
||||
|
||||
3. High-performance
|
||||
- The built-in high-performance kernel computing library NNACL supports multiple convolution optimization algorithms such as Slide window, im2col+gemm, winograde, etc.
|
||||
- Assembly code to improve performance of kernel operators. Supports CPU, GPU, and NPU.
|
||||
4. Versatility
|
||||
- Supports IOS, Android.
|
||||
- Supports Lite OS.
|
||||
- Supports mobile device, smart screen, pad, and IOT devices.
|
||||
- Supports third party models such as TFLite, CAFFE and ONNX.
|
||||
|
||||
## MindSpore Lite AI deployment procedure
|
||||
|
||||
1. Model selection and personalized training
|
||||
|
||||
Select a new model or use an existing model for incremental training using labeled data. When designing a model for mobile device, it is necessary to consider the model size, accuracy and calculation amount.
|
||||
|
||||
The MindSpore team provides a series of pre-training models used for image classification, object detection. You can use these pre-trained models in your application.
|
||||
|
||||
The pre-trained models provided by MindSpore include: [Image Classification](https://download.mindspore.cn/model_zoo/official/lite/) and [Object Detection](https://download.mindspore.cn/model_zoo/official/lite/). More models will be provided in the feature.
|
||||
|
||||
MindSpore allows you to retrain pre-trained models to perform other tasks. For example: using a pre-trained image classification model, it can be retrained to recognize new image types. See [Retraining](https://www.mindspore.cn/lite/tutorial/zh-CN/master/advanced_use/retraining_of_quantized_network.html).
|
||||
|
||||
2. Model converter and optimization
|
||||
|
||||
If you use MindSpore or a third-party model, you need to use [MindSpore Lite Model Converter Tool](https://www.mindspore.cn/lite/tutorial/zh-CN/master/use/converter_tool.html) to convert the model into MindSpore Lite model. The MindSpore Lite model converter tool provides the converter of TensorFlow Lite, Caffe, ONNX to MindSpore Lite model, fusion and quantization could be introduced during convert procedure.
|
||||
|
||||
MindSpore also provides a tool to convert models running on IoT devices .
|
||||
|
||||
3. Model deployment
|
||||
|
||||
This stage mainly realizes model deployment, including model management, deployment, operation and maintenance monitoring, etc.
|
||||
|
||||
4. Inference
|
||||
|
||||
Load the model and perform inference. [Inference](https://www.mindspore.cn/lite/tutorial/zh-CN/master/use/runtime.html) is the process of running input data through the model to get output.
|
||||
|
||||
MindSpore provides a series of pre-trained models that can be deployed on mobile device [example](#TODO).
|
|
@ -0,0 +1,66 @@
|
|||
|
||||
[View English](./README.md)
|
||||
|
||||
## MindSpore Lite介绍
|
||||
|
||||
MindSpore Lite是MindSpore推出的端云协同的、轻量化、高性能AI推理框架,用于满足越来越多的端测AI应用需求。MindSpore Lite聚焦AI技术在端侧设备上的部署和运行,已经在华为HMS和智能终端的图像分类、目标识别、人脸识别、文字识别等应用中广泛使用,未来MindSpore Lite将与MindSpore AI社区一起,致力于丰富AI软硬件应用生态。
|
||||
|
||||
|
||||
|
||||
<img src="../../docs/MindSpore-Lite-architecture.png" alt="MindSpore Lite Architecture" width="600"/>
|
||||
|
||||
欲了解更多详情,请查看我们的[MindSpore Lite 总体架构](https://www.mindspore.cn/lite/docs/zh-CN/master/architecture.html)。
|
||||
|
||||
## MindSpore Lite技术特点
|
||||
|
||||
1. 端云协同提供一站式训练和推理
|
||||
|
||||
- 提供模型训练、模型转换优化、部署和推理端到端流程。
|
||||
- 统一的IR实现端云AI应用一体化。
|
||||
|
||||
2. 超轻量
|
||||
|
||||
- 支持模型量化压缩,模型更小跑得更快。
|
||||
- 提供超轻量的推理解决方案MindSpore Micro,满足智能手表、耳机等极限环境下的部署要求。
|
||||
|
||||
3. 高性能
|
||||
|
||||
- 自带的高性能内核计算库NNACL,支持Sliding Windows、Im2Col+GEMM、Winograd等多种卷积优化算法。
|
||||
- 汇编级优化,支持CPU、GPU、NPU异构调度,最大化发挥硬件算力,最小化推理时延和功耗。
|
||||
|
||||
4. 广覆盖
|
||||
|
||||
- 支持iOS、Android等手机操作系统。
|
||||
- 支持LiteOS嵌入式操作系统。
|
||||
- 支持手机、大屏、平板、IoT等各种智能设备上的AI应用。
|
||||
- 支持MindSpore/TensorFlow Lite/Caffe/ONNX模型,方便用户快速部署。
|
||||
|
||||
## MindSpore Lite AI部署流程
|
||||
|
||||
1. 模型选择和个性化训练
|
||||
|
||||
包括选择新模型或对已有模型,利用标注数据进行增量训练。面向端侧设计模型时,需要考虑模型大小、精度和计算量。
|
||||
|
||||
MindSpore团队提供了一系列预训练模型,用于解决图像分类、目标检测等场景的学习问题。可以在您的应用程序中使用这些预训练模型对应的终端模型。
|
||||
|
||||
MindSpore提供的预训练模型包括:[图像分类(Image Classification)](https://download.mindspore.cn/model_zoo/official/lite/)和[目标检测(Object Detection)](https://download.mindspore.cn/model_zoo/official/lite/)。后续MindSpore团队会增加更多的预置模型。
|
||||
|
||||
MindSpore允许您重新训练预训练模型,以执行其他任务。比如:使用预训练的图像分类模型,可以重新训练来识别新的图像类型。参见[重训练](https://www.mindspore.cn/lite/tutorial/zh-CN/master/advanced_use/retraining_of_quantized_network.html)。
|
||||
|
||||
2. 模型转换/优化
|
||||
|
||||
如果您使用MindSpore或第三方训练的模型,需要使用[MindSpore Lite模型转换工具](https://www.mindspore.cn/lite/tutorial/zh-CN/master/use/converter_tool.html)转换成MindSpore Lite模型格式。MindSpore Lite模型转换工具不仅提供了将TensorFlow Lite、Caffe、ONNX等模型格式转换为MindSpore Lite模型格式,还提供了算子融合、量化等功能。
|
||||
|
||||
MindSpore还提供了将IoT设备上运行的模型转换成.C代码的生成工具。
|
||||
|
||||
经过上述两个部署,您已经得到端侧可以部署的模型。
|
||||
|
||||
3. 模型部署
|
||||
|
||||
这个阶段主要实现模型部署,包括模型管理、部署和运维监控等。
|
||||
|
||||
4. 模型推理
|
||||
|
||||
主要完成模型推理工作,即加载模型,完成模型相关的所有计算。[推理](https://www.mindspore.cn/lite/tutorial/zh-CN/master/use/runtime.html)是通过模型运行输入数据,获取预测的过程。
|
||||
|
||||
MindSpore提供了一系列预训练模型部署在智能终端的[样例](#TODO)。
|
|
@ -28,17 +28,17 @@ namespace mindspore::lite {
|
|||
class Allocator;
|
||||
|
||||
/// \brief CpuBindMode defined for holding bind cpu strategy argument.
|
||||
enum CpuBindMode {
|
||||
typedef enum {
|
||||
MID_CPU = -1, /**< bind middle cpu first */
|
||||
HIGHER_CPU = 1, /**< bind higher cpu first */
|
||||
NO_BIND = 0 /**< no bind */
|
||||
};
|
||||
} CpuBindMode;
|
||||
|
||||
/// \brief DeviceType defined for holding user's preferred backend.
|
||||
typedef enum {
|
||||
DT_CPU, /**< CPU device type */
|
||||
DT_GPU, /**< GPU device type */
|
||||
DT_NPU /**< NPU device type */
|
||||
DT_NPU /**< NPU device type, not supported yet */
|
||||
} DeviceType;
|
||||
|
||||
/// \brief DeviceContext defined for holding DeviceType.
|
||||
|
|
|
@ -86,17 +86,34 @@ class MS_API LiteSession {
|
|||
/// \return STATUS as an error code of running graph, STATUS is defined in errorcode.h.
|
||||
virtual int RunGraph(const KernelCallBack &before = nullptr, const KernelCallBack &after = nullptr) = 0;
|
||||
|
||||
/// \brief Get output MindSpore Lite MSTensors of model.
|
||||
/// \brief Get output MindSpore Lite MSTensors of model mapped by node name.
|
||||
///
|
||||
/// \return The map of output node name and MindSpore Lite MSTensor.
|
||||
virtual std::unordered_map<std::string, std::vector<mindspore::tensor::MSTensor *>> GetOutputs() const = 0;
|
||||
virtual std::unordered_map<std::string, std::vector<mindspore::tensor::MSTensor *>> GetOutputMapByNode() const = 0;
|
||||
|
||||
/// \brief Get output MindSpore Lite MSTensors of model by node name.
|
||||
///
|
||||
/// \param[in] node_name Define node name.
|
||||
///
|
||||
/// \return The vector of MindSpore Lite MSTensor.
|
||||
virtual std::vector<tensor::MSTensor *> GetOutputsByName(const std::string &node_name) const = 0;
|
||||
virtual std::vector<tensor::MSTensor *> GetOutputsByNodeName(const std::string &node_name) const = 0;
|
||||
|
||||
/// \brief Get output MindSpore Lite MSTensors of model mapped by tensor name.
|
||||
///
|
||||
/// \return The map of output tensor name and MindSpore Lite MSTensor.
|
||||
virtual std::unordered_map<std::string, mindspore::tensor::MSTensor *> GetOutputMapByTensor() const = 0;
|
||||
|
||||
/// \brief Get name of output tensors of model compiled by this session.
|
||||
///
|
||||
/// \return The vector of string as output tensor names in order.
|
||||
virtual std::vector<std::string> GetOutputTensorNames() const = 0;
|
||||
|
||||
/// \brief Get output MindSpore Lite MSTensors of model by tensor name.
|
||||
///
|
||||
/// \param[in] tensor_name Define tensor name.
|
||||
///
|
||||
/// \return Pointer of MindSpore Lite MSTensor.
|
||||
virtual mindspore::tensor::MSTensor *GetOutputByTensorName(const std::string &tensor_name) const = 0;
|
||||
|
||||
/// \brief Resize inputs shape.
|
||||
///
|
||||
|
|
|
@ -24,8 +24,17 @@ namespace lite {
|
|||
/// \brief Global method to get a version string.
|
||||
///
|
||||
/// \return The version string of MindSpore Lite.
|
||||
#ifndef MS_VERSION_MAJOR
|
||||
#define MS_VERSION_MAJOR 0
|
||||
#endif
|
||||
#ifndef MS_VERSION_MINOR
|
||||
#define MS_VERSION_MINOR 7
|
||||
#endif
|
||||
#ifndef MS_VERSION_REVISION
|
||||
#define MS_VERSION_REVISION 0
|
||||
#endif
|
||||
std::string Version() {
|
||||
return "MindSpore Lite " + std::to_string(MS_VERSION_MAJOY) + "." + std::to_string(MS_VERSION_MINOR) + "." +
|
||||
return "MindSpore Lite " + std::to_string(MS_VERSION_MAJOR) + "." + std::to_string(MS_VERSION_MINOR) + "." +
|
||||
std::to_string(MS_VERSION_REVISION);
|
||||
}
|
||||
} // namespace lite
|
||||
|
|
|
@ -15,10 +15,10 @@ fi
|
|||
|
||||
# copy arm64 so
|
||||
cd ${TOP_PATH}/output/
|
||||
rm -rf mindspore-lite-0.6.0
|
||||
tar -zxvf mindspore-lite-0.6.0-runtime-arm64-cpu.tar.gz
|
||||
rm -rf mindspore-lite-0.7.0
|
||||
tar -zxvf mindspore-lite-0.7.0-runtime-arm64-cpu.tar.gz
|
||||
mkdir -p ${BASE_PATH}/lib/
|
||||
cp ${TOP_PATH}/output/mindspore-lite-0.6.0/lib/libmindspore-lite.so ${BASE_PATH}/lib/
|
||||
cp ${TOP_PATH}/output/mindspore-lite-0.7.0-runtime-arm64-cpu/lib/libmindspore-lite.so ${BASE_PATH}/lib/
|
||||
cp ${ANDROID_NDK}/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/lib/aarch64-linux-android/libc++_shared.so ${BASE_PATH}/lib/
|
||||
|
||||
# build jni so
|
||||
|
|
|
@ -76,8 +76,8 @@ public class LiteSession {
|
|||
return tensors;
|
||||
}
|
||||
|
||||
public Map<String, List<MSTensor>> getOutputs() {
|
||||
Map<String, List<Long>> ret = this.getOutputs(this.sessionPtr);
|
||||
public Map<String, List<MSTensor>> getOutputMapByNode() {
|
||||
Map<String, List<Long>> ret = this.getOutputMapByNode(this.sessionPtr);
|
||||
Map<String, List<MSTensor>> tensorMap = new HashMap<>();
|
||||
Set<Map.Entry<String, List<Long>>> entrySet = ret.entrySet();
|
||||
for (Map.Entry<String, List<Long>> entry : entrySet) {
|
||||
|
@ -93,8 +93,8 @@ public class LiteSession {
|
|||
return tensorMap;
|
||||
}
|
||||
|
||||
public List<MSTensor> getOutputsByName(String nodeName) {
|
||||
List<Long> ret = this.getOutputsByName(this.sessionPtr, nodeName);
|
||||
public List<MSTensor> getOutputsByNodeName(String nodeName) {
|
||||
List<Long> ret = this.getOutputsByNodeName(this.sessionPtr, nodeName);
|
||||
ArrayList<MSTensor> tensors = new ArrayList<>();
|
||||
for (Long msTensorAddr : ret) {
|
||||
MSTensor msTensor = new MSTensor(msTensorAddr);
|
||||
|
@ -103,6 +103,27 @@ public class LiteSession {
|
|||
return tensors;
|
||||
}
|
||||
|
||||
public Map<String, MSTensor> getOutputMapByTensor() {
|
||||
Map<String, Long> ret = this.getOutputMapByTensor(this.sessionPtr);
|
||||
Map<String, MSTensor> tensorMap = new HashMap<>();
|
||||
Set<Map.Entry<String, Long>> entrySet = ret.entrySet();
|
||||
for (Map.Entry<String, Long> entry : entrySet) {
|
||||
String name = entry.getKey();
|
||||
Long msTensorAddr = entry.getValue();
|
||||
tensorMap.put(name, new MSTensor(msTensorAddr));
|
||||
}
|
||||
return tensorMap;
|
||||
}
|
||||
|
||||
public List<String> getOutputTensorNames() {
|
||||
return getOutputTensorNames(this.sessionPtr);
|
||||
}
|
||||
|
||||
public MSTensor getOutputByTensorName(String tensorName) {
|
||||
Long tensor_addr = getOutputByTensorName(this.sessionPtr, tensorName);
|
||||
return new MSTensor(tensor_addr);
|
||||
}
|
||||
|
||||
public void free() {
|
||||
this.free(this.sessionPtr);
|
||||
this.sessionPtr = 0;
|
||||
|
@ -120,9 +141,15 @@ public class LiteSession {
|
|||
|
||||
private native List<Long> getInputsByName(long sessionPtr, String nodeName);
|
||||
|
||||
private native Map<String, List<Long>> getOutputs(long sessionPtr);
|
||||
private native Map<String, List<Long>> getOutputMapByNode(long sessionPtr);
|
||||
|
||||
private native List<Long> getOutputsByName(long sessionPtr, String nodeName);
|
||||
private native List<Long> getOutputsByNodeName(long sessionPtr, String nodeName);
|
||||
|
||||
private native Map<String, Long> getOutputMapByTensor(long sessionPtr);
|
||||
|
||||
private native List<String> getOutputTensorNames(long sessionPtr);
|
||||
|
||||
private native Long getOutputByTensorName(long sessionPtr, String tensorName);
|
||||
|
||||
private native void free(long sessionPtr);
|
||||
}
|
||||
|
|
|
@ -16,6 +16,10 @@
|
|||
|
||||
package com.mindspore.lite;
|
||||
|
||||
import android.util.Log;
|
||||
|
||||
import java.nio.ByteBuffer;
|
||||
|
||||
public class MSTensor {
|
||||
private long tensorPtr;
|
||||
|
||||
|
@ -27,7 +31,7 @@ public class MSTensor {
|
|||
this.tensorPtr = tensorPtr;
|
||||
}
|
||||
|
||||
public boolean init (int dataType, int[] shape) {
|
||||
public boolean init(int dataType, int[] shape) {
|
||||
this.tensorPtr = createMSTensor(dataType, shape, shape.length);
|
||||
return this.tensorPtr != 0;
|
||||
}
|
||||
|
@ -48,14 +52,30 @@ public class MSTensor {
|
|||
this.setDataType(this.tensorPtr, dataType);
|
||||
}
|
||||
|
||||
public byte[] getData() {
|
||||
return this.getData(this.tensorPtr);
|
||||
public byte[] getByteData() {
|
||||
return this.getByteData(this.tensorPtr);
|
||||
}
|
||||
|
||||
public float[] getFloatData() {
|
||||
return this.getFloatData(this.tensorPtr);
|
||||
}
|
||||
|
||||
public int[] getIntData() {
|
||||
return this.getIntData(this.tensorPtr);
|
||||
}
|
||||
|
||||
public long[] getLongData() {
|
||||
return this.getLongData(this.tensorPtr);
|
||||
}
|
||||
|
||||
public void setData(byte[] data) {
|
||||
this.setData(this.tensorPtr, data, data.length);
|
||||
}
|
||||
|
||||
public void setData(ByteBuffer data) {
|
||||
this.setByteBufferData(this.tensorPtr, data);
|
||||
}
|
||||
|
||||
public long size() {
|
||||
return this.size(this.tensorPtr);
|
||||
}
|
||||
|
@ -69,6 +89,24 @@ public class MSTensor {
|
|||
this.tensorPtr = 0;
|
||||
}
|
||||
|
||||
private float[] decodeBytes(byte[] bytes) {
|
||||
if (bytes.length % 4 != 0) {
|
||||
Log.e("MS_LITE", "Length of bytes should be multi of 4 ");
|
||||
return null;
|
||||
}
|
||||
int size = bytes.length / 4;
|
||||
float[] ret = new float[size];
|
||||
for (int i = 0; i < size; i = i + 4) {
|
||||
int accNum = 0;
|
||||
accNum = accNum | (bytes[i] & 0xff) << 0;
|
||||
accNum = accNum | (bytes[i + 1] & 0xff) << 8;
|
||||
accNum = accNum | (bytes[i + 2] & 0xff) << 16;
|
||||
accNum = accNum | (bytes[i + 3] & 0xff) << 24;
|
||||
ret[i / 4] = Float.intBitsToFloat(accNum);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
private native long createMSTensor(int dataType, int[] shape, int shapeLen);
|
||||
|
||||
private native int[] getShape(long tensorPtr);
|
||||
|
@ -79,10 +117,18 @@ public class MSTensor {
|
|||
|
||||
private native boolean setDataType(long tensorPtr, int dataType);
|
||||
|
||||
private native byte[] getData(long tensorPtr);
|
||||
private native byte[] getByteData(long tensorPtr);
|
||||
|
||||
private native long[] getLongData(long tensorPtr);
|
||||
|
||||
private native int[] getIntData(long tensorPtr);
|
||||
|
||||
private native float[] getFloatData(long tensorPtr);
|
||||
|
||||
private native boolean setData(long tensorPtr, byte[] data, long dataLen);
|
||||
|
||||
private native boolean setByteBufferData(long tensorPtr, ByteBuffer buffer);
|
||||
|
||||
private native long size(long tensorPtr);
|
||||
|
||||
private native int elementsNum(long tensorPtr);
|
||||
|
|
|
@ -80,6 +80,11 @@ public class Model {
|
|||
return ret;
|
||||
}
|
||||
|
||||
public boolean loadModel(String modelPath) {
|
||||
this.modelPtr = loadModelByPath(modelPath);
|
||||
return this.modelPtr != 0;
|
||||
}
|
||||
|
||||
public void free() {
|
||||
this.free(this.modelPtr);
|
||||
this.modelPtr = 0;
|
||||
|
@ -87,5 +92,7 @@ public class Model {
|
|||
|
||||
private native long loadModel(MappedByteBuffer buffer);
|
||||
|
||||
private native long loadModelByPath(String modelPath);
|
||||
|
||||
private native void free(long modelPtr);
|
||||
}
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
cmake_minimum_required(VERSION 3.14)
|
||||
project (Lite-java)
|
||||
|
||||
set(MS_VERSION_MAJOY 0)
|
||||
set(MS_VERSION_MAJOR 0)
|
||||
set(MS_VERSION_MINOR 7)
|
||||
set(MS_VERSION_REVISION 0)
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DMS_VERSION_MAJOY=${MS_VERSION_MAJOY} -DMS_VERSION_MINOR=${MS_VERSION_MINOR} -DMS_VERSION_REVISION=${MS_VERSION_REVISION}")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DMS_VERSION_MAJOY=${MS_VERSION_MAJOY} -DMS_VERSION_MINOR=${MS_VERSION_MINOR} -DMS_VERSION_REVISION=${MS_VERSION_REVISION}")
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DMS_VERSION_MAJOR=${MS_VERSION_MAJOR} -DMS_VERSION_MINOR=${MS_VERSION_MINOR} -DMS_VERSION_REVISION=${MS_VERSION_REVISION}")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DMS_VERSION_MAJOR=${MS_VERSION_MAJOR} -DMS_VERSION_MINOR=${MS_VERSION_MINOR} -DMS_VERSION_REVISION=${MS_VERSION_REVISION}")
|
||||
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/)
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../../../lite/)
|
||||
|
|
|
@ -14,12 +14,11 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
|
||||
#include "common/jni_utils.h"
|
||||
#include <cstring>
|
||||
|
||||
char *JstringToChar(JNIEnv *env, jstring jstr) {
|
||||
char *rtn = NULL;
|
||||
char *rtn = nullptr;
|
||||
jclass clsstring = env->FindClass("java/lang/String");
|
||||
jstring strencode = env->NewStringUTF("GB2312");
|
||||
jmethodID mid = env->GetMethodID(clsstring, "getBytes", "(Ljava/lang/String;)[B");
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#include <jni.h>
|
||||
#include "common/ms_log.h"
|
||||
#include "include/context.h"
|
||||
#include "include/thread_pool_config.h"
|
||||
|
||||
extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_context_Context_createContext(JNIEnv *env, jobject thiz,
|
||||
jint device_type,
|
||||
|
@ -44,13 +45,13 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_context_Context_creat
|
|||
}
|
||||
switch (cpu_bind_mode) {
|
||||
case -1:
|
||||
context->cpu_bind_mode_ = mindspore::lite::MID_CPU;
|
||||
context->cpu_bind_mode_ = MID_CPU;
|
||||
break;
|
||||
case 0:
|
||||
context->cpu_bind_mode_ = mindspore::lite::NO_BIND;
|
||||
context->cpu_bind_mode_ = NO_BIND;
|
||||
break;
|
||||
case 1:
|
||||
context->cpu_bind_mode_ = mindspore::lite::HIGHER_CPU;
|
||||
context->cpu_bind_mode_ = HIGHER_CPU;
|
||||
break;
|
||||
default:
|
||||
MS_LOGE("Invalid cpu_bind_mode : %d", cpu_bind_mode);
|
||||
|
|
|
@ -14,7 +14,6 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
|
||||
#include <jni.h>
|
||||
#include "common/ms_log.h"
|
||||
#include "common/jni_utils.h"
|
||||
|
@ -22,7 +21,7 @@
|
|||
#include "include/errorcode.h"
|
||||
|
||||
extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_LiteSession_createSession(JNIEnv *env, jobject thiz,
|
||||
jlong context_ptr) {
|
||||
jlong context_ptr) {
|
||||
auto *pointer = reinterpret_cast<void *>(context_ptr);
|
||||
if (pointer == nullptr) {
|
||||
MS_LOGE("Context pointer from java is nullptr");
|
||||
|
@ -38,8 +37,8 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_LiteSession_createSes
|
|||
}
|
||||
|
||||
extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_LiteSession_compileGraph(JNIEnv *env, jobject thiz,
|
||||
jlong session_ptr,
|
||||
jlong model_ptr) {
|
||||
jlong session_ptr,
|
||||
jlong model_ptr) {
|
||||
auto *session_pointer = reinterpret_cast<void *>(session_ptr);
|
||||
if (session_pointer == nullptr) {
|
||||
MS_LOGE("Session pointer from java is nullptr");
|
||||
|
@ -58,7 +57,7 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_LiteSession_compil
|
|||
}
|
||||
|
||||
extern "C" JNIEXPORT void JNICALL Java_com_mindspore_lite_LiteSession_bindThread(JNIEnv *env, jobject thiz,
|
||||
jlong session_ptr, jboolean if_bind) {
|
||||
jlong session_ptr, jboolean if_bind) {
|
||||
auto *pointer = reinterpret_cast<void *>(session_ptr);
|
||||
if (pointer == nullptr) {
|
||||
MS_LOGE("Session pointer from java is nullptr");
|
||||
|
@ -69,7 +68,7 @@ extern "C" JNIEXPORT void JNICALL Java_com_mindspore_lite_LiteSession_bindThread
|
|||
}
|
||||
|
||||
extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_LiteSession_runGraph(JNIEnv *env, jobject thiz,
|
||||
jlong session_ptr) {
|
||||
jlong session_ptr) {
|
||||
auto *pointer = reinterpret_cast<void *>(session_ptr);
|
||||
if (pointer == nullptr) {
|
||||
MS_LOGE("Session pointer from java is nullptr");
|
||||
|
@ -81,7 +80,7 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_LiteSession_runGra
|
|||
}
|
||||
|
||||
extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getInputs(JNIEnv *env, jobject thiz,
|
||||
jlong session_ptr) {
|
||||
jlong session_ptr) {
|
||||
jclass array_list = env->FindClass("java/util/ArrayList");
|
||||
jmethodID array_list_construct = env->GetMethodID(array_list, "<init>", "()V");
|
||||
jobject ret = env->NewObject(array_list, array_list_construct);
|
||||
|
@ -104,8 +103,8 @@ extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getInpu
|
|||
}
|
||||
|
||||
extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getInputsByName(JNIEnv *env, jobject thiz,
|
||||
jlong session_ptr,
|
||||
jstring node_name) {
|
||||
jlong session_ptr,
|
||||
jstring node_name) {
|
||||
jclass array_list = env->FindClass("java/util/ArrayList");
|
||||
jmethodID array_list_construct = env->GetMethodID(array_list, "<init>", "()V");
|
||||
jobject ret = env->NewObject(array_list, array_list_construct);
|
||||
|
@ -127,8 +126,8 @@ extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getInpu
|
|||
return ret;
|
||||
}
|
||||
|
||||
extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getOutputs(JNIEnv *env, jobject thiz,
|
||||
jlong session_ptr) {
|
||||
extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getOutputMapByNode(JNIEnv *env, jobject thiz,
|
||||
jlong session_ptr) {
|
||||
jclass hash_map_clazz = env->FindClass("java/util/HashMap");
|
||||
jmethodID hash_map_construct = env->GetMethodID(hash_map_clazz, "<init>", "()V");
|
||||
jobject hash_map = env->NewObject(hash_map_clazz, hash_map_construct);
|
||||
|
@ -140,7 +139,7 @@ extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getOutp
|
|||
return hash_map;
|
||||
}
|
||||
auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(pointer);
|
||||
auto outputs = lite_session_ptr->GetOutputs();
|
||||
auto outputs = lite_session_ptr->GetOutputMapByNode();
|
||||
jclass long_object = env->FindClass("java/lang/Long");
|
||||
jmethodID long_object_construct = env->GetMethodID(long_object, "<init>", "(J)V");
|
||||
jclass array_list = env->FindClass("java/util/ArrayList");
|
||||
|
@ -159,9 +158,9 @@ extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getOutp
|
|||
return hash_map;
|
||||
}
|
||||
|
||||
extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getOutputsByName(JNIEnv *env, jobject thiz,
|
||||
jlong session_ptr,
|
||||
jstring node_name) {
|
||||
extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getOutputsByNodeName(JNIEnv *env, jobject thiz,
|
||||
jlong session_ptr,
|
||||
jstring node_name) {
|
||||
jclass array_list = env->FindClass("java/util/ArrayList");
|
||||
jmethodID array_list_construct = env->GetMethodID(array_list, "<init>", "()V");
|
||||
jobject ret = env->NewObject(array_list, array_list_construct);
|
||||
|
@ -175,7 +174,7 @@ extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getOutp
|
|||
return ret;
|
||||
}
|
||||
auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(pointer);
|
||||
auto inputs = lite_session_ptr->GetOutputsByName(JstringToChar(env, node_name));
|
||||
auto inputs = lite_session_ptr->GetOutputsByNodeName(JstringToChar(env, node_name));
|
||||
for (auto input : inputs) {
|
||||
jobject tensor_addr = env->NewObject(long_object, long_object_construct, jlong(input));
|
||||
env->CallBooleanMethod(ret, array_list_add, tensor_addr);
|
||||
|
@ -183,8 +182,66 @@ extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getOutp
|
|||
return ret;
|
||||
}
|
||||
|
||||
extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getOutputMapByTensor(JNIEnv *env, jobject thiz,
|
||||
jlong session_ptr) {
|
||||
jclass hash_map_clazz = env->FindClass("java/util/HashMap");
|
||||
jmethodID hash_map_construct = env->GetMethodID(hash_map_clazz, "<init>", "()V");
|
||||
jobject hash_map = env->NewObject(hash_map_clazz, hash_map_construct);
|
||||
jmethodID hash_map_put =
|
||||
env->GetMethodID(hash_map_clazz, "put", "(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;");
|
||||
auto *pointer = reinterpret_cast<void *>(session_ptr);
|
||||
if (pointer == nullptr) {
|
||||
MS_LOGE("Session pointer from java is nullptr");
|
||||
return hash_map;
|
||||
}
|
||||
auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(pointer);
|
||||
auto outputs = lite_session_ptr->GetOutputMapByTensor();
|
||||
jclass long_object = env->FindClass("java/lang/Long");
|
||||
jmethodID long_object_construct = env->GetMethodID(long_object, "<init>", "(J)V");
|
||||
for (auto output_iter : outputs) {
|
||||
auto node_name = output_iter.first;
|
||||
auto ms_tensor = output_iter.second;
|
||||
jobject tensor_addr = env->NewObject(long_object, long_object_construct, jlong(ms_tensor));
|
||||
env->CallObjectMethod(hash_map, hash_map_put, env->NewStringUTF(node_name.c_str()), tensor_addr);
|
||||
}
|
||||
return hash_map;
|
||||
}
|
||||
|
||||
extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getOutputTensorNames(JNIEnv *env, jobject thiz,
|
||||
jlong session_ptr) {
|
||||
jclass array_list = env->FindClass("java/util/ArrayList");
|
||||
jmethodID array_list_construct = env->GetMethodID(array_list, "<init>", "()V");
|
||||
jobject ret = env->NewObject(array_list, array_list_construct);
|
||||
jmethodID array_list_add = env->GetMethodID(array_list, "add", "(Ljava/lang/Object;)Z");
|
||||
|
||||
auto *pointer = reinterpret_cast<void *>(session_ptr);
|
||||
if (pointer == nullptr) {
|
||||
MS_LOGE("Session pointer from java is nullptr");
|
||||
return ret;
|
||||
}
|
||||
auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(pointer);
|
||||
auto output_names = lite_session_ptr->GetOutputTensorNames();
|
||||
for (auto output_name : output_names) {
|
||||
env->CallBooleanMethod(ret, array_list_add, env->NewStringUTF(output_name.c_str()));
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_LiteSession_getOutputByTensorName(JNIEnv *env, jobject thiz,
|
||||
jlong session_ptr,
|
||||
jstring tensor_name) {
|
||||
auto *pointer = reinterpret_cast<void *>(session_ptr);
|
||||
if (pointer == nullptr) {
|
||||
MS_LOGE("Session pointer from java is nullptr");
|
||||
return jlong(nullptr);
|
||||
}
|
||||
auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(pointer);
|
||||
auto output = lite_session_ptr->GetOutputByTensorName(JstringToChar(env, tensor_name));
|
||||
return jlong(output);
|
||||
}
|
||||
|
||||
extern "C" JNIEXPORT void JNICALL Java_com_mindspore_lite_LiteSession_free(JNIEnv *env, jobject thiz,
|
||||
jlong session_ptr) {
|
||||
jlong session_ptr) {
|
||||
auto *pointer = reinterpret_cast<void *>(session_ptr);
|
||||
if (pointer == nullptr) {
|
||||
MS_LOGE("Session pointer from java is nullptr");
|
||||
|
|
|
@ -14,9 +14,10 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
|
||||
#include <jni.h>
|
||||
#include <fstream>
|
||||
#include "common/ms_log.h"
|
||||
#include "common/jni_utils.h"
|
||||
#include "include/model.h"
|
||||
|
||||
extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_Model_loadModel(JNIEnv *env, jobject thiz, jobject buffer) {
|
||||
|
@ -38,6 +39,46 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_Model_loadModel(JNIEn
|
|||
return reinterpret_cast<jlong>(model);
|
||||
}
|
||||
|
||||
extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_Model_loadModelByPath(JNIEnv *env, jobject thiz,
|
||||
jstring model_path) {
|
||||
auto model_path_char = JstringToChar(env, model_path);
|
||||
if (nullptr == model_path_char) {
|
||||
MS_LOGE("model_path_char is nullptr");
|
||||
return reinterpret_cast<jlong>(nullptr);
|
||||
}
|
||||
std::ifstream ifs(model_path_char);
|
||||
if (!ifs.good()) {
|
||||
MS_LOGE("file: %s is not exist", model_path_char);
|
||||
return reinterpret_cast<jlong>(nullptr);
|
||||
}
|
||||
|
||||
if (!ifs.is_open()) {
|
||||
MS_LOGE("file: %s open failed", model_path_char);
|
||||
return reinterpret_cast<jlong>(nullptr);
|
||||
}
|
||||
|
||||
ifs.seekg(0, std::ios::end);
|
||||
auto size = ifs.tellg();
|
||||
std::unique_ptr<char[]> buf(new (std::nothrow) char[size]);
|
||||
if (buf == nullptr) {
|
||||
MS_LOGE("malloc buf failed, file: %s", model_path_char);
|
||||
ifs.close();
|
||||
return reinterpret_cast<jlong>(nullptr);
|
||||
}
|
||||
|
||||
ifs.seekg(0, std::ios::beg);
|
||||
ifs.read(buf.get(), size);
|
||||
ifs.close();
|
||||
delete[](model_path_char);
|
||||
MS_LOGD("Start Loading model");
|
||||
auto model = mindspore::lite::Model::Import(buf.get(), size);
|
||||
if (model == nullptr) {
|
||||
MS_LOGE("Import model failed");
|
||||
return reinterpret_cast<jlong>(nullptr);
|
||||
}
|
||||
return reinterpret_cast<jlong>(model);
|
||||
}
|
||||
|
||||
extern "C" JNIEXPORT void JNICALL Java_com_mindspore_lite_Model_free(JNIEnv *env, jobject thiz, jlong model_ptr) {
|
||||
auto *pointer = reinterpret_cast<void *>(model_ptr);
|
||||
if (pointer == nullptr) {
|
||||
|
|
|
@ -14,15 +14,14 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
|
||||
#include <jni.h>
|
||||
#include "common/ms_log.h"
|
||||
#include "include/ms_tensor.h"
|
||||
#include "ir/dtype/type_id.h"
|
||||
|
||||
extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_MSTensor_createMSTensor(JNIEnv *env, jobject thiz,
|
||||
jint data_type, jintArray shape,
|
||||
jint shape_len) {
|
||||
jint data_type, jintArray shape,
|
||||
jint shape_len) {
|
||||
jboolean is_copy = false;
|
||||
jint *local_shape_arr = env->GetIntArrayElements(shape, &is_copy);
|
||||
std::vector<int> local_shape(shape_len);
|
||||
|
@ -39,7 +38,7 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_MSTensor_createMSTens
|
|||
}
|
||||
|
||||
extern "C" JNIEXPORT jintArray JNICALL Java_com_mindspore_lite_MSTensor_getShape(JNIEnv *env, jobject thiz,
|
||||
jlong tensor_ptr) {
|
||||
jlong tensor_ptr) {
|
||||
auto *pointer = reinterpret_cast<void *>(tensor_ptr);
|
||||
if (pointer == nullptr) {
|
||||
MS_LOGE("Tensor pointer from java is nullptr");
|
||||
|
@ -59,8 +58,8 @@ extern "C" JNIEXPORT jintArray JNICALL Java_com_mindspore_lite_MSTensor_getShape
|
|||
}
|
||||
|
||||
extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_MSTensor_setShape(JNIEnv *env, jobject thiz,
|
||||
jlong tensor_ptr, jintArray shape,
|
||||
jint shape_len) {
|
||||
jlong tensor_ptr, jintArray shape,
|
||||
jint shape_len) {
|
||||
jboolean is_copy = false;
|
||||
jint *local_shape_arr = env->GetIntArrayElements(shape, &is_copy);
|
||||
auto *pointer = reinterpret_cast<void *>(tensor_ptr);
|
||||
|
@ -78,7 +77,7 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_MSTensor_setShape(
|
|||
}
|
||||
|
||||
extern "C" JNIEXPORT jint JNICALL Java_com_mindspore_lite_MSTensor_getDataType(JNIEnv *env, jobject thiz,
|
||||
jlong tensor_ptr) {
|
||||
jlong tensor_ptr) {
|
||||
auto *pointer = reinterpret_cast<void *>(tensor_ptr);
|
||||
if (pointer == nullptr) {
|
||||
MS_LOGE("Tensor pointer from java is nullptr");
|
||||
|
@ -89,7 +88,7 @@ extern "C" JNIEXPORT jint JNICALL Java_com_mindspore_lite_MSTensor_getDataType(J
|
|||
}
|
||||
|
||||
extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_MSTensor_setDataType(JNIEnv *env, jobject thiz,
|
||||
jlong tensor_ptr, jint data_type) {
|
||||
jlong tensor_ptr, jint data_type) {
|
||||
auto *pointer = reinterpret_cast<void *>(tensor_ptr);
|
||||
if (pointer == nullptr) {
|
||||
MS_LOGE("Tensor pointer from java is nullptr");
|
||||
|
@ -100,8 +99,8 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_MSTensor_setDataTy
|
|||
return ret == data_type;
|
||||
}
|
||||
|
||||
extern "C" JNIEXPORT jbyteArray JNICALL Java_com_mindspore_lite_MSTensor_getData(JNIEnv *env, jobject thiz,
|
||||
jlong tensor_ptr) {
|
||||
extern "C" JNIEXPORT jbyteArray JNICALL Java_com_mindspore_lite_MSTensor_getByteData(JNIEnv *env, jobject thiz,
|
||||
jlong tensor_ptr) {
|
||||
auto *pointer = reinterpret_cast<void *>(tensor_ptr);
|
||||
if (pointer == nullptr) {
|
||||
MS_LOGE("Tensor pointer from java is nullptr");
|
||||
|
@ -113,15 +112,99 @@ extern "C" JNIEXPORT jbyteArray JNICALL Java_com_mindspore_lite_MSTensor_getData
|
|||
MS_LOGD("Tensor has no data");
|
||||
return env->NewByteArray(0);
|
||||
}
|
||||
auto local_data_size = ms_tensor_ptr->Size();
|
||||
auto ret = env->NewByteArray(local_data_size);
|
||||
env->SetByteArrayRegion(ret, 0, local_data_size, local_data);
|
||||
|
||||
if (ms_tensor_ptr->data_type() != mindspore::kNumberTypeUInt8) {
|
||||
MS_LOGE("data type is error : %d", ms_tensor_ptr->data_type());
|
||||
return env->NewByteArray(0);
|
||||
}
|
||||
|
||||
auto local_element_num = ms_tensor_ptr->ElementsNum();
|
||||
auto ret = env->NewByteArray(local_element_num);
|
||||
env->SetByteArrayRegion(ret, 0, local_element_num, local_data);
|
||||
return ret;
|
||||
}
|
||||
|
||||
extern "C" JNIEXPORT jlongArray JNICALL Java_com_mindspore_lite_MSTensor_getLongData(JNIEnv *env, jobject thiz,
|
||||
jlong tensor_ptr) {
|
||||
auto *pointer = reinterpret_cast<void *>(tensor_ptr);
|
||||
if (pointer == nullptr) {
|
||||
MS_LOGE("Tensor pointer from java is nullptr");
|
||||
return env->NewLongArray(0);
|
||||
}
|
||||
|
||||
auto *ms_tensor_ptr = static_cast<mindspore::tensor::MSTensor *>(pointer);
|
||||
|
||||
auto *local_data = static_cast<jlong *>(ms_tensor_ptr->MutableData());
|
||||
if (local_data == nullptr) {
|
||||
MS_LOGD("Tensor has no data");
|
||||
return env->NewLongArray(0);
|
||||
}
|
||||
|
||||
if (ms_tensor_ptr->data_type() != mindspore::kNumberTypeInt64) {
|
||||
MS_LOGE("data type is error : %d", ms_tensor_ptr->data_type());
|
||||
return env->NewLongArray(0);
|
||||
}
|
||||
auto local_element_num = ms_tensor_ptr->ElementsNum();
|
||||
auto ret = env->NewLongArray(local_element_num);
|
||||
env->SetLongArrayRegion(ret, 0, local_element_num, local_data);
|
||||
return ret;
|
||||
}
|
||||
|
||||
extern "C" JNIEXPORT jintArray JNICALL Java_com_mindspore_lite_MSTensor_getIntData(JNIEnv *env, jobject thiz,
|
||||
jlong tensor_ptr) {
|
||||
auto *pointer = reinterpret_cast<void *>(tensor_ptr);
|
||||
if (pointer == nullptr) {
|
||||
MS_LOGE("Tensor pointer from java is nullptr");
|
||||
return env->NewIntArray(0);
|
||||
}
|
||||
|
||||
auto *ms_tensor_ptr = static_cast<mindspore::tensor::MSTensor *>(pointer);
|
||||
|
||||
auto *local_data = static_cast<jint *>(ms_tensor_ptr->MutableData());
|
||||
if (local_data == nullptr) {
|
||||
MS_LOGD("Tensor has no data");
|
||||
return env->NewIntArray(0);
|
||||
}
|
||||
|
||||
if (ms_tensor_ptr->data_type() != mindspore::kNumberTypeInt32) {
|
||||
MS_LOGE("data type is error : %d", ms_tensor_ptr->data_type());
|
||||
return env->NewIntArray(0);
|
||||
}
|
||||
auto local_element_num = ms_tensor_ptr->ElementsNum();
|
||||
auto ret = env->NewIntArray(local_element_num);
|
||||
env->SetIntArrayRegion(ret, 0, local_element_num, local_data);
|
||||
return ret;
|
||||
}
|
||||
|
||||
extern "C" JNIEXPORT jfloatArray JNICALL Java_com_mindspore_lite_MSTensor_getFloatData(JNIEnv *env, jobject thiz,
|
||||
jlong tensor_ptr) {
|
||||
auto *pointer = reinterpret_cast<void *>(tensor_ptr);
|
||||
if (pointer == nullptr) {
|
||||
MS_LOGE("Tensor pointer from java is nullptr");
|
||||
return env->NewFloatArray(0);
|
||||
}
|
||||
|
||||
auto *ms_tensor_ptr = static_cast<mindspore::tensor::MSTensor *>(pointer);
|
||||
|
||||
auto *local_data = static_cast<jfloat *>(ms_tensor_ptr->MutableData());
|
||||
if (local_data == nullptr) {
|
||||
MS_LOGD("Tensor has no data");
|
||||
return env->NewFloatArray(0);
|
||||
}
|
||||
|
||||
if (ms_tensor_ptr->data_type() != mindspore::kNumberTypeFloat32) {
|
||||
MS_LOGE("data type is error : %d", ms_tensor_ptr->data_type());
|
||||
return env->NewFloatArray(0);
|
||||
}
|
||||
auto local_element_num = ms_tensor_ptr->ElementsNum();
|
||||
auto ret = env->NewFloatArray(local_element_num);
|
||||
env->SetFloatArrayRegion(ret, 0, local_element_num, local_data);
|
||||
return ret;
|
||||
}
|
||||
|
||||
extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_MSTensor_setData(JNIEnv *env, jobject thiz,
|
||||
jlong tensor_ptr, jbyteArray data,
|
||||
jlong data_len) {
|
||||
jlong tensor_ptr, jbyteArray data,
|
||||
jlong data_len) {
|
||||
auto *pointer = reinterpret_cast<void *>(tensor_ptr);
|
||||
if (pointer == nullptr) {
|
||||
MS_LOGE("Tensor pointer from java is nullptr");
|
||||
|
@ -139,6 +222,36 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_MSTensor_setData(J
|
|||
return static_cast<jboolean>(true);
|
||||
}
|
||||
|
||||
extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_MSTensor_setByteBufferData(JNIEnv *env, jobject thiz,
|
||||
jlong tensor_ptr,
|
||||
jobject buffer) {
|
||||
jbyte *p_data = reinterpret_cast<jbyte *>(env->GetDirectBufferAddress(buffer)); // get buffer poiter
|
||||
jlong data_len = env->GetDirectBufferCapacity(buffer); // get buffer capacity
|
||||
if (!p_data) {
|
||||
MS_LOGE("GetDirectBufferAddress return null");
|
||||
return NULL;
|
||||
}
|
||||
jbyteArray data = env->NewByteArray(data_len); // create byte[]
|
||||
env->SetByteArrayRegion(data, 0, data_len, p_data); // copy data to byte[]
|
||||
|
||||
auto *pointer = reinterpret_cast<void *>(tensor_ptr);
|
||||
if (pointer == nullptr) {
|
||||
MS_LOGE("Tensor pointer from java is nullptr");
|
||||
return static_cast<jboolean>(false);
|
||||
}
|
||||
|
||||
auto *ms_tensor_ptr = static_cast<mindspore::tensor::MSTensor *>(pointer);
|
||||
if (data_len != ms_tensor_ptr->Size()) {
|
||||
MS_LOGE("data_len(%ld) not equal to Size of ms_tensor(%zu)", data_len, ms_tensor_ptr->Size());
|
||||
return static_cast<jboolean>(false);
|
||||
}
|
||||
jboolean is_copy = false;
|
||||
auto *data_arr = env->GetByteArrayElements(data, &is_copy);
|
||||
auto *local_data = ms_tensor_ptr->MutableData();
|
||||
memcpy(local_data, data_arr, data_len);
|
||||
return static_cast<jboolean>(true);
|
||||
}
|
||||
|
||||
extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_MSTensor_size(JNIEnv *env, jobject thiz, jlong tensor_ptr) {
|
||||
auto *pointer = reinterpret_cast<void *>(tensor_ptr);
|
||||
if (pointer == nullptr) {
|
||||
|
@ -150,7 +263,7 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_MSTensor_size(JNIEnv
|
|||
}
|
||||
|
||||
extern "C" JNIEXPORT jint JNICALL Java_com_mindspore_lite_MSTensor_elementsNum(JNIEnv *env, jobject thiz,
|
||||
jlong tensor_ptr) {
|
||||
jlong tensor_ptr) {
|
||||
auto *pointer = reinterpret_cast<void *>(tensor_ptr);
|
||||
if (pointer == nullptr) {
|
||||
MS_LOGE("Tensor pointer from java is nullptr");
|
||||
|
|
|
@ -32,9 +32,11 @@ if (PLATFORM_ARM64)
|
|||
)
|
||||
set_target_properties(optimize PROPERTIES CLEAN_DIRECT_OUTPUT 1)
|
||||
|
||||
add_custom_command(TARGET optimize POST_BUILD
|
||||
if ("${CMAKE_BUILD_TYPE}" STREQUAL "Release")
|
||||
add_custom_command(TARGET optimize POST_BUILD
|
||||
COMMAND ${ANDROID_NDK}/toolchains/aarch64-linux-android-4.9/prebuilt/linux-x86_64/aarch64-linux-android/bin/strip
|
||||
${TOP_DIR}/build/nnacl/liboptimize.so)
|
||||
endif ()
|
||||
|
||||
add_custom_command(TARGET optimize POST_BUILD
|
||||
COMMAND rm -rf ${TOP_DIR}/output/lib/liboptimize.so
|
||||
|
|
|
@ -51,6 +51,8 @@ void TileOneDimension(float *inData, float *outData, int dim, size_t ndim, int *
|
|||
int *outStrides, int *multiple);
|
||||
void ComputeStrides(int *shape, int *strides, int ndim);
|
||||
|
||||
void CalcMultiplesAndStrides(ArithmeticParameter *param);
|
||||
|
||||
void TileDimensions(float *data0, float *data1, float *tile_data0, float *tile_data1, ArithmeticParameter *param);
|
||||
void TileDimensionsUint8(uint8_t *data0, uint8_t *data1, uint8_t *tile_data0, uint8_t *tile_data1,
|
||||
ArithmeticParameter *param);
|
||||
|
|
|
@ -29,7 +29,7 @@ mov x6, x1
|
|||
mov x7, x2
|
||||
mov x8, x4
|
||||
|
||||
LoopInputDepth16In:
|
||||
LoopDepth16In:
|
||||
cmp x8, #16
|
||||
blt L4
|
||||
sub x8, x8, #16
|
||||
|
@ -39,8 +39,8 @@ mov x8, x4
|
|||
ld1 {v16.4s, v17.4s}, [x0], #32
|
||||
|
||||
cmp x8, #16
|
||||
blt LoopInputDepth16Out
|
||||
LoopInputDepth16:
|
||||
blt LoopDepth16Out
|
||||
LoopDepth16:
|
||||
fmla v16.4s, v0.4s, v2.4s
|
||||
fmla v17.4s, v1.4s, v3.4s
|
||||
|
||||
|
@ -61,9 +61,9 @@ mov x8, x4
|
|||
|
||||
sub x8, x8, #16
|
||||
cmp x8, #16
|
||||
bge LoopInputDepth16
|
||||
bge LoopDepth16
|
||||
|
||||
LoopInputDepth16Out:
|
||||
LoopDepth16Out:
|
||||
fmla v16.4s, v0.4s, v2.4s
|
||||
fmla v17.4s, v1.4s, v3.4s
|
||||
st1 {v16.4s, v17.4s}, [x9], #32
|
||||
|
@ -81,7 +81,7 @@ mov x8, x4
|
|||
cmp x8, #4
|
||||
blt L0
|
||||
|
||||
LoopInputDepth4:
|
||||
LoopDepth4:
|
||||
ld1 {v0.4s}, [x6], #16
|
||||
ld1 {v2.4s}, [x7], #16
|
||||
ld1 {v16.4s}, [x0], #16
|
||||
|
@ -89,13 +89,13 @@ mov x8, x4
|
|||
st1 {v16.4s}, [x9], #16
|
||||
sub x8, x8, #4
|
||||
cmp x8, #4
|
||||
bge LoopInputDepth4
|
||||
bge LoopDepth4
|
||||
|
||||
L0:
|
||||
cmp x8, #0
|
||||
beq Loop16LineEnd
|
||||
|
||||
LoopInputDepth0:
|
||||
LoopDepth0:
|
||||
ldr s0, [x6], #4
|
||||
ldr s1, [x7], #4
|
||||
ldr s2, [x0], #4
|
||||
|
@ -103,7 +103,7 @@ mov x8, x4
|
|||
fadd s2, s2, s0
|
||||
str s2, [x9], #4
|
||||
subs x8, x8, #1
|
||||
bne LoopInputDepth0
|
||||
bne LoopDepth0
|
||||
|
||||
Loop16LineEnd:
|
||||
|
||||
|
|
|
@ -90,36 +90,36 @@ ConvDwInt8Center:
|
|||
LoopKw16:
|
||||
mov x22, x21
|
||||
ld1 {v25.4h}, [x17], #8
|
||||
ld1 {v16.4h}, [x22], x13
|
||||
ld1 {v17.4h}, [x22], x13
|
||||
ld1 {v16.4h}, [x22], x11
|
||||
ld1 {v17.4h}, [x22], x11
|
||||
smlal v0.4s, v16.4h, v25.4h
|
||||
smlal v1.4s, v17.4h, v25.4h
|
||||
ld1 {v18.4h}, [x22], x13
|
||||
ld1 {v19.4h}, [x22], x13
|
||||
ld1 {v18.4h}, [x22], x11
|
||||
ld1 {v19.4h}, [x22], x11
|
||||
smlal v2.4s, v18.4h, v25.4h
|
||||
smlal v3.4s, v19.4h, v25.4h
|
||||
ld1 {v20.4h}, [x22], x13
|
||||
ld1 {v21.4h}, [x22], x13
|
||||
ld1 {v20.4h}, [x22], x11
|
||||
ld1 {v21.4h}, [x22], x11
|
||||
smlal v4.4s, v20.4h, v25.4h
|
||||
smlal v5.4s, v21.4h, v25.4h
|
||||
ld1 {v22.4h}, [x22], x13
|
||||
ld1 {v23.4h}, [x22], x13
|
||||
ld1 {v22.4h}, [x22], x11
|
||||
ld1 {v23.4h}, [x22], x11
|
||||
smlal v6.4s, v22.4h, v25.4h
|
||||
smlal v7.4s, v23.4h, v25.4h
|
||||
ld1 {v16.4h}, [x22], x13
|
||||
ld1 {v17.4h}, [x22], x13
|
||||
ld1 {v16.4h}, [x22], x11
|
||||
ld1 {v17.4h}, [x22], x11
|
||||
smlal v8.4s, v16.4h, v25.4h
|
||||
smlal v9.4s, v17.4h, v25.4h
|
||||
ld1 {v18.4h}, [x22], x13
|
||||
ld1 {v19.4h}, [x22], x13
|
||||
ld1 {v18.4h}, [x22], x11
|
||||
ld1 {v19.4h}, [x22], x11
|
||||
smlal v10.4s, v18.4h, v25.4h
|
||||
smlal v11.4s, v19.4h, v25.4h
|
||||
ld1 {v20.4h}, [x22], x13
|
||||
ld1 {v21.4h}, [x22], x13
|
||||
ld1 {v20.4h}, [x22], x11
|
||||
ld1 {v21.4h}, [x22], x11
|
||||
smlal v12.4s, v20.4h, v25.4h
|
||||
smlal v13.4s, v21.4h, v25.4h
|
||||
ld1 {v22.4h}, [x22], x13
|
||||
ld1 {v23.4h}, [x22], x13
|
||||
ld1 {v22.4h}, [x22], x11
|
||||
ld1 {v23.4h}, [x22], x11
|
||||
smlal v14.4s, v22.4h, v25.4h
|
||||
smlal v15.4s, v23.4h, v25.4h
|
||||
subs x18, x18, #1
|
||||
|
@ -420,20 +420,20 @@ ConvDwInt8Center:
|
|||
LoopKw8:
|
||||
mov x22, x21
|
||||
ld1 {v25.4h}, [x17], #8
|
||||
ld1 {v16.4h}, [x22], x13
|
||||
ld1 {v17.4h}, [x22], x13
|
||||
ld1 {v16.4h}, [x22], x11
|
||||
ld1 {v17.4h}, [x22], x11
|
||||
smlal v0.4s, v16.4h, v25.4h
|
||||
smlal v1.4s, v17.4h, v25.4h
|
||||
ld1 {v18.4h}, [x22], x13
|
||||
ld1 {v19.4h}, [x22], x13
|
||||
ld1 {v18.4h}, [x22], x11
|
||||
ld1 {v19.4h}, [x22], x11
|
||||
smlal v2.4s, v18.4h, v25.4h
|
||||
smlal v3.4s, v19.4h, v25.4h
|
||||
ld1 {v20.4h}, [x22], x13
|
||||
ld1 {v21.4h}, [x22], x13
|
||||
ld1 {v20.4h}, [x22], x11
|
||||
ld1 {v21.4h}, [x22], x11
|
||||
smlal v4.4s, v20.4h, v25.4h
|
||||
smlal v5.4s, v21.4h, v25.4h
|
||||
ld1 {v22.4h}, [x22], x13
|
||||
ld1 {v23.4h}, [x22], x13
|
||||
ld1 {v22.4h}, [x22], x11
|
||||
ld1 {v23.4h}, [x22], x11
|
||||
smlal v6.4s, v22.4h, v25.4h
|
||||
smlal v7.4s, v23.4h, v25.4h
|
||||
subs x18, x18, #1
|
||||
|
|
|
@ -0,0 +1,169 @@
|
|||
#ifdef __aarch64__
|
||||
|
||||
.text
|
||||
.align 5
|
||||
.global ConvDwInt8PostAlign4
|
||||
#ifndef __APPLE__
|
||||
.type ConvDwInt8PostAlign4, %function
|
||||
#endif
|
||||
|
||||
// void ConvDwInt8PostAlign4(int8_t *dst, int32_t *buffer, int num_pixels, int32_t output_zp, int32_t out_multiplier,
|
||||
// int32_t left_shift, int32_t right_shift, int32_t acc_min, int32_t acc_max);
|
||||
// x0: dst, x1: buffer, x2: num_pixels, x3: output_zp, x4: out_multiplier,
|
||||
// x5: left_shift, x6: right_shift, x7: acc_min, x8: acc_max
|
||||
|
||||
ConvDwInt8PostAlign4:
|
||||
// registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to
|
||||
// https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers
|
||||
// x19 ~ x29 should be also preserved
|
||||
// whereas our coding style do not permit such amount of parameters
|
||||
ldr x8, [sp]
|
||||
|
||||
dup v26.4s, w5
|
||||
dup v27.4s, w4
|
||||
dup v28.4s, w6
|
||||
|
||||
dup v29.4s, w3
|
||||
dup v30.4s, w7
|
||||
dup v31.4s, w8
|
||||
|
||||
cmp x2, 16
|
||||
blt LoopDepth8
|
||||
|
||||
LoopDepth16:
|
||||
ld1 {v0.4s}, [x1], #16
|
||||
ld1 {v1.4s}, [x1], #16
|
||||
ld1 {v2.4s}, [x1], #16
|
||||
ld1 {v3.4s}, [x1], #16
|
||||
|
||||
sqshl v0.4s, v0.4s, v26.4s
|
||||
sqshl v1.4s, v1.4s, v26.4s
|
||||
sqshl v2.4s, v2.4s, v26.4s
|
||||
sqshl v3.4s, v3.4s, v26.4s
|
||||
|
||||
sqrdmulh v0.4s, v0.4s, v27.4s
|
||||
sqrdmulh v1.4s, v1.4s, v27.4s
|
||||
sqrdmulh v2.4s, v2.4s, v27.4s
|
||||
sqrdmulh v3.4s, v3.4s, v27.4s
|
||||
|
||||
and v16.16b, v28.16b, v0.16b
|
||||
sshr v16.4s, v16.4s, #31
|
||||
sqadd v0.4s, v0.4s, v16.4s
|
||||
srshl v0.4s, v0.4s, v28.4s
|
||||
and v17.16b, v28.16b, v1.16b
|
||||
sshr v17.4s, v17.4s, #31
|
||||
sqadd v1.4s, v1.4s, v17.4s
|
||||
srshl v1.4s, v1.4s, v28.4s
|
||||
and v18.16b, v28.16b, v2.16b
|
||||
sshr v18.4s, v18.4s, #31
|
||||
sqadd v2.4s, v2.4s, v18.4s
|
||||
srshl v2.4s, v2.4s, v28.4s
|
||||
and v19.16b, v28.16b, v3.16b
|
||||
sshr v19.4s, v19.4s, #31
|
||||
sqadd v3.4s, v3.4s, v19.4s
|
||||
srshl v3.4s, v3.4s, v28.4s
|
||||
|
||||
add v0.4s, v0.4s, v29.4s
|
||||
add v1.4s, v1.4s, v29.4s
|
||||
add v2.4s, v2.4s, v29.4s
|
||||
add v3.4s, v3.4s, v29.4s
|
||||
|
||||
smax v0.4s, v0.4s, v30.4s
|
||||
smax v1.4s, v1.4s, v30.4s
|
||||
smax v2.4s, v2.4s, v30.4s
|
||||
smax v3.4s, v3.4s, v30.4s
|
||||
|
||||
smin v0.4s, v0.4s, v31.4s
|
||||
smin v1.4s, v1.4s, v31.4s
|
||||
smin v2.4s, v2.4s, v31.4s
|
||||
smin v3.4s, v3.4s, v31.4s
|
||||
|
||||
sqxtn v0.4h, v0.4s
|
||||
sqxtn v1.4h, v1.4s
|
||||
sqxtn v2.4h, v2.4s
|
||||
sqxtn v3.4h, v3.4s
|
||||
|
||||
sqxtn v0.8b, v0.8h
|
||||
sqxtn v1.8b, v1.8h
|
||||
sqxtn v2.8b, v2.8h
|
||||
sqxtn v3.8b, v3.8h
|
||||
|
||||
st1 {v0.s}[0], [x0], #4
|
||||
st1 {v1.s}[0], [x0], #4
|
||||
st1 {v2.s}[0], [x0], #4
|
||||
st1 {v3.s}[0], [x0], #4
|
||||
|
||||
sub x2, x2, #16
|
||||
cmp x2, #16
|
||||
bge LoopDepth16
|
||||
|
||||
LoopDepth8:
|
||||
cmp x2, #8
|
||||
blt LoopDepth4
|
||||
ld1 {v0.4s}, [x1], #16
|
||||
ld1 {v1.4s}, [x1], #16
|
||||
|
||||
sqshl v0.4s, v0.4s, v26.4s
|
||||
sqshl v1.4s, v1.4s, v26.4s
|
||||
|
||||
sqrdmulh v0.4s, v0.4s, v27.4s
|
||||
sqrdmulh v1.4s, v1.4s, v27.4s
|
||||
|
||||
and v16.16b, v28.16b, v0.16b
|
||||
sshr v16.4s, v16.4s, #31
|
||||
sqadd v0.4s, v0.4s, v16.4s
|
||||
srshl v0.4s, v0.4s, v28.4s
|
||||
and v17.16b, v28.16b, v1.16b
|
||||
sshr v17.4s, v17.4s, #31
|
||||
sqadd v1.4s, v1.4s, v17.4s
|
||||
srshl v1.4s, v1.4s, v28.4s
|
||||
|
||||
add v0.4s, v0.4s, v29.4s
|
||||
add v1.4s, v1.4s, v29.4s
|
||||
|
||||
smax v0.4s, v0.4s, v30.4s
|
||||
smax v1.4s, v1.4s, v30.4s
|
||||
|
||||
smin v0.4s, v0.4s, v31.4s
|
||||
smin v1.4s, v1.4s, v31.4s
|
||||
|
||||
sqxtn v0.4h, v0.4s
|
||||
sqxtn v1.4h, v1.4s
|
||||
|
||||
sqxtn v0.8b, v0.8h
|
||||
sqxtn v1.8b, v1.8h
|
||||
|
||||
st1 {v0.s}[0], [x0], #4
|
||||
st1 {v1.s}[0], [x0], #4
|
||||
|
||||
sub x2, x2, #8
|
||||
cmp x2, #8
|
||||
bge LoopDepth8
|
||||
|
||||
LoopDepth4:
|
||||
cmp x2, #4
|
||||
blt End
|
||||
ld1 {v0.4s}, [x1], #16
|
||||
|
||||
sqshl v0.4s, v0.4s, v26.4s
|
||||
sqrdmulh v0.4s, v0.4s, v27.4s
|
||||
|
||||
and v16.16b, v28.16b, v0.16b
|
||||
sshr v16.4s, v16.4s, #31
|
||||
sqadd v0.4s, v0.4s, v16.4s
|
||||
srshl v0.4s, v0.4s, v28.4s
|
||||
|
||||
add v0.4s, v0.4s, v29.4s
|
||||
smax v0.4s, v0.4s, v30.4s
|
||||
smin v0.4s, v0.4s, v31.4s
|
||||
|
||||
sqxtn v0.4h, v0.4s
|
||||
sqxtn v0.8b, v0.8h
|
||||
|
||||
st1 {v0.s}[0], [x0], #4
|
||||
|
||||
sub x2, x2, #4
|
||||
bge LoopDepth4
|
||||
End:
|
||||
ret
|
||||
#endif
|
|
@ -0,0 +1,122 @@
|
|||
#ifdef __aarch64__
|
||||
|
||||
.text
|
||||
.align 5
|
||||
.global ConvDwInt8Row
|
||||
#ifndef __APPLE__
|
||||
.type ConvDwInt8Row, %function
|
||||
#endif
|
||||
|
||||
// void ConvDwInt8Row(int32_t *output_ptr, const int8_t *input_ptr, const int16_t *weight_ptr, int num_pixels,
|
||||
// int output_channel, int input_step, int8_t input_zp)
|
||||
// x0: output_ptr, x1: input_ptr, x2: weight_ptr, x3: num_pixels,
|
||||
// x4: output_channel, x5: input_step, x6: input_zp
|
||||
//
|
||||
ConvDwInt8Row:
|
||||
// registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to
|
||||
// https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers
|
||||
// x19 ~ x29 should be also preserved
|
||||
// whereas our coding style do not permit such amount of parameters
|
||||
cmp x3, #0
|
||||
beq End
|
||||
|
||||
mov x10, x0
|
||||
|
||||
dup v31.8b, w6
|
||||
|
||||
LoopOutPixel:
|
||||
mov x7, x1
|
||||
mov x8, x2
|
||||
mov x9, x4
|
||||
|
||||
LoopDepth16In:
|
||||
cmp x9, #16
|
||||
blt L8
|
||||
sub x9, x9, #16
|
||||
|
||||
ld1 {v0.8b, v1.8b}, [x7], #16
|
||||
ld1 {v2.8h, v3.8h}, [x8], #32
|
||||
ld1 {v16.4s, v17.4s}, [x0], #32
|
||||
|
||||
ssubl v20.8h, v0.8b, v31.8b
|
||||
smlal v16.4s, v20.4h, v2.4h
|
||||
smlal2 v17.4s, v20.8h, v2.8h
|
||||
|
||||
|
||||
cmp x9, #16
|
||||
blt LoopDepth16Out
|
||||
LoopDepth16:
|
||||
|
||||
st1 {v16.4s, v17.4s}, [x10], #32
|
||||
ld1 {v18.4s, v19.4s}, [x0], #32
|
||||
ssubl v21.8h, v1.8b, v31.8b
|
||||
smlal v18.4s, v21.4h, v3.4h
|
||||
smlal2 v19.4s, v21.8h, v3.8h
|
||||
st1 {v18.4s, v19.4s}, [x10], #32
|
||||
|
||||
ld1 {v0.8b, v1.8b}, [x7], #16
|
||||
ld1 {v2.8h, v3.8h}, [x8], #32
|
||||
ld1 {v16.4s, v17.4s}, [x0], #32
|
||||
|
||||
ssubl v20.8h, v0.8b, v31.8b
|
||||
smlal v16.4s, v20.4h, v2.4h
|
||||
smlal2 v17.4s, v20.8h, v2.8h
|
||||
|
||||
sub x9, x9, #16
|
||||
cmp x9, #16
|
||||
bge LoopDepth16
|
||||
|
||||
LoopDepth16Out:
|
||||
|
||||
st1 {v16.4s, v17.4s}, [x10], #32
|
||||
ld1 {v18.4s, v19.4s}, [x0], #32
|
||||
ssubl v21.8h, v1.8b, v31.8b
|
||||
smlal v18.4s, v21.4h, v3.4h
|
||||
smlal2 v19.4s, v21.8h, v3.8h
|
||||
st1 {v18.4s, v19.4s}, [x10], #32
|
||||
|
||||
L8:
|
||||
cmp x9, #8
|
||||
blt L0
|
||||
|
||||
LoopDepth8:
|
||||
ld1 {v0.8b}, [x7], #8
|
||||
ld1 {v2.8h}, [x8], #16
|
||||
ld1 {v16.4s, v17.4s}, [x0], #32
|
||||
|
||||
ssubl v20.8h, v0.8b, v31.8b
|
||||
smlal v16.4s, v20.4h, v2.4h
|
||||
smlal2 v17.4s, v20.8h, v2.8h
|
||||
st1 {v16.4s, v17.4s}, [x10], #32
|
||||
|
||||
sub x9, x9, #8
|
||||
cmp x9, #8
|
||||
bge LoopDepth8
|
||||
|
||||
L0:
|
||||
cmp x9, #0
|
||||
beq Loop16LineEnd
|
||||
|
||||
LoopDepth0:
|
||||
ldrsb w14, [x7], #1
|
||||
ldrsh w15, [x8], #2
|
||||
ldr w16, [x0], #4
|
||||
add w14, w14, w6
|
||||
|
||||
sxth w14, w14
|
||||
madd w14, w14, w15, w16
|
||||
str w14, [x10], #4
|
||||
|
||||
subs x9, x9, #1
|
||||
bne LoopDepth0
|
||||
|
||||
Loop16LineEnd:
|
||||
|
||||
subs x3, x3, #1
|
||||
add x1, x1, x5
|
||||
bne LoopOutPixel
|
||||
|
||||
End:
|
||||
ret
|
||||
|
||||
#endif
|
|
@ -0,0 +1,812 @@
|
|||
#ifdef __aarch64__
|
||||
.text
|
||||
.align 5
|
||||
.global MatmulFloatNeon64Opt
|
||||
#ifndef __APPLE__
|
||||
.type MatmulFloatNeon64Opt, %function
|
||||
#endif
|
||||
|
||||
// A: LM [row_8 * depth] col_8_major
|
||||
// B: RM [depth * col_8] row_8_major
|
||||
// C: A*B [row_8 * col_8] col_8x8_major
|
||||
// A * B -> [8 * depth] * [depth * 8] -> [8 * 4] * [4 * 8] or [8 * 1] * [1 * 8]
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
//CommLoopMul RM 1x8 block
|
||||
// /-----------------------------------------\
|
||||
// |v2.s[0] ... v2.s[3] v3.s[0] ... v3.s[3]|
|
||||
// \-----------------------------------------/
|
||||
// LM 8x1 block
|
||||
// /---------------------\ /-----------------------------------------\
|
||||
// | v0.s[0] | |v16.s[0]...v16.s[3] v17.s[0]...v17.s[3]|
|
||||
// | ... | | ... ... |
|
||||
// | v0.s[3] | |v22.s[0]...v22.s[3] v23.s[0]...v23.s[3]|
|
||||
// | v1.s[0] | |v24.s[0]...v24.s[3] v25.s[0]...v25.s[3]|
|
||||
// | ... | | ... ... |
|
||||
// | v1.s[3] | |v30.s[0]...v30.s[3] v31.s[0]...v31.s[3]|
|
||||
// \---------------------/ \-----------------------------------------/
|
||||
// accumulators 8x8 block
|
||||
//
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
//OptLoopMul4 RM 4x8 block
|
||||
// /--------------------------------------------\
|
||||
// |v8.s[0] ... v8.s[3] v9.s[0] ... v9.s[3] |
|
||||
// |v10.s[0] ... v10.s[3] v11.s[0] ... v11.s[3]|
|
||||
// |v12.s[0] ... v12.s[3] v13.s[0] ... v13.s[3]|
|
||||
// |v14.s[0] ... v14.s[3] v15.s[0] ... v15.s[3]|
|
||||
// \--------------------------------------------/
|
||||
// LM 8x4 block
|
||||
// /---------------------------------\ /--------------------------------------------\
|
||||
// | v0.s[0] v2.s[0] v4.s[0] v6.s[0] | |v16.s[0]...v16.s[3] v17.s[0]...v17.s[3] |
|
||||
// | ... ... ... ... | | ... ... |
|
||||
// | v0.s[3] v2.s[3] v4.s[3] v6.s[3] | |v22.s[0]...v22.s[3] v23.s[0]...v23.s[3] |
|
||||
// | v1.s[0] v3.s[0] v5.s[0] v7.s[0] | |v24.s[0]...v24.s[3] v25.s[0]...v25.s[3] |
|
||||
// | ... ... ... ... | | ... ... |
|
||||
// | v1.s[3] v3.s[3] v5.s[3] v7.s[3] | |v30.s[0]...v30.s[3] v31.s[0]...v31.s[3] |
|
||||
// \---------------------------------/ \--------------------------------------------/
|
||||
// accumulators 8x8 block
|
||||
/////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// void MatmulFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth
|
||||
// int row, int col, size_t stride, size_t writeNhwc, size_t WriteWino)
|
||||
// x0: a
|
||||
// x1: b
|
||||
// x2: c
|
||||
// x3: bias
|
||||
// w4: act_type
|
||||
// w5: depth
|
||||
// w6: row
|
||||
// w7: col
|
||||
// w17: stride
|
||||
// w13: c8_nhwc_c4
|
||||
|
||||
MatmulFloatNeon64Opt:
|
||||
sub sp, sp, #128
|
||||
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
|
||||
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
|
||||
|
||||
ldr x9, [sp, #8]
|
||||
ldr x14, [sp, #16]
|
||||
|
||||
mov w18, #32 // sizeof(float) * 8
|
||||
mul w15, w5, w18 // block stride of lhs/rhs: sizeof(float) * 8 * depth
|
||||
mov x18, #4
|
||||
ldr x17, [sp]
|
||||
cbz x14, NoWinoSteps
|
||||
mul x8, x7, x17
|
||||
mov x11, #8
|
||||
mul x11, x11, x17
|
||||
mul x8, x8, x18
|
||||
mul x11, x11, x18
|
||||
NoWinoSteps:
|
||||
mul x17, x17, x18
|
||||
|
||||
L1:
|
||||
mov w10, w6 // reload lhs row
|
||||
mov x12, x0 // reload lhs ptr
|
||||
mov x18, x2 // reload dst ptr
|
||||
|
||||
L2:
|
||||
mov x16, x1 // reload rhs ptr
|
||||
mov w13, w5 // reload depth
|
||||
dup v8.4s, wzr
|
||||
dup v9.4s, wzr
|
||||
dup v10.4s, wzr
|
||||
dup v11.4s, wzr
|
||||
dup v12.4s, wzr
|
||||
dup v13.4s, wzr
|
||||
dup v14.4s, wzr
|
||||
dup v15.4s, wzr
|
||||
dup v16.4s, wzr
|
||||
dup v17.4s, wzr
|
||||
dup v18.4s, wzr
|
||||
dup v19.4s, wzr
|
||||
dup v20.4s, wzr
|
||||
dup v21.4s, wzr
|
||||
dup v22.4s, wzr
|
||||
dup v23.4s, wzr
|
||||
dup v24.4s, wzr
|
||||
dup v25.4s, wzr
|
||||
dup v26.4s, wzr
|
||||
dup v27.4s, wzr
|
||||
dup v28.4s, wzr
|
||||
dup v29.4s, wzr
|
||||
dup v30.4s, wzr
|
||||
dup v31.4s, wzr
|
||||
|
||||
LoopStart:
|
||||
ld1 {v0.4s, v1.4s, v2.4s}, [x12], #48
|
||||
ld1 {v3.4s, v4.4s}, [x16], #32
|
||||
fmla v8.4s, v3.4s, v0.s[0]
|
||||
fmla v10.4s, v3.4s, v0.s[1]
|
||||
fmla v12.4s, v3.4s, v0.s[2]
|
||||
fmla v14.4s, v3.4s, v0.s[3]
|
||||
fmla v9.4s, v4.4s, v0.s[0]
|
||||
fmla v11.4s, v4.4s, v0.s[1]
|
||||
fmla v13.4s, v4.4s, v0.s[2]
|
||||
fmla v15.4s, v4.4s, v0.s[3]
|
||||
|
||||
subs w13, w13, #1
|
||||
beq LoopEnd
|
||||
|
||||
Loop:
|
||||
ld1 {v0.4s}, [x12], #16
|
||||
fmla v16.4s, v3.4s, v1.s[0]
|
||||
fmla v18.4s, v3.4s, v1.s[1]
|
||||
fmla v20.4s, v3.4s, v1.s[2]
|
||||
fmla v22.4s, v3.4s, v1.s[3]
|
||||
fmla v17.4s, v4.4s, v1.s[0]
|
||||
fmla v19.4s, v4.4s, v1.s[1]
|
||||
fmla v21.4s, v4.4s, v1.s[2]
|
||||
fmla v23.4s, v4.4s, v1.s[3]
|
||||
ld1 {v1.4s}, [x12], #16
|
||||
fmla v24.4s, v3.4s, v2.s[0]
|
||||
fmla v26.4s, v3.4s, v2.s[1]
|
||||
fmla v28.4s, v3.4s, v2.s[2]
|
||||
fmla v30.4s, v3.4s, v2.s[3]
|
||||
ld1 {v3.4s}, [x16], #16
|
||||
fmla v25.4s, v4.4s, v2.s[0]
|
||||
fmla v27.4s, v4.4s, v2.s[1]
|
||||
fmla v29.4s, v4.4s, v2.s[2]
|
||||
fmla v31.4s, v4.4s, v2.s[3]
|
||||
ld1 {v4.4s}, [x16], #16
|
||||
fmla v8.4s, v3.4s, v0.s[0]
|
||||
fmla v10.4s, v3.4s, v0.s[1]
|
||||
fmla v12.4s, v3.4s, v0.s[2]
|
||||
fmla v14.4s, v3.4s, v0.s[3]
|
||||
ld1 {v2.4s}, [x12], #16
|
||||
fmla v9.4s, v4.4s, v0.s[0]
|
||||
fmla v11.4s, v4.4s, v0.s[1]
|
||||
fmla v13.4s, v4.4s, v0.s[2]
|
||||
fmla v15.4s, v4.4s, v0.s[3]
|
||||
|
||||
subs w13, w13, #1
|
||||
bgt Loop
|
||||
|
||||
LoopEnd:
|
||||
fmla v16.4s, v3.4s, v1.s[0]
|
||||
fmla v18.4s, v3.4s, v1.s[1]
|
||||
fmla v20.4s, v3.4s, v1.s[2]
|
||||
fmla v22.4s, v3.4s, v1.s[3]
|
||||
fmla v17.4s, v4.4s, v1.s[0]
|
||||
fmla v19.4s, v4.4s, v1.s[1]
|
||||
fmla v21.4s, v4.4s, v1.s[2]
|
||||
fmla v23.4s, v4.4s, v1.s[3]
|
||||
fmla v24.4s, v3.4s, v2.s[0]
|
||||
fmla v26.4s, v3.4s, v2.s[1]
|
||||
fmla v28.4s, v3.4s, v2.s[2]
|
||||
fmla v30.4s, v3.4s, v2.s[3]
|
||||
fmla v25.4s, v4.4s, v2.s[0]
|
||||
fmla v27.4s, v4.4s, v2.s[1]
|
||||
fmla v29.4s, v4.4s, v2.s[2]
|
||||
fmla v31.4s, v4.4s, v2.s[3]
|
||||
|
||||
Bias:
|
||||
cbz x3, Activation
|
||||
ld1 {v0.4s}, [x3], #16
|
||||
ld1 {v1.4s}, [x3]
|
||||
sub x3, x3, #16
|
||||
fadd v8.4s, v8.4s, v0.4s
|
||||
fadd v9.4s, v9.4s, v1.4s
|
||||
fadd v10.4s, v10.4s, v0.4s
|
||||
fadd v11.4s, v11.4s, v1.4s
|
||||
fadd v12.4s, v12.4s, v0.4s
|
||||
fadd v13.4s, v13.4s, v1.4s
|
||||
fadd v14.4s, v14.4s, v0.4s
|
||||
fadd v15.4s, v15.4s, v1.4s
|
||||
fadd v16.4s, v16.4s, v0.4s
|
||||
fadd v17.4s, v17.4s, v1.4s
|
||||
fadd v18.4s, v18.4s, v0.4s
|
||||
fadd v19.4s, v19.4s, v1.4s
|
||||
fadd v20.4s, v20.4s, v0.4s
|
||||
fadd v21.4s, v21.4s, v1.4s
|
||||
fadd v22.4s, v22.4s, v0.4s
|
||||
fadd v23.4s, v23.4s, v1.4s
|
||||
fadd v24.4s, v24.4s, v0.4s
|
||||
fadd v25.4s, v25.4s, v1.4s
|
||||
fadd v26.4s, v26.4s, v0.4s
|
||||
fadd v27.4s, v27.4s, v1.4s
|
||||
fadd v28.4s, v28.4s, v0.4s
|
||||
fadd v29.4s, v29.4s, v1.4s
|
||||
fadd v30.4s, v30.4s, v0.4s
|
||||
fadd v31.4s, v31.4s, v1.4s
|
||||
|
||||
Activation:
|
||||
cmp w4, #2
|
||||
beq Relu6
|
||||
cmp w4, #1
|
||||
beq Relu
|
||||
b Write
|
||||
|
||||
Relu6:
|
||||
mov w13, #6
|
||||
dup v2.4s, w13
|
||||
scvtf v2.4s, v2.4s
|
||||
fmin v8.4s, v8.4s, v2.4s
|
||||
fmin v9.4s, v9.4s, v2.4s
|
||||
fmin v10.4s, v10.4s, v2.4s
|
||||
fmin v11.4s, v11.4s, v2.4s
|
||||
fmin v12.4s, v12.4s, v2.4s
|
||||
fmin v13.4s, v13.4s, v2.4s
|
||||
fmin v14.4s, v14.4s, v2.4s
|
||||
fmin v15.4s, v15.4s, v2.4s
|
||||
fmin v16.4s, v16.4s, v2.4s
|
||||
fmin v17.4s, v17.4s, v2.4s
|
||||
fmin v18.4s, v18.4s, v2.4s
|
||||
fmin v19.4s, v19.4s, v2.4s
|
||||
fmin v20.4s, v20.4s, v2.4s
|
||||
fmin v21.4s, v21.4s, v2.4s
|
||||
fmin v22.4s, v22.4s, v2.4s
|
||||
fmin v23.4s, v23.4s, v2.4s
|
||||
fmin v24.4s, v24.4s, v2.4s
|
||||
fmin v25.4s, v25.4s, v2.4s
|
||||
fmin v26.4s, v26.4s, v2.4s
|
||||
fmin v27.4s, v27.4s, v2.4s
|
||||
fmin v28.4s, v28.4s, v2.4s
|
||||
fmin v29.4s, v29.4s, v2.4s
|
||||
fmin v30.4s, v30.4s, v2.4s
|
||||
fmin v31.4s, v31.4s, v2.4s
|
||||
|
||||
Relu:
|
||||
dup v3.4s, wzr
|
||||
fmax v8.4s, v8.4s, v3.4s
|
||||
fmax v9.4s, v9.4s, v3.4s
|
||||
fmax v10.4s, v10.4s, v3.4s
|
||||
fmax v11.4s, v11.4s, v3.4s
|
||||
fmax v12.4s, v12.4s, v3.4s
|
||||
fmax v13.4s, v13.4s, v3.4s
|
||||
fmax v14.4s, v14.4s, v3.4s
|
||||
fmax v15.4s, v15.4s, v3.4s
|
||||
fmax v16.4s, v16.4s, v3.4s
|
||||
fmax v17.4s, v17.4s, v3.4s
|
||||
fmax v18.4s, v18.4s, v3.4s
|
||||
fmax v19.4s, v19.4s, v3.4s
|
||||
fmax v20.4s, v20.4s, v3.4s
|
||||
fmax v21.4s, v21.4s, v3.4s
|
||||
fmax v22.4s, v22.4s, v3.4s
|
||||
fmax v23.4s, v23.4s, v3.4s
|
||||
fmax v24.4s, v24.4s, v3.4s
|
||||
fmax v25.4s, v25.4s, v3.4s
|
||||
fmax v26.4s, v26.4s, v3.4s
|
||||
fmax v27.4s, v27.4s, v3.4s
|
||||
fmax v28.4s, v28.4s, v3.4s
|
||||
fmax v29.4s, v29.4s, v3.4s
|
||||
fmax v30.4s, v30.4s, v3.4s
|
||||
fmax v31.4s, v31.4s, v3.4s
|
||||
|
||||
Write:
|
||||
cbnz x14, WriteWino
|
||||
cbz x9, WriteC8
|
||||
cmp w7, #1
|
||||
beq Write1
|
||||
cmp w7, #2
|
||||
beq Write2
|
||||
cmp w7, #3
|
||||
beq Write3
|
||||
cmp w7, #4
|
||||
beq Write4
|
||||
cmp w7, #5
|
||||
beq Write5
|
||||
cmp w7, #6
|
||||
beq Write6
|
||||
cmp w7, #7
|
||||
beq Write7
|
||||
b Write8
|
||||
|
||||
Write1:
|
||||
str s8, [x18]
|
||||
cmp w10, #1
|
||||
beq WriteEnd
|
||||
add x18, x18, x17
|
||||
str s10, [x18]
|
||||
cmp w10, #2
|
||||
beq WriteEnd
|
||||
add x18, x18, x17
|
||||
str s12, [x18]
|
||||
cmp w10, #3
|
||||
beq WriteEnd
|
||||
add x18, x18, x17
|
||||
str s14, [x18]
|
||||
cmp w10, #4
|
||||
beq WriteEnd
|
||||
add x18, x18, x17
|
||||
str s16, [x18]
|
||||
cmp w10, #5
|
||||
beq WriteEnd
|
||||
add x18, x18, x17
|
||||
str s18, [x18]
|
||||
cmp w10, #6
|
||||
beq WriteEnd
|
||||
add x18, x18, x17
|
||||
str s20, [x18]
|
||||
cmp w10, #7
|
||||
beq WriteEnd
|
||||
add x18, x18, x17
|
||||
str s22, [x18]
|
||||
cmp w10, #8
|
||||
beq WriteEnd
|
||||
add x18, x18, x17
|
||||
str s24, [x18]
|
||||
cmp w10, #9
|
||||
beq WriteEnd
|
||||
add x18, x18, x17
|
||||
str s26, [x18]
|
||||
cmp w10, #10
|
||||
beq WriteEnd
|
||||
add x18, x18, x17
|
||||
str s28, [x18]
|
||||
cmp w10, #11
|
||||
beq WriteEnd
|
||||
add x18, x18, x17
|
||||
str s30, [x18]
|
||||
add x18, x18, x17
|
||||
b WriteEnd
|
||||
Write2:
|
||||
dup s9, v8.s[1]
|
||||
stp s8, s9, [x18]
|
||||
cmp w10, #1
|
||||
beq WriteEnd
|
||||
add x18, x18, x17
|
||||
dup s11, v10.s[1]
|
||||
stp s10, s11, [x18]
|
||||
cmp w10, #2
|
||||
beq WriteEnd
|
||||
add x18, x18, x17
|
||||
dup s13, v12.s[1]
|
||||
stp s12, s13, [x18]
|
||||
cmp w10, #3
|
||||
beq WriteEnd
|
||||
add x18, x18, x17
|
||||
dup s15, v14.s[1]
|
||||
stp s14, s15, [x18]
|
||||
cmp w10, #4
|
||||
beq WriteEnd
|
||||
add x18, x18, x17
|
||||
dup s17, v16.s[1]
|
||||
stp s16, s17, [x18]
|
||||
cmp w10, #5
|
||||
beq WriteEnd
|
||||
add x18, x18, x17
|
||||
dup s19, v18.s[1]
|
||||
stp s18, s19, [x18]
|
||||
cmp w10, #6
|
||||
beq WriteEnd
|
||||
add x18, x18, x17
|
||||
dup s21, v20.s[1]
|
||||
stp s20, s21, [x18]
|
||||
cmp w10, #7
|
||||
beq WriteEnd
|
||||
add x18, x18, x17
|
||||
dup s23, v22.s[1]
|
||||
stp s22, s23, [x18]
|
||||
cmp w10, #8
|
||||
beq WriteEnd
|
||||
add x18, x18, x17
|
||||
dup s25, v24.s[1]
|
||||
stp s24, s25, [x18]
|
||||
cmp w10, #9
|
||||
beq WriteEnd
|
||||
add x18, x18, x17
|
||||
dup s27, v26.s[1]
|
||||
stp s26, s27, [x18]
|
||||
cmp w10, #10
|
||||
beq WriteEnd
|
||||
add x18, x18, x17
|
||||
dup s29, v28.s[1]
|
||||
stp s28, s29, [x18]
|
||||
cmp w10, #11
|
||||
beq WriteEnd
|
||||
add x18, x18, x17
|
||||
dup s31, v30.s[1]
|
||||
stp s30, s31, [x18]
|
||||
add x18, x18, x17
|
||||
b WriteEnd
|
||||
Write3:
|
||||
add x13, x18, #8
|
||||
dup s9, v8.s[1]
|
||||
stp s8, s9, [x18]
|
||||
add x18, x18, x17
|
||||
st1 {v8.s}[2], [x13], x17
|
||||
cmp w10, #1
|
||||
beq WriteEnd
|
||||
dup s11, v10.s[1]
|
||||
stp s10, s11, [x18]
|
||||
add x18, x18, x17
|
||||
st1 {v10.s}[2], [x13], x17
|
||||
cmp w10, #2
|
||||
beq WriteEnd
|
||||
dup s13, v12.s[1]
|
||||
stp s12, s13, [x18]
|
||||
add x18, x18, x17
|
||||
st1 {v12.s}[2], [x13], x17
|
||||
cmp w10, #3
|
||||
beq WriteEnd
|
||||
dup s15, v14.s[1]
|
||||
stp s14, s15, [x18]
|
||||
add x18, x18, x17
|
||||
st1 {v14.s}[2], [x13], x17
|
||||
cmp w10, #4
|
||||
beq WriteEnd
|
||||
dup s17, v16.s[1]
|
||||
stp s16, s17, [x18]
|
||||
add x18, x18, x17
|
||||
st1 {v16.s}[2], [x13], x17
|
||||
cmp w10, #5
|
||||
beq WriteEnd
|
||||
dup s19, v18.s[1]
|
||||
stp s18, s19, [x18]
|
||||
add x18, x18, x17
|
||||
st1 {v18.s}[2], [x13], x17
|
||||
cmp w10, #6
|
||||
beq WriteEnd
|
||||
dup s21, v20.s[1]
|
||||
stp s20, s21, [x18]
|
||||
add x18, x18, x17
|
||||
st1 {v20.s}[2], [x13], x17
|
||||
cmp w10, #7
|
||||
beq WriteEnd
|
||||
dup s23, v22.s[1]
|
||||
stp s22, s23, [x18]
|
||||
add x18, x18, x17
|
||||
st1 {v22.s}[2], [x13], x17
|
||||
cmp w10, #8
|
||||
beq WriteEnd
|
||||
dup s25, v24.s[1]
|
||||
stp s24, s25, [x18]
|
||||
add x18, x18, x17
|
||||
st1 {v24.s}[2], [x13], x17
|
||||
cmp w10, #9
|
||||
beq WriteEnd
|
||||
dup s27, v26.s[1]
|
||||
stp s26, s27, [x18]
|
||||
add x18, x18, x17
|
||||
st1 {v26.s}[2], [x13], x17
|
||||
cmp w10, #10
|
||||
beq WriteEnd
|
||||
dup s29, v28.s[1]
|
||||
stp s28, s29, [x18]
|
||||
add x18, x18, x17
|
||||
st1 {v28.s}[2], [x13], x17
|
||||
cmp w10, #11
|
||||
beq WriteEnd
|
||||
dup s31, v30.s[1]
|
||||
stp s30, s31, [x18]
|
||||
add x18, x18, x17
|
||||
st1 {v30.s}[2], [x13]
|
||||
b WriteEnd
|
||||
Write4:
|
||||
st1 {v8.4s}, [x18], x17
|
||||
cmp w10, #1
|
||||
beq WriteEnd
|
||||
st1 {v10.4s}, [x18], x17
|
||||
cmp w10, #2
|
||||
beq WriteEnd
|
||||
st1 {v12.4s}, [x18], x17
|
||||
cmp w10, #3
|
||||
beq WriteEnd
|
||||
st1 {v14.4s}, [x18], x17
|
||||
cmp w10, #4
|
||||
beq WriteEnd
|
||||
st1 {v16.4s}, [x18], x17
|
||||
cmp w10, #5
|
||||
beq WriteEnd
|
||||
st1 {v18.4s}, [x18], x17
|
||||
cmp w10, #6
|
||||
beq WriteEnd
|
||||
st1 {v20.4s}, [x18], x17
|
||||
cmp w10, #7
|
||||
beq WriteEnd
|
||||
st1 {v22.4s}, [x18], x17
|
||||
cmp w10, #8
|
||||
beq WriteEnd
|
||||
st1 {v24.4s}, [x18], x17
|
||||
cmp w10, #9
|
||||
beq WriteEnd
|
||||
st1 {v26.4s}, [x18], x17
|
||||
cmp w10, #10
|
||||
beq WriteEnd
|
||||
st1 {v28.4s}, [x18], x17
|
||||
cmp w10, #11
|
||||
beq WriteEnd
|
||||
st1 {v30.4s}, [x18], x17
|
||||
b WriteEnd
|
||||
Write5:
|
||||
add x13, x18, #16
|
||||
st1 {v8.4s}, [x18], x17
|
||||
str s9, [x13]
|
||||
cmp w10, #1
|
||||
beq WriteEnd
|
||||
add x13, x13, x17
|
||||
st1 {v10.4s}, [x18], x17
|
||||
str s11, [x13]
|
||||
cmp w10, #2
|
||||
beq WriteEnd
|
||||
add x13, x13, x17
|
||||
st1 {v12.4s}, [x18], x17
|
||||
str s13, [x13]
|
||||
cmp w10, #3
|
||||
beq WriteEnd
|
||||
add x13, x13, x17
|
||||
st1 {v14.4s}, [x18], x17
|
||||
str s15, [x13]
|
||||
cmp w10, #4
|
||||
beq WriteEnd
|
||||
add x13, x13, x17
|
||||
st1 {v16.4s}, [x18], x17
|
||||
str s17, [x13]
|
||||
cmp w10, #5
|
||||
beq WriteEnd
|
||||
add x13, x13, x17
|
||||
st1 {v18.4s}, [x18], x17
|
||||
str s19, [x13]
|
||||
cmp w10, #6
|
||||
beq WriteEnd
|
||||
add x13, x13, x17
|
||||
st1 {v20.4s}, [x18], x17
|
||||
str s21, [x13]
|
||||
cmp w10, #7
|
||||
beq WriteEnd
|
||||
add x13, x13, x17
|
||||
st1 {v22.4s}, [x18], x17
|
||||
str s23, [x13]
|
||||
cmp w10, #8
|
||||
beq WriteEnd
|
||||
add x13, x13, x17
|
||||
st1 {v24.4s}, [x18], x17
|
||||
str s25, [x13]
|
||||
cmp w10, #9
|
||||
beq WriteEnd
|
||||
add x13, x13, x17
|
||||
st1 {v26.4s}, [x18], x17
|
||||
str s27, [x13]
|
||||
cmp w10, #10
|
||||
beq WriteEnd
|
||||
add x13, x13, x17
|
||||
st1 {v28.4s}, [x18], x17
|
||||
str s29, [x13]
|
||||
cmp w10, #11
|
||||
beq WriteEnd
|
||||
add x13, x13, x17
|
||||
st1 {v30.4s}, [x18], x17
|
||||
str s31, [x13]
|
||||
b WriteEnd
|
||||
Write6:
|
||||
add x13, x18, #16
|
||||
st1 {v8.4s}, [x18], x17
|
||||
dup s8, v9.s[1]
|
||||
stp s9, s8, [x13]
|
||||
cmp w10, #1
|
||||
beq WriteEnd
|
||||
add x13, x13, x17
|
||||
st1 {v10.4s}, [x18], x17
|
||||
dup s10, v11.s[1]
|
||||
stp s11, s10, [x13]
|
||||
cmp w10, #2
|
||||
beq WriteEnd
|
||||
add x13, x13, x17
|
||||
st1 {v12.4s}, [x18], x17
|
||||
dup s12, v13.s[1]
|
||||
stp s13, s12, [x13]
|
||||
cmp w10, #3
|
||||
beq WriteEnd
|
||||
add x13, x13, x17
|
||||
st1 {v14.4s}, [x18], x17
|
||||
dup s14, v15.s[1]
|
||||
stp s15, s14, [x13]
|
||||
cmp w10, #4
|
||||
beq WriteEnd
|
||||
add x13, x13, x17
|
||||
st1 {v16.4s}, [x18], x17
|
||||
dup s16, v17.s[1]
|
||||
stp s17, s16, [x13]
|
||||
cmp w10, #5
|
||||
beq WriteEnd
|
||||
add x13, x13, x17
|
||||
st1 {v18.4s}, [x18], x17
|
||||
dup s18, v19.s[1]
|
||||
stp s19, s18, [x13]
|
||||
cmp w10, #6
|
||||
beq WriteEnd
|
||||
add x13, x13, x17
|
||||
st1 {v20.4s}, [x18], x17
|
||||
dup s20, v21.s[1]
|
||||
stp s21, s20, [x13]
|
||||
cmp w10, #7
|
||||
beq WriteEnd
|
||||
add x13, x13, x17
|
||||
st1 {v22.4s}, [x18], x17
|
||||
dup s22, v23.s[1]
|
||||
stp s23, s22, [x13]
|
||||
cmp w10, #8
|
||||
beq WriteEnd
|
||||
add x13, x13, x17
|
||||
st1 {v24.4s}, [x18], x17
|
||||
dup s24, v25.s[1]
|
||||
stp s25, s24, [x13]
|
||||
cmp w10, #9
|
||||
beq WriteEnd
|
||||
add x13, x13, x17
|
||||
st1 {v26.4s}, [x18], x17
|
||||
dup s26, v27.s[1]
|
||||
stp s27, s26, [x13]
|
||||
cmp w10, #10
|
||||
beq WriteEnd
|
||||
add x13, x13, x17
|
||||
st1 {v28.4s}, [x18], x17
|
||||
dup s28, v29.s[1]
|
||||
stp s29, s28, [x13]
|
||||
cmp w10, #11
|
||||
beq WriteEnd
|
||||
add x13, x13, x17
|
||||
st1 {v30.4s}, [x18], x17
|
||||
dup s30, v31.s[1]
|
||||
stp s31, s30, [x13]
|
||||
b WriteEnd
|
||||
Write7:
|
||||
add x13, x18, #16
|
||||
add x16, x18, #24
|
||||
st1 {v8.4s}, [x18], x17
|
||||
dup s8, v9.s[1]
|
||||
stp s9, s8, [x13]
|
||||
add x13, x13, x17
|
||||
st1 {v9.s}[2], [x16], x17
|
||||
cmp w10, #1
|
||||
beq WriteEnd
|
||||
st1 {v10.4s}, [x18], x17
|
||||
dup s10, v11.s[1]
|
||||
stp s11, s10, [x13]
|
||||
add x13, x13, x17
|
||||
st1 {v11.s}[2], [x16], x17
|
||||
cmp w10, #2
|
||||
beq WriteEnd
|
||||
st1 {v12.4s}, [x18], x17
|
||||
dup s12, v13.s[1]
|
||||
stp s13, s12, [x13]
|
||||
add x13, x13, x17
|
||||
st1 {v13.s}[2], [x16], x17
|
||||
cmp w10, #3
|
||||
beq WriteEnd
|
||||
st1 {v14.4s}, [x18], x17
|
||||
dup s14, v15.s[1]
|
||||
stp s15, s14, [x13]
|
||||
add x13, x13, x17
|
||||
st1 {v15.s}[2], [x16], x17
|
||||
cmp w10, #4
|
||||
beq WriteEnd
|
||||
st1 {v16.4s}, [x18], x17
|
||||
dup s16, v17.s[1]
|
||||
stp s17, s16, [x13]
|
||||
add x13, x13, x17
|
||||
st1 {v17.s}[2], [x16], x17
|
||||
cmp w10, #5
|
||||
beq WriteEnd
|
||||
st1 {v18.4s}, [x18], x17
|
||||
dup s18, v19.s[1]
|
||||
stp s19, s18, [x13]
|
||||
add x13, x13, x17
|
||||
st1 {v19.s}[2], [x16], x17
|
||||
cmp w10, #6
|
||||
beq WriteEnd
|
||||
st1 {v20.4s}, [x18], x17
|
||||
dup s20, v21.s[1]
|
||||
stp s21, s20, [x13]
|
||||
add x13, x13, x17
|
||||
st1 {v21.s}[2], [x16], x17
|
||||
cmp w10, #7
|
||||
beq WriteEnd
|
||||
st1 {v22.4s}, [x18], x17
|
||||
dup s22, v23.s[1]
|
||||
stp s23, s22, [x13]
|
||||
add x13, x13, x17
|
||||
st1 {v23.s}[2], [x16], x17
|
||||
cmp w10, #8
|
||||
beq WriteEnd
|
||||
st1 {v24.4s}, [x18], x17
|
||||
dup s24, v25.s[1]
|
||||
stp s25, s24, [x13]
|
||||
add x13, x13, x17
|
||||
st1 {v25.s}[2], [x16], x17
|
||||
cmp w10, #9
|
||||
beq WriteEnd
|
||||
st1 {v26.4s}, [x18], x17
|
||||
dup s26, v27.s[1]
|
||||
stp s27, s26, [x13]
|
||||
add x13, x13, x17
|
||||
st1 {v27.s}[2], [x16], x17
|
||||
cmp w10, #10
|
||||
beq WriteEnd
|
||||
st1 {v28.4s}, [x18], x17
|
||||
dup s28, v29.s[1]
|
||||
stp s29, s28, [x13]
|
||||
add x13, x13, x17
|
||||
st1 {v29.s}[2], [x16], x17
|
||||
cmp w10, #11
|
||||
beq WriteEnd
|
||||
st1 {v30.4s}, [x18], x17
|
||||
dup s30, v31.s[1]
|
||||
stp s31, s30, [x13]
|
||||
add x13, x13, x17
|
||||
st1 {v31.s}[2], [x16], x17
|
||||
b WriteEnd
|
||||
WriteC8:
|
||||
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x2], #64
|
||||
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x2], #64
|
||||
st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64
|
||||
st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x2], #64
|
||||
st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x2], #64
|
||||
st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64
|
||||
b WriteEnd
|
||||
WriteWino:
|
||||
st1 {v8.4s, v9.4s}, [x18], x8
|
||||
st1 {v10.4s, v11.4s}, [x18], x8
|
||||
st1 {v12.4s, v13.4s}, [x18], x8
|
||||
st1 {v14.4s, v15.4s}, [x18], x8
|
||||
st1 {v16.4s, v17.4s}, [x18], x8
|
||||
st1 {v18.4s, v19.4s}, [x18], x8
|
||||
st1 {v20.4s, v21.4s}, [x18], x8
|
||||
st1 {v22.4s, v23.4s}, [x18], x8
|
||||
st1 {v24.4s, v25.4s}, [x18], x8
|
||||
st1 {v26.4s, v27.4s}, [x18], x8
|
||||
st1 {v28.4s, v29.4s}, [x18], x8
|
||||
st1 {v30.4s, v31.4s}, [x18], x8
|
||||
b WriteEnd
|
||||
Write8:
|
||||
st1 {v8.4s, v9.4s}, [x18], x17
|
||||
cmp w10, #1
|
||||
beq WriteEnd
|
||||
st1 {v10.4s, v11.4s}, [x18], x17
|
||||
cmp w10, #2
|
||||
beq WriteEnd
|
||||
st1 {v12.4s, v13.4s}, [x18], x17
|
||||
cmp w10, #3
|
||||
beq WriteEnd
|
||||
st1 {v14.4s, v15.4s}, [x18], x17
|
||||
cmp w10, #4
|
||||
beq WriteEnd
|
||||
st1 {v16.4s, v17.4s}, [x18], x17
|
||||
cmp w10, #5
|
||||
beq WriteEnd
|
||||
st1 {v18.4s, v19.4s}, [x18], x17
|
||||
cmp w10, #6
|
||||
beq WriteEnd
|
||||
st1 {v20.4s, v21.4s}, [x18], x17
|
||||
cmp w10, #7
|
||||
beq WriteEnd
|
||||
st1 {v22.4s, v23.4s}, [x18], x17
|
||||
cmp w10, #8
|
||||
beq WriteEnd
|
||||
st1 {v24.4s, v25.4s}, [x18], x17
|
||||
cmp w10, #9
|
||||
beq WriteEnd
|
||||
st1 {v26.4s, v27.4s}, [x18], x17
|
||||
cmp w10, #10
|
||||
beq WriteEnd
|
||||
st1 {v28.4s, v29.4s}, [x18], x17
|
||||
cmp w10, #11
|
||||
beq WriteEnd
|
||||
st1 {v30.4s, v31.4s}, [x18], x17
|
||||
|
||||
WriteEnd:
|
||||
subs w10, w10, #12 // lhs row - 12
|
||||
bgt L2
|
||||
|
||||
End2:
|
||||
subs w7, w7, #8 // rhs col - 8
|
||||
add x1, x1, x15 // rhs ptr + stride
|
||||
cbz x3, NoBiasStep
|
||||
add x3, x3, #32 // bias ptr + stride
|
||||
NoBiasStep:
|
||||
cbnz x14, WinoDstStep
|
||||
cbz x9, NoDstStep
|
||||
add x2, x2, #32 // dst ptr + stride
|
||||
b NoDstStep
|
||||
WinoDstStep:
|
||||
add x2, x2, x11
|
||||
NoDstStep:
|
||||
bgt L1
|
||||
|
||||
End1:
|
||||
sub sp, sp, #128
|
||||
ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
|
||||
ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
|
||||
ret
|
||||
#endif
|
|
@ -0,0 +1,144 @@
|
|||
#ifdef __aarch64__
|
||||
.text
|
||||
.align 5
|
||||
.global MatmulFloatNeon64OptRemain
|
||||
#ifndef __APPLE__
|
||||
.type MatmulFloatNeon64OptRemain, %function
|
||||
#endif
|
||||
|
||||
// void MatmulFloatNeon64(const float *a, const float *b, float *c, int depth
|
||||
// int row, int col, size_t stride)
|
||||
// x0: a
|
||||
// x1: b
|
||||
// x2: c
|
||||
// x3: depth
|
||||
// x4: row
|
||||
// x5: col
|
||||
// x6: stride
|
||||
// only for winograd
|
||||
MatmulFloatNeon64OptRemain:
|
||||
mov x18, #32 // sizeof(float) * 8
|
||||
mul x9, x3, x18 // block stride of lhs/rhs: sizeof(float) * 8 * depth
|
||||
mov x18, #4
|
||||
mul x8, x5, x6
|
||||
mov x11, #8
|
||||
mul x11, x11, x6
|
||||
mul x8, x8, x18
|
||||
mul x11, x11, x18
|
||||
|
||||
cmp x4, #4
|
||||
ble LoopH4
|
||||
|
||||
LoopH8:
|
||||
mov x10, x4 // reload lhs row
|
||||
mov x12, x0 // reload lhs ptr
|
||||
mov x18, x2 // reload dst ptr
|
||||
|
||||
LoopW8:
|
||||
mov x16, x1 // reload rhs ptr
|
||||
mov x13, x3 // reload depth
|
||||
dup v16.4s, wzr
|
||||
dup v17.4s, wzr
|
||||
dup v18.4s, wzr
|
||||
dup v19.4s, wzr
|
||||
dup v20.4s, wzr
|
||||
dup v21.4s, wzr
|
||||
dup v22.4s, wzr
|
||||
dup v23.4s, wzr
|
||||
dup v24.4s, wzr
|
||||
dup v25.4s, wzr
|
||||
dup v26.4s, wzr
|
||||
dup v27.4s, wzr
|
||||
dup v28.4s, wzr
|
||||
dup v29.4s, wzr
|
||||
dup v30.4s, wzr
|
||||
dup v31.4s, wzr
|
||||
|
||||
LoopD8:
|
||||
ld1 {v0.4s, v1.4s, v2.4s}, [x12], #48
|
||||
ld1 {v3.4s, v4.4s}, [x16], #32
|
||||
fmla v16.4s, v3.4s, v0.s[0]
|
||||
fmla v18.4s, v3.4s, v0.s[1]
|
||||
fmla v20.4s, v3.4s, v0.s[2]
|
||||
fmla v22.4s, v3.4s, v0.s[3]
|
||||
fmla v17.4s, v4.4s, v0.s[0]
|
||||
fmla v19.4s, v4.4s, v0.s[1]
|
||||
fmla v21.4s, v4.4s, v0.s[2]
|
||||
fmla v23.4s, v4.4s, v0.s[3]
|
||||
fmla v24.4s, v3.4s, v1.s[0]
|
||||
fmla v26.4s, v3.4s, v1.s[1]
|
||||
fmla v28.4s, v3.4s, v1.s[2]
|
||||
fmla v30.4s, v3.4s, v1.s[3]
|
||||
fmla v25.4s, v4.4s, v1.s[0]
|
||||
fmla v27.4s, v4.4s, v1.s[1]
|
||||
fmla v29.4s, v4.4s, v1.s[2]
|
||||
fmla v31.4s, v4.4s, v1.s[3]
|
||||
|
||||
subs w13, w13, #1
|
||||
bgt LoopD8
|
||||
|
||||
st1 {v16.4s, v17.4s}, [x18], x8
|
||||
st1 {v18.4s, v19.4s}, [x18], x8
|
||||
st1 {v20.4s, v21.4s}, [x18], x8
|
||||
st1 {v22.4s, v23.4s}, [x18], x8
|
||||
st1 {v24.4s, v25.4s}, [x18], x8
|
||||
st1 {v26.4s, v27.4s}, [x18], x8
|
||||
st1 {v28.4s, v29.4s}, [x18], x8
|
||||
st1 {v30.4s, v31.4s}, [x18], x8
|
||||
|
||||
subs x10, x10, #8 // lhs row - 8
|
||||
bgt LoopW8
|
||||
|
||||
subs x5, x5, #8 // rhs col - 8
|
||||
add x1, x1, x9 // rhs ptr + stride
|
||||
add x2, x2, x11
|
||||
bgt LoopH8
|
||||
|
||||
ret
|
||||
|
||||
LoopH4:
|
||||
mov x10, x4 // reload lhs row
|
||||
mov x12, x0 // reload lhs ptr
|
||||
mov x18, x2 // reload dst ptr
|
||||
|
||||
LoopW4:
|
||||
mov x16, x1 // reload rhs ptr
|
||||
mov x13, x3 // reload depth
|
||||
dup v16.4s, wzr
|
||||
dup v17.4s, wzr
|
||||
dup v18.4s, wzr
|
||||
dup v19.4s, wzr
|
||||
dup v20.4s, wzr
|
||||
dup v21.4s, wzr
|
||||
dup v22.4s, wzr
|
||||
dup v23.4s, wzr
|
||||
|
||||
LoopD4:
|
||||
ld1 {v0.4s, v1.4s, v2.4s}, [x12], #48
|
||||
ld1 {v3.4s, v4.4s}, [x16], #32
|
||||
fmla v16.4s, v3.4s, v0.s[0]
|
||||
fmla v18.4s, v3.4s, v0.s[1]
|
||||
fmla v20.4s, v3.4s, v0.s[2]
|
||||
fmla v22.4s, v3.4s, v0.s[3]
|
||||
fmla v17.4s, v4.4s, v0.s[0]
|
||||
fmla v19.4s, v4.4s, v0.s[1]
|
||||
fmla v21.4s, v4.4s, v0.s[2]
|
||||
fmla v23.4s, v4.4s, v0.s[3]
|
||||
|
||||
subs x13, x13, #1
|
||||
bgt LoopD4
|
||||
|
||||
st1 {v16.4s, v17.4s}, [x18], x8
|
||||
st1 {v18.4s, v19.4s}, [x18], x8
|
||||
st1 {v20.4s, v21.4s}, [x18], x8
|
||||
st1 {v22.4s, v23.4s}, [x18], x8
|
||||
|
||||
subs x10, x10, #4 // lhs row - 4
|
||||
bgt LoopW4
|
||||
|
||||
subs x5, x5, #8 // rhs col - 8
|
||||
add x1, x1, x9 // rhs ptr + stride
|
||||
add x2, x2, x11
|
||||
bgt LoopH4
|
||||
ret
|
||||
#endif
|
|
@ -24,7 +24,7 @@
|
|||
|
||||
//void MatmulInt8Neon64(const int8_t *a, const int8_t *b, int8_t *dst, int row4, int col4, int deep16,
|
||||
// const int *a_sums, const int *bias, int act_min, int act_max, int out_zp,
|
||||
// int multiplier, int left_shift, int right_shift);
|
||||
// int multiplier, int left_shift, int right_shift, int row, int col, int stride);
|
||||
|
||||
// x0: a(left matrix ptr)
|
||||
// x1: b(right matrix ptr)
|
||||
|
@ -40,13 +40,18 @@
|
|||
// w11: multiplier
|
||||
// w12: left_shift
|
||||
// w13: right_shift
|
||||
// w14: row
|
||||
// w15: col
|
||||
// w24: stride
|
||||
|
||||
MatmulInt8Neon64:
|
||||
sub sp, sp, #160
|
||||
sub sp, sp, #192
|
||||
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
|
||||
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
|
||||
stp x19, x20, [sp], #16
|
||||
stp x21, x22, [sp], #16
|
||||
stp x23, x24, [sp], #16
|
||||
stp x25, x26, [sp], #16
|
||||
|
||||
ldr w8, [sp]
|
||||
ldr w9, [sp, #8]
|
||||
|
@ -54,25 +59,28 @@ MatmulInt8Neon64:
|
|||
ldr w11, [sp, #24]
|
||||
ldr w12, [sp, #32]
|
||||
ldr w13, [sp, #40]
|
||||
ldr w14, [sp, #48]
|
||||
ldr w15, [sp, #56]
|
||||
ldr w24, [sp, #64]
|
||||
|
||||
mov w15, #0 // b col index
|
||||
mov w16, #0 // a row index
|
||||
mov w17, #4 // sizeof(int8)*4
|
||||
mul w21, w5, w17 // the stride of a/b: sizeof(int8)*4*deep16
|
||||
|
||||
mov w17, #1
|
||||
mov x25, x2
|
||||
L1:
|
||||
cmp w15, w4
|
||||
cmp w4, #0 // if at the end of col4
|
||||
beq End1
|
||||
|
||||
mov w16, #0 // reset a row index
|
||||
mov w16, w3 // reset a row4 counter
|
||||
mov w23, w14 // reset a row counter
|
||||
mov x17, x0 // reload a ptr
|
||||
mov x22, x6 // reload a_sums ptr
|
||||
L2:
|
||||
cmp w16, w3
|
||||
cmp w16, #0
|
||||
beq End2
|
||||
|
||||
mov x18, x1 // reload b ptr
|
||||
mov x19, x7 // reload bias ptr
|
||||
mov x19, x7 // reload bias ptr
|
||||
mov w20, w5 // reload depth
|
||||
dup v16.4s, wzr
|
||||
dup v17.4s, wzr
|
||||
|
@ -256,21 +264,128 @@ End3:
|
|||
sqxtn v15.8b, v13.8h
|
||||
sqxtn2 v15.16b, v14.8h
|
||||
|
||||
st1 {v15.16b}, [x2], #16
|
||||
add w16, w16, #4 // a row index + 4
|
||||
cmp w23, #4
|
||||
blt Write // if rows < 4
|
||||
cmp w15, #4
|
||||
blt Write // if cols < 4
|
||||
|
||||
st1 {v15.s}[0], [x2], x24
|
||||
st1 {v15.s}[1], [x2], x24
|
||||
st1 {v15.s}[2], [x2], x24
|
||||
st1 {v15.s}[3], [x2], x24
|
||||
b Endwrite
|
||||
|
||||
Write:
|
||||
cmp w15, #4
|
||||
beq WriteCol4
|
||||
cmp w15, #3
|
||||
beq WriteCol3
|
||||
cmp w15, #2
|
||||
beq WriteCol2
|
||||
cmp w15, #1
|
||||
beq WriteCol1
|
||||
|
||||
WriteCol4:
|
||||
st1 {v15.s}[0], [x2], x24
|
||||
cmp w23, #1
|
||||
beq Endwrite
|
||||
st1 {v15.s}[1], [x2], x24
|
||||
cmp w23, #2
|
||||
beq Endwrite
|
||||
st1 {v15.s}[2], [x2], x24
|
||||
cmp w23, #3
|
||||
beq Endwrite
|
||||
st1 {v15.s}[3], [x2], x24
|
||||
b Endwrite
|
||||
|
||||
WriteCol3:
|
||||
mov x26, x2
|
||||
st1 {v15.b}[0], [x26], #1
|
||||
st1 {v15.b}[1], [x26], #1
|
||||
st1 {v15.b}[2], [x26], #1
|
||||
add x2, x2, x24
|
||||
cmp w23, #1
|
||||
beq Endwrite
|
||||
mov x26, x2
|
||||
st1 {v15.b}[4], [x26], #1
|
||||
st1 {v15.b}[5], [x26], #1
|
||||
st1 {v15.b}[6], [x26], #1
|
||||
add x2, x2, x24
|
||||
cmp w23, #2
|
||||
beq Endwrite
|
||||
mov x26, x2
|
||||
st1 {v15.b}[8], [x26], #1
|
||||
st1 {v15.b}[9], [x26], #1
|
||||
st1 {v15.b}[10], [x26], #1
|
||||
add x2, x2, x24
|
||||
cmp w23, #3
|
||||
beq Endwrite
|
||||
mov x26, x2
|
||||
st1 {v15.b}[12], [x26], #1
|
||||
st1 {v15.b}[13], [x26], #1
|
||||
st1 {v15.b}[14], [x26], #1
|
||||
add x2, x2, x24
|
||||
b Endwrite
|
||||
|
||||
WriteCol2:
|
||||
mov x26, x2
|
||||
st1 {v15.b}[0], [x26], #1
|
||||
st1 {v15.b}[1], [x26], #1
|
||||
add x2, x2, x24
|
||||
cmp w23, #1
|
||||
beq Endwrite
|
||||
mov x26, x2
|
||||
st1 {v15.b}[4], [x26], #1
|
||||
st1 {v15.b}[5], [x26], #1
|
||||
add x2, x2, x24
|
||||
cmp w23, #2
|
||||
beq Endwrite
|
||||
mov x26, x2
|
||||
st1 {v15.b}[8], [x26], #1
|
||||
st1 {v15.b}[9], [x26], #1
|
||||
add x2, x2, x24
|
||||
cmp w23, #3
|
||||
beq Endwrite
|
||||
mov x26, x2
|
||||
st1 {v15.b}[12], [x26], #1
|
||||
st1 {v15.b}[13], [x26], #1
|
||||
add x2, x2, x24
|
||||
b Endwrite
|
||||
|
||||
WriteCol1:
|
||||
st1 {v15.b}[0], [x2], x24
|
||||
cmp w23, #1
|
||||
beq Endwrite
|
||||
st1 {v15.b}[4], [x2], x24
|
||||
cmp w23, #2
|
||||
beq Endwrite
|
||||
st1 {v15.b}[8], [x2], x24
|
||||
cmp w23, #3
|
||||
beq Endwrite
|
||||
st1 {v15.b}[12], [x2], x24
|
||||
b Endwrite
|
||||
|
||||
Endwrite:
|
||||
sub w16, w16, #4 // a row4 counter - 4
|
||||
sub w23, w23, #4 // a row counter - 4
|
||||
b L2
|
||||
|
||||
End2:
|
||||
add w15, w15, #4 // b col index + 4
|
||||
sub w4, w4, #4 // b col4 counter - 4
|
||||
sub w15, w15, #4 // b col counter - 4
|
||||
add x1, x1, x21 // b ptr + stride
|
||||
add x7, x7, #16 // bias ptr + stride
|
||||
add x25, x25, #4 // output + stride(4 * sizeof(int8))
|
||||
mov x2, x25
|
||||
b L1
|
||||
|
||||
End1:
|
||||
sub sp, sp, #160
|
||||
sub sp, sp, #192
|
||||
ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
|
||||
ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
|
||||
ldp x19, x20, [sp], #16
|
||||
ldp x21, x22, [sp], #16
|
||||
ldp x23, x24, [sp], #16
|
||||
ldp x25, x26, [sp], #16
|
||||
ret
|
||||
#endif
|
||||
|
|
|
@ -0,0 +1,117 @@
|
|||
#ifdef __aarch64__
|
||||
|
||||
.text
|
||||
.align 5
|
||||
.global ConvDwFp16Row
|
||||
#ifndef __APPLE__
|
||||
.type ConvDwFp16Row, %function
|
||||
#endif
|
||||
|
||||
// void ConvDwFp16Row(float16_t* output_ptr, const float16_t* input_ptr,const float16_t* filter_ptr,
|
||||
// size_t num_pixels, size_t input_channel, size_t input_step)
|
||||
// x0: output_ptr, x1: input_ptr, x2: filter_ptr, x3: num_pixels,
|
||||
// x4: input_channel, x5: input_step
|
||||
//
|
||||
ConvDwFp16Row:
|
||||
// registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to
|
||||
// https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers
|
||||
// x19 ~ x29 should be also preserved
|
||||
// whereas our coding style do not permit such amount of parameters
|
||||
cmp x3, #0
|
||||
beq End
|
||||
|
||||
mov x9, x0
|
||||
mov x12, #2 // sizeof(float16_t)
|
||||
mul x5, x5, x12
|
||||
|
||||
LoopOutPixel:
|
||||
mov x6, x1
|
||||
mov x7, x2
|
||||
mov x8, x4
|
||||
|
||||
LoopInputDepth32In:
|
||||
cmp x8, #32
|
||||
blt Loop8
|
||||
sub x8, x8, #32
|
||||
|
||||
ld1 {v0.8h, v1.8h}, [x6], #32
|
||||
ld1 {v2.8h, v3.8h}, [x7], #32
|
||||
ld1 {v16.8h, v17.8h}, [x0], #32
|
||||
|
||||
cmp x8, #32
|
||||
blt LoopInputDepth32Out
|
||||
LoopInputDepth32:
|
||||
fmla v16.8h, v0.8h, v2.8h
|
||||
fmla v17.8h, v1.8h, v3.8h
|
||||
|
||||
st1 {v16.8h, v17.8h}, [x9], #32
|
||||
|
||||
ld1 {v4.8h, v5.8h}, [x6], #32
|
||||
ld1 {v6.8h, v7.8h}, [x7], #32
|
||||
ld1 {v18.8h, v19.8h}, [x0], #32
|
||||
|
||||
fmla v18.8h, v4.8h, v6.8h
|
||||
fmla v19.8h, v5.8h, v7.8h
|
||||
|
||||
st1 {v18.8h, v19.8h}, [x9], #32
|
||||
|
||||
ld1 {v0.8h, v1.8h}, [x6], #32
|
||||
ld1 {v2.8h, v3.8h}, [x7], #32
|
||||
ld1 {v16.8h, v17.8h}, [x0], #32
|
||||
|
||||
sub x8, x8, #32
|
||||
cmp x8, #32
|
||||
bge LoopInputDepth32
|
||||
|
||||
LoopInputDepth32Out:
|
||||
fmla v16.8h, v0.8h, v2.8h
|
||||
fmla v17.8h, v1.8h, v3.8h
|
||||
st1 {v16.8h, v17.8h}, [x9], #32
|
||||
|
||||
ld1 {v4.8h, v5.8h}, [x6], #32
|
||||
ld1 {v6.8h, v7.8h}, [x7], #32
|
||||
ld1 {v18.8h, v19.8h}, [x0], #32
|
||||
|
||||
fmla v18.8h, v4.8h, v6.8h
|
||||
fmla v19.8h, v5.8h, v7.8h
|
||||
|
||||
st1 {v18.8h, v19.8h}, [x9], #32
|
||||
|
||||
Loop8:
|
||||
cmp x8, #8
|
||||
blt L0
|
||||
|
||||
LoopInputDepth8:
|
||||
ld1 {v0.8h}, [x6], #16
|
||||
ld1 {v2.8h}, [x7], #16
|
||||
ld1 {v16.8h}, [x0], #16
|
||||
fmla v16.8h, v0.8h, v2.8h
|
||||
st1 {v16.8h}, [x9], #16
|
||||
sub x8, x8, #8
|
||||
cmp x8, #8
|
||||
bge LoopInputDepth8
|
||||
|
||||
L0:
|
||||
cmp x8, #0
|
||||
beq Loop8LineEnd
|
||||
|
||||
LoopInputDepth0:
|
||||
ldr h0, [x6], #2
|
||||
ldr h1, [x7], #2
|
||||
ldr h2, [x0], #2
|
||||
fmul h0, h0, h1
|
||||
fadd h2, h2, h0
|
||||
str h2, [x9], #2
|
||||
subs x8, x8, #1
|
||||
bne LoopInputDepth0
|
||||
|
||||
Loop8LineEnd:
|
||||
|
||||
subs x3, x3, #1
|
||||
add x1, x1, x5
|
||||
bne LoopOutPixel
|
||||
|
||||
End:
|
||||
ret
|
||||
|
||||
#endif
|
|
@ -36,7 +36,7 @@ IndirectGemmInt8_24x4_dp:
|
|||
ld1 {v17.4s}, [x22], x23
|
||||
ld1 {v18.4s}, [x22], x23
|
||||
ld1 {v19.4s}, [x22], x23
|
||||
ld1{v20.4s}, [x22], x23
|
||||
ld1 {v20.4s}, [x22], x23
|
||||
ld1 {v21.4s}, [x22], x23
|
||||
ld1 {v22.4s}, [x22], x23
|
||||
ld1 {v23.4s}, [x22], x23
|
||||
|
@ -404,7 +404,7 @@ IndirectGemmInt8_24x4_dp:
|
|||
sshr v0.4s, v0.4s, #31
|
||||
sqadd v8.4s, v8.4s, v0.4s
|
||||
srshl v8.4s, v8.4s, v4.4s
|
||||
and v0.16b, v4.16b, v9.16b
|
||||
and v1.16b, v4.16b, v9.16b
|
||||
sshr v1.4s, v1.4s, #31
|
||||
sqadd v9.4s, v9.4s, v1.4s
|
||||
srshl v9.4s, v9.4s, v4.4s
|
||||
|
@ -420,7 +420,7 @@ IndirectGemmInt8_24x4_dp:
|
|||
sshr v0.4s, v0.4s, #31
|
||||
sqadd v12.4s, v12.4s, v0.4s
|
||||
srshl v12.4s, v12.4s, v4.4s
|
||||
and v0.16b, v4.16b, v13.16b
|
||||
and v1.16b, v4.16b, v13.16b
|
||||
sshr v1.4s, v1.4s, #31
|
||||
sqadd v13.4s, v13.4s, v1.4s
|
||||
srshl v13.4s, v13.4s, v4.4s
|
||||
|
@ -436,7 +436,7 @@ IndirectGemmInt8_24x4_dp:
|
|||
sshr v0.4s, v0.4s, #31
|
||||
sqadd v16.4s, v16.4s, v0.4s
|
||||
srshl v16.4s, v16.4s, v4.4s
|
||||
and v0.16b, v4.16b, v17.16b
|
||||
and v1.16b, v4.16b, v17.16b
|
||||
sshr v1.4s, v1.4s, #31
|
||||
sqadd v17.4s, v17.4s, v1.4s
|
||||
srshl v17.4s, v17.4s, v4.4s
|
||||
|
@ -452,7 +452,7 @@ IndirectGemmInt8_24x4_dp:
|
|||
sshr v0.4s, v0.4s, #31
|
||||
sqadd v20.4s, v20.4s, v0.4s
|
||||
srshl v20.4s, v20.4s, v4.4s
|
||||
and v0.16b, v4.16b, v21.16b
|
||||
and v1.16b, v4.16b, v21.16b
|
||||
sshr v1.4s, v1.4s, #31
|
||||
sqadd v21.4s, v21.4s, v1.4s
|
||||
srshl v21.4s, v21.4s, v4.4s
|
||||
|
@ -468,7 +468,7 @@ IndirectGemmInt8_24x4_dp:
|
|||
sshr v0.4s, v0.4s, #31
|
||||
sqadd v24.4s, v24.4s, v0.4s
|
||||
srshl v24.4s, v24.4s, v4.4s
|
||||
and v0.16b, v4.16b, v25.16b
|
||||
and v1.16b, v4.16b, v25.16b
|
||||
sshr v1.4s, v1.4s, #31
|
||||
sqadd v25.4s, v25.4s, v1.4s
|
||||
srshl v25.4s, v25.4s, v4.4s
|
||||
|
@ -484,7 +484,7 @@ IndirectGemmInt8_24x4_dp:
|
|||
sshr v0.4s, v0.4s, #31
|
||||
sqadd v28.4s, v28.4s, v0.4s
|
||||
srshl v28.4s, v28.4s, v4.4s
|
||||
and v0.16b, v4.16b, v29.16b
|
||||
and v1.16b, v4.16b, v29.16b
|
||||
sshr v1.4s, v1.4s, #31
|
||||
sqadd v29.4s, v29.4s, v1.4s
|
||||
srshl v29.4s, v29.4s, v4.4s
|
||||
|
|
|
@ -0,0 +1,820 @@
|
|||
#ifdef __aarch64__
|
||||
.text
|
||||
.align 5
|
||||
.global MatmulInt8DpNeon64
|
||||
#ifndef __APPLE__
|
||||
.type MatmulInt8DpNeon64, %function
|
||||
#endif
|
||||
|
||||
//
|
||||
// int8 RHS 4x8 block
|
||||
// /-----------------------------------------|
|
||||
// |v2.b[0] ... v2.b[12] v3.b[0] ... v3.b[12]|
|
||||
// | ... ... |
|
||||
// |v2.b[3] ... v2.b[15] v3.b[3] ... v3.b[15]|
|
||||
// \-----------------------------------------/
|
||||
// int8 LHS 8x4 block
|
||||
// /---------------------\ /-------------------------------------------|
|
||||
// |v0.b[0] ... v0.b[3] | |v16.s[0] ... v16.s[3] v17.s[0] ... v17.s[3]|
|
||||
// |v0.b[4] ... v0.b[7] |v18.s[0] ... v18.s[3] v19.s[0] ... v19.s[3]|
|
||||
// |v0.b[8] ... v0.b[11] |v20.s[0] ... v20.s[3] v21.s[0] ... v21.s[3]|
|
||||
// |v0.b[12] ... v0.b[15]| |v22.s[0] ... v22.s[3] v23.s[0] ... v23.s[3]|
|
||||
// |v1.b[0] ... v1.b[3] | |v24.s[0] ... v24.s[3] v25.s[0] ... v25.s[3]|
|
||||
// |v1.b[4] ... v1.b[7] | |v26.s[0] ... v26.s[3] v27.s[0] ... v27.s[3]|
|
||||
// |v1.b[8] ... v1.b[11]| |v28.s[0] ... v28.s[3] v29.s[0] ... v29.s[3]|
|
||||
// |v1.b[12] ... v1.b[15]| |v30.s[0] ... v30.s[3] v31.s[0] ... v31.s[3]|
|
||||
// \---------------------/ \-------------------------------------------/
|
||||
// int32 accumulators 8x8 block
|
||||
|
||||
|
||||
|
||||
// int8 RHS 16x8 block
|
||||
// /-------------|
|
||||
// |v2 v3 |
|
||||
// |v6 v7 |
|
||||
// |v10 v11 |
|
||||
// |v14 v15 |
|
||||
// \-------------/
|
||||
// int8 LHS 8x16 block
|
||||
// /--------------------\ /-------------|
|
||||
// |v0 v4 v8 v12| | |
|
||||
// |v1 v5 v9 v13| | |
|
||||
// \--------------------/ \-------------/
|
||||
|
||||
|
||||
|
||||
//void MatmulInt8DpNeon64(const int8_t *a, const int8_t *b, int8_t *dst, int row8, int col8, int deep4,
|
||||
// const int *a_sums, const int *bias, int act_min, int act_max, int out_zp,
|
||||
// int multiplier, int left_shift, int right_shift, int row, int col, int stride);
|
||||
|
||||
// x0: a(left matrix ptr)
|
||||
// x1: b(right matrix ptr)
|
||||
// x2: out ptr
|
||||
// w3: row8
|
||||
// w4: col8
|
||||
// w5: deep4
|
||||
// x6: a_sums
|
||||
// x7: bias
|
||||
// w8: act_min
|
||||
// w9: act_max
|
||||
// w10: out_zp
|
||||
// w11: multiplier
|
||||
// w12: left_shift
|
||||
// w13: right_shift
|
||||
// w14: row
|
||||
// w15: col
|
||||
// w24: stride
|
||||
|
||||
MatmulInt8DpNeon64:
|
||||
sub sp, sp, #192
|
||||
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
|
||||
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
|
||||
stp x19, x20, [sp], #16
|
||||
stp x21, x22, [sp], #16
|
||||
stp x23, x24, [sp], #16
|
||||
stp x25, x26, [sp], #16
|
||||
|
||||
ldr w8, [sp]
|
||||
ldr w9, [sp, #8]
|
||||
ldr w10, [sp, #16]
|
||||
ldr w11, [sp, #24]
|
||||
ldr w12, [sp, #32]
|
||||
ldr w13, [sp, #40]
|
||||
ldr w14, [sp, #48]
|
||||
ldr w15, [sp, #56]
|
||||
ldr w24, [sp, #64]
|
||||
|
||||
mov w17, #8 // sizeof(int8)*8
|
||||
mul w21, w5, w17 // the stride of a/b: sizeof(int8)*8*deep4
|
||||
mov x25, x2
|
||||
L1:
|
||||
cmp w4, #0 // if at the end of col8
|
||||
beq End1
|
||||
|
||||
mov w16, w3 // reset a row8 counter
|
||||
mov w23, w14 // reset a row counter
|
||||
mov x17, x0 // reload a ptr
|
||||
mov x22, x6 // reload a_sums ptr
|
||||
L2:
|
||||
cmp w16, #0
|
||||
beq End2
|
||||
|
||||
mov x18, x1 // reload b ptr
|
||||
mov x19, x7 // reload bias ptr
|
||||
mov w20, w5 // reload depth
|
||||
dup v16.4s, wzr
|
||||
dup v17.4s, wzr
|
||||
dup v18.4s, wzr
|
||||
dup v19.4s, wzr
|
||||
dup v20.4s, wzr
|
||||
dup v21.4s, wzr
|
||||
dup v22.4s, wzr
|
||||
dup v23.4s, wzr
|
||||
dup v24.4s, wzr
|
||||
dup v25.4s, wzr
|
||||
dup v26.4s, wzr
|
||||
dup v27.4s, wzr
|
||||
dup v28.4s, wzr
|
||||
dup v29.4s, wzr
|
||||
dup v30.4s, wzr
|
||||
dup v31.4s, wzr
|
||||
L3:
|
||||
cmp w20, #16
|
||||
blt LoopD4
|
||||
|
||||
LoopD16:
|
||||
ld1 {v0.16b, v1.16b}, [x17], #32
|
||||
ld1 {v2.16b, v3.16b}, [x18], #32
|
||||
|
||||
sdot v16.4s, v2.16b, v0.4b[0]
|
||||
sdot v18.4s, v2.16b, v0.4b[1]
|
||||
sdot v20.4s, v2.16b, v0.4b[2]
|
||||
sdot v22.4s, v2.16b, v0.4b[3]
|
||||
|
||||
ld1 {v4.16b, v5.16b}, [x17], #32
|
||||
sdot v24.4s, v2.16b, v1.4b[0]
|
||||
sdot v26.4s, v2.16b, v1.4b[1]
|
||||
sdot v28.4s, v2.16b, v1.4b[2]
|
||||
sdot v30.4s, v2.16b, v1.4b[3]
|
||||
|
||||
ld1 {v6.16b, v7.16b}, [x18], #32
|
||||
sdot v17.4s, v3.16b, v0.4b[0]
|
||||
sdot v19.4s, v3.16b, v0.4b[1]
|
||||
sdot v21.4s, v3.16b, v0.4b[2]
|
||||
sdot v23.4s, v3.16b, v0.4b[3]
|
||||
|
||||
sdot v25.4s, v3.16b, v1.4b[0]
|
||||
sdot v27.4s, v3.16b, v1.4b[1]
|
||||
sdot v29.4s, v3.16b, v1.4b[2]
|
||||
sdot v31.4s, v3.16b, v1.4b[3]
|
||||
|
||||
ld1 {v8.16b, v9.16b}, [x17], #32
|
||||
sdot v16.4s, v6.16b, v4.4b[0]
|
||||
sdot v18.4s, v6.16b, v4.4b[1]
|
||||
sdot v20.4s, v6.16b, v4.4b[2]
|
||||
sdot v22.4s, v6.16b, v4.4b[3]
|
||||
|
||||
sdot v24.4s, v6.16b, v5.4b[0]
|
||||
sdot v26.4s, v6.16b, v5.4b[1]
|
||||
sdot v28.4s, v6.16b, v5.4b[2]
|
||||
sdot v30.4s, v6.16b, v5.4b[3]
|
||||
|
||||
ld1 {v10.16b, v11.16b}, [x18], #32
|
||||
sdot v17.4s, v7.16b, v4.4b[0]
|
||||
sdot v19.4s, v7.16b, v4.4b[1]
|
||||
sdot v21.4s, v7.16b, v4.4b[2]
|
||||
sdot v23.4s, v7.16b, v4.4b[3]
|
||||
|
||||
sdot v25.4s, v7.16b, v5.4b[0]
|
||||
sdot v27.4s, v7.16b, v5.4b[1]
|
||||
sdot v29.4s, v7.16b, v5.4b[2]
|
||||
sdot v31.4s, v7.16b, v5.4b[3]
|
||||
|
||||
ld1 {v12.16b, v13.16b}, [x17], #32
|
||||
sdot v16.4s, v10.16b, v8.4b[0]
|
||||
sdot v18.4s, v10.16b, v8.4b[1]
|
||||
sdot v20.4s, v10.16b, v8.4b[2]
|
||||
sdot v22.4s, v10.16b, v8.4b[3]
|
||||
|
||||
sdot v24.4s, v10.16b, v9.4b[0]
|
||||
sdot v26.4s, v10.16b, v9.4b[1]
|
||||
sdot v28.4s, v10.16b, v9.4b[2]
|
||||
sdot v30.4s, v10.16b, v9.4b[3]
|
||||
|
||||
ld1 {v14.16b, v15.16b}, [x18], #32
|
||||
sdot v17.4s, v11.16b, v8.4b[0]
|
||||
sdot v19.4s, v11.16b, v8.4b[1]
|
||||
sdot v21.4s, v11.16b, v8.4b[2]
|
||||
sdot v23.4s, v11.16b, v8.4b[3]
|
||||
|
||||
sdot v25.4s, v11.16b, v9.4b[0]
|
||||
sdot v27.4s, v11.16b, v9.4b[1]
|
||||
sdot v29.4s, v11.16b, v9.4b[2]
|
||||
sdot v31.4s, v11.16b, v9.4b[3]
|
||||
|
||||
sdot v16.4s, v14.16b, v12.4b[0]
|
||||
sdot v18.4s, v14.16b, v12.4b[1]
|
||||
sdot v20.4s, v14.16b, v12.4b[2]
|
||||
sdot v22.4s, v14.16b, v12.4b[3]
|
||||
|
||||
sdot v24.4s, v14.16b, v13.4b[0]
|
||||
sdot v26.4s, v14.16b, v13.4b[1]
|
||||
sdot v28.4s, v14.16b, v13.4b[2]
|
||||
sdot v30.4s, v14.16b, v13.4b[3]
|
||||
|
||||
sdot v17.4s, v15.16b, v12.4b[0]
|
||||
sdot v19.4s, v15.16b, v12.4b[1]
|
||||
sdot v21.4s, v15.16b, v12.4b[2]
|
||||
sdot v23.4s, v15.16b, v12.4b[3]
|
||||
|
||||
sdot v25.4s, v15.16b, v13.4b[0]
|
||||
sdot v27.4s, v15.16b, v13.4b[1]
|
||||
sdot v29.4s, v15.16b, v13.4b[2]
|
||||
sdot v31.4s, v15.16b, v13.4b[3]
|
||||
|
||||
subs w20, w20, #16 // depth - 16
|
||||
b L3
|
||||
|
||||
LoopD4:
|
||||
cmp w20, #0
|
||||
beq End3
|
||||
|
||||
ld1 {v0.16b, v1.16b}, [x17], #32
|
||||
ld1 {v2.16b, v3.16b}, [x18], #32
|
||||
|
||||
sdot v16.4s, v2.16b, v0.4b[0]
|
||||
sdot v18.4s, v2.16b, v0.4b[1]
|
||||
sdot v20.4s, v2.16b, v0.4b[2]
|
||||
sdot v22.4s, v2.16b, v0.4b[3]
|
||||
sdot v24.4s, v2.16b, v1.4b[0]
|
||||
sdot v26.4s, v2.16b, v1.4b[1]
|
||||
sdot v28.4s, v2.16b, v1.4b[2]
|
||||
sdot v30.4s, v2.16b, v1.4b[3]
|
||||
sdot v17.4s, v3.16b, v0.4b[0]
|
||||
sdot v19.4s, v3.16b, v0.4b[1]
|
||||
sdot v21.4s, v3.16b, v0.4b[2]
|
||||
sdot v23.4s, v3.16b, v0.4b[3]
|
||||
sdot v25.4s, v3.16b, v1.4b[0]
|
||||
sdot v27.4s, v3.16b, v1.4b[1]
|
||||
sdot v29.4s, v3.16b, v1.4b[2]
|
||||
sdot v31.4s, v3.16b, v1.4b[3]
|
||||
|
||||
subs w20, w20, #4 // depth - 4
|
||||
b LoopD4
|
||||
|
||||
End3:
|
||||
// Add (Bias+Depth*Za*Zb-Za*Bsums)
|
||||
ld1 {v15.4s}, [x19], #16
|
||||
ld1 {v14.4s}, [x19], #16
|
||||
add v16.4s, v16.4s, v15.4s
|
||||
add v18.4s, v18.4s, v15.4s
|
||||
add v20.4s, v20.4s, v15.4s
|
||||
add v22.4s, v22.4s, v15.4s
|
||||
add v24.4s, v24.4s, v15.4s
|
||||
add v26.4s, v26.4s, v15.4s
|
||||
add v28.4s, v28.4s, v15.4s
|
||||
add v30.4s, v30.4s, v15.4s
|
||||
add v17.4s, v17.4s, v14.4s
|
||||
add v19.4s, v19.4s, v14.4s
|
||||
add v21.4s, v21.4s, v14.4s
|
||||
add v23.4s, v23.4s, v14.4s
|
||||
add v25.4s, v25.4s, v14.4s
|
||||
add v27.4s, v27.4s, v14.4s
|
||||
add v29.4s, v29.4s, v14.4s
|
||||
add v31.4s, v31.4s, v14.4s
|
||||
|
||||
// Subtract (Asums*Zb)
|
||||
ld1 {v13.4s}, [x22], #16
|
||||
ld1 {v12.4s}, [x22], #16
|
||||
dup v0.4s, v13.s[0]
|
||||
dup v1.4s, v13.s[1]
|
||||
dup v2.4s, v13.s[2]
|
||||
dup v3.4s, v13.s[3]
|
||||
dup v4.4s, v12.s[0]
|
||||
dup v5.4s, v12.s[1]
|
||||
dup v6.4s, v12.s[2]
|
||||
dup v7.4s, v12.s[3]
|
||||
sub v16.4s, v16.4s, v0.4s
|
||||
sub v17.4s, v17.4s, v0.4s
|
||||
sub v18.4s, v18.4s, v1.4s
|
||||
sub v19.4s, v19.4s, v1.4s
|
||||
sub v20.4s, v20.4s, v2.4s
|
||||
sub v21.4s, v21.4s, v2.4s
|
||||
sub v22.4s, v22.4s, v3.4s
|
||||
sub v23.4s, v23.4s, v3.4s
|
||||
sub v24.4s, v24.4s, v4.4s
|
||||
sub v25.4s, v25.4s, v4.4s
|
||||
sub v26.4s, v26.4s, v5.4s
|
||||
sub v27.4s, v27.4s, v5.4s
|
||||
sub v28.4s, v28.4s, v6.4s
|
||||
sub v29.4s, v29.4s, v6.4s
|
||||
sub v30.4s, v30.4s, v7.4s
|
||||
sub v31.4s, v31.4s, v7.4s
|
||||
|
||||
// Apply left shift
|
||||
dup v11.4s, w12
|
||||
sqshl v16.4s, v16.4s, v11.4s
|
||||
sqshl v17.4s, v17.4s, v11.4s
|
||||
sqshl v18.4s, v18.4s, v11.4s
|
||||
sqshl v19.4s, v19.4s, v11.4s
|
||||
sqshl v20.4s, v20.4s, v11.4s
|
||||
sqshl v21.4s, v21.4s, v11.4s
|
||||
sqshl v22.4s, v22.4s, v11.4s
|
||||
sqshl v23.4s, v23.4s, v11.4s
|
||||
sqshl v24.4s, v24.4s, v11.4s
|
||||
sqshl v25.4s, v25.4s, v11.4s
|
||||
sqshl v26.4s, v26.4s, v11.4s
|
||||
sqshl v27.4s, v27.4s, v11.4s
|
||||
sqshl v28.4s, v28.4s, v11.4s
|
||||
sqshl v29.4s, v29.4s, v11.4s
|
||||
sqshl v30.4s, v30.4s, v11.4s
|
||||
sqshl v31.4s, v31.4s, v11.4s
|
||||
|
||||
// Apply the fixed-point part of the multiplier.
|
||||
dup v10.4s, w11
|
||||
sqrdmulh v16.4s, v16.4s, v10.4s
|
||||
sqrdmulh v17.4s, v17.4s, v10.4s
|
||||
sqrdmulh v18.4s, v18.4s, v10.4s
|
||||
sqrdmulh v19.4s, v19.4s, v10.4s
|
||||
sqrdmulh v20.4s, v20.4s, v10.4s
|
||||
sqrdmulh v21.4s, v21.4s, v10.4s
|
||||
sqrdmulh v22.4s, v22.4s, v10.4s
|
||||
sqrdmulh v23.4s, v23.4s, v10.4s
|
||||
sqrdmulh v24.4s, v24.4s, v10.4s
|
||||
sqrdmulh v25.4s, v25.4s, v10.4s
|
||||
sqrdmulh v26.4s, v26.4s, v10.4s
|
||||
sqrdmulh v27.4s, v27.4s, v10.4s
|
||||
sqrdmulh v28.4s, v28.4s, v10.4s
|
||||
sqrdmulh v29.4s, v29.4s, v10.4s
|
||||
sqrdmulh v30.4s, v30.4s, v10.4s
|
||||
sqrdmulh v31.4s, v31.4s, v10.4s
|
||||
|
||||
// Apply right shift
|
||||
dup v9.4s, w13
|
||||
and v0.16b, v9.16b, v16.16b
|
||||
sshr v0.4s, v0.4s, #31
|
||||
sqadd v16.4s, v16.4s, v0.4s
|
||||
srshl v16.4s, v16.4s, v9.4s
|
||||
and v1.16b, v9.16b, v17.16b
|
||||
sshr v1.4s, v1.4s, #31
|
||||
sqadd v17.4s, v17.4s, v1.4s
|
||||
srshl v17.4s, v17.4s, v9.4s
|
||||
and v2.16b, v9.16b, v18.16b
|
||||
sshr v2.4s, v2.4s, #31
|
||||
sqadd v18.4s, v18.4s, v2.4s
|
||||
srshl v18.4s, v18.4s, v9.4s
|
||||
and v3.16b, v9.16b, v19.16b
|
||||
sshr v3.4s, v3.4s, #31
|
||||
sqadd v19.4s, v19.4s, v3.4s
|
||||
srshl v19.4s, v19.4s, v9.4s
|
||||
and v0.16b, v9.16b, v20.16b
|
||||
sshr v0.4s, v0.4s, #31
|
||||
sqadd v20.4s, v20.4s, v0.4s
|
||||
srshl v20.4s, v20.4s, v9.4s
|
||||
and v1.16b, v9.16b, v21.16b
|
||||
sshr v1.4s, v1.4s, #31
|
||||
sqadd v21.4s, v21.4s, v1.4s
|
||||
srshl v21.4s, v21.4s, v9.4s
|
||||
and v2.16b, v9.16b, v22.16b
|
||||
sshr v2.4s, v2.4s, #31
|
||||
sqadd v22.4s, v22.4s, v2.4s
|
||||
srshl v22.4s, v22.4s, v9.4s
|
||||
and v3.16b, v9.16b, v23.16b
|
||||
sshr v3.4s, v3.4s, #31
|
||||
sqadd v23.4s, v23.4s, v3.4s
|
||||
srshl v23.4s, v23.4s, v9.4s
|
||||
and v0.16b, v9.16b, v24.16b
|
||||
sshr v0.4s, v0.4s, #31
|
||||
sqadd v24.4s, v24.4s, v0.4s
|
||||
srshl v24.4s, v24.4s, v9.4s
|
||||
and v1.16b, v9.16b, v25.16b
|
||||
sshr v1.4s, v1.4s, #31
|
||||
sqadd v25.4s, v25.4s, v1.4s
|
||||
srshl v25.4s, v25.4s, v9.4s
|
||||
and v2.16b, v9.16b, v26.16b
|
||||
sshr v2.4s, v2.4s, #31
|
||||
sqadd v26.4s, v26.4s, v2.4s
|
||||
srshl v26.4s, v26.4s, v9.4s
|
||||
and v3.16b, v9.16b, v27.16b
|
||||
sshr v3.4s, v3.4s, #31
|
||||
sqadd v27.4s, v27.4s, v3.4s
|
||||
srshl v27.4s, v27.4s, v9.4s
|
||||
and v0.16b, v9.16b, v28.16b
|
||||
sshr v0.4s, v0.4s, #31
|
||||
sqadd v28.4s, v28.4s, v0.4s
|
||||
srshl v28.4s, v28.4s, v9.4s
|
||||
and v1.16b, v9.16b, v29.16b
|
||||
sshr v1.4s, v1.4s, #31
|
||||
sqadd v29.4s, v29.4s, v1.4s
|
||||
srshl v29.4s, v29.4s, v9.4s
|
||||
and v2.16b, v9.16b, v30.16b
|
||||
sshr v2.4s, v2.4s, #31
|
||||
sqadd v30.4s, v30.4s, v2.4s
|
||||
srshl v30.4s, v30.4s, v9.4s
|
||||
and v3.16b, v9.16b, v31.16b
|
||||
sshr v3.4s, v3.4s, #31
|
||||
sqadd v31.4s, v31.4s, v3.4s
|
||||
srshl v31.4s, v31.4s, v9.4s
|
||||
|
||||
// Add the destination zero point
|
||||
dup v8.4s, w10
|
||||
add v16.4s, v16.4s, v8.4s
|
||||
add v17.4s, v17.4s, v8.4s
|
||||
add v18.4s, v18.4s, v8.4s
|
||||
add v19.4s, v19.4s, v8.4s
|
||||
add v20.4s, v20.4s, v8.4s
|
||||
add v21.4s, v21.4s, v8.4s
|
||||
add v22.4s, v22.4s, v8.4s
|
||||
add v23.4s, v23.4s, v8.4s
|
||||
add v24.4s, v24.4s, v8.4s
|
||||
add v25.4s, v25.4s, v8.4s
|
||||
add v26.4s, v26.4s, v8.4s
|
||||
add v27.4s, v27.4s, v8.4s
|
||||
add v28.4s, v28.4s, v8.4s
|
||||
add v29.4s, v29.4s, v8.4s
|
||||
add v30.4s, v30.4s, v8.4s
|
||||
add v31.4s, v31.4s, v8.4s
|
||||
|
||||
// Apply the act_min bound
|
||||
dup v7.4s, w8
|
||||
smax v16.4s, v16.4s, v7.4s
|
||||
smax v17.4s, v17.4s, v7.4s
|
||||
smax v18.4s, v18.4s, v7.4s
|
||||
smax v19.4s, v19.4s, v7.4s
|
||||
|
||||
// Apply the act_min bound
|
||||
dup v6.4s, w9
|
||||
smin v16.4s, v16.4s, v6.4s
|
||||
smin v17.4s, v17.4s, v6.4s
|
||||
smin v18.4s, v18.4s, v6.4s
|
||||
smin v19.4s, v19.4s, v6.4s
|
||||
|
||||
// int32 -> int16
|
||||
sqxtn v0.4h, v16.4s
|
||||
sqxtn2 v0.8h, v17.4s
|
||||
sqxtn v1.4h, v18.4s
|
||||
sqxtn2 v1.8h, v19.4s
|
||||
sqxtn v2.4h, v20.4s
|
||||
sqxtn2 v2.8h, v21.4s
|
||||
sqxtn v3.4h, v22.4s
|
||||
sqxtn2 v3.8h, v23.4s
|
||||
sqxtn v4.4h, v24.4s
|
||||
sqxtn2 v4.8h, v25.4s
|
||||
sqxtn v5.4h, v26.4s
|
||||
sqxtn2 v5.8h, v27.4s
|
||||
sqxtn v6.4h, v28.4s
|
||||
sqxtn2 v6.8h, v29.4s
|
||||
sqxtn v7.4h, v30.4s
|
||||
sqxtn2 v7.8h, v31.4s
|
||||
|
||||
// int16 -> int8
|
||||
sqxtn v8.8b, v0.8h
|
||||
sqxtn2 v8.16b, v1.8h
|
||||
sqxtn v9.8b, v2.8h
|
||||
sqxtn2 v9.16b, v3.8h
|
||||
sqxtn v10.8b, v4.8h
|
||||
sqxtn2 v10.16b, v5.8h
|
||||
sqxtn v11.8b, v6.8h
|
||||
sqxtn2 v11.16b, v7.8h
|
||||
|
||||
cmp w23, #8
|
||||
blt Write // if rows < 8
|
||||
cmp w15, #8
|
||||
blt Write // if cols < 8
|
||||
|
||||
st1 {v8.d}[0], [x2], x24
|
||||
st1 {v8.d}[1], [x2], x24
|
||||
st1 {v9.d}[0], [x2], x24
|
||||
st1 {v9.d}[1], [x2], x24
|
||||
st1 {v10.d}[0], [x2], x24
|
||||
st1 {v10.d}[1], [x2], x24
|
||||
st1 {v11.d}[0], [x2], x24
|
||||
st1 {v11.d}[1], [x2], x24
|
||||
b Endwrite
|
||||
|
||||
Write:
|
||||
cmp w15, #8
|
||||
bge WriteCol8
|
||||
cmp w15, #7
|
||||
beq WriteCol7
|
||||
cmp w15, #6
|
||||
beq WriteCol6
|
||||
cmp w15, #5
|
||||
beq WriteCol5
|
||||
cmp w15, #4
|
||||
beq WriteCol4
|
||||
cmp w15, #3
|
||||
beq WriteCol3
|
||||
cmp w15, #2
|
||||
beq WriteCol2
|
||||
cmp w15, #1
|
||||
beq WriteCol1
|
||||
|
||||
WriteCol8:
|
||||
st1 {v8.d}[0], [x2], x24
|
||||
cmp w23, #1
|
||||
beq Endwrite
|
||||
st1 {v8.d}[1], [x2], x24
|
||||
cmp w23, #2
|
||||
beq Endwrite
|
||||
st1 {v9.d}[0], [x2], x24
|
||||
cmp w23, #3
|
||||
beq Endwrite
|
||||
st1 {v9.d}[1], [x2], x24
|
||||
cmp w23, #4
|
||||
beq Endwrite
|
||||
st1 {v10.d}[0], [x2], x24
|
||||
cmp w23, #5
|
||||
beq Endwrite
|
||||
st1 {v10.d}[1], [x2], x24
|
||||
cmp w23, #6
|
||||
beq Endwrite
|
||||
st1 {v11.d}[0], [x2], x24
|
||||
cmp w23, #7
|
||||
beq Endwrite
|
||||
st1 {v11.d}[1], [x2], x24
|
||||
b Endwrite
|
||||
|
||||
WriteCol7:
|
||||
mov x26, x2
|
||||
st1 {v8.s}[0], [x26], #4
|
||||
st1 {v8.h}[2], [x26], #2
|
||||
st1 {v8.b}[6], [x26], #1
|
||||
add x2, x2, x24
|
||||
cmp w23, #1
|
||||
beq Endwrite
|
||||
mov x26, x2
|
||||
st1 {v8.s}[2], [x26], #4
|
||||
st1 {v8.h}[6], [x26], #2
|
||||
st1 {v8.b}[14], [x26], #1
|
||||
add x2, x2, x24
|
||||
cmp w23, #2
|
||||
beq Endwrite
|
||||
mov x26, x2
|
||||
st1 {v9.s}[0], [x26], #4
|
||||
st1 {v9.h}[2], [x26], #2
|
||||
st1 {v9.b}[6], [x26], #1
|
||||
add x2, x2, x24
|
||||
cmp w23, #3
|
||||
beq Endwrite
|
||||
mov x26, x2
|
||||
st1 {v9.s}[2], [x26], #4
|
||||
st1 {v9.h}[6], [x26], #2
|
||||
st1 {v9.b}[14], [x26], #1
|
||||
add x2, x2, x24
|
||||
cmp w23, #4
|
||||
beq Endwrite
|
||||
mov x26, x2
|
||||
st1 {v10.s}[0], [x26], #4
|
||||
st1 {v10.h}[2], [x26], #2
|
||||
st1 {v10.b}[6], [x26], #1
|
||||
add x2, x2, x24
|
||||
cmp w23, #5
|
||||
beq Endwrite
|
||||
mov x26, x2
|
||||
st1 {v10.s}[2], [x26], #4
|
||||
st1 {v10.h}[6], [x26], #2
|
||||
st1 {v10.b}[14], [x26], #1
|
||||
add x2, x2, x24
|
||||
cmp w23, #6
|
||||
beq Endwrite
|
||||
mov x26, x2
|
||||
st1 {v11.s}[0], [x26], #4
|
||||
st1 {v11.h}[2], [x26], #2
|
||||
st1 {v11.b}[6], [x26], #1
|
||||
add x2, x2, x24
|
||||
cmp w23, #7
|
||||
beq Endwrite
|
||||
mov x26, x2
|
||||
st1 {v11.s}[2], [x26], #4
|
||||
st1 {v11.h}[6], [x26], #2
|
||||
st1 {v11.b}[14], [x26], #1
|
||||
add x2, x2, x24
|
||||
b Endwrite
|
||||
|
||||
WriteCol6:
|
||||
mov x26, x2
|
||||
st1 {v8.s}[0], [x26], #4
|
||||
st1 {v8.h}[2], [x26], #2
|
||||
add x2, x2, x24
|
||||
cmp w23, #1
|
||||
beq Endwrite
|
||||
mov x26, x2
|
||||
st1 {v8.s}[2], [x26], #4
|
||||
st1 {v8.h}[6], [x26], #2
|
||||
add x2, x2, x24
|
||||
cmp w23, #2
|
||||
beq Endwrite
|
||||
mov x26, x2
|
||||
st1 {v9.s}[0], [x26], #4
|
||||
st1 {v9.h}[2], [x26], #2
|
||||
add x2, x2, x24
|
||||
cmp w23, #3
|
||||
beq Endwrite
|
||||
mov x26, x2
|
||||
st1 {v9.s}[2], [x26], #4
|
||||
st1 {v9.h}[6], [x26], #2
|
||||
add x2, x2, x24
|
||||
cmp w23, #4
|
||||
beq Endwrite
|
||||
mov x26, x2
|
||||
st1 {v10.s}[0], [x26], #4
|
||||
st1 {v10.h}[2], [x26], #2
|
||||
add x2, x2, x24
|
||||
cmp w23, #5
|
||||
beq Endwrite
|
||||
mov x26, x2
|
||||
st1 {v10.s}[2], [x26], #4
|
||||
st1 {v10.h}[6], [x26], #2
|
||||
add x2, x2, x24
|
||||
cmp w23, #6
|
||||
beq Endwrite
|
||||
mov x26, x2
|
||||
st1 {v11.s}[0], [x26], #4
|
||||
st1 {v11.h}[2], [x26], #2
|
||||
add x2, x2, x24
|
||||
cmp w23, #7
|
||||
beq Endwrite
|
||||
mov x26, x2
|
||||
st1 {v11.s}[2], [x26], #4
|
||||
st1 {v11.h}[6], [x26], #2
|
||||
add x2, x2, x24
|
||||
b Endwrite
|
||||
|
||||
WriteCol5:
|
||||
mov x26, x2
|
||||
st1 {v8.s}[0], [x26], #4
|
||||
st1 {v8.b}[4], [x26], #1
|
||||
add x2, x2, x24
|
||||
cmp w23, #1
|
||||
beq Endwrite
|
||||
mov x26, x2
|
||||
st1 {v8.s}[2], [x26], #4
|
||||
st1 {v8.b}[12], [x26], #1
|
||||
add x2, x2, x24
|
||||
cmp w23, #2
|
||||
beq Endwrite
|
||||
mov x26, x2
|
||||
st1 {v9.s}[0], [x26], #4
|
||||
st1 {v9.b}[4], [x26], #1
|
||||
add x2, x2, x24
|
||||
cmp w23, #3
|
||||
beq Endwrite
|
||||
mov x26, x2
|
||||
st1 {v9.s}[2], [x26], #4
|
||||
st1 {v9.b}[12], [x26], #1
|
||||
add x2, x2, x24
|
||||
cmp w23, #4
|
||||
beq Endwrite
|
||||
mov x26, x2
|
||||
st1 {v10.s}[0], [x26], #4
|
||||
st1 {v10.b}[4], [x26], #1
|
||||
add x2, x2, x24
|
||||
cmp w23, #5
|
||||
beq Endwrite
|
||||
mov x26, x2
|
||||
st1 {v10.s}[2], [x26], #4
|
||||
st1 {v10.b}[12], [x26], #1
|
||||
add x2, x2, x24
|
||||
cmp w23, #6
|
||||
beq Endwrite
|
||||
mov x26, x2
|
||||
st1 {v11.s}[0], [x26], #4
|
||||
st1 {v11.b}[4], [x26], #1
|
||||
add x2, x2, x24
|
||||
cmp w23, #7
|
||||
beq Endwrite
|
||||
mov x26, x2
|
||||
st1 {v11.s}[2], [x26], #4
|
||||
st1 {v11.b}[12], [x26], #1
|
||||
add x2, x2, x24
|
||||
b Endwrite
|
||||
|
||||
WriteCol4:
|
||||
st1 {v8.s}[0], [x2], x24
|
||||
cmp w23, #1
|
||||
beq Endwrite
|
||||
st1 {v8.s}[2], [x2], x24
|
||||
cmp w23, #2
|
||||
beq Endwrite
|
||||
st1 {v9.s}[0], [x2], x24
|
||||
cmp w23, #3
|
||||
beq Endwrite
|
||||
st1 {v9.s}[2], [x2], x24
|
||||
cmp w23, #4
|
||||
beq Endwrite
|
||||
st1 {v10.s}[0], [x2], x24
|
||||
cmp w23, #5
|
||||
beq Endwrite
|
||||
st1 {v10.s}[2], [x2], x24
|
||||
cmp w23, #6
|
||||
beq Endwrite
|
||||
st1 {v11.s}[0], [x2], x24
|
||||
cmp w23, #7
|
||||
beq Endwrite
|
||||
st1 {v11.s}[2], [x2], x24
|
||||
b Endwrite
|
||||
|
||||
WriteCol3:
|
||||
mov x26, x2
|
||||
st1 {v8.h}[0], [x26], #2
|
||||
st1 {v8.b}[2], [x26], #1
|
||||
add x2, x2, x24
|
||||
cmp w23, #1
|
||||
beq Endwrite
|
||||
mov x26, x2
|
||||
st1 {v8.h}[4], [x26], #2
|
||||
st1 {v8.b}[10], [x26], #1
|
||||
add x2, x2, x24
|
||||
cmp w23, #2
|
||||
beq Endwrite
|
||||
mov x26, x2
|
||||
st1 {v9.h}[0], [x26], #2
|
||||
st1 {v9.b}[2], [x26], #1
|
||||
add x2, x2, x24
|
||||
cmp w23, #3
|
||||
beq Endwrite
|
||||
mov x26, x2
|
||||
st1 {v9.h}[4], [x26], #2
|
||||
st1 {v9.b}[10], [x26], #1
|
||||
add x2, x2, x24
|
||||
cmp w23, #4
|
||||
beq Endwrite
|
||||
mov x26, x2
|
||||
st1 {v10.h}[0], [x26], #2
|
||||
st1 {v10.b}[2], [x26], #1
|
||||
add x2, x2, x24
|
||||
cmp w23, #5
|
||||
beq Endwrite
|
||||
mov x26, x2
|
||||
st1 {v10.h}[4], [x26], #2
|
||||
st1 {v10.b}[10], [x26], #1
|
||||
add x2, x2, x24
|
||||
cmp w23, #6
|
||||
beq Endwrite
|
||||
mov x26, x2
|
||||
st1 {v11.h}[0], [x26], #2
|
||||
st1 {v11.b}[2], [x26], #1
|
||||
add x2, x2, x24
|
||||
cmp w23, #7
|
||||
beq Endwrite
|
||||
mov x26, x2
|
||||
st1 {v11.h}[4], [x26], #2
|
||||
st1 {v11.b}[10], [x26], #1
|
||||
add x2, x2, x24
|
||||
b Endwrite
|
||||
|
||||
WriteCol2:
|
||||
st1 {v8.h}[0], [x2], x24
|
||||
cmp w23, #1
|
||||
beq Endwrite
|
||||
st1 {v8.h}[4], [x2], x24
|
||||
cmp w23, #2
|
||||
beq Endwrite
|
||||
st1 {v9.h}[0], [x2], x24
|
||||
cmp w23, #3
|
||||
beq Endwrite
|
||||
st1 {v9.h}[4], [x2], x24
|
||||
cmp w23, #4
|
||||
beq Endwrite
|
||||
st1 {v10.h}[0], [x2], x24
|
||||
cmp w23, #5
|
||||
beq Endwrite
|
||||
st1 {v10.h}[4], [x2], x24
|
||||
cmp w23, #6
|
||||
beq Endwrite
|
||||
st1 {v11.h}[0], [x2], x24
|
||||
cmp w23, #7
|
||||
beq Endwrite
|
||||
st1 {v11.h}[4], [x2], x24
|
||||
b Endwrite
|
||||
|
||||
WriteCol1:
|
||||
st1 {v8.b}[0], [x2], x24
|
||||
cmp w23, #1
|
||||
beq Endwrite
|
||||
st1 {v8.b}[8], [x2], x24
|
||||
cmp w23, #2
|
||||
beq Endwrite
|
||||
st1 {v9.b}[0], [x2], x24
|
||||
cmp w23, #3
|
||||
beq Endwrite
|
||||
st1 {v9.b}[8], [x2], x24
|
||||
cmp w23, #4
|
||||
beq Endwrite
|
||||
st1 {v10.b}[0], [x2], x24
|
||||
cmp w23, #5
|
||||
beq Endwrite
|
||||
st1 {v10.b}[8], [x2], x24
|
||||
cmp w23, #6
|
||||
beq Endwrite
|
||||
st1 {v11.b}[0], [x2], x24
|
||||
cmp w23, #7
|
||||
beq Endwrite
|
||||
st1 {v11.b}[8], [x2], x24
|
||||
b Endwrite
|
||||
|
||||
Endwrite:
|
||||
sub w16, w16, #8 // a row8 counter - 8
|
||||
sub w23, w23, #8 // a row counter - 8
|
||||
b L2
|
||||
|
||||
End2:
|
||||
sub w4, w4, #8 // b col8 counter - 8
|
||||
sub w15, w15, #8 // b col counter - 8
|
||||
add x1, x1, x21 // b ptr + stride
|
||||
add x7, x7, #32 // bias ptr + stride
|
||||
add x25, x25, #8 // output + stride(8 * sizeof(int8))
|
||||
mov x2, x25
|
||||
b L1
|
||||
|
||||
End1:
|
||||
sub sp, sp, #192
|
||||
ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
|
||||
ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
|
||||
ldp x19, x20, [sp], #16
|
||||
ldp x21, x22, [sp], #16
|
||||
ldp x23, x24, [sp], #16
|
||||
ldp x25, x26, [sp], #16
|
||||
ret
|
||||
#endif
|
|
@ -228,19 +228,3 @@ void IndirectGemmFp32_Comm(float *output, const float *input, const float *weigh
|
|||
return;
|
||||
}
|
||||
|
||||
void SimplePostFuncInt8(const int *in, int8_t *out, int oc, int plane, int plane8, int32_t multiplier,
|
||||
int32_t left_shift, int32_t right_shift, int32_t zp) {
|
||||
/* (int32_t)row8x8-major * multiplier => (int8_t)row-major */
|
||||
for (int r = 0; r < plane; r++) {
|
||||
for (int c = 0; c < oc; c++) {
|
||||
int c8div = c / 8, c8mod = c % 8;
|
||||
int src_index = c8div * plane8 * 8 + r * 8 + c8mod;
|
||||
int dst_index = r * oc + c;
|
||||
int32_t value = in[src_index];
|
||||
value = MultiplyByQuantizedMultiplier(value, multiplier, left_shift, right_shift) + zp;
|
||||
value = MSMIN(CHAR_MAX, value);
|
||||
value = MSMAX(CHAR_MIN, value);
|
||||
out[dst_index] = (int8_t)value;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -32,8 +32,6 @@ typedef struct ConvParameter {
|
|||
int stride_w_;
|
||||
int dilation_h_;
|
||||
int dilation_w_;
|
||||
int pad_h_;
|
||||
int pad_w_;
|
||||
int pad_u_;
|
||||
int pad_d_;
|
||||
int pad_l_;
|
||||
|
@ -51,8 +49,7 @@ typedef struct ConvParameter {
|
|||
int thread_num_;
|
||||
int input_unit_;
|
||||
int output_unit_;
|
||||
bool is_relu_;
|
||||
bool is_relu6_;
|
||||
ActType act_type_;
|
||||
} ConvParameter;
|
||||
|
||||
typedef struct SlidingWindowParam {
|
||||
|
|
|
@ -0,0 +1,98 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "nnacl/fp16/activation_fp16.h"
|
||||
#include "nnacl/errorcode.h"
|
||||
|
||||
int ReluFp16(const float16_t *src, float16_t *dst, int ele_num) {
|
||||
int eight_block = UP_DIV(ele_num, C8NUM);
|
||||
int i;
|
||||
for (i = 0; i < eight_block - 1; i++) {
|
||||
int index = i * C8NUM;
|
||||
#ifdef ENABLE_NEON
|
||||
float16x8_t relu_src = vld1q_f16(src + index);
|
||||
float16x8_t zero_src = vdupq_n_f16(0);
|
||||
relu_src = vmaxq_f16(relu_src, zero_src);
|
||||
vst1q_f16(dst + index, relu_src);
|
||||
#else
|
||||
int j;
|
||||
for (j = 0; j < C8NUM; j++) {
|
||||
dst[index + j] = src[index + j] < 0 ? 0 : src[index + j];
|
||||
}
|
||||
#endif
|
||||
}
|
||||
for (int j = (eight_block - 1) * C8NUM; j < ele_num; ++j) {
|
||||
dst[j] = src[j] < 0 ? 0 : src[j];
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int Relu6Fp16(const float16_t *data, float16_t *dst, int ele_num) {
|
||||
int eight_block = UP_DIV(ele_num, C8NUM);
|
||||
int i;
|
||||
for (i = 0; i < eight_block - 1; i++) {
|
||||
int index = i * C8NUM;
|
||||
#ifdef ENABLE_NEON
|
||||
float16x8_t relu6_data = vld1q_f16(data + index);
|
||||
float16x8_t zero_data = vdupq_n_f16(0);
|
||||
float16x8_t six_data = vdupq_n_f16(6);
|
||||
relu6_data = vmaxq_f16(relu6_data, zero_data);
|
||||
relu6_data = vminq_f16(relu6_data, six_data);
|
||||
vst1q_f16(dst + index, relu6_data);
|
||||
#else
|
||||
int j;
|
||||
for (j = 0; j < C8NUM; ++j) {
|
||||
dst[index + j] = data[index + j] < 0 ? 0 : data[index + j];
|
||||
dst[index + j] = dst[index + j] > 6 ? 6 : dst[index + j];
|
||||
}
|
||||
#endif
|
||||
}
|
||||
for (int j = (eight_block - 1) * C8NUM; j < ele_num; ++j) {
|
||||
dst[j] = data[j] < 0 ? 0 : data[j];
|
||||
dst[j] = dst[j] > 6 ? 6 : dst[j];
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int LReluFp16(const float16_t *src, float16_t *dst, int ele_num, float16_t alpha) {
|
||||
for (int i = 0; i < ele_num; ++i) {
|
||||
dst[i] = src[i] > (float16_t)0.0f ? src[i] : (src[i] * alpha);
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int SigmoidFp16(const float16_t *src, float16_t *dst, int ele_num) {
|
||||
for (int i = 0; i < ele_num; ++i) {
|
||||
dst[i] = (float16_t)1.0f / (float16_t)(1.0f + exp(-src[i]));
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int TanhFp16(const float16_t *src, float16_t *dst, int ele_num) {
|
||||
for (int i = 0; i < ele_num; ++i) {
|
||||
dst[i] = (float16_t)1.0f - (float16_t)2.0f / (float16_t)(exp(2 * src[i]) + 1);
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int HSwishFp16(const float16_t *src, float16_t *dst, int ele_num) {
|
||||
for (int i = 0; i < ele_num; ++i) {
|
||||
float16_t in = src[i];
|
||||
float16_t relu6 = MSMIN(MSMAX(in + 3, 0), 6);
|
||||
dst[i] = in * relu6 / (float16_t)6.0f;
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
|
@ -0,0 +1,44 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_NNACL_FP16_ACTIVATION_FP16_H_
|
||||
#define MINDSPORE_LITE_NNACL_FP16_ACTIVATION_FP16_H_
|
||||
|
||||
#ifdef ENABLE_NEON
|
||||
#include <arm_neon.h>
|
||||
#endif
|
||||
#include <math.h>
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/quantization/fixed_point.h"
|
||||
|
||||
typedef struct ActivationParameter {
|
||||
OpParameter op_parameter_;
|
||||
int type_;
|
||||
float alpha_;
|
||||
} ActivationParameter;
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
int ReluFp16(const float16_t *src, float16_t *dst, int ele_num);
|
||||
int Relu6Fp16(const float16_t *data, float16_t *dst, int ele_num);
|
||||
int LReluFp16(const float16_t *src, float16_t *dst, int ele_num, float16_t alpha);
|
||||
int SigmoidFp16(const float16_t *src, float16_t *dst, int ele_num);
|
||||
int TanhFp16(const float16_t *src, float16_t *dst, int ele_num);
|
||||
int HSwishFp16(const float16_t *src, float16_t *dst, int ele_num);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
#endif // MINDSPORE_LITE_NNACL_FP16_ACTIVATION_FP16_H_
|
|
@ -74,33 +74,48 @@ int ElementOptMulFp16(float16_t *input0, float16_t *input1, float16_t *output, i
|
|||
ArithmeticParameter *param) {
|
||||
int block_mod = element_size % C8NUM;
|
||||
int block_c8 = element_size - block_mod;
|
||||
float16_t in0_opt = input0[0];
|
||||
float16_t in1_opt = input1[0];
|
||||
#ifdef ENABLE_NEON
|
||||
float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]};
|
||||
float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]};
|
||||
float16_t in0_opt = input0[0];
|
||||
float16_t in1_opt = input1[0];
|
||||
#endif
|
||||
for (int index = 0; index < block_c8; index += C8NUM) {
|
||||
if (param->in_elements_num0_ == 1) {
|
||||
for (int index = 0; index < block_c8; index += C8NUM) {
|
||||
#ifdef ENABLE_NEON
|
||||
float16x8_t vin0 = param->in_elements_num0_ == 1 ? vin0_opt : vld1q_f16(input0);
|
||||
float16x8_t vin1 = param->in_elements_num1_ == 1 ? vin1_opt : vld1q_f16(input1);
|
||||
float16x8_t vout = vmulq_f16(vin0, vin1);
|
||||
vst1q_f16(output, vout);
|
||||
float16x8_t vin0 = vin0_opt;
|
||||
float16x8_t vin1 = vld1q_f16(input1);
|
||||
float16x8_t vout = vmulq_f16(vin0, vin1);
|
||||
vst1q_f16(output, vout);
|
||||
#else
|
||||
for (int i = 0; i < C8NUM; ++i) {
|
||||
float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[i];
|
||||
float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[i];
|
||||
output[i] = in0 * in1;
|
||||
}
|
||||
for (int i = 0; i < C8NUM; ++i) {
|
||||
output[i] = in0_opt * input1[i];
|
||||
}
|
||||
#endif
|
||||
input0 += C8NUM;
|
||||
input1 += C8NUM;
|
||||
output += C8NUM;
|
||||
}
|
||||
for (int index = 0; index < block_mod; ++index) {
|
||||
float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[index];
|
||||
float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[index];
|
||||
output[index] = in0 * in1;
|
||||
input1 += C8NUM;
|
||||
output += C8NUM;
|
||||
}
|
||||
for (int index = 0; index < block_mod; ++index) {
|
||||
output[index] = in0_opt * input1[index];
|
||||
}
|
||||
} else {
|
||||
for (int index = 0; index < block_c8; index += C8NUM) {
|
||||
#ifdef ENABLE_NEON
|
||||
float16x8_t vin0 = vld1q_f16(input0);
|
||||
float16x8_t vin1 = vin1_opt;
|
||||
float16x8_t vout = vmulq_f16(vin0, vin1);
|
||||
vst1q_f16(output, vout);
|
||||
#else
|
||||
for (int i = 0; i < C8NUM; ++i) {
|
||||
output[i] = input0[i] * in1_opt;
|
||||
}
|
||||
#endif
|
||||
input0 += C8NUM;
|
||||
output += C8NUM;
|
||||
}
|
||||
for (int index = 0; index < block_mod; ++index) {
|
||||
output[index] = input0[index] * in1_opt;
|
||||
}
|
||||
}
|
||||
|
||||
return NNACL_OK;
|
||||
|
@ -113,7 +128,6 @@ int ElementMulReluFp16(float16_t *input0, float16_t *input1, float16_t *output,
|
|||
#ifdef ENABLE_NEON
|
||||
float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0};
|
||||
#endif
|
||||
|
||||
for (int index = 0; index < block_c8; index += C8NUM) {
|
||||
#ifdef ENABLE_NEON
|
||||
float16x8_t vin0 = vld1q_f16(input0);
|
||||
|
@ -143,39 +157,58 @@ int ElementOptMulReluFp16(float16_t *input0, float16_t *input1, float16_t *outpu
|
|||
ArithmeticParameter *param) {
|
||||
int block_mod = element_size % C8NUM;
|
||||
int block_c8 = element_size - block_mod;
|
||||
float16_t in0_opt = input0[0];
|
||||
float16_t in1_opt = input1[0];
|
||||
#ifdef ENABLE_NEON
|
||||
float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]};
|
||||
float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]};
|
||||
float16_t in0_opt = input0[0];
|
||||
float16_t in1_opt = input1[0];
|
||||
float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0};
|
||||
#endif
|
||||
|
||||
for (int index = 0; index < block_c8; index += C8NUM) {
|
||||
if (param->in_elements_num0_ == 1) {
|
||||
for (int index = 0; index < block_c8; index += C8NUM) {
|
||||
#ifdef ENABLE_NEON
|
||||
float16x8_t vin0 = param->in_elements_num0_ == 1 ? vin0_opt : vld1q_f16(input0);
|
||||
float16x8_t vin1 = param->in_elements_num1_ == 1 ? vin1_opt : vld1q_f16(input1);
|
||||
float16x8_t vout = vmulq_f16(vin0, vin1);
|
||||
vout = vmaxq_f16(vout, zeros);
|
||||
vst1q_f16(output, vout);
|
||||
float16x8_t vin0 = vin0_opt;
|
||||
float16x8_t vin1 = vld1q_f16(input1);
|
||||
float16x8_t vout = vmulq_f16(vin0, vin1);
|
||||
vout = vmaxq_f16(vout, zeros);
|
||||
vst1q_f16(output, vout);
|
||||
#else
|
||||
float16_t res;
|
||||
for (int i = 0; i < C8NUM; ++i) {
|
||||
float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[i];
|
||||
float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[i];
|
||||
res = in0 * in1;
|
||||
output[i] = res > 0 ? res : 0;
|
||||
}
|
||||
float16_t res;
|
||||
for (int i = 0; i < C8NUM; ++i) {
|
||||
res = in0_opt * input1[i];
|
||||
output[i] = res > 0 ? res : 0;
|
||||
}
|
||||
#endif
|
||||
input0 += C8NUM;
|
||||
input1 += C8NUM;
|
||||
output += C8NUM;
|
||||
}
|
||||
for (int index = 0; index < block_mod; ++index) {
|
||||
float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[index];
|
||||
float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[index];
|
||||
float16_t res = in0 * in1;
|
||||
output[index] = res > 0 ? res : 0;
|
||||
input1 += C8NUM;
|
||||
output += C8NUM;
|
||||
}
|
||||
for (int index = 0; index < block_mod; ++index) {
|
||||
float16_t res = in0_opt * input1[index];
|
||||
output[index] = res > 0 ? res : 0;
|
||||
}
|
||||
} else {
|
||||
for (int index = 0; index < block_c8; index += C8NUM) {
|
||||
#ifdef ENABLE_NEON
|
||||
float16x8_t vin0 = vld1q_f16(input0);
|
||||
float16x8_t vin1 = vin1_opt;
|
||||
float16x8_t vout = vmulq_f16(vin0, vin1);
|
||||
vout = vmaxq_f16(vout, zeros);
|
||||
vst1q_f16(output, vout);
|
||||
#else
|
||||
float16_t res;
|
||||
for (int i = 0; i < C8NUM; ++i) {
|
||||
res = input0[i] * in1_opt;
|
||||
output[i] = res > 0 ? res : 0;
|
||||
}
|
||||
#endif
|
||||
input0 += C8NUM;
|
||||
output += C8NUM;
|
||||
}
|
||||
for (int index = 0; index < block_mod; ++index) {
|
||||
float16_t res = input0[index] * in1_opt;
|
||||
output[index] = res > 0 ? res : 0;
|
||||
}
|
||||
}
|
||||
|
||||
return NNACL_OK;
|
||||
|
@ -216,37 +249,52 @@ int ElementOptMulRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *outp
|
|||
ArithmeticParameter *param) {
|
||||
int block_mod = element_size % C8NUM;
|
||||
int block_c8 = element_size - block_mod;
|
||||
float16_t in0_opt = input0[0];
|
||||
float16_t in1_opt = input1[0];
|
||||
#ifdef ENABLE_NEON
|
||||
float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]};
|
||||
float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]};
|
||||
float16_t in0_opt = input0[0];
|
||||
float16_t in1_opt = input1[0];
|
||||
float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0};
|
||||
float16x8_t bounds = {6, 6, 6, 6, 6, 6, 6, 6};
|
||||
#endif
|
||||
|
||||
for (int index = 0; index < block_c8; index += C8NUM) {
|
||||
if (param->in_elements_num0_ == 1) {
|
||||
for (int index = 0; index < block_c8; index += C8NUM) {
|
||||
#ifdef ENABLE_NEON
|
||||
float16x8_t vin0 = param->in_elements_num0_ == 1 ? vin0_opt : vld1q_f16(input0);
|
||||
float16x8_t vin1 = param->in_elements_num1_ == 1 ? vin1_opt : vld1q_f16(input1);
|
||||
float16x8_t vout = vmulq_f16(vin0, vin1);
|
||||
vout = vminq_f16(vmaxq_f16(vout, zeros), bounds);
|
||||
vst1q_f16(output, vout);
|
||||
float16x8_t vin0 = vin0_opt;
|
||||
float16x8_t vin1 = vld1q_f16(input1);
|
||||
float16x8_t vout = vmulq_f16(vin0, vin1);
|
||||
vout = vminq_f16(vmaxq_f16(vout, zeros), bounds);
|
||||
vst1q_f16(output, vout);
|
||||
#else
|
||||
for (int i = 0; i < C8NUM; ++i) {
|
||||
float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[i];
|
||||
float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[i];
|
||||
output[i] = MSMIN(MSMAX(in0 * in1, 0), 6);
|
||||
}
|
||||
for (int i = 0; i < C8NUM; ++i) {
|
||||
output[i] = MSMIN(MSMAX(in0_opt * input1[i], 0), 6);
|
||||
}
|
||||
#endif
|
||||
input0 += C8NUM;
|
||||
input1 += C8NUM;
|
||||
output += C8NUM;
|
||||
}
|
||||
for (int index = 0; index < block_mod; ++index) {
|
||||
float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[index];
|
||||
float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[index];
|
||||
output[index] = MSMIN(MSMAX(in0 * in1, 0), 6);
|
||||
input1 += C8NUM;
|
||||
output += C8NUM;
|
||||
}
|
||||
for (int index = 0; index < block_mod; ++index) {
|
||||
output[index] = MSMIN(MSMAX(in0_opt * input1[index], 0), 6);
|
||||
}
|
||||
} else {
|
||||
for (int index = 0; index < block_c8; index += C8NUM) {
|
||||
#ifdef ENABLE_NEON
|
||||
float16x8_t vin0 = vld1q_f16(input0);
|
||||
float16x8_t vin1 = vin1_opt;
|
||||
float16x8_t vout = vmulq_f16(vin0, vin1);
|
||||
vout = vminq_f16(vmaxq_f16(vout, zeros), bounds);
|
||||
vst1q_f16(output, vout);
|
||||
#else
|
||||
for (int i = 0; i < C8NUM; ++i) {
|
||||
output[i] = MSMIN(MSMAX(input0[i] * in1_opt, 0), 6);
|
||||
}
|
||||
#endif
|
||||
input0 += C8NUM;
|
||||
output += C8NUM;
|
||||
}
|
||||
for (int index = 0; index < block_mod; ++index) {
|
||||
output[index] = MSMIN(MSMAX(input0[index] * in1_opt, 0), 6);
|
||||
}
|
||||
}
|
||||
|
||||
return NNACL_OK;
|
||||
|
@ -255,7 +303,6 @@ int ElementOptMulRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *outp
|
|||
int ElementAddFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) {
|
||||
int block_mod = element_size % C8NUM;
|
||||
int block_c8 = element_size - block_mod;
|
||||
|
||||
for (int index = 0; index < block_c8; index += C8NUM) {
|
||||
#ifdef ENABLE_NEON
|
||||
float16x8_t vin0 = vld1q_f16(input0);
|
||||
|
@ -280,34 +327,50 @@ int ElementOptAddFp16(float16_t *input0, float16_t *input1, float16_t *output, i
|
|||
ArithmeticParameter *param) {
|
||||
int block_mod = element_size % C8NUM;
|
||||
int block_c8 = element_size - block_mod;
|
||||
float16_t in0_opt = input0[0];
|
||||
float16_t in1_opt = input1[0];
|
||||
#ifdef ENABLE_NEON
|
||||
float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]};
|
||||
float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]};
|
||||
float16_t in0_opt = input0[0];
|
||||
float16_t in1_opt = input1[0];
|
||||
#endif
|
||||
for (int index = 0; index < block_c8; index += C8NUM) {
|
||||
if (param->in_elements_num0_ == 1) {
|
||||
for (int index = 0; index < block_c8; index += C8NUM) {
|
||||
#ifdef ENABLE_NEON
|
||||
float16x8_t vin0 = param->in_elements_num0_ == 1 ? vin0_opt : vld1q_f16(input0);
|
||||
float16x8_t vin1 = param->in_elements_num1_ == 1 ? vin1_opt : vld1q_f16(input1);
|
||||
float16x8_t vout = vaddq_f16(vin0, vin1);
|
||||
vst1q_f16(output, vout);
|
||||
float16x8_t vin0 = vin0_opt;
|
||||
float16x8_t vin1 = vld1q_f16(input1);
|
||||
float16x8_t vout = vaddq_f16(vin0, vin1);
|
||||
vst1q_f16(output, vout);
|
||||
#else
|
||||
for (int i = 0; i < C8NUM; ++i) {
|
||||
float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[i];
|
||||
float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[i];
|
||||
output[i] = in0 + in1;
|
||||
}
|
||||
for (int i = 0; i < C8NUM; ++i) {
|
||||
output[i] = in0_opt + input1[i];
|
||||
}
|
||||
#endif
|
||||
input0 += C8NUM;
|
||||
input1 += C8NUM;
|
||||
output += C8NUM;
|
||||
}
|
||||
for (int index = 0; index < block_mod; ++index) {
|
||||
float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[index];
|
||||
float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[index];
|
||||
output[index] = in0 + in1;
|
||||
input1 += C8NUM;
|
||||
output += C8NUM;
|
||||
}
|
||||
for (int index = 0; index < block_mod; ++index) {
|
||||
output[index] = in0_opt + input1[index];
|
||||
}
|
||||
} else {
|
||||
for (int index = 0; index < block_c8; index += C8NUM) {
|
||||
#ifdef ENABLE_NEON
|
||||
float16x8_t vin0 = vld1q_f16(input0);
|
||||
float16x8_t vin1 = vin1_opt;
|
||||
float16x8_t vout = vaddq_f16(vin0, vin1);
|
||||
vst1q_f16(output, vout);
|
||||
#else
|
||||
for (int i = 0; i < C8NUM; ++i) {
|
||||
output[i] = input0[i] + in1_opt;
|
||||
}
|
||||
#endif
|
||||
input0 += C8NUM;
|
||||
output += C8NUM;
|
||||
}
|
||||
for (int index = 0; index < block_mod; ++index) {
|
||||
output[index] = input0[index] + in1_opt;
|
||||
}
|
||||
}
|
||||
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
|
@ -345,37 +408,54 @@ int ElementOptAddReluFp16(float16_t *input0, float16_t *input1, float16_t *outpu
|
|||
ArithmeticParameter *param) {
|
||||
int block_mod = element_size % C8NUM;
|
||||
int block_c8 = element_size - block_mod;
|
||||
float16_t in0_opt = input0[0];
|
||||
float16_t in1_opt = input1[0];
|
||||
#ifdef ENABLE_NEON
|
||||
float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]};
|
||||
float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]};
|
||||
float16_t in0_opt = input0[0];
|
||||
float16_t in1_opt = input1[0];
|
||||
float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0};
|
||||
#endif
|
||||
|
||||
for (int index = 0; index < block_c8; index += C8NUM) {
|
||||
if (param->in_elements_num0_ == 1) {
|
||||
for (int index = 0; index < block_c8; index += C8NUM) {
|
||||
#ifdef ENABLE_NEON
|
||||
float16x8_t vin0 = param->in_elements_num0_ == 1 ? vin0_opt : vld1q_f16(input0);
|
||||
float16x8_t vin1 = param->in_elements_num1_ == 1 ? vin1_opt : vld1q_f16(input1);
|
||||
float16x8_t vout = vaddq_f16(vin0, vin1);
|
||||
vout = vmaxq_f16(vout, zeros);
|
||||
vst1q_f16(output, vout);
|
||||
float16x8_t vin0 = vin0_opt;
|
||||
float16x8_t vin1 = vld1q_f16(input1);
|
||||
float16x8_t vout = vaddq_f16(vin0, vin1);
|
||||
vout = vmaxq_f16(vout, zeros);
|
||||
vst1q_f16(output, vout);
|
||||
#else
|
||||
for (int i = 0; i < C8NUM; ++i) {
|
||||
float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[i];
|
||||
float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[i];
|
||||
output[i] = MSMAX(in0 + in1, 0);
|
||||
}
|
||||
for (int i = 0; i < C8NUM; ++i) {
|
||||
output[i] = MSMAX(in0_opt + input1[i], 0);
|
||||
}
|
||||
#endif
|
||||
input0 += C8NUM;
|
||||
input1 += C8NUM;
|
||||
output += C8NUM;
|
||||
}
|
||||
for (int index = 0; index < block_mod; ++index) {
|
||||
float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[index];
|
||||
float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[index];
|
||||
float16_t res = in0 + in1;
|
||||
output[index] = res > 0 ? res : 0;
|
||||
input1 += C8NUM;
|
||||
output += C8NUM;
|
||||
}
|
||||
for (int index = 0; index < block_mod; ++index) {
|
||||
float16_t res = in0_opt + input1[index];
|
||||
output[index] = res > 0 ? res : 0;
|
||||
}
|
||||
} else {
|
||||
for (int index = 0; index < block_c8; index += C8NUM) {
|
||||
#ifdef ENABLE_NEON
|
||||
float16x8_t vin0 = vld1q_f16(input0);
|
||||
float16x8_t vin1 = vin1_opt;
|
||||
float16x8_t vout = vaddq_f16(vin0, vin1);
|
||||
vout = vmaxq_f16(vout, zeros);
|
||||
vst1q_f16(output, vout);
|
||||
#else
|
||||
for (int i = 0; i < C8NUM; ++i) {
|
||||
output[i] = MSMAX(input0[i] + in1_opt, 0);
|
||||
}
|
||||
#endif
|
||||
input0 += C8NUM;
|
||||
output += C8NUM;
|
||||
}
|
||||
for (int index = 0; index < block_mod; ++index) {
|
||||
float16_t res = input0[index] + in1_opt;
|
||||
output[index] = res > 0 ? res : 0;
|
||||
}
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
@ -415,39 +495,54 @@ int ElementOptAddRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *outp
|
|||
ArithmeticParameter *param) {
|
||||
int block_mod = element_size % C8NUM;
|
||||
int block_c8 = element_size - block_mod;
|
||||
float16_t in0_opt = input0[0];
|
||||
float16_t in1_opt = input1[0];
|
||||
#ifdef ENABLE_NEON
|
||||
float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]};
|
||||
float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]};
|
||||
float16_t in0_opt = input0[0];
|
||||
float16_t in1_opt = input1[0];
|
||||
float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0};
|
||||
float16x8_t bounds = {6, 6, 6, 6, 6, 6, 6, 6};
|
||||
#endif
|
||||
|
||||
for (int index = 0; index < block_c8; index += C8NUM) {
|
||||
if (param->in_elements_num0_ == 1) {
|
||||
for (int index = 0; index < block_c8; index += C8NUM) {
|
||||
#ifdef ENABLE_NEON
|
||||
float16x8_t vin0 = param->in_elements_num0_ == 1 ? vin0_opt : vld1q_f16(input0);
|
||||
float16x8_t vin1 = param->in_elements_num1_ == 1 ? vin1_opt : vld1q_f16(input1);
|
||||
float16x8_t vout = vaddq_f16(vin0, vin1);
|
||||
vout = vminq_f16(vmaxq_f16(vout, zeros), bounds);
|
||||
vst1q_f16(output, vout);
|
||||
float16x8_t vin0 = vin0_opt;
|
||||
float16x8_t vin1 = vld1q_f16(input1);
|
||||
float16x8_t vout = vaddq_f16(vin0, vin1);
|
||||
vout = vminq_f16(vmaxq_f16(vout, zeros), bounds);
|
||||
vst1q_f16(output, vout);
|
||||
#else
|
||||
for (int i = 0; i < C8NUM; ++i) {
|
||||
float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[i];
|
||||
float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[i];
|
||||
output[i] = MSMIN(MSMAX(in0 + in1, 0), 6);
|
||||
}
|
||||
for (int i = 0; i < C8NUM; ++i) {
|
||||
output[i] = MSMIN(MSMAX(in0_opt + input1[i], 0), 6);
|
||||
}
|
||||
#endif
|
||||
input0 += C8NUM;
|
||||
input1 += C8NUM;
|
||||
output += C8NUM;
|
||||
input1 += C8NUM;
|
||||
output += C8NUM;
|
||||
}
|
||||
for (int index = 0; index < block_mod; ++index) {
|
||||
output[index] = MSMIN(MSMAX(in0_opt + input1[index], 0), 6);
|
||||
}
|
||||
} else {
|
||||
for (int index = 0; index < block_c8; index += C8NUM) {
|
||||
#ifdef ENABLE_NEON
|
||||
float16x8_t vin0 = vld1q_f16(input0);
|
||||
float16x8_t vin1 = vin1_opt;
|
||||
float16x8_t vout = vaddq_f16(vin0, vin1);
|
||||
vout = vminq_f16(vmaxq_f16(vout, zeros), bounds);
|
||||
vst1q_f16(output, vout);
|
||||
#else
|
||||
for (int i = 0; i < C8NUM; ++i) {
|
||||
output[i] = MSMIN(MSMAX(input0[i] + in1_opt, 0), 6);
|
||||
}
|
||||
#endif
|
||||
input0 += C8NUM;
|
||||
output += C8NUM;
|
||||
}
|
||||
for (int index = 0; index < block_mod; ++index) {
|
||||
output[index] = MSMIN(MSMAX(input0[index] + in1_opt, 0), 6);
|
||||
}
|
||||
}
|
||||
for (int index = 0; index < block_mod; ++index) {
|
||||
float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[index];
|
||||
float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[index];
|
||||
output[index] = MSMIN(MSMAX(in0 + in1, 0), 6);
|
||||
}
|
||||
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
|
@ -479,11 +574,11 @@ int ElementOptSubFp16(float16_t *input0, float16_t *input1, float16_t *output, i
|
|||
ArithmeticParameter *param) {
|
||||
int block_mod = element_size % C8NUM;
|
||||
int block_c8 = element_size - block_mod;
|
||||
float16_t in0_opt = input0[0];
|
||||
float16_t in1_opt = input1[0];
|
||||
#ifdef ENABLE_NEON
|
||||
float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]};
|
||||
float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]};
|
||||
float16_t in0_opt = input0[0];
|
||||
float16_t in1_opt = input1[0];
|
||||
#endif
|
||||
for (int index = 0; index < block_c8; index += C8NUM) {
|
||||
#ifdef ENABLE_NEON
|
||||
|
@ -542,11 +637,11 @@ int ElementOptSubReluFp16(float16_t *input0, float16_t *input1, float16_t *outpu
|
|||
ArithmeticParameter *param) {
|
||||
int block_mod = element_size % C8NUM;
|
||||
int block_c8 = element_size - block_mod;
|
||||
float16_t in0_opt = input0[0];
|
||||
float16_t in1_opt = input1[0];
|
||||
#ifdef ENABLE_NEON
|
||||
float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]};
|
||||
float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]};
|
||||
float16_t in0_opt = input0[0];
|
||||
float16_t in1_opt = input1[0];
|
||||
float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0};
|
||||
#endif
|
||||
for (int index = 0; index < block_c8; index += C8NUM) {
|
||||
|
@ -609,11 +704,11 @@ int ElementOptSubRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *outp
|
|||
ArithmeticParameter *param) {
|
||||
int block_mod = element_size % C8NUM;
|
||||
int block_c8 = element_size - block_mod;
|
||||
float16_t in0_opt = input0[0];
|
||||
float16_t in1_opt = input1[0];
|
||||
#ifdef ENABLE_NEON
|
||||
float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]};
|
||||
float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]};
|
||||
float16_t in0_opt = input0[0];
|
||||
float16_t in1_opt = input1[0];
|
||||
float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0};
|
||||
float16x8_t bounds = {6, 6, 6, 6, 6, 6, 6, 6};
|
||||
#endif
|
||||
|
@ -680,11 +775,11 @@ int ElementOptDivFp16(float16_t *input0, float16_t *input1, float16_t *output, i
|
|||
ArithmeticParameter *param) {
|
||||
int block_mod = element_size % C8NUM;
|
||||
int block_c8 = element_size - block_mod;
|
||||
float16_t in0_opt = input0[0];
|
||||
float16_t in1_opt = input1[0];
|
||||
#ifdef ENABLE_NEON
|
||||
float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]};
|
||||
float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]};
|
||||
float16_t in0_opt = input0[0];
|
||||
float16_t in1_opt = input1[0];
|
||||
#endif
|
||||
for (int index = 0; index < block_c8; index += C8NUM) {
|
||||
if (param->in_elements_num1_ == 1) {
|
||||
|
@ -765,12 +860,11 @@ int ElementOptDivReluFp16(float16_t *input0, float16_t *input1, float16_t *outpu
|
|||
ArithmeticParameter *param) {
|
||||
int block_mod = element_size % C8NUM;
|
||||
int block_c8 = element_size - block_mod;
|
||||
|
||||
float16_t in0_opt = input0[0];
|
||||
float16_t in1_opt = input1[0];
|
||||
#ifdef ENABLE_NEON
|
||||
float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]};
|
||||
float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]};
|
||||
float16_t in0_opt = input0[0];
|
||||
float16_t in1_opt = input1[0];
|
||||
float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0};
|
||||
#endif
|
||||
for (int index = 0; index < block_c8; index += C8NUM) {
|
||||
|
@ -855,11 +949,11 @@ int ElementOptDivRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *outp
|
|||
int block_mod = element_size % C8NUM;
|
||||
int block_c8 = element_size - block_mod;
|
||||
|
||||
float16_t in0_opt = input0[0];
|
||||
float16_t in1_opt = input1[0];
|
||||
#ifdef ENABLE_NEON
|
||||
float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]};
|
||||
float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]};
|
||||
float16_t in0_opt = input0[0];
|
||||
float16_t in1_opt = input1[0];
|
||||
float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0};
|
||||
float16x8_t bounds = {6, 6, 6, 6, 6, 6, 6, 6};
|
||||
#endif
|
||||
|
|
|
@ -17,8 +17,8 @@
|
|||
#include "nnacl/fp16/batchnorm_fp16.h"
|
||||
#include <math.h>
|
||||
|
||||
void BatchNormFp16(const void *input, const void *mean, const void *variance,
|
||||
BatchNormParameter *param, int task_id, void *output) {
|
||||
void BatchNormFp16(const float16_t *input, const void *mean, const void *variance,
|
||||
BatchNormParameter *param, int task_id, float16_t *output) {
|
||||
int units_per_thread = UP_DIV(param->unit_, param->op_parameter_.thread_num_);
|
||||
int completed_units = task_id * units_per_thread;
|
||||
int cur_unit = MSMIN(units_per_thread, param->unit_ - completed_units);
|
||||
|
@ -27,8 +27,9 @@ void BatchNormFp16(const void *input, const void *mean, const void *variance,
|
|||
for (int i = 0; i < cur_unit; i++) {
|
||||
for (int c = 0; c < param->channel_; c++) {
|
||||
float16_t variance_sqrt = sqrt(((const float16_t *)variance)[c] + param->epsilon_);
|
||||
((float16_t *)output)[cur_offset + c] =
|
||||
(((const float16_t *)input)[cur_offset + c] - ((const float16_t *)mean)[c]) / variance_sqrt;
|
||||
if (variance_sqrt != 0) {
|
||||
output[cur_offset + c] = (input[cur_offset + c] - ((const float16_t *)mean)[c]) / variance_sqrt;
|
||||
}
|
||||
}
|
||||
cur_offset += param->channel_;
|
||||
}
|
||||
|
@ -44,8 +45,12 @@ void FusedBatchNormFp16(const void *input, const void *scale, const void *offset
|
|||
for (int i = 0; i < cur_unit; i++) {
|
||||
for (int c = 0; c < param->channel_; c++) {
|
||||
float16_t variance_sqrt = sqrt(((const float16_t *)variance)[c] + param->epsilon_);
|
||||
float16_t norm_val = (((const float16_t *)input)[cur_offset + c] - ((const float16_t *)mean)[c]) / variance_sqrt;
|
||||
((float16_t *)output)[cur_offset + c] = norm_val * ((const float16_t *)scale)[c] + ((const float16_t *)offset)[c];
|
||||
if (variance_sqrt != 0) {
|
||||
float16_t norm_val =
|
||||
(((const float16_t *)input)[cur_offset + c] - ((const float16_t *)mean)[c]) / variance_sqrt;
|
||||
((float16_t *)output)[cur_offset + c] =
|
||||
norm_val * ((const float16_t *)scale)[c] + ((const float16_t *)offset)[c];
|
||||
}
|
||||
}
|
||||
cur_offset += param->channel_;
|
||||
}
|
||||
|
|
|
@ -25,8 +25,8 @@
|
|||
extern "C" {
|
||||
#endif
|
||||
|
||||
void BatchNormFp16(const void *input, const void *mean, const void *variance, BatchNormParameter *param, int task_id,
|
||||
void *output);
|
||||
void BatchNormFp16(const float16_t *input, const void *mean, const void *variance, BatchNormParameter *param,
|
||||
int task_id, float16_t *output);
|
||||
void FusedBatchNormFp16(const void *input, const void *scale, const void *offset, const void *mean,
|
||||
const void *variance, BatchNormParameter *param, int task_id, void *output);
|
||||
|
||||
|
|
|
@ -1,61 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "nnacl/fp16/common_func.h"
|
||||
|
||||
void ReluFp16(float16_t *data, float16_t *dst, int ele_num) {
|
||||
int eight_block = UP_DIV(ele_num, C8NUM);
|
||||
for (int i = 0; i < eight_block - 1; i++) {
|
||||
int index = i * C8NUM;
|
||||
#ifdef ENABLE_NEON
|
||||
float16x8_t relu_data = vld1q_f16(data + index);
|
||||
float16x8_t zero_data = vdupq_n_f16(0);
|
||||
relu_data = vmaxq_f16(relu_data, zero_data);
|
||||
vst1q_f16(dst + index, relu_data);
|
||||
#else
|
||||
data[index] = data[index] < 0 ? 0 : data[index];
|
||||
data[index + 1] = data[index + 1] < 0 ? 0 : data[index + 1];
|
||||
data[index + 2] = data[index + 2] < 0 ? 0 : data[index + 2];
|
||||
data[index + 3] = data[index + 3] < 0 ? 0 : data[index + 3];
|
||||
#endif
|
||||
}
|
||||
for (int j = (eight_block - 1) * C8NUM; j < ele_num; ++j) {
|
||||
data[j] = data[j] < 0 ? 0 : data[j];
|
||||
}
|
||||
}
|
||||
|
||||
void Relu6Fp16(float16_t *data, float16_t *dst, int ele_num) {
|
||||
int eight_block = UP_DIV(ele_num, C8NUM);
|
||||
for (int i = 0; i < eight_block - 1; i++) {
|
||||
int index = i * C8NUM;
|
||||
#ifdef ENABLE_NEON
|
||||
float16x8_t relu6_data = vld1q_f16(data + index);
|
||||
float16x8_t zero_data = vdupq_n_f16(0);
|
||||
float16x8_t six_data = vdupq_n_f16(6);
|
||||
relu6_data = vmaxq_f16(relu6_data, zero_data);
|
||||
relu6_data = vminq_f16(relu6_data, six_data);
|
||||
vst1q_f16(dst + index, relu6_data);
|
||||
#else
|
||||
for (int j = 0; j < C8NUM; ++j) {
|
||||
data[index + j] = data[index + j] < 0 ? 0 : data[index + j];
|
||||
data[index + j] = data[index + j] > 6 ? 6 : data[index + j];
|
||||
}
|
||||
#endif
|
||||
}
|
||||
for (int j = (eight_block - 1) * C8NUM; j < ele_num; ++j) {
|
||||
data[j] = data[j] < 0 ? 0 : data[j];
|
||||
data[j] = data[j] > 6 ? 6 : data[j];
|
||||
}
|
||||
}
|
|
@ -1,51 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_NNACL_FP16_COMMON_FUNC_H_
|
||||
#define MINDSPORE_LITE_NNACL_FP16_COMMON_FUNC_H_
|
||||
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
#include <string.h>
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/conv_parameter.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_ARM64
|
||||
void ConvDwFp16Border(float16_t *dst, const float16_t *src, const float16_t *weight, const float16_t *bias,
|
||||
size_t height, size_t width, size_t in_kh_step, size_t in_kw_step, size_t kernel_w, size_t relu,
|
||||
size_t relu6);
|
||||
void ConvDwFp16Center(float16_t *dst, const float16_t *src, const float16_t *weight, const float16_t *bias,
|
||||
size_t height, size_t width, size_t kernel_h, size_t kernel_w, size_t out_h_step,
|
||||
size_t block_channel, size_t in_sh_step, size_t in_sw_step, size_t in_kh_step, size_t in_kw_step,
|
||||
size_t relu, size_t relu6);
|
||||
void DeconvDwFp16Border(float16_t *dst, const float16_t *src, const float16_t *weight, size_t height, size_t width,
|
||||
size_t in_kh_step, size_t in_kw_step, size_t kernel_w);
|
||||
void DeconvDwFp16Center(float16_t *dst, const float16_t *src, const float16_t *weight, size_t height, size_t width,
|
||||
size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step,
|
||||
size_t in_sw_step, size_t in_kh_step, size_t in_kw_step);
|
||||
#endif
|
||||
void ReluFp16(float16_t *data, float16_t *dst, int ele_num);
|
||||
void Relu6Fp16(float16_t *data, float16_t *dst, int ele_num);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif /* MINDSPORE_LITE_NNACL_FP32_COMMON_FUNC_H_ */
|
|
@ -15,8 +15,62 @@
|
|||
*/
|
||||
|
||||
#include "nnacl/fp16/conv_depthwise_fp16.h"
|
||||
#include <arm_neon.h>
|
||||
#include "nnacl/fp16/common_func.h"
|
||||
#include <string.h>
|
||||
#include "nnacl/fp16/activation_fp16.h"
|
||||
|
||||
void ConvDwFp16(float16_t *output_data, const float16_t *input_data, const float16_t *weight_data,
|
||||
const float16_t *bias_data, const ConvParameter *conv_param, int task_id) {
|
||||
int h_step = UP_DIV(conv_param->output_h_, conv_param->thread_num_);
|
||||
int h_start = h_step * task_id;
|
||||
int h_end = MSMIN(h_start + h_step, conv_param->output_h_);
|
||||
bool relu = conv_param->act_type_ == ActType_Relu;
|
||||
bool relu6 = conv_param->act_type_ == ActType_Relu6;
|
||||
for (int b = 0; b < conv_param->output_batch_; b++) {
|
||||
const float16_t *src = input_data + b * conv_param->input_h_ * conv_param->input_w_ * conv_param->input_channel_;
|
||||
float16_t *dst = output_data + b * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_channel_;
|
||||
for (int oh = h_start; oh < h_end; oh++) {
|
||||
float16_t *dst_data = dst + oh * conv_param->output_w_ * conv_param->output_channel_;
|
||||
|
||||
int ih_origin = oh * conv_param->stride_h_ - conv_param->pad_u_;
|
||||
int start_kh = MSMAX(0, UP_DIV(-ih_origin, conv_param->dilation_h_));
|
||||
int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih_origin, conv_param->dilation_h_));
|
||||
|
||||
for (int ow = 0; ow < conv_param->output_w_; ow++) {
|
||||
memcpy(dst_data + ow * conv_param->output_channel_, bias_data, conv_param->output_channel_ * sizeof(float16_t));
|
||||
}
|
||||
for (int kh = start_kh; kh < end_kh; kh++) {
|
||||
int ih = ih_origin + conv_param->dilation_w_ * kh;
|
||||
|
||||
const float16_t *src_kh = src + ih * conv_param->input_w_ * conv_param->input_channel_;
|
||||
const float16_t *weight_kh = weight_data + kh * conv_param->kernel_w_ * conv_param->output_channel_;
|
||||
|
||||
int in_sw_step = conv_param->stride_w_ * conv_param->input_channel_;
|
||||
for (int kw = 0; kw < conv_param->kernel_w_; kw++) {
|
||||
int out_w_start = MSMAX(
|
||||
0, (conv_param->pad_l_ - conv_param->dilation_w_ * kw + conv_param->stride_w_ - 1) / conv_param->stride_w_);
|
||||
int out_w_end = MSMIN(conv_param->output_w_, (conv_param->input_w_ + conv_param->pad_l_ -
|
||||
conv_param->dilation_w_ * kw + conv_param->stride_w_ - 1) /
|
||||
conv_param->stride_w_);
|
||||
|
||||
float16_t *dst_w = dst_data + out_w_start * conv_param->output_channel_;
|
||||
int iw_origin = (out_w_start * conv_param->stride_w_) - conv_param->pad_l_ + conv_param->dilation_w_ * kw;
|
||||
|
||||
const float16_t *src_kw = src_kh + iw_origin * conv_param->input_channel_;
|
||||
int num_pixels = out_w_end - out_w_start;
|
||||
|
||||
ConvDwFp16Row(dst_w, src_kw, weight_kh, num_pixels, conv_param->output_channel_, in_sw_step);
|
||||
weight_kh += conv_param->output_channel_;
|
||||
}
|
||||
}
|
||||
if (relu) {
|
||||
ReluFp16(dst_data, dst_data, conv_param->output_w_ * conv_param->output_channel_);
|
||||
}
|
||||
if (relu6) {
|
||||
Relu6Fp16(dst_data, dst_data, conv_param->output_w_ * conv_param->output_channel_);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*conv depthwise fp16 begin*/
|
||||
void DepthwiseBorderPixelFp16(float16_t *dst, const float16_t *src, const float16_t *weight, const float16_t *bias,
|
||||
|
@ -53,16 +107,18 @@ void DepthwiseBorderPixelFp16(float16_t *dst, const float16_t *src, const float1
|
|||
void DepthwiseBorderFp16(float16_t *dst, const float16_t *src, const float16_t *weight, const float16_t *bias, int top,
|
||||
int bottom, int left, int right, const ConvParameter *conv_param,
|
||||
const SlidingWindowParam *sliding) {
|
||||
bool relu = conv_param->act_type_ == ActType_Relu;
|
||||
bool relu6 = conv_param->act_type_ == ActType_Relu6;
|
||||
float16_t *dst_h = dst + top * sliding->out_h_step_;
|
||||
for (int oh = top; oh < bottom; oh++) {
|
||||
int ih = oh * conv_param->stride_h_ - conv_param->pad_h_;
|
||||
int ih = oh * conv_param->stride_h_ - conv_param->pad_u_;
|
||||
int start_kh = MSMAX(0, UP_DIV(-ih, conv_param->dilation_h_));
|
||||
int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih, conv_param->dilation_h_));
|
||||
const float16_t *src_h = src + ih * sliding->in_h_step_;
|
||||
|
||||
float16_t *dst_kernel = dst_h + left * sliding->block_channel_;
|
||||
for (int ow = left; ow < right; ow++) {
|
||||
int iw = ow * conv_param->stride_w_ - conv_param->pad_w_;
|
||||
int iw = ow * conv_param->stride_w_ - conv_param->pad_l_;
|
||||
int start_kw = MSMAX(0, UP_DIV(-iw, conv_param->dilation_w_));
|
||||
int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->input_w_ - iw, conv_param->dilation_w_));
|
||||
const float16_t *src_w = src_h + iw * sliding->block_channel_;
|
||||
|
@ -72,11 +128,10 @@ void DepthwiseBorderFp16(float16_t *dst, const float16_t *src, const float16_t *
|
|||
#ifdef ENABLE_ARM64
|
||||
ConvDwFp16Border(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw,
|
||||
sliding->in_kh_step_ * sizeof(float16_t), sliding->in_kw_step_ * sizeof(float16_t),
|
||||
conv_param->kernel_w_ * C8NUM * sizeof(float16_t), conv_param->is_relu_, conv_param->is_relu6_);
|
||||
conv_param->kernel_w_ * C8NUM * sizeof(float16_t), relu, relu6);
|
||||
#else
|
||||
DepthwiseBorderPixelFp16(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw,
|
||||
sliding->in_kh_step_, sliding->in_kw_step_, conv_param->kernel_w_ * C8NUM,
|
||||
conv_param->is_relu_, conv_param->is_relu6_);
|
||||
sliding->in_kh_step_, sliding->in_kw_step_, conv_param->kernel_w_ * C8NUM, relu, relu6);
|
||||
#endif
|
||||
dst_kernel += sliding->block_channel_;
|
||||
} // width loop
|
||||
|
@ -139,6 +194,8 @@ void DepthwiseCenterFp16(float16_t *dst, const float16_t *src, const float16_t *
|
|||
void ConvDwC8Fp16(float16_t *output_data, const float16_t *input_data, const float16_t *weight_data,
|
||||
const float16_t *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding,
|
||||
int task_id) {
|
||||
bool relu = conv_param->act_type_ == ActType_Relu;
|
||||
bool relu6 = conv_param->act_type_ == ActType_Relu6;
|
||||
const float16_t *src = input_data;
|
||||
float16_t *dst = output_data;
|
||||
for (int b = 0; b < conv_param->output_batch_; b++) {
|
||||
|
@ -157,8 +214,8 @@ void ConvDwC8Fp16(float16_t *output_data, const float16_t *input_data, const flo
|
|||
conv_param->output_w_, conv_param, sliding);
|
||||
|
||||
if (sliding->right_ > sliding->left_ && sliding->bottom_ > sliding->top_) {
|
||||
int in_h_start = sliding->top_ * conv_param->stride_h_ - conv_param->pad_h_;
|
||||
int in_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_w_;
|
||||
int in_h_start = sliding->top_ * conv_param->stride_h_ - conv_param->pad_u_;
|
||||
int in_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_l_;
|
||||
const float16_t *in_t = src_data + in_h_start * sliding->in_h_step_ + in_w_start * sliding->block_channel_;
|
||||
float16_t *out_t = dst_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_;
|
||||
#ifdef ENABLE_ARM64
|
||||
|
@ -166,12 +223,12 @@ void ConvDwC8Fp16(float16_t *output_data, const float16_t *input_data, const flo
|
|||
conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_ * sizeof(float16_t),
|
||||
sliding->block_channel_ * sizeof(float16_t), sliding->in_sh_step_ * sizeof(float16_t),
|
||||
sliding->in_sw_step_ * sizeof(float16_t), sliding->in_kh_step_ * sizeof(float16_t),
|
||||
sliding->in_kw_step_ * sizeof(float16_t), conv_param->is_relu_, conv_param->is_relu6_);
|
||||
sliding->in_kw_step_ * sizeof(float16_t), relu, relu6);
|
||||
#else
|
||||
DepthwiseCenterFp16(out_t, in_t, weight, bias, sliding->bottom_ - sliding->top_,
|
||||
sliding->right_ - sliding->left_, conv_param->kernel_h_, conv_param->kernel_w_,
|
||||
sliding->out_h_step_, sliding->block_channel_, sliding->in_sh_step_, sliding->in_sw_step_,
|
||||
sliding->in_kh_step_, sliding->in_kw_step_, conv_param->is_relu_, conv_param->is_relu6_);
|
||||
sliding->in_kh_step_, sliding->in_kw_step_, relu, relu6);
|
||||
#endif
|
||||
}
|
||||
} // output C8 loop
|
||||
|
@ -210,14 +267,14 @@ void DeconvDepthwiseBorderFp16(float16_t *dst, const float16_t *src, const float
|
|||
const SlidingWindowParam *sliding) {
|
||||
const float16_t *src_h = src + top * sliding->out_h_step_;
|
||||
for (int ih = top; ih < bottom; ih++) {
|
||||
int oh = ih * conv_param->stride_h_ - conv_param->pad_h_;
|
||||
int oh = ih * conv_param->stride_h_ - conv_param->pad_u_;
|
||||
int start_kh = MSMAX(0, UP_DIV(-oh, conv_param->dilation_h_));
|
||||
int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->output_h_ - oh, conv_param->dilation_h_));
|
||||
float16_t *dst_h = dst + oh * sliding->in_h_step_;
|
||||
|
||||
const float16_t *src_kernel = src_h + left * sliding->block_channel_;
|
||||
for (int iw = left; iw < right; iw++) {
|
||||
int ow = iw * conv_param->stride_w_ - conv_param->pad_w_;
|
||||
int ow = iw * conv_param->stride_w_ - conv_param->pad_l_;
|
||||
int start_kw = MSMAX(0, UP_DIV(-ow, conv_param->dilation_w_));
|
||||
int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->output_w_ - ow, conv_param->dilation_w_));
|
||||
float16_t *dst_w = dst_h + ow * sliding->block_channel_;
|
||||
|
@ -282,12 +339,14 @@ void DeconvDepthwiseCenterFp16(float16_t *dst, const float16_t *src, const float
|
|||
|
||||
void DeconvDepthwisePostFuncFp16(float16_t *dst, const float16_t *bias, int block_channel,
|
||||
const ConvParameter *conv_param) {
|
||||
bool relu = conv_param->act_type_ == ActType_Relu;
|
||||
bool relu6 = conv_param->act_type_ == ActType_Relu6;
|
||||
float16_t *dst_k = dst;
|
||||
for (int k = 0; k < conv_param->output_h_ * conv_param->output_w_; k++) {
|
||||
for (int c = 0; c < C8NUM; c++) {
|
||||
dst_k[c] += bias[c];
|
||||
dst_k[c] = (conv_param->is_relu_) ? (MSMAX(0, dst_k[c])) : (dst_k[c]);
|
||||
dst_k[c] = (conv_param->is_relu6_) ? (MSMIN(6, MSMAX(0, dst_k[c]))) : (dst_k[c]);
|
||||
dst_k[c] = (relu) ? (MSMAX(0, dst_k[c])) : (dst_k[c]);
|
||||
dst_k[c] = (relu6) ? (MSMIN(6, MSMAX(0, dst_k[c]))) : (dst_k[c]);
|
||||
}
|
||||
dst_k += block_channel;
|
||||
}
|
||||
|
@ -315,8 +374,8 @@ void DeconvDwC8Fp16(float16_t *output_data, const float16_t *input_data, const f
|
|||
conv_param->input_w_, conv_param, sliding);
|
||||
|
||||
if (sliding->right_ > sliding->left_ && sliding->bottom_ > sliding->top_) {
|
||||
int oh_h_start = sliding->top_ * conv_param->stride_h_ - conv_param->pad_h_;
|
||||
int oh_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_w_;
|
||||
int oh_h_start = sliding->top_ * conv_param->stride_h_ - conv_param->pad_u_;
|
||||
int oh_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_l_;
|
||||
float16_t *out_t = dst_data + oh_h_start * sliding->in_h_step_ + oh_w_start * sliding->block_channel_;
|
||||
const float16_t *in_t =
|
||||
src_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_;
|
||||
|
|
|
@ -23,6 +23,26 @@
|
|||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
#ifdef ENABLE_ARM64
|
||||
void ConvDwFp16Row(float16_t *output_ptr, const float16_t *input_ptr, const float16_t *filter_ptr, size_t num_pixels,
|
||||
size_t input_channel, size_t input_step);
|
||||
void ConvDwFp16Border(float16_t *dst, const float16_t *src, const float16_t *weight, const float16_t *bias,
|
||||
size_t height, size_t width, size_t in_kh_step, size_t in_kw_step, size_t kernel_w, size_t relu,
|
||||
size_t relu6);
|
||||
void ConvDwFp16Center(float16_t *dst, const float16_t *src, const float16_t *weight, const float16_t *bias,
|
||||
size_t height, size_t width, size_t kernel_h, size_t kernel_w, size_t out_h_step,
|
||||
size_t block_channel, size_t in_sh_step, size_t in_sw_step, size_t in_kh_step, size_t in_kw_step,
|
||||
size_t relu, size_t relu6);
|
||||
void DeconvDwFp16Border(float16_t *dst, const float16_t *src, const float16_t *weight, size_t height, size_t width,
|
||||
size_t in_kh_step, size_t in_kw_step, size_t kernel_w);
|
||||
void DeconvDwFp16Center(float16_t *dst, const float16_t *src, const float16_t *weight, size_t height, size_t width,
|
||||
size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step,
|
||||
size_t in_sw_step, size_t in_kh_step, size_t in_kw_step);
|
||||
#endif
|
||||
|
||||
void ConvDwFp16(float16_t *output_data, const float16_t *input_data, const float16_t *weight_data,
|
||||
const float16_t *bias_data, const ConvParameter *conv_param, int task_id);
|
||||
|
||||
void ConvDwC8Fp16(float16_t *output_data, const float16_t *input_data, const float16_t *weight_data,
|
||||
const float16_t *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding,
|
||||
int task_id);
|
||||
|
|
|
@ -173,16 +173,18 @@ void SWBorderPixel(float16_t *dst, const float16_t *src, const float16_t *weight
|
|||
|
||||
void SWBorderFp16(float16_t *dst, const float16_t *src, const float16_t *weight, const float16_t *bias, int top,
|
||||
int bottom, int left, int right, const ConvParameter *conv_param, const SlidingWindowParam *sliding) {
|
||||
bool relu = conv_param->act_type_ == ActType_Relu;
|
||||
bool relu6 = conv_param->act_type_ == ActType_Relu6;
|
||||
float16_t *dst_h = dst + top * sliding->out_h_step_;
|
||||
for (int oh = top; oh < bottom; oh++) {
|
||||
int ih = oh * conv_param->stride_h_ - conv_param->pad_h_;
|
||||
int ih = oh * conv_param->stride_h_ - conv_param->pad_u_;
|
||||
int start_kh = MSMAX(0, UP_DIV(-ih, conv_param->dilation_h_));
|
||||
int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih, conv_param->dilation_h_));
|
||||
const float16_t *src_h = src + ih * sliding->in_h_step_;
|
||||
|
||||
float16_t *dst_kernel = dst_h + left * sliding->block_channel_;
|
||||
for (int ow = left; ow < right; ow++) {
|
||||
int iw = ow * conv_param->stride_w_ - conv_param->pad_w_;
|
||||
int iw = ow * conv_param->stride_w_ - conv_param->pad_l_;
|
||||
int start_kw = MSMAX(0, UP_DIV(-iw, conv_param->dilation_w_));
|
||||
int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->input_w_ - iw, conv_param->dilation_w_));
|
||||
const float16_t *src_w = src_h + iw * sliding->ic4_channel_;
|
||||
|
@ -192,7 +194,7 @@ void SWBorderFp16(float16_t *dst, const float16_t *src, const float16_t *weight,
|
|||
|
||||
SWBorderPixel(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw,
|
||||
sliding->in_kh_step_, sliding->in_kw_step_, conv_param->kernel_h_, conv_param->kernel_w_,
|
||||
sliding->ic4_channel_, conv_param->is_relu_, conv_param->is_relu6_);
|
||||
sliding->ic4_channel_, relu, relu6);
|
||||
|
||||
dst_kernel += sliding->block_channel_;
|
||||
} // width loop
|
||||
|
@ -273,6 +275,8 @@ void SWCenterFp16(float16_t *dst, const float16_t *src, const float16_t *weight,
|
|||
void ConvSWFp16(const float16_t *input_data, const float16_t *packed_weight, const float16_t *bias_data,
|
||||
float16_t *tmp_out_block, float16_t *output_data, int task_id, ConvParameter *conv_param,
|
||||
SlidingWindowParam *slidingWindow_param) {
|
||||
bool relu = conv_param->act_type_ == ActType_Relu;
|
||||
bool relu6 = conv_param->act_type_ == ActType_Relu6;
|
||||
int oc4_res = conv_param->output_channel_ % C4NUM;
|
||||
const float16_t *src = input_data;
|
||||
float16_t *dst;
|
||||
|
@ -299,8 +303,8 @@ void ConvSWFp16(const float16_t *input_data, const float16_t *packed_weight, con
|
|||
|
||||
if (slidingWindow_param->right_ > slidingWindow_param->left_ &&
|
||||
slidingWindow_param->bottom_ > slidingWindow_param->top_) {
|
||||
int in_h_start = slidingWindow_param->top_ * conv_param->stride_h_ - conv_param->pad_h_;
|
||||
int in_w_start = slidingWindow_param->left_ * conv_param->stride_w_ - conv_param->pad_w_;
|
||||
int in_h_start = slidingWindow_param->top_ * conv_param->stride_h_ - conv_param->pad_u_;
|
||||
int in_w_start = slidingWindow_param->left_ * conv_param->stride_w_ - conv_param->pad_l_;
|
||||
const float16_t *in_t =
|
||||
src_data + in_h_start * slidingWindow_param->in_h_step_ + in_w_start * slidingWindow_param->ic4_channel_;
|
||||
float16_t *out_t = dst_data + slidingWindow_param->top_ * slidingWindow_param->out_h_step_ +
|
||||
|
@ -310,7 +314,7 @@ void ConvSWFp16(const float16_t *input_data, const float16_t *packed_weight, con
|
|||
conv_param->kernel_w_, slidingWindow_param->out_h_step_, slidingWindow_param->block_channel_,
|
||||
slidingWindow_param->ic4_channel_, slidingWindow_param->in_sh_step_,
|
||||
slidingWindow_param->in_sw_step_, slidingWindow_param->in_kh_step_,
|
||||
slidingWindow_param->in_kw_step_, conv_param->is_relu_, conv_param->is_relu6_);
|
||||
slidingWindow_param->in_kw_step_, relu, relu6);
|
||||
}
|
||||
} // output C4 loop
|
||||
src += slidingWindow_param->in_step_;
|
||||
|
@ -330,8 +334,8 @@ void ConvFp16(float16_t *input_data, float16_t *packed_input, float16_t *packed_
|
|||
int out_h = conv_param->output_h_;
|
||||
int out_w = conv_param->output_w_;
|
||||
int out_channel = conv_param->output_channel_;
|
||||
bool relu = conv_param->is_relu_;
|
||||
bool relu6 = conv_param->is_relu6_;
|
||||
bool relu = conv_param->act_type_ == ActType_Relu;
|
||||
bool relu6 = conv_param->act_type_ == ActType_Relu6;
|
||||
int thread_count = conv_param->thread_num_;
|
||||
const int tile_n = 16;
|
||||
int output_count = out_h * out_w;
|
||||
|
@ -365,9 +369,10 @@ void ConvFp16(float16_t *input_data, float16_t *packed_input, float16_t *packed_
|
|||
out_channel * sizeof(float16_t), 0, 0, relu, relu6);
|
||||
} else {
|
||||
// res part
|
||||
IndirectGemmFp16_16x8(tmp_out_block, gemm_input, packed_weight, bias_data, conv_depth, ic4, out_channel,
|
||||
float16_t *tmp_out_ptr = tmp_out_block + task_id * tile_n * out_channel;
|
||||
IndirectGemmFp16_16x8(tmp_out_ptr, gemm_input, packed_weight, bias_data, conv_depth, ic4, out_channel,
|
||||
out_channel * sizeof(float16_t), 0, 0, relu, relu6);
|
||||
memcpy(output_data + out_offset, tmp_out_block, real_cal_num * out_channel * sizeof(float16_t));
|
||||
memcpy(output_data + out_offset, tmp_out_ptr, real_cal_num * out_channel * sizeof(float16_t));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -395,7 +400,6 @@ void Conv3x3Fp16(float16_t *input_data, float16_t *transed_weight, const float16
|
|||
|
||||
int input_batch = conv_param->input_batch_;
|
||||
for (int batch = 0; batch < input_batch; batch++) {
|
||||
int in_batch_offset = batch * ic4 * C4NUM * conv_param->input_h_ * conv_param->input_w_;
|
||||
int tmp_out_batch_offset = batch * oc8 * C8NUM * out_w_block * out_h_block * output_unit * output_unit;
|
||||
for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_count) {
|
||||
int start_index = thread_id * tile_num;
|
||||
|
|
|
@ -73,8 +73,8 @@ int DeConvPostFp16(const float16_t *src, float16_t *tmp, const float16_t *bias,
|
|||
|
||||
for (int ih = 0; ih < conv_param->input_h_; ih++) {
|
||||
for (int iw = 0; iw < conv_param->input_w_; iw++) {
|
||||
int oh = ih * conv_param->stride_h_ - conv_param->pad_h_;
|
||||
int ow = iw * conv_param->stride_w_ - conv_param->pad_w_;
|
||||
int oh = ih * conv_param->stride_h_ - conv_param->pad_u_;
|
||||
int ow = iw * conv_param->stride_w_ - conv_param->pad_l_;
|
||||
|
||||
int kh_start = MSMAX(0, UP_DIV(-oh, conv_param->dilation_h_));
|
||||
int kh_end = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->output_h_ - oh, conv_param->dilation_h_));
|
||||
|
@ -112,7 +112,7 @@ int DeConvPostFp16(const float16_t *src, float16_t *tmp, const float16_t *bias,
|
|||
} /*ih*/
|
||||
} /*oc8*/
|
||||
|
||||
PostConvFuncFp16C8(tmp, dst, bias, output_channel, output_plane, conv_param->output_channel_, conv_param->is_relu_,
|
||||
conv_param->is_relu6_);
|
||||
PostConvFuncFp16C8(tmp, dst, bias, output_channel, output_plane, conv_param->output_channel_,
|
||||
conv_param->act_type_ == ActType_Relu, conv_param->act_type_ == ActType_Relu6);
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
|
|
@ -21,14 +21,14 @@
|
|||
void Conv1x1InputPackFp16(const float16_t *src, float16_t *dst, ConvParameter *conv_param) {
|
||||
/* support nhwc */
|
||||
for (int dst_h = 0; dst_h < conv_param->output_h_; dst_h++) {
|
||||
int src_h = dst_h * conv_param->stride_h_ - conv_param->pad_h_;
|
||||
int src_h = dst_h * conv_param->stride_h_ - conv_param->pad_u_;
|
||||
if (src_h < 0 || src_h >= conv_param->input_h_) {
|
||||
continue;
|
||||
}
|
||||
const float16_t *src_h_ptr = src + src_h * conv_param->input_w_ * conv_param->input_channel_;
|
||||
float16_t *dst_h_ptr = dst + dst_h * conv_param->output_w_ * conv_param->input_channel_;
|
||||
for (int dst_w = 0; dst_w < conv_param->output_w_; dst_w++) {
|
||||
int src_w = dst_w * conv_param->stride_w_ - conv_param->pad_w_;
|
||||
int src_w = dst_w * conv_param->stride_w_ - conv_param->pad_l_;
|
||||
if (src_w < 0 || src_w >= conv_param->input_w_) {
|
||||
continue;
|
||||
}
|
||||
|
@ -46,44 +46,40 @@ void Im2ColPackUnitFp16(float16_t *input_data, ConvParameter *conv_param, float1
|
|||
int kernel_w = conv_param->kernel_w_;
|
||||
int stride_h = conv_param->stride_h_;
|
||||
int stride_w = conv_param->stride_w_;
|
||||
int pad_h = conv_param->pad_h_;
|
||||
int pad_w = conv_param->pad_w_;
|
||||
int pad_h = conv_param->pad_u_;
|
||||
int pad_w = conv_param->pad_l_;
|
||||
int dilation_h = conv_param->dilation_h_;
|
||||
int dilation_w = conv_param->dilation_w_;
|
||||
int in_channel = conv_param->input_channel_;
|
||||
int in_h = conv_param->input_h_;
|
||||
int in_w = conv_param->input_w_;
|
||||
int out_w = conv_param->output_w_;
|
||||
int channel_block = UP_DIV(in_channel, 4);
|
||||
int kernel_plane = kernel_h * kernel_w;
|
||||
int ic4 = UP_DIV(in_channel, 4);
|
||||
memset(packed_input, 0, kernel_w * kernel_h * ic4 * C4NUM * 16 * sizeof(float16_t));
|
||||
|
||||
for (int i = 0; i < real_cal_num; i++) {
|
||||
int block_start = block_index + i;
|
||||
int input_h = block_start / out_w * stride_h - pad_h;
|
||||
int input_w = block_start % out_w * stride_w - pad_w;
|
||||
for (int j = 0; j < kernel_h; j++) {
|
||||
int input_y = input_h + j * dilation_h;
|
||||
if (input_y < 0 || input_y >= in_h) {
|
||||
continue;
|
||||
}
|
||||
int input_y_stride = input_y * in_w * channel_block * C4NUM;
|
||||
for (int n = 0; n < kernel_w; n++) {
|
||||
int input_x = input_w + n * dilation_w;
|
||||
if (input_x < 0 || input_x >= in_w) {
|
||||
continue;
|
||||
}
|
||||
int input_x_stride = input_y_stride + input_x * channel_block * C4NUM;
|
||||
int input_plane_offset = (j * kernel_w + n) * 16 * C4NUM * channel_block + i * C4NUM;
|
||||
for (int m = 0; m < channel_block; m++) {
|
||||
int input_stride = input_h * in_w * ic4 * C4NUM + input_w * ic4 * C4NUM;
|
||||
int kh_s = MSMAX(0, UP_DIV(-input_h, dilation_h));
|
||||
int kh_e = MSMIN(kernel_h, UP_DIV(in_h - input_h, dilation_h));
|
||||
int kw_s = MSMAX(0, UP_DIV(-input_w, dilation_w));
|
||||
int kw_e = MSMIN(kernel_w, UP_DIV(in_w - input_w, dilation_w));
|
||||
for (int j = kh_s; j < kh_e; j++) {
|
||||
int input_y_stride = j * dilation_h * in_w * ic4 * C4NUM + input_stride;
|
||||
for (int n = kw_s; n < kw_e; n++) {
|
||||
int input_x_stride = input_y_stride + n * dilation_w * ic4 * C4NUM;
|
||||
int input_plane_offset = (j * kernel_w + n) * 16 * C4NUM * ic4 + i * C4NUM;
|
||||
for (int m = 0; m < ic4; m++) {
|
||||
int channel_block_stride = input_x_stride + m * C4NUM;
|
||||
int channel_block_offset = input_plane_offset + m * 16 * C4NUM;
|
||||
#ifdef ENABLE_ARM64
|
||||
vst1_f16(packed_input + channel_block_offset, vld1_f16(input_data + channel_block_stride));
|
||||
#else
|
||||
(packed_input + channel_block_offset)[0] = (input_data + channel_block_stride)[0];
|
||||
(packed_input + channel_block_offset)[1] = (input_data + channel_block_stride)[1];
|
||||
(packed_input + channel_block_offset)[2] = (input_data + channel_block_stride)[2];
|
||||
(packed_input + channel_block_offset)[3] = (input_data + channel_block_stride)[3];
|
||||
for (int l = 0; l < C4NUM; ++l) {
|
||||
(packed_input + channel_block_offset)[l] = (input_data + channel_block_stride)[l];
|
||||
}
|
||||
#endif
|
||||
} // channel_block loop
|
||||
} // kernel_w loop
|
||||
|
@ -221,6 +217,19 @@ void PackNCHWToNC4HW4Fp16(const void *src, void *dst, int batch, int plane, int
|
|||
}
|
||||
}
|
||||
|
||||
void PackNCHWToNHWCFp16(const void *src, void *dst, int batch, int plane, int channel) {
|
||||
for (int n = 0; n < batch; n++) {
|
||||
for (int c = 0; c < channel; c++) {
|
||||
for (int hw = 0; hw < plane; hw++) {
|
||||
int nhwc_index = n * channel * plane + hw * channel + c;
|
||||
int nchw_index = n * channel * plane + c * plane + hw;
|
||||
((float16_t *)(dst))[nhwc_index] = ((const float16_t *)(src))[nchw_index];
|
||||
}
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void PackNHWCToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int channel) {
|
||||
int ic4 = UP_DIV(channel, C4NUM);
|
||||
int nhwc4_batch_unit_offset = ic4 * C4NUM * plane;
|
||||
|
|
|
@ -41,6 +41,8 @@ void PackNHWCToNC4HW4Fp16(const void *src, void *dst, int batch, int plane, int
|
|||
|
||||
void PackNCHWToNC4HW4Fp16(const void *src, void *dst, int batch, int plane, int channel);
|
||||
|
||||
void PackNCHWToNHWCFp16(const void *src, void *dst, int batch, int plane, int channel);
|
||||
|
||||
void PackNHWCToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int channel);
|
||||
|
||||
void PackNHWCToNHWC8Fp16(const void *src, void *dst, int batch, int plane, int channel);
|
||||
|
|
|
@ -0,0 +1,67 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "nnacl/fp16/softmax_fp16.h"
|
||||
#include <math.h>
|
||||
#include <float.h>
|
||||
|
||||
// output = exp(input) / reduce_sum(exp(input), axis)
|
||||
void SoftmaxFp16(const float16_t *input_ptr, float16_t *output_ptr, float16_t *sum_data, SoftmaxParameter *parameter) {
|
||||
int32_t axis = parameter->axis_;
|
||||
int n_dim = parameter->n_dim_;
|
||||
int ele_size = parameter->element_size_;
|
||||
int *input_shape = parameter->input_shape_;
|
||||
|
||||
float16_t max_data = input_ptr[0];
|
||||
for (int i = 0; i < ele_size; i++) {
|
||||
max_data = max_data > input_ptr[i] ? max_data : input_ptr[i];
|
||||
}
|
||||
|
||||
for (int i = 0; i < ele_size; i++) {
|
||||
output_ptr[i] = exp(input_ptr[i] - max_data);
|
||||
}
|
||||
int inner_size = 1, outter_size = 1;
|
||||
for (int i = 0; i < axis; i++) {
|
||||
outter_size *= input_shape[i];
|
||||
}
|
||||
for (int i = axis + 1; i < n_dim; i++) {
|
||||
inner_size *= input_shape[i];
|
||||
}
|
||||
|
||||
for (int i = 0; i < outter_size; i++) {
|
||||
int outter_offset = i * input_shape[axis] * inner_size;
|
||||
int sum_outter_offset = i * inner_size;
|
||||
for (int k = 0; k < inner_size; k++) {
|
||||
int inner_offset = outter_offset + k;
|
||||
for (int j = 0; j < input_shape[axis]; j++) {
|
||||
int axis_offset = inner_offset + j * inner_size;
|
||||
sum_data[k + sum_outter_offset] += output_ptr[axis_offset];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < outter_size; i++) {
|
||||
int outter_offset = i * input_shape[axis] * inner_size;
|
||||
int sum_outter_offset = i * inner_size;
|
||||
for (int j = 0; j < input_shape[axis]; j++) {
|
||||
int axis_offset = outter_offset + j * inner_size;
|
||||
for (int k = 0; k < inner_size; k++) {
|
||||
int inner_offset = axis_offset + k;
|
||||
output_ptr[inner_offset] = output_ptr[inner_offset] / sum_data[k + sum_outter_offset];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,33 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_NNACL_FP16_SOFTMAX_FP16_H_
|
||||
#define MINDSPORE_LITE_NNACL_FP16_SOFTMAX_FP16_H_
|
||||
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/softmax_parameter.h"
|
||||
#ifdef ENABLE_NEON
|
||||
#include <arm_neon.h>
|
||||
#endif
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
void SoftmaxFp16(const float16_t *input_ptr, float16_t *output_ptr, float16_t *sum_data, SoftmaxParameter *parameter);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // MINDSPORE_LITE_NNACL_FP16_SOFTMAX_FP16_H_
|
|
@ -230,8 +230,8 @@ void Conv3x3Fp16InputTransform(const float16_t *input_data, float16_t *trans_inp
|
|||
int input_channel = conv_param->input_channel_;
|
||||
int input_width = conv_param->input_w_;
|
||||
int input_height = conv_param->input_h_;
|
||||
int pad_w = conv_param->pad_w_;
|
||||
int pad_h = conv_param->pad_h_;
|
||||
int pad_w = conv_param->pad_l_;
|
||||
int pad_h = conv_param->pad_u_;
|
||||
int ic8 = UP_DIV(input_channel, C8NUM);
|
||||
if (out_w_block == 0) {
|
||||
return;
|
||||
|
@ -576,8 +576,8 @@ void WinogradInputTransformFp16(const float16_t *input_data, float16_t *trans_in
|
|||
int output_unit = conv_param->output_unit_;
|
||||
int in_channel = conv_param->input_channel_;
|
||||
int ic8 = UP_DIV(in_channel, C8NUM);
|
||||
int pad_h = conv_param->pad_h_;
|
||||
int pad_w = conv_param->pad_w_;
|
||||
int pad_h = conv_param->pad_u_;
|
||||
int pad_w = conv_param->pad_l_;
|
||||
int input_h = conv_param->input_h_;
|
||||
int input_w = conv_param->input_w_;
|
||||
if (out_w_block_num == 0) {
|
||||
|
@ -607,7 +607,7 @@ void WinogradInputTransformFp16(const float16_t *input_data, float16_t *trans_in
|
|||
for (int j = 0; j < (interval_x_e - interval_x_s); j++) {
|
||||
int src_x_offset = src_y_offset + j * ic8 * C8NUM;
|
||||
int dst_x_offset = dst_y_offset + j * C8NUM;
|
||||
float16_t *src_addr = input_data + src_x_offset;
|
||||
const float16_t *src_addr = input_data + src_x_offset;
|
||||
float16_t *dst_addr = tmp_data + dst_x_offset;
|
||||
#ifdef ENABLE_NEON
|
||||
vst1q_f16(dst_addr, vld1q_f16(src_addr));
|
||||
|
|
|
@ -43,15 +43,33 @@ int LRelu(const float *src, int length, float *dst, float alpha) {
|
|||
}
|
||||
|
||||
int Sigmoid(const float *src, int length, float *dst) {
|
||||
const float upper_bound = 16.619047164916992188f;
|
||||
const float lower_bound = -9.0f;
|
||||
for (int i = 0; i < length; ++i) {
|
||||
dst[i] = 1.0f / (1.0f + exp(-src[i]));
|
||||
float input_val = src[i];
|
||||
float result;
|
||||
if (input_val > upper_bound) {
|
||||
result = 1.0f;
|
||||
} else if (input_val < lower_bound) {
|
||||
result = exp(input_val);
|
||||
} else {
|
||||
result = 1.0f / (1.0f + exp(-input_val));
|
||||
}
|
||||
dst[i] = result;
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int Tanh(const float *src, int length, float *dst) {
|
||||
for (int i = 0; i < length; ++i) {
|
||||
dst[i] = 1.0f - 2.0f / (exp(2 * src[i]) + 1);
|
||||
float tmp_in = src[i];
|
||||
if (tmp_in > 5.0) {
|
||||
dst[i] = 1.0f;
|
||||
} else if (tmp_in < -5.0) {
|
||||
dst[i] = -1.0f;
|
||||
} else {
|
||||
dst[i] = 1.0f - 2.0f / (exp(2 * tmp_in) + 1);
|
||||
}
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
|
|
@ -20,53 +20,453 @@
|
|||
#define ACCURACY_DATA 0.00000001
|
||||
|
||||
int ElementOptMul(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param) {
|
||||
int block_mod = element_size % C4NUM;
|
||||
int block_c4 = element_size - block_mod;
|
||||
float in0_opt = input0[0];
|
||||
float in1_opt = input1[0];
|
||||
#ifdef ENABLE_NEON
|
||||
float32x4_t vin0_opt = {input0[0], input0[0], input0[0], input0[0]};
|
||||
float32x4_t vin1_opt = {input1[0], input1[0], input1[0], input1[0]};
|
||||
#endif
|
||||
if (param->in_elements_num0_ == 1) {
|
||||
for (int i = 0; i < element_size; ++i) {
|
||||
output[i] = input0[0] * input1[i];
|
||||
for (int index = 0; index < block_c4; index += C4NUM) {
|
||||
#ifdef ENABLE_NEON
|
||||
float32x4_t vin0 = vin0_opt;
|
||||
float32x4_t vin1 = vld1q_f32(input1);
|
||||
float32x4_t vout = vmulq_f32(vin0, vin1);
|
||||
vst1q_f32(output, vout);
|
||||
#else
|
||||
for (int i = 0; i < C4NUM; ++i) {
|
||||
output[i] = in0_opt * input1[i];
|
||||
}
|
||||
#endif
|
||||
input1 += C4NUM;
|
||||
output += C4NUM;
|
||||
}
|
||||
} else if (param->in_elements_num1_ == 1) {
|
||||
for (int i = 0; i < element_size; ++i) {
|
||||
output[i] = input0[i] * input1[0];
|
||||
for (int index = 0; index < block_mod; ++index) {
|
||||
output[index] = in0_opt * input1[index];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < element_size; ++i) {
|
||||
output[i] = input0[i] * input1[i];
|
||||
for (int index = 0; index < block_c4; index += C4NUM) {
|
||||
#ifdef ENABLE_NEON
|
||||
float32x4_t vin0 = vld1q_f32(input0);
|
||||
float32x4_t vin1 = vin1_opt;
|
||||
float32x4_t vout = vmulq_f32(vin0, vin1);
|
||||
vst1q_f32(output, vout);
|
||||
#else
|
||||
for (int i = 0; i < C4NUM; ++i) {
|
||||
output[i] = input0[i] * in1_opt;
|
||||
}
|
||||
#endif
|
||||
input0 += C4NUM;
|
||||
output += C4NUM;
|
||||
}
|
||||
for (int index = 0; index < block_mod; ++index) {
|
||||
output[index] = input0[index] * in1_opt;
|
||||
}
|
||||
}
|
||||
|
||||
return NNACL_OK;
|
||||
}
|
||||
int ElementOptMulRelu(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param) {
|
||||
int block_mod = element_size % C4NUM;
|
||||
int block_c4 = element_size - block_mod;
|
||||
float in0_opt = input0[0];
|
||||
float in1_opt = input1[0];
|
||||
#ifdef ENABLE_NEON
|
||||
float32x4_t vin0_opt = {input0[0], input0[0], input0[0], input0[0]};
|
||||
float32x4_t vin1_opt = {input1[0], input1[0], input1[0], input1[0]};
|
||||
float32x4_t zeros = {0, 0, 0, 0};
|
||||
#endif
|
||||
if (param->in_elements_num0_ == 1) {
|
||||
for (int index = 0; index < block_c4; index += C4NUM) {
|
||||
#ifdef ENABLE_NEON
|
||||
float32x4_t vin0 = vin0_opt;
|
||||
float32x4_t vin1 = vld1q_f32(input1);
|
||||
float32x4_t vout = vmaxq_f32(vmulq_f32(vin0, vin1), zeros);
|
||||
vst1q_f32(output, vout);
|
||||
#else
|
||||
for (int i = 0; i < C4NUM; ++i) {
|
||||
output[i] = MSMAX(in0_opt * input1[i], 0);
|
||||
}
|
||||
#endif
|
||||
input1 += C4NUM;
|
||||
output += C4NUM;
|
||||
}
|
||||
for (int index = 0; index < block_mod; ++index) {
|
||||
output[index] = MSMAX(in0_opt * input1[index], 0);
|
||||
}
|
||||
} else {
|
||||
for (int index = 0; index < block_c4; index += C4NUM) {
|
||||
#ifdef ENABLE_NEON
|
||||
float32x4_t vin0 = vld1q_f32(input0);
|
||||
float32x4_t vin1 = vin1_opt;
|
||||
float32x4_t vout = vmaxq_f32(vmulq_f32(vin0, vin1), zeros);
|
||||
vst1q_f32(output, vout);
|
||||
#else
|
||||
for (int i = 0; i < C4NUM; ++i) {
|
||||
output[i] = MSMAX(input0[i] * in1_opt, 0);
|
||||
}
|
||||
#endif
|
||||
input0 += C4NUM;
|
||||
output += C4NUM;
|
||||
}
|
||||
for (int index = 0; index < block_mod; ++index) {
|
||||
output[index] = MSMAX(input0[index] * in1_opt, 0);
|
||||
}
|
||||
}
|
||||
|
||||
return NNACL_OK;
|
||||
}
|
||||
int ElementOptMulRelu6(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param) {
|
||||
int block_mod = element_size % C4NUM;
|
||||
int block_c4 = element_size - block_mod;
|
||||
float in0_opt = input0[0];
|
||||
float in1_opt = input1[0];
|
||||
#ifdef ENABLE_NEON
|
||||
float32x4_t vin0_opt = {input0[0], input0[0], input0[0], input0[0]};
|
||||
float32x4_t vin1_opt = {input1[0], input1[0], input1[0], input1[0]};
|
||||
float32x4_t zeros = {0, 0, 0, 0};
|
||||
float32x4_t bounds = {6, 6, 6, 6};
|
||||
#endif
|
||||
if (param->in_elements_num0_ == 1) {
|
||||
for (int index = 0; index < block_c4; index += C4NUM) {
|
||||
#ifdef ENABLE_NEON
|
||||
float32x4_t vin0 = vin0_opt;
|
||||
float32x4_t vin1 = vld1q_f32(input1);
|
||||
float32x4_t vout = vminq_f32(vmaxq_f32(vmulq_f32(vin0, vin1), zeros), bounds);
|
||||
vst1q_f32(output, vout);
|
||||
#else
|
||||
for (int i = 0; i < C4NUM; ++i) {
|
||||
output[i] = MSMIN(MSMAX(in0_opt * input1[i], 0), 6);
|
||||
}
|
||||
#endif
|
||||
input1 += C4NUM;
|
||||
output += C4NUM;
|
||||
}
|
||||
for (int index = 0; index < block_mod; ++index) {
|
||||
output[index] = MSMIN(MSMAX(in0_opt * input1[index], 0), 6);
|
||||
}
|
||||
} else {
|
||||
for (int index = 0; index < block_c4; index += C4NUM) {
|
||||
#ifdef ENABLE_NEON
|
||||
float32x4_t vin0 = vld1q_f32(input0);
|
||||
float32x4_t vin1 = vin1_opt;
|
||||
float32x4_t vout = vminq_f32(vmaxq_f32(vmulq_f32(vin0, vin1), zeros), bounds);
|
||||
vst1q_f32(output, vout);
|
||||
#else
|
||||
for (int i = 0; i < C4NUM; ++i) {
|
||||
output[i] = MSMIN(MSMAX(input0[i] * in1_opt, 0), 6);
|
||||
}
|
||||
#endif
|
||||
input0 += C4NUM;
|
||||
output += C4NUM;
|
||||
}
|
||||
for (int index = 0; index < block_mod; ++index) {
|
||||
output[index] = MSMIN(MSMAX(input0[index] * in1_opt, 0), 6);
|
||||
}
|
||||
}
|
||||
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int ElementOptSub(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param) {
|
||||
int block_mod = element_size % C4NUM;
|
||||
int block_c4 = element_size - block_mod;
|
||||
float in0_opt = input0[0];
|
||||
float in1_opt = input1[0];
|
||||
#ifdef ENABLE_NEON
|
||||
float32x4_t vin0_opt = {input0[0], input0[0], input0[0], input0[0]};
|
||||
float32x4_t vin1_opt = {input1[0], input1[0], input1[0], input1[0]};
|
||||
#endif
|
||||
if (param->in_elements_num0_ == 1) {
|
||||
for (int i = 0; i < element_size; ++i) {
|
||||
output[i] = input0[0] - input1[i];
|
||||
for (int index = 0; index < block_c4; index += C4NUM) {
|
||||
#ifdef ENABLE_NEON
|
||||
float32x4_t vin0 = vin0_opt;
|
||||
float32x4_t vin1 = vld1q_f32(input1);
|
||||
float32x4_t vout = vsubq_f32(vin0, vin1);
|
||||
vst1q_f32(output, vout);
|
||||
#else
|
||||
for (int i = 0; i < C4NUM; ++i) {
|
||||
output[i] = in0_opt - input1[i];
|
||||
}
|
||||
#endif
|
||||
input1 += C4NUM;
|
||||
output += C4NUM;
|
||||
}
|
||||
} else if (param->in_elements_num1_ == 1) {
|
||||
for (int i = 0; i < element_size; ++i) {
|
||||
output[i] = input0[i] - input1[0];
|
||||
for (int index = 0; index < block_mod; ++index) {
|
||||
output[index] = in0_opt - input1[index];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < element_size; ++i) {
|
||||
output[i] = input0[i] - input1[i];
|
||||
for (int index = 0; index < block_c4; index += C4NUM) {
|
||||
#ifdef ENABLE_NEON
|
||||
float32x4_t vin0 = vld1q_f32(input0);
|
||||
float32x4_t vin1 = vin1_opt;
|
||||
float32x4_t vout = vsubq_f32(vin0, vin1);
|
||||
vst1q_f32(output, vout);
|
||||
#else
|
||||
for (int i = 0; i < C4NUM; ++i) {
|
||||
output[i] = input0[i] - in1_opt;
|
||||
}
|
||||
#endif
|
||||
input0 += C4NUM;
|
||||
output += C4NUM;
|
||||
}
|
||||
for (int index = 0; index < block_mod; ++index) {
|
||||
output[index] = input0[index] - in1_opt;
|
||||
}
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int ElementOptAdd(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param) {
|
||||
int ElementOptSubRelu(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param) {
|
||||
int block_mod = element_size % C4NUM;
|
||||
int block_c4 = element_size - block_mod;
|
||||
float in0_opt = input0[0];
|
||||
float in1_opt = input1[0];
|
||||
#ifdef ENABLE_NEON
|
||||
float32x4_t vin0_opt = {input0[0], input0[0], input0[0], input0[0]};
|
||||
float32x4_t vin1_opt = {input1[0], input1[0], input1[0], input1[0]};
|
||||
float32x4_t zeros = {0, 0, 0, 0};
|
||||
#endif
|
||||
if (param->in_elements_num0_ == 1) {
|
||||
for (int i = 0; i < element_size; ++i) {
|
||||
output[i] = input0[0] + input1[i];
|
||||
for (int index = 0; index < block_c4; index += C4NUM) {
|
||||
#ifdef ENABLE_NEON
|
||||
float32x4_t vin0 = vin0_opt;
|
||||
float32x4_t vin1 = vld1q_f32(input1);
|
||||
float32x4_t vout = vmaxq_f32(vsubq_f32(vin0, vin1), zeros);
|
||||
vst1q_f32(output, vout);
|
||||
#else
|
||||
for (int i = 0; i < C4NUM; ++i) {
|
||||
output[i] = MSMAX(in0_opt - input1[i], 0);
|
||||
}
|
||||
#endif
|
||||
input1 += C4NUM;
|
||||
output += C4NUM;
|
||||
}
|
||||
} else if (param->in_elements_num1_ == 1) {
|
||||
for (int i = 0; i < element_size; ++i) {
|
||||
output[i] = input0[i] + input1[0];
|
||||
for (int index = 0; index < block_mod; ++index) {
|
||||
output[index] = MSMAX(in0_opt - input1[index], 0);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < element_size; ++i) {
|
||||
output[i] = input0[i] + input1[i];
|
||||
for (int index = 0; index < block_c4; index += C4NUM) {
|
||||
#ifdef ENABLE_NEON
|
||||
float32x4_t vin0 = vld1q_f32(input0);
|
||||
float32x4_t vin1 = vin1_opt;
|
||||
float32x4_t vout = vmaxq_f32(vsubq_f32(vin0, vin1), zeros);
|
||||
vst1q_f32(output, vout);
|
||||
#else
|
||||
for (int i = 0; i < C4NUM; ++i) {
|
||||
output[i] = MSMAX(input0[i] - in1_opt, 0);
|
||||
}
|
||||
#endif
|
||||
input0 += C4NUM;
|
||||
output += C4NUM;
|
||||
}
|
||||
for (int index = 0; index < block_mod; ++index) {
|
||||
output[index] = MSMAX(input0[index] - in1_opt, 0);
|
||||
}
|
||||
}
|
||||
|
||||
return NNACL_OK;
|
||||
}
|
||||
int ElementOptSubRelu6(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param) {
|
||||
int block_mod = element_size % C4NUM;
|
||||
int block_c4 = element_size - block_mod;
|
||||
float in0_opt = input0[0];
|
||||
float in1_opt = input1[0];
|
||||
#ifdef ENABLE_NEON
|
||||
float32x4_t vin0_opt = {input0[0], input0[0], input0[0], input0[0]};
|
||||
float32x4_t vin1_opt = {input1[0], input1[0], input1[0], input1[0]};
|
||||
float32x4_t zeros = {0, 0, 0, 0};
|
||||
float32x4_t bounds = {6, 6, 6, 6};
|
||||
#endif
|
||||
if (param->in_elements_num0_ == 1) {
|
||||
for (int index = 0; index < block_c4; index += C4NUM) {
|
||||
#ifdef ENABLE_NEON
|
||||
float32x4_t vin0 = vin0_opt;
|
||||
float32x4_t vin1 = vld1q_f32(input1);
|
||||
float32x4_t vout = vminq_f32(vmaxq_f32(vsubq_f32(vin0, vin1), zeros), bounds);
|
||||
vst1q_f32(output, vout);
|
||||
#else
|
||||
for (int i = 0; i < C4NUM; ++i) {
|
||||
output[i] = MSMIN(MSMAX(in0_opt - input1[i], 0), 6);
|
||||
}
|
||||
#endif
|
||||
input1 += C4NUM;
|
||||
output += C4NUM;
|
||||
}
|
||||
for (int index = 0; index < block_mod; ++index) {
|
||||
output[index] = MSMIN(MSMAX(in0_opt - input1[index], 0), 6);
|
||||
}
|
||||
} else {
|
||||
for (int index = 0; index < block_c4; index += C4NUM) {
|
||||
#ifdef ENABLE_NEON
|
||||
float32x4_t vin0 = vld1q_f32(input0);
|
||||
float32x4_t vin1 = vin1_opt;
|
||||
float32x4_t vout = vminq_f32(vmaxq_f32(vsubq_f32(vin0, vin1), zeros), bounds);
|
||||
vst1q_f32(output, vout);
|
||||
#else
|
||||
for (int i = 0; i < C4NUM; ++i) {
|
||||
output[i] = MSMIN(MSMAX(input0[i] - in1_opt, 0), 6);
|
||||
}
|
||||
#endif
|
||||
input0 += C4NUM;
|
||||
output += C4NUM;
|
||||
}
|
||||
for (int index = 0; index < block_mod; ++index) {
|
||||
output[index] = MSMIN(MSMAX(input0[index] - in1_opt, 0), 6);
|
||||
}
|
||||
}
|
||||
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int ElementOptAdd(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param) {
|
||||
int block_mod = element_size % C4NUM;
|
||||
int block_c4 = element_size - block_mod;
|
||||
float in0_opt = input0[0];
|
||||
float in1_opt = input1[0];
|
||||
#ifdef ENABLE_NEON
|
||||
float32x4_t vin0_opt = {input0[0], input0[0], input0[0], input0[0]};
|
||||
float32x4_t vin1_opt = {input1[0], input1[0], input1[0], input1[0]};
|
||||
#endif
|
||||
if (param->in_elements_num0_ == 1) {
|
||||
for (int index = 0; index < block_c4; index += C4NUM) {
|
||||
#ifdef ENABLE_NEON
|
||||
float32x4_t vin0 = vin0_opt;
|
||||
float32x4_t vin1 = vld1q_f32(input1);
|
||||
float32x4_t vout = vaddq_f32(vin0, vin1);
|
||||
vst1q_f32(output, vout);
|
||||
#else
|
||||
for (int i = 0; i < C4NUM; ++i) {
|
||||
output[i] = in0_opt + input1[i];
|
||||
}
|
||||
#endif
|
||||
input1 += C4NUM;
|
||||
output += C4NUM;
|
||||
}
|
||||
for (int index = 0; index < block_mod; ++index) {
|
||||
output[index] = in0_opt + input1[index];
|
||||
}
|
||||
} else {
|
||||
for (int index = 0; index < block_c4; index += C4NUM) {
|
||||
#ifdef ENABLE_NEON
|
||||
float32x4_t vin0 = vld1q_f32(input0);
|
||||
float32x4_t vin1 = vin1_opt;
|
||||
float32x4_t vout = vaddq_f32(vin0, vin1);
|
||||
vst1q_f32(output, vout);
|
||||
#else
|
||||
for (int i = 0; i < C4NUM; ++i) {
|
||||
output[i] = input0[i] + in1_opt;
|
||||
}
|
||||
#endif
|
||||
input0 += C4NUM;
|
||||
output += C4NUM;
|
||||
}
|
||||
for (int index = 0; index < block_mod; ++index) {
|
||||
output[index] = input0[index] + in1_opt;
|
||||
}
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
int ElementOptAddRelu(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param) {
|
||||
int block_mod = element_size % C4NUM;
|
||||
int block_c4 = element_size - block_mod;
|
||||
float in0_opt = input0[0];
|
||||
float in1_opt = input1[0];
|
||||
#ifdef ENABLE_NEON
|
||||
float32x4_t vin0_opt = {input0[0], input0[0], input0[0], input0[0]};
|
||||
float32x4_t vin1_opt = {input1[0], input1[0], input1[0], input1[0]};
|
||||
float32x4_t zeros = {0, 0, 0, 0};
|
||||
#endif
|
||||
if (param->in_elements_num0_ == 1) {
|
||||
for (int index = 0; index < block_c4; index += C4NUM) {
|
||||
#ifdef ENABLE_NEON
|
||||
float32x4_t vin0 = vin0_opt;
|
||||
float32x4_t vin1 = vld1q_f32(input1);
|
||||
float32x4_t vout = vmaxq_f32(vaddq_f32(vin0, vin1), zeros);
|
||||
vst1q_f32(output, vout);
|
||||
#else
|
||||
for (int i = 0; i < C4NUM; ++i) {
|
||||
output[i] = MSMAX(in0_opt + input1[i], 0);
|
||||
}
|
||||
#endif
|
||||
input1 += C4NUM;
|
||||
output += C4NUM;
|
||||
}
|
||||
for (int index = 0; index < block_mod; ++index) {
|
||||
output[index] = MSMAX(in0_opt + input1[index], 0);
|
||||
}
|
||||
} else {
|
||||
for (int index = 0; index < block_c4; index += C4NUM) {
|
||||
#ifdef ENABLE_NEON
|
||||
float32x4_t vin0 = vld1q_f32(input0);
|
||||
float32x4_t vin1 = vin1_opt;
|
||||
float32x4_t vout = vmaxq_f32(vaddq_f32(vin0, vin1), zeros);
|
||||
vst1q_f32(output, vout);
|
||||
#else
|
||||
for (int i = 0; i < C4NUM; ++i) {
|
||||
output[i] = MSMAX(input0[i] + in1_opt, 0);
|
||||
}
|
||||
#endif
|
||||
input0 += C4NUM;
|
||||
output += C4NUM;
|
||||
}
|
||||
for (int index = 0; index < block_mod; ++index) {
|
||||
output[index] = MSMAX(input0[index] + in1_opt, 0);
|
||||
}
|
||||
}
|
||||
|
||||
return NNACL_OK;
|
||||
}
|
||||
int ElementOptAddRelu6(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param) {
|
||||
int block_mod = element_size % C4NUM;
|
||||
int block_c4 = element_size - block_mod;
|
||||
float in0_opt = input0[0];
|
||||
float in1_opt = input1[0];
|
||||
#ifdef ENABLE_NEON
|
||||
float32x4_t vin0_opt = {input0[0], input0[0], input0[0], input0[0]};
|
||||
float32x4_t vin1_opt = {input1[0], input1[0], input1[0], input1[0]};
|
||||
float32x4_t zeros = {0, 0, 0, 0};
|
||||
float32x4_t bounds = {6, 6, 6, 6};
|
||||
#endif
|
||||
if (param->in_elements_num0_ == 1) {
|
||||
for (int index = 0; index < block_c4; index += C4NUM) {
|
||||
#ifdef ENABLE_NEON
|
||||
float32x4_t vin0 = vin0_opt;
|
||||
float32x4_t vin1 = vld1q_f32(input1);
|
||||
float32x4_t vout = vminq_f32(vmaxq_f32(vaddq_f32(vin0, vin1), zeros), bounds);
|
||||
vst1q_f32(output, vout);
|
||||
#else
|
||||
for (int i = 0; i < C4NUM; ++i) {
|
||||
output[i] = MSMIN(MSMAX(in0_opt + input1[i], 0), 6);
|
||||
}
|
||||
#endif
|
||||
input1 += C4NUM;
|
||||
output += C4NUM;
|
||||
}
|
||||
for (int index = 0; index < block_mod; ++index) {
|
||||
output[index] = MSMIN(MSMAX(in0_opt + input1[index], 0), 6);
|
||||
}
|
||||
} else {
|
||||
for (int index = 0; index < block_c4; index += C4NUM) {
|
||||
#ifdef ENABLE_NEON
|
||||
float32x4_t vin0 = vld1q_f32(input0);
|
||||
float32x4_t vin1 = vin1_opt;
|
||||
float32x4_t vout = vminq_f32(vmaxq_f32(vaddq_f32(vin0, vin1), zeros), bounds);
|
||||
vst1q_f32(output, vout);
|
||||
#else
|
||||
for (int i = 0; i < C4NUM; ++i) {
|
||||
output[i] = MSMIN(MSMAX(input0[i] + in1_opt, 0), 6);
|
||||
}
|
||||
#endif
|
||||
input0 += C4NUM;
|
||||
output += C4NUM;
|
||||
}
|
||||
for (int index = 0; index < block_mod; ++index) {
|
||||
output[index] = MSMIN(MSMAX(input0[index] + in1_opt, 0), 6);
|
||||
}
|
||||
}
|
||||
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -27,8 +27,14 @@
|
|||
extern "C" {
|
||||
#endif
|
||||
int ElementOptAdd(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param);
|
||||
int ElementOptAddRelu(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param);
|
||||
int ElementOptAddRelu6(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param);
|
||||
int ElementOptSub(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param);
|
||||
int ElementOptSubRelu(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param);
|
||||
int ElementOptSubRelu6(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param);
|
||||
int ElementOptMul(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param);
|
||||
int ElementOptMulRelu(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param);
|
||||
int ElementOptMulRelu6(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param);
|
||||
int ElementMul(float *input0, float *input1, float *output, int element_size);
|
||||
int ElementMulRelu(float *input0, float *input1, float *output, int element_size);
|
||||
int ElementMulRelu6(float *input0, float *input1, float *output, int element_size);
|
||||
|
|
|
@ -15,7 +15,6 @@
|
|||
*/
|
||||
|
||||
#include "nnacl/fp32/batchnorm.h"
|
||||
#include "nnacl/fp16/batchnorm_fp16.h"
|
||||
#include <math.h>
|
||||
#include "nnacl/batchnorm_parameter.h"
|
||||
#include "nnacl/op_base.h"
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#include <string.h>
|
||||
#include "nnacl/fp32/common_func.h"
|
||||
#include "nnacl/winograd_transform.h"
|
||||
#include "nnacl/fp32/matmul.h"
|
||||
|
||||
void SWBorderPixel(float *dst, const float *src, const float *weight, const float *bias, int height, int width,
|
||||
int in_kh_step, int in_kw_step, int kernel_h, int kernel_w, int ic4, bool is_relu, bool is_relu6) {
|
||||
|
@ -57,16 +58,18 @@ void SWBorderPixel(float *dst, const float *src, const float *weight, const floa
|
|||
void SWBorder(float *dst, const float *src, const float *weight, const float *bias, int top, int bottom, int left,
|
||||
int right, const ConvParameter *conv_param, const SlidingWindowParam *sliding) {
|
||||
int ic4 = sliding->ic4_channel_ / C4NUM;
|
||||
bool relu = conv_param->act_type_ == ActType_Relu;
|
||||
bool relu6 = conv_param->act_type_ == ActType_Relu6;
|
||||
float *dst_h = dst + top * sliding->out_h_step_;
|
||||
for (int oh = top; oh < bottom; oh++) {
|
||||
int ih = oh * conv_param->stride_h_ - conv_param->pad_h_;
|
||||
int ih = oh * conv_param->stride_h_ - conv_param->pad_u_;
|
||||
int start_kh = MSMAX(0, UP_DIV(-ih, conv_param->dilation_h_));
|
||||
int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih, conv_param->dilation_h_));
|
||||
const float *src_h = src + ih * sliding->in_h_step_;
|
||||
|
||||
float *dst_kernel = dst_h + left * sliding->block_channel_;
|
||||
for (int ow = left; ow < right; ow++) {
|
||||
int iw = ow * conv_param->stride_w_ - conv_param->pad_w_;
|
||||
int iw = ow * conv_param->stride_w_ - conv_param->pad_l_;
|
||||
int start_kw = MSMAX(0, UP_DIV(-iw, conv_param->dilation_w_));
|
||||
int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->input_w_ - iw, conv_param->dilation_w_));
|
||||
const float *src_w = src_h + iw * sliding->ic4_channel_;
|
||||
|
@ -75,8 +78,8 @@ void SWBorder(float *dst, const float *src, const float *weight, const float *bi
|
|||
const float *weight_kernel = weight + (start_kh * conv_param->kernel_w_ + start_kw) * sliding->ic4_channel_;
|
||||
|
||||
SWBorderPixel(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw,
|
||||
sliding->in_kh_step_, sliding->in_kw_step_, conv_param->kernel_h_, conv_param->kernel_w_, ic4,
|
||||
conv_param->is_relu_, conv_param->is_relu6_);
|
||||
sliding->in_kh_step_, sliding->in_kw_step_, conv_param->kernel_h_, conv_param->kernel_w_, ic4, relu,
|
||||
relu6);
|
||||
|
||||
dst_kernel += sliding->block_channel_;
|
||||
} // width loop
|
||||
|
@ -144,6 +147,8 @@ void ConvSWFp32(const float *input_data, const float *packed_weight, const float
|
|||
float *output_data, int task_id, ConvParameter *conv_param, SlidingWindowParam *slidingWindow_param) {
|
||||
int ic4 = slidingWindow_param->ic4_channel_ / C4NUM;
|
||||
int oc4_res = conv_param->output_channel_ % C4NUM;
|
||||
bool relu = conv_param->act_type_ == ActType_Relu;
|
||||
bool relu6 = conv_param->act_type_ == ActType_Relu6;
|
||||
const float *src = input_data;
|
||||
float *dst = NULL;
|
||||
if (oc4_res == 0) {
|
||||
|
@ -169,28 +174,26 @@ void ConvSWFp32(const float *input_data, const float *packed_weight, const float
|
|||
|
||||
if (slidingWindow_param->right_ > slidingWindow_param->left_ &&
|
||||
slidingWindow_param->bottom_ > slidingWindow_param->top_) {
|
||||
int in_h_start = slidingWindow_param->top_ * conv_param->stride_h_ - conv_param->pad_h_;
|
||||
int in_w_start = slidingWindow_param->left_ * conv_param->stride_w_ - conv_param->pad_w_;
|
||||
int in_h_start = slidingWindow_param->top_ * conv_param->stride_h_ - conv_param->pad_u_;
|
||||
int in_w_start = slidingWindow_param->left_ * conv_param->stride_w_ - conv_param->pad_l_;
|
||||
const float *in_t =
|
||||
src_data + in_h_start * slidingWindow_param->in_h_step_ + in_w_start * slidingWindow_param->ic4_channel_;
|
||||
float *out_t = dst_data + slidingWindow_param->top_ * slidingWindow_param->out_h_step_ +
|
||||
slidingWindow_param->left_ * slidingWindow_param->block_channel_;
|
||||
#ifdef ENABLE_ARM64
|
||||
ConvSwFp32Center(out_t, in_t, weight, bias, slidingWindow_param->bottom_ - slidingWindow_param->top_,
|
||||
slidingWindow_param->right_ - slidingWindow_param->left_, conv_param->kernel_h_,
|
||||
conv_param->kernel_w_, slidingWindow_param->out_h_step_ * sizeof(float),
|
||||
slidingWindow_param->block_channel_ * sizeof(float), ic4,
|
||||
slidingWindow_param->in_sh_step_ * sizeof(float),
|
||||
slidingWindow_param->in_sw_step_ * sizeof(float),
|
||||
slidingWindow_param->in_kh_step_ * sizeof(float),
|
||||
slidingWindow_param->in_kw_step_ * sizeof(float),
|
||||
conv_param->is_relu_, conv_param->is_relu6_);
|
||||
ConvSwFp32Center(
|
||||
out_t, in_t, weight, bias, slidingWindow_param->bottom_ - slidingWindow_param->top_,
|
||||
slidingWindow_param->right_ - slidingWindow_param->left_, conv_param->kernel_h_, conv_param->kernel_w_,
|
||||
slidingWindow_param->out_h_step_ * sizeof(float), slidingWindow_param->block_channel_ * sizeof(float), ic4,
|
||||
slidingWindow_param->in_sh_step_ * sizeof(float), slidingWindow_param->in_sw_step_ * sizeof(float),
|
||||
slidingWindow_param->in_kh_step_ * sizeof(float), slidingWindow_param->in_kw_step_ * sizeof(float), relu,
|
||||
relu6);
|
||||
#else
|
||||
SWCenter(out_t, in_t, weight, bias, slidingWindow_param->bottom_ - slidingWindow_param->top_,
|
||||
slidingWindow_param->right_ - slidingWindow_param->left_, conv_param->kernel_h_,
|
||||
conv_param->kernel_w_, slidingWindow_param->out_h_step_, slidingWindow_param->block_channel_, ic4,
|
||||
slidingWindow_param->right_ - slidingWindow_param->left_, conv_param->kernel_h_, conv_param->kernel_w_,
|
||||
slidingWindow_param->out_h_step_, slidingWindow_param->block_channel_, ic4,
|
||||
slidingWindow_param->in_sh_step_, slidingWindow_param->in_sw_step_, slidingWindow_param->in_kh_step_,
|
||||
slidingWindow_param->in_kw_step_, conv_param->is_relu_, conv_param->is_relu6_);
|
||||
slidingWindow_param->in_kw_step_, relu, relu6);
|
||||
#endif
|
||||
}
|
||||
} // output C4 loop
|
||||
|
@ -219,6 +222,8 @@ void ConvFp32(float *input_data, float *packed_input, float *packed_weight, cons
|
|||
int kernel_plane = kernel_h * kernel_w;
|
||||
int unit_size = kernel_plane * ic4 * C4NUM;
|
||||
int packed_input_size = output_tile_count * TILE_NUM * unit_size;
|
||||
bool relu = conv_param->act_type_ == ActType_Relu;
|
||||
bool relu6 = conv_param->act_type_ == ActType_Relu6;
|
||||
|
||||
// we accumulate 4 channels per time for input blocks
|
||||
int conv_depth = kernel_h * kernel_w;
|
||||
|
@ -240,23 +245,18 @@ void ConvFp32(float *input_data, float *packed_input, float *packed_weight, cons
|
|||
if (real_cal_num == TILE_NUM) {
|
||||
float *gemm_output = output_data + out_offset;
|
||||
gemm_func(gemm_output, gemm_input, packed_weight, bias_data, conv_depth, ic4, out_channel, output_offset, 0, 0,
|
||||
conv_param->is_relu_, conv_param->is_relu6_);
|
||||
relu, relu6);
|
||||
} else {
|
||||
// res part
|
||||
gemm_func(tmp_out_block, gemm_input, packed_weight, bias_data, conv_depth, ic4, out_channel, output_offset, 0,
|
||||
0, conv_param->is_relu_, conv_param->is_relu6_);
|
||||
memcpy(output_data + out_offset, tmp_out_block, real_cal_num * out_channel * sizeof(float));
|
||||
float *tmp_out_ptr = tmp_out_block + task_id * TILE_NUM * out_channel;
|
||||
gemm_func(tmp_out_ptr, gemm_input, packed_weight, bias_data, conv_depth, ic4, out_channel, output_offset, 0, 0,
|
||||
relu, relu6);
|
||||
memcpy(output_data + out_offset, tmp_out_ptr, real_cal_num * out_channel * sizeof(float));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// fp32 conv1x1 strassen matmul
|
||||
int Conv1x1Fp32(const float *input_data, const float *weight_data, float *output_data, float *tmp_ptr,
|
||||
StrassenMatMulParameter matmul_param) {
|
||||
return StrassenMatmul(input_data, weight_data, output_data, &matmul_param, FP32_STRASSEN_MAX_RECURSION, 0, tmp_ptr);
|
||||
}
|
||||
|
||||
// fp32 conv winograd
|
||||
void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_data, TmpBufferAddress *buffer_list,
|
||||
int task_id, ConvParameter *conv_param, InputTransformUnitFunc input_trans_func,
|
||||
|
@ -270,38 +270,46 @@ void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_
|
|||
int out_w_block = UP_DIV(conv_param->output_w_, out_unit);
|
||||
int out_h_block = UP_DIV(conv_param->output_h_, out_unit);
|
||||
int output_count = out_w_block * out_h_block;
|
||||
int output_tile_count = UP_DIV(output_count, TILE_NUM);
|
||||
int output_tile_count = UP_DIV(output_count, C12NUM);
|
||||
int out_channel = conv_param->output_channel_;
|
||||
int oc4 = UP_DIV(out_channel, C4NUM);
|
||||
int oc8 = UP_DIV(out_channel, C8NUM);
|
||||
int input_unit_square = input_unit * input_unit;
|
||||
size_t output_offset = oc4 * C4NUM * input_unit_square * sizeof(float);
|
||||
|
||||
float *trans_input = buffer_list[0];
|
||||
float *gemm_out = buffer_list[1];
|
||||
float *tmp_out_data = buffer_list[2];
|
||||
float *tmp_data = buffer_list[3];
|
||||
int trans_input_offset = TILE_NUM * input_unit_square * ic4 * C4NUM;
|
||||
int gemm_out_offset = TILE_NUM * input_unit_square * oc4 * C4NUM;
|
||||
float *col_buffer = buffer_list[4];
|
||||
int trans_input_offset = C12NUM * input_unit_square * ic4 * C4NUM;
|
||||
int gemm_out_offset = C12NUM * input_unit_square * oc8 * C8NUM;
|
||||
int tmp_data_offset = input_unit_square * C4NUM;
|
||||
int col_buffer_offset = C12NUM * ic4 * C4NUM;
|
||||
// step 1 : filter transform (pre-processed offline)
|
||||
// step 2 : input transform (online)
|
||||
for (int b = 0; b < in_batch; b++) {
|
||||
int in_batch_offset = b * ic4 * C4NUM * conv_param->input_h_ * conv_param->input_w_;
|
||||
int tmp_out_batch_offset = b * out_w_block * out_h_block * out_unit * out_unit * oc4 * C4NUM;
|
||||
for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_num) {
|
||||
int out_tile_index = thread_id * TILE_NUM;
|
||||
int cal_num = output_count - thread_id * TILE_NUM;
|
||||
cal_num = cal_num > TILE_NUM ? TILE_NUM : cal_num;
|
||||
int out_tile_index = thread_id * C12NUM;
|
||||
int cal_num = output_count - thread_id * C12NUM;
|
||||
cal_num = cal_num > C12NUM ? C12NUM : cal_num;
|
||||
WinogradInputTransform(input_data + in_batch_offset, trans_input + task_id * trans_input_offset,
|
||||
tmp_data + task_id * tmp_data_offset, cal_num, out_tile_index, out_w_block, conv_param,
|
||||
input_trans_func);
|
||||
// step 3 : gemm
|
||||
gemm_func(gemm_out + task_id * gemm_out_offset, trans_input + task_id * trans_input_offset, trans_weight, NULL,
|
||||
input_unit_square, ic4, oc4 * C4NUM, output_offset, 1, 1, 0, 0);
|
||||
float *src_ptr = trans_input + task_id * trans_input_offset;
|
||||
float *dst_ptr = gemm_out + task_id * gemm_out_offset;
|
||||
float *tmp_col_ptr = col_buffer + task_id * col_buffer_offset;
|
||||
for (int i = 0; i < input_unit_square; ++i) {
|
||||
RowMajor2Col12Major(src_ptr + i * C12NUM * ic4 * C4NUM, tmp_col_ptr, C12NUM, ic4 * C4NUM);
|
||||
MatMulOpt(tmp_col_ptr, trans_weight + i * ic4 * C4NUM * oc8 * C8NUM, dst_ptr + i * C8NUM, NULL, 0, ic4 * C4NUM,
|
||||
cal_num, oc8 * C8NUM, input_unit_square, 2);
|
||||
}
|
||||
|
||||
// step 4 : output transform
|
||||
WinogradOutputTransform(gemm_out + task_id * gemm_out_offset, tmp_out_data + tmp_out_batch_offset, bias_data,
|
||||
cal_num, out_tile_index, out_w_block, conv_param, output_trans_func);
|
||||
WinogradOutputTransform(dst_ptr, tmp_out_data + tmp_out_batch_offset, bias_data, cal_num, out_tile_index,
|
||||
out_w_block, conv_param, output_trans_func);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -442,24 +450,28 @@ void UnPackWinogradRelu6Output(const float *src, float *dst, int batch, int heig
|
|||
}
|
||||
|
||||
// fp32 conv3x3
|
||||
void Conv3x3Fp32(float *input_data, float *transed_weight, const float *bias_data, float *output_data,
|
||||
TmpBufferAddress *buffer_list, int task_id, ConvParameter *conv_param, GEMM_FUNC_FP32 gemm_func) {
|
||||
void Conv3x3Fp32(float *input_data, float *transed_weight, const float *bias_data, TmpBufferAddress *buffer_list,
|
||||
int task_id, ConvParameter *conv_param, GEMM_FUNC_FP32 gemm_func) {
|
||||
int thread_count = conv_param->thread_num_;
|
||||
int ic4 = UP_DIV(conv_param->input_channel_, C4NUM);
|
||||
int output_channel = conv_param->output_channel_;
|
||||
int oc4 = UP_DIV(output_channel, C4NUM);
|
||||
int oc8 = UP_DIV(output_channel, C8NUM);
|
||||
int out_w_block = UP_DIV(conv_param->output_w_, OUPUT_UNIT);
|
||||
int out_h_block = UP_DIV(conv_param->output_h_, OUPUT_UNIT);
|
||||
int output_count = out_w_block * out_h_block;
|
||||
int output_tile_count = UP_DIV(output_count, TILE_NUM);
|
||||
int output_tile_count = UP_DIV(output_count, C12NUM);
|
||||
const int input_unit_square = 4 * 4;
|
||||
|
||||
float *tile_buffer = buffer_list[0];
|
||||
float *block_unit_buffer = buffer_list[1];
|
||||
float *tmp_dst_buffer = buffer_list[2];
|
||||
float *nc4hw4_out = buffer_list[3];
|
||||
int tile_buffer_offset = TILE_NUM * input_unit_square * ic4 * C4NUM;
|
||||
float *col_buffer = buffer_list[4];
|
||||
int tile_buffer_offset = C12NUM * input_unit_square * ic4 * C4NUM;
|
||||
int block_unit_buffer_offset = input_unit_square * C4NUM;
|
||||
int tmp_dst_buffer_offset = TILE_NUM * input_unit_square * oc4 * C4NUM;
|
||||
int tmp_dst_buffer_offset = C12NUM * input_unit_square * oc8 * C8NUM;
|
||||
int col_buffer_offset = C12NUM * ic4 * C4NUM;
|
||||
|
||||
int input_batch = conv_param->input_batch_;
|
||||
for (int batch = 0; batch < input_batch; batch++) {
|
||||
|
@ -467,18 +479,22 @@ void Conv3x3Fp32(float *input_data, float *transed_weight, const float *bias_dat
|
|||
int nc4hw4_buffer_offset = batch * oc4 * C4NUM * conv_param->output_h_ * conv_param->output_w_;
|
||||
|
||||
for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_count) {
|
||||
int start_index = thread_id * TILE_NUM;
|
||||
int real_cal_num = (output_count - start_index) < TILE_NUM ? (output_count - start_index) : TILE_NUM;
|
||||
int start_index = thread_id * C12NUM;
|
||||
int real_cal_num = (output_count - start_index) < C12NUM ? (output_count - start_index) : C12NUM;
|
||||
Conv3x3Fp32InputTransform(input_data + in_batch_offset, tile_buffer + task_id * tile_buffer_offset,
|
||||
block_unit_buffer + task_id * block_unit_buffer_offset, start_index, real_cal_num,
|
||||
out_w_block, conv_param);
|
||||
|
||||
gemm_func(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, tile_buffer + task_id * tile_buffer_offset,
|
||||
transed_weight, NULL, input_unit_square, ic4, oc4 * C4NUM,
|
||||
oc4 * C4NUM * input_unit_square * sizeof(float), 1, 1, 0, 0);
|
||||
|
||||
Conv3x3Fp32OutputTransform(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, nc4hw4_out + nc4hw4_buffer_offset,
|
||||
bias_data, start_index, real_cal_num, out_w_block, conv_param);
|
||||
float *src_ptr = tile_buffer + task_id * tile_buffer_offset;
|
||||
float *tmp_col_ptr = col_buffer + task_id * col_buffer_offset;
|
||||
float *dst_ptr = tmp_dst_buffer + task_id * tmp_dst_buffer_offset;
|
||||
for (int i = 0; i < input_unit_square; ++i) {
|
||||
RowMajor2Col12Major(src_ptr + i * C12NUM * ic4 * C4NUM, tmp_col_ptr, C12NUM, ic4 * C4NUM);
|
||||
MatMulOpt(tmp_col_ptr, transed_weight + i * ic4 * C4NUM * oc8 * C8NUM, dst_ptr + i * C8NUM, NULL, 0,
|
||||
ic4 * C4NUM, real_cal_num, oc8 * C8NUM, input_unit_square, 2);
|
||||
}
|
||||
Conv3x3Fp32OutputTransform(dst_ptr, nc4hw4_out + nc4hw4_buffer_offset, bias_data, start_index, real_cal_num,
|
||||
out_w_block, conv_param);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -24,7 +24,6 @@
|
|||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/common_func.h"
|
||||
#include "nnacl/conv_parameter.h"
|
||||
#include "nnacl/fp32/strassen_matmul.h"
|
||||
#include "nnacl/winograd_utils.h"
|
||||
#include "nnacl/fp32/conv_depthwise.h"
|
||||
|
||||
|
@ -52,10 +51,6 @@ void ConvFp32(float *input_data, float *packed_input, float *packed_weight, cons
|
|||
float *tmp_out_block, float *output_data, int task_id, ConvParameter *conv_param,
|
||||
GEMM_FUNC_FP32 gemm_func);
|
||||
|
||||
// fp32 conv1x1 strassen matmul
|
||||
int Conv1x1Fp32(const float *input_data, const float *weight_data, float *output_data, float *tmp_ptr,
|
||||
StrassenMatMulParameter matmul_param);
|
||||
|
||||
// fp32 convolution winograd
|
||||
void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_data, TmpBufferAddress *buffer_list,
|
||||
int task_id, ConvParameter *conv_param, InputTransformUnitFunc input_trans_func,
|
||||
|
@ -70,8 +65,8 @@ void UnPackWinogradRelu6Output(const float *src, float *dst, int batch, int heig
|
|||
int output_unit);
|
||||
|
||||
// fp32 conv3x3
|
||||
void Conv3x3Fp32(float *input_data, float *transed_weight, const float *bias_data, float *output_data,
|
||||
TmpBufferAddress *buffer_list, int task_id, ConvParameter *conv_param, GEMM_FUNC_FP32 gemm_func);
|
||||
void Conv3x3Fp32(float *input_data, float *transed_weight, const float *bias_data, TmpBufferAddress *buffer_list,
|
||||
int task_id, ConvParameter *conv_param, GEMM_FUNC_FP32 gemm_func);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -38,13 +38,15 @@ void ConvDw(float *output_data, const float *input_data, const float *weight_dat
|
|||
int h_step = UP_DIV(conv_param->output_h_, conv_param->thread_num_);
|
||||
int h_start = h_step * task_id;
|
||||
int h_end = MSMIN(h_start + h_step, conv_param->output_h_);
|
||||
bool relu = conv_param->act_type_ == ActType_Relu;
|
||||
bool relu6 = conv_param->act_type_ == ActType_Relu6;
|
||||
for (int b = 0; b < conv_param->output_batch_; b++) {
|
||||
const float *src = input_data + b * conv_param->input_h_ * conv_param->input_w_ * conv_param->input_channel_;
|
||||
float *dst = output_data + b * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_channel_;
|
||||
for (int oh = h_start; oh < h_end; oh++) {
|
||||
float *dst_data = dst + oh * conv_param->output_w_ * conv_param->output_channel_;
|
||||
|
||||
int ih_origin = oh * conv_param->stride_h_ - conv_param->pad_h_;
|
||||
int ih_origin = oh * conv_param->stride_h_ - conv_param->pad_u_;
|
||||
int start_kh = MSMAX(0, UP_DIV(-ih_origin, conv_param->dilation_h_));
|
||||
int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih_origin, conv_param->dilation_h_));
|
||||
|
||||
|
@ -60,13 +62,13 @@ void ConvDw(float *output_data, const float *input_data, const float *weight_dat
|
|||
int in_sw_step = conv_param->stride_w_ * conv_param->input_channel_;
|
||||
for (int kw = 0; kw < conv_param->kernel_w_; kw++) {
|
||||
int out_w_start = MSMAX(
|
||||
0, (conv_param->pad_w_ - conv_param->dilation_w_ * kw + conv_param->stride_w_ - 1) / conv_param->stride_w_);
|
||||
int out_w_end = MSMIN(conv_param->output_w_, (conv_param->input_w_ + conv_param->pad_w_ -
|
||||
0, (conv_param->pad_l_ - conv_param->dilation_w_ * kw + conv_param->stride_w_ - 1) / conv_param->stride_w_);
|
||||
int out_w_end = MSMIN(conv_param->output_w_, (conv_param->input_w_ + conv_param->pad_l_ -
|
||||
conv_param->dilation_w_ * kw + conv_param->stride_w_ - 1) /
|
||||
conv_param->stride_w_);
|
||||
|
||||
float *dst_w = dst_data + out_w_start * conv_param->output_channel_;
|
||||
int iw_origin = (out_w_start * conv_param->stride_w_) - conv_param->pad_w_ + conv_param->dilation_w_ * kw;
|
||||
int iw_origin = (out_w_start * conv_param->stride_w_) - conv_param->pad_l_ + conv_param->dilation_w_ * kw;
|
||||
|
||||
const float *src_kw = src_kh + iw_origin * conv_param->input_channel_;
|
||||
int num_pixels = out_w_end - out_w_start;
|
||||
|
@ -75,10 +77,10 @@ void ConvDw(float *output_data, const float *input_data, const float *weight_dat
|
|||
weight_kh += conv_param->output_channel_;
|
||||
}
|
||||
}
|
||||
if (conv_param->is_relu_) {
|
||||
if (relu) {
|
||||
ReluFp32(dst_data, dst_data, conv_param->output_w_ * conv_param->output_channel_);
|
||||
}
|
||||
if (conv_param->is_relu6_) {
|
||||
if (relu6) {
|
||||
Relu6Fp32(dst_data, dst_data, conv_param->output_w_ * conv_param->output_channel_);
|
||||
}
|
||||
}
|
||||
|
@ -91,16 +93,16 @@ void InitSlidingParam(SlidingWindowParam *sliding, const ConvParameter *conv_par
|
|||
int top = 0;
|
||||
int bottom = conv_param->output_h_;
|
||||
|
||||
for (; left * conv_param->stride_w_ < conv_param->pad_w_; left++) {
|
||||
for (; left * conv_param->stride_w_ < conv_param->pad_l_; left++) {
|
||||
}
|
||||
for (; (right - 1) * conv_param->stride_w_ - conv_param->pad_w_ + conv_param->kernel_w_ * conv_param->dilation_w_ >
|
||||
for (; (right - 1) * conv_param->stride_w_ - conv_param->pad_l_ + conv_param->kernel_w_ * conv_param->dilation_w_ >
|
||||
conv_param->input_w_ &&
|
||||
right > left;
|
||||
right--) {
|
||||
}
|
||||
for (; top * conv_param->stride_h_ < conv_param->pad_h_; top++) {
|
||||
for (; top * conv_param->stride_h_ < conv_param->pad_u_; top++) {
|
||||
}
|
||||
for (; (bottom - 1) * conv_param->stride_h_ - conv_param->pad_h_ + conv_param->kernel_h_ * conv_param->dilation_h_ >
|
||||
for (; (bottom - 1) * conv_param->stride_h_ - conv_param->pad_u_ + conv_param->kernel_h_ * conv_param->dilation_h_ >
|
||||
conv_param->input_h_ &&
|
||||
bottom > top;
|
||||
bottom--) {
|
||||
|
@ -181,16 +183,18 @@ void DepthwiseBorderPixel(float *dst, const float *src, const float *weight, con
|
|||
|
||||
void DepthwiseBorder(float *dst, const float *src, const float *weight, const float *bias, int top, int bottom,
|
||||
int left, int right, const ConvParameter *conv_param, const SlidingWindowParam *sliding) {
|
||||
bool relu = conv_param->act_type_ == ActType_Relu;
|
||||
bool relu6 = conv_param->act_type_ == ActType_Relu6;
|
||||
float *dst_h = dst + top * sliding->out_h_step_;
|
||||
for (int oh = top; oh < bottom; oh++) {
|
||||
int ih = oh * conv_param->stride_h_ - conv_param->pad_h_;
|
||||
int ih = oh * conv_param->stride_h_ - conv_param->pad_u_;
|
||||
int start_kh = MSMAX(0, UP_DIV(-ih, conv_param->dilation_h_));
|
||||
int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih, conv_param->dilation_h_));
|
||||
const float *src_h = src + ih * sliding->in_h_step_;
|
||||
|
||||
float *dst_kernel = dst_h + left * sliding->block_channel_;
|
||||
for (int ow = left; ow < right; ow++) {
|
||||
int iw = ow * conv_param->stride_w_ - conv_param->pad_w_;
|
||||
int iw = ow * conv_param->stride_w_ - conv_param->pad_l_;
|
||||
int start_kw = MSMAX(0, UP_DIV(-iw, conv_param->dilation_w_));
|
||||
int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->input_w_ - iw, conv_param->dilation_w_));
|
||||
const float *src_w = src_h + iw * sliding->block_channel_;
|
||||
|
@ -201,11 +205,10 @@ void DepthwiseBorder(float *dst, const float *src, const float *weight, const fl
|
|||
#ifdef ENABLE_ARM64
|
||||
ConvDwFp32Border(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw,
|
||||
sliding->in_kh_step_ * sizeof(float), sliding->in_kw_step_ * sizeof(float),
|
||||
conv_param->kernel_w_ * C4NUM * sizeof(float), conv_param->is_relu_, conv_param->is_relu6_);
|
||||
conv_param->kernel_w_ * C4NUM * sizeof(float), relu, relu6);
|
||||
#else
|
||||
DepthwiseBorderPixel(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw,
|
||||
sliding->in_kh_step_, sliding->in_kw_step_, conv_param->kernel_w_ * C4NUM,
|
||||
conv_param->is_relu_, conv_param->is_relu6_);
|
||||
sliding->in_kh_step_, sliding->in_kw_step_, conv_param->kernel_w_ * C4NUM, relu, relu6);
|
||||
#endif
|
||||
dst_kernel += sliding->block_channel_;
|
||||
} // width loop
|
||||
|
@ -259,6 +262,8 @@ void DepthwiseCenter(float *dst, const float *src, const float *weight, const fl
|
|||
// conv depthwise fp32: sliding window
|
||||
void ConvDwC4Fp32(float *output_data, const float *input_data, const float *weight_data, const float *bias_data,
|
||||
const ConvParameter *conv_param, const SlidingWindowParam *sliding, int task_id) {
|
||||
bool relu = conv_param->act_type_ == ActType_Relu;
|
||||
bool relu6 = conv_param->act_type_ == ActType_Relu6;
|
||||
const float *src = input_data;
|
||||
float *dst = output_data;
|
||||
for (int b = 0; b < conv_param->output_batch_; b++) {
|
||||
|
@ -277,8 +282,8 @@ void ConvDwC4Fp32(float *output_data, const float *input_data, const float *weig
|
|||
conv_param->output_w_, conv_param, sliding);
|
||||
|
||||
if (sliding->right_ > sliding->left_ && sliding->bottom_ > sliding->top_) {
|
||||
int in_h_start = sliding->top_ * conv_param->stride_h_ - conv_param->pad_h_;
|
||||
int in_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_w_;
|
||||
int in_h_start = sliding->top_ * conv_param->stride_h_ - conv_param->pad_u_;
|
||||
int in_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_l_;
|
||||
const float *in_t = src_data + in_h_start * sliding->in_h_step_ + in_w_start * sliding->block_channel_;
|
||||
float *out_t = dst_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_;
|
||||
#ifdef ENABLE_ARM64
|
||||
|
@ -286,12 +291,12 @@ void ConvDwC4Fp32(float *output_data, const float *input_data, const float *weig
|
|||
conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_ * sizeof(float),
|
||||
sliding->block_channel_ * sizeof(float), sliding->in_sh_step_ * sizeof(float),
|
||||
sliding->in_sw_step_ * sizeof(float), sliding->in_kh_step_ * sizeof(float),
|
||||
sliding->in_kw_step_ * sizeof(float), conv_param->is_relu_, conv_param->is_relu6_);
|
||||
sliding->in_kw_step_ * sizeof(float), relu, relu6);
|
||||
#else
|
||||
DepthwiseCenter(out_t, in_t, weight, bias, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_,
|
||||
conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_, sliding->block_channel_,
|
||||
sliding->in_sh_step_, sliding->in_sw_step_, sliding->in_kh_step_, sliding->in_kw_step_,
|
||||
conv_param->is_relu_, conv_param->is_relu6_);
|
||||
sliding->in_sh_step_, sliding->in_sw_step_, sliding->in_kh_step_, sliding->in_kw_step_, relu,
|
||||
relu6);
|
||||
#endif
|
||||
}
|
||||
} // output C4 loop
|
||||
|
@ -302,399 +307,6 @@ void ConvDwC4Fp32(float *output_data, const float *input_data, const float *weig
|
|||
}
|
||||
/*conv depthwise fp32 end*/
|
||||
|
||||
/*conv depthwise 3x3 fp32 begin*/
|
||||
void ConvDw3x3Fp32FilterTrans(float *trans_weight, float *weight, int oc4) {
|
||||
for (int c = 0; c < oc4; c++) {
|
||||
float *src = weight + c * C4NUM * 9;
|
||||
float *dst = trans_weight + c * C4NUM * 16;
|
||||
#ifdef ENABLE_ARM
|
||||
float32x4_t g00 = vld1q_f32(src);
|
||||
float32x4_t g01 = vld1q_f32(src + 4);
|
||||
float32x4_t g02 = vld1q_f32(src + 2 * 4);
|
||||
float32x4_t g10 = vld1q_f32(src + 3 * 4);
|
||||
float32x4_t g11 = vld1q_f32(src + 4 * 4);
|
||||
float32x4_t g12 = vld1q_f32(src + 5 * 4);
|
||||
float32x4_t g20 = vld1q_f32(src + 6 * 4);
|
||||
float32x4_t g21 = vld1q_f32(src + 7 * 4);
|
||||
float32x4_t g22 = vld1q_f32(src + 8 * 4);
|
||||
|
||||
float32x4_t dst00 = g00;
|
||||
float32x4_t dst01 = g01;
|
||||
float32x4_t dst02 = g02;
|
||||
|
||||
float32x4_t dst10 = vaddq_f32(vmulq_n_f32(g00, 0.5), vmulq_n_f32(g10, 0.5));
|
||||
dst10 = vaddq_f32(dst10, vmulq_n_f32(g20, 0.5));
|
||||
float32x4_t dst11 = vaddq_f32(vmulq_n_f32(g01, 0.5), vmulq_n_f32(g11, 0.5));
|
||||
dst11 = vaddq_f32(dst11, vmulq_n_f32(g21, 0.5));
|
||||
float32x4_t dst12 = vaddq_f32(vmulq_n_f32(g02, 0.5), vmulq_n_f32(g12, 0.5));
|
||||
dst12 = vaddq_f32(dst12, vmulq_n_f32(g22, 0.5));
|
||||
|
||||
float32x4_t dst20 = vsubq_f32(vmulq_n_f32(g00, 0.5), vmulq_n_f32(g10, 0.5));
|
||||
dst20 = vaddq_f32(dst20, vmulq_n_f32(g20, 0.5));
|
||||
float32x4_t dst21 = vsubq_f32(vmulq_n_f32(g01, 0.5), vmulq_n_f32(g11, 0.5));
|
||||
dst21 = vaddq_f32(dst21, vmulq_n_f32(g21, 0.5));
|
||||
float32x4_t dst22 = vsubq_f32(vmulq_n_f32(g02, 0.5), vmulq_n_f32(g12, 0.5));
|
||||
dst22 = vaddq_f32(dst22, vmulq_n_f32(g22, 0.5));
|
||||
|
||||
float32x4_t dst30 = g20;
|
||||
float32x4_t dst31 = g21;
|
||||
float32x4_t dst32 = g22;
|
||||
|
||||
float32x4_t m00 = dst00;
|
||||
float32x4_t m01 = vaddq_f32(vmulq_n_f32(dst00, 0.5), vmulq_n_f32(dst01, 0.5));
|
||||
m01 = vaddq_f32(m01, vmulq_n_f32(dst02, 0.5));
|
||||
float32x4_t m02 = vsubq_f32(vmulq_n_f32(dst00, 0.5), vmulq_n_f32(dst01, 0.5));
|
||||
m02 = vaddq_f32(m02, vmulq_n_f32(dst02, 0.5));
|
||||
float32x4_t m03 = dst02;
|
||||
|
||||
float32x4_t m10 = dst10;
|
||||
float32x4_t m11 = vaddq_f32(vmulq_n_f32(dst10, 0.5), vmulq_n_f32(dst11, 0.5));
|
||||
m11 = vaddq_f32(m11, vmulq_n_f32(dst12, 0.5));
|
||||
float32x4_t m12 = vsubq_f32(vmulq_n_f32(dst10, 0.5), vmulq_n_f32(dst11, 0.5));
|
||||
m12 = vaddq_f32(m12, vmulq_n_f32(dst12, 0.5));
|
||||
float32x4_t m13 = dst12;
|
||||
|
||||
float32x4_t m20 = dst20;
|
||||
float32x4_t m21 = vaddq_f32(vmulq_n_f32(dst20, 0.5), vmulq_n_f32(dst21, 0.5));
|
||||
m21 = vaddq_f32(m21, vmulq_n_f32(dst22, 0.5));
|
||||
float32x4_t m22 = vsubq_f32(vmulq_n_f32(dst20, 0.5), vmulq_n_f32(dst21, 0.5));
|
||||
m22 = vaddq_f32(m22, vmulq_n_f32(dst22, 0.5));
|
||||
float32x4_t m23 = dst22;
|
||||
|
||||
float32x4_t m30 = dst30;
|
||||
float32x4_t m31 = vaddq_f32(vmulq_n_f32(dst30, 0.5), vmulq_n_f32(dst31, 0.5));
|
||||
m31 = vaddq_f32(m31, vmulq_n_f32(dst32, 0.5));
|
||||
float32x4_t m32 = vsubq_f32(vmulq_n_f32(dst30, 0.5), vmulq_n_f32(dst31, 0.5));
|
||||
m32 = vaddq_f32(m32, vmulq_n_f32(dst32, 0.5));
|
||||
float32x4_t m33 = dst32;
|
||||
|
||||
vst1q_f32(dst, m00);
|
||||
vst1q_f32(dst + 4, m01);
|
||||
vst1q_f32(dst + 8, m02);
|
||||
vst1q_f32(dst + 12, m03);
|
||||
vst1q_f32(dst + 16, m10);
|
||||
vst1q_f32(dst + 20, m11);
|
||||
vst1q_f32(dst + 24, m12);
|
||||
vst1q_f32(dst + 28, m13);
|
||||
vst1q_f32(dst + 32, m20);
|
||||
vst1q_f32(dst + 36, m21);
|
||||
vst1q_f32(dst + 40, m22);
|
||||
vst1q_f32(dst + 44, m23);
|
||||
vst1q_f32(dst + 48, m30);
|
||||
vst1q_f32(dst + 52, m31);
|
||||
vst1q_f32(dst + 56, m32);
|
||||
vst1q_f32(dst + 60, m33);
|
||||
#else
|
||||
for (int j = 0; j < C4NUM; j++) {
|
||||
float *local_ptr = src + j;
|
||||
float dst00 = local_ptr[0];
|
||||
float dst01 = (local_ptr + 4)[0];
|
||||
float dst02 = (local_ptr + 8)[0];
|
||||
|
||||
const float dst10 = 0.5f * local_ptr[0] + 0.5f * (local_ptr + 12)[0] + 0.5f * (local_ptr + 24)[0];
|
||||
const float dst11 = 0.5f * (local_ptr + 4)[0] + 0.5f * (local_ptr + 16)[0] + 0.5f * (local_ptr + 28)[0];
|
||||
const float dst12 = 0.5f * (local_ptr + 8)[0] + 0.5f * (local_ptr + 20)[0] + 0.5f * (local_ptr + 32)[0];
|
||||
|
||||
const float dst20 = 0.5f * local_ptr[0] - 0.5f * (local_ptr + 12)[0] + 0.5f * (local_ptr + 24)[0];
|
||||
const float dst21 = 0.5f * (local_ptr + 4)[0] - 0.5f * (local_ptr + 16)[0] + 0.5f * (local_ptr + 28)[0];
|
||||
const float dst22 = 0.5f * (local_ptr + 8)[0] - 0.5f * (local_ptr + 20)[0] + 0.5f * (local_ptr + 32)[0];
|
||||
|
||||
float dst30 = (local_ptr + 24)[0];
|
||||
float dst31 = (local_ptr + 28)[0];
|
||||
float dst32 = (local_ptr + 32)[0];
|
||||
|
||||
float m00 = dst00;
|
||||
const float m01 = 0.5f * dst00 + 0.5f * dst01 + 0.5f * dst02;
|
||||
const float m02 = 0.5f * dst00 - 0.5f * dst01 + 0.5f * dst02;
|
||||
float m03 = dst02;
|
||||
|
||||
float m10 = dst10;
|
||||
const float m11 = 0.5f * dst10 + 0.5f * dst11 + 0.5f * dst12;
|
||||
const float m12 = 0.5f * dst10 - 0.5f * dst11 + 0.5f * dst12;
|
||||
float m13 = dst12;
|
||||
|
||||
float m20 = dst20;
|
||||
const float m21 = 0.5f * dst20 + 0.5f * dst21 + 0.5f * dst22;
|
||||
const float m22 = 0.5f * dst20 - 0.5f * dst21 + 0.5f * dst22;
|
||||
float m23 = dst22;
|
||||
|
||||
float m30 = dst30;
|
||||
const float m31 = 0.5f * dst30 + 0.5f * dst31 + 0.5f * dst32;
|
||||
const float m32 = 0.5f * dst30 - 0.5f * dst31 + 0.5f * dst32;
|
||||
float m33 = dst32;
|
||||
|
||||
*(dst + j) = m00;
|
||||
*(dst + j + 4) = m01;
|
||||
*(dst + j + 8) = m02;
|
||||
*(dst + j + 12) = m03;
|
||||
|
||||
*(dst + j + 16) = m10;
|
||||
*(dst + j + 20) = m11;
|
||||
*(dst + j + 24) = m12;
|
||||
*(dst + j + 28) = m13;
|
||||
|
||||
*(dst + j + 32) = m20;
|
||||
*(dst + j + 36) = m21;
|
||||
*(dst + j + 40) = m22;
|
||||
*(dst + j + 44) = m23;
|
||||
|
||||
*(dst + j + 48) = m30;
|
||||
*(dst + j + 52) = m31;
|
||||
*(dst + j + 56) = m32;
|
||||
*(dst + j + 60) = m33;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
void ConvDw3x3Fp32InputTrans(const float *input_data, float *trans_input, float *block_buffer, int out_h_block,
|
||||
int out_w_block, const ConvParameter *conv_param) {
|
||||
int ic4 = UP_DIV(conv_param->input_channel_, C4NUM);
|
||||
const int input_unit = 4;
|
||||
memset(trans_input, 0, out_h_block * out_h_block * 16 * C4NUM * sizeof(float));
|
||||
|
||||
for (int oh = 0; oh < out_h_block; oh++) {
|
||||
int ih = oh * 2 - conv_param->pad_h_;
|
||||
int real_h_start = ih > 0 ? 0 : -ih;
|
||||
int real_h_end = (ih + input_unit) < conv_param->input_h_ ? input_unit : (conv_param->input_h_ - ih);
|
||||
for (int ow = 0; ow < out_w_block; ow++) {
|
||||
int iw = ow * 2 - conv_param->pad_w_;
|
||||
int real_w_start = iw > 0 ? 0 : -iw;
|
||||
int real_w_end = (iw + input_unit) < conv_param->input_w_ ? input_unit : (conv_param->input_w_ - iw);
|
||||
|
||||
memset(block_buffer, 0, 16 * C4NUM * sizeof(float));
|
||||
int src_plane_offset = ic4 * C4NUM * (ih * conv_param->input_w_ + iw);
|
||||
for (int h = real_h_start; h < real_h_end; h++) {
|
||||
int src_h_offset = src_plane_offset + (h * conv_param->input_w_) * ic4 * C4NUM;
|
||||
int dst_h_offset = (h * input_unit) * C4NUM;
|
||||
for (int w = real_w_start; w < real_w_end; w++) {
|
||||
int src_w_offset = src_h_offset + w * ic4 * C4NUM;
|
||||
int dst_w_offset = dst_h_offset + w * C4NUM;
|
||||
float *src_addr = (float *)(input_data) + src_w_offset;
|
||||
float *dst_addr = block_buffer + dst_w_offset;
|
||||
#ifdef ENABLE_NEON
|
||||
vst1q_f32(dst_addr, vld1q_f32(src_addr));
|
||||
#else
|
||||
for (int k = 0; k < C4NUM; k++) {
|
||||
(dst_addr + k)[0] = (src_addr + k)[0];
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
int trans_offset = (oh * out_w_block + ow) * 16 * C4NUM;
|
||||
Conv3x3Fp32InputUnit(block_buffer, trans_input + trans_offset, C4NUM);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ConvDw3x3Fp32Winograd(float *trans_buffer, const float *weight, int out_h_block, int out_w_block) {
|
||||
const int unit = 4;
|
||||
for (int oh = 0; oh < out_h_block; oh++) {
|
||||
float *buf_oh = trans_buffer + oh * out_w_block * 16 * C4NUM;
|
||||
for (int ow = 0; ow < out_w_block; ow++) {
|
||||
float *buf_ow = buf_oh + ow * 16 * C4NUM;
|
||||
for (int kh = 0; kh < unit; kh++) {
|
||||
float *buf_kh = buf_ow + kh * unit * C4NUM;
|
||||
const float *weight_kh = weight + kh * unit * C4NUM;
|
||||
for (int kw = 0; kw < unit; kw++) {
|
||||
float *buf_kw = buf_kh + kw * C4NUM;
|
||||
const float *weight_kw = weight_kh + kw * C4NUM;
|
||||
for (int c = 0; c < C4NUM; c++) {
|
||||
buf_kw[c] = buf_kw[c] * weight_kw[c];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ConvDw3x3Fp32OutputUnit(float *src_buf, float *dst_output, const float *bias, int channel, int output_w,
|
||||
bool h_in_range, bool w_in_range, bool is_relu, bool is_relu6) {
|
||||
#ifdef ENABLE_ARM
|
||||
float32x4_t bias_ptr = vld1q_f32(bias);
|
||||
|
||||
float32x4_t s00 = vld1q_f32(src_buf);
|
||||
float32x4_t s01 = vld1q_f32(src_buf + 4);
|
||||
float32x4_t s02 = vld1q_f32(src_buf + 8);
|
||||
float32x4_t s03 = vld1q_f32(src_buf + 12);
|
||||
|
||||
float32x4_t s10 = vld1q_f32(src_buf + 16);
|
||||
float32x4_t s11 = vld1q_f32(src_buf + 20);
|
||||
float32x4_t s12 = vld1q_f32(src_buf + 24);
|
||||
float32x4_t s13 = vld1q_f32(src_buf + 28);
|
||||
|
||||
float32x4_t s20 = vld1q_f32(src_buf + 32);
|
||||
float32x4_t s21 = vld1q_f32(src_buf + 36);
|
||||
float32x4_t s22 = vld1q_f32(src_buf + 40);
|
||||
float32x4_t s23 = vld1q_f32(src_buf + 44);
|
||||
|
||||
float32x4_t s30 = vld1q_f32(src_buf + 48);
|
||||
float32x4_t s31 = vld1q_f32(src_buf + 52);
|
||||
float32x4_t s32 = vld1q_f32(src_buf + 56);
|
||||
float32x4_t s33 = vld1q_f32(src_buf + 60);
|
||||
|
||||
float32x4_t t00 = vaddq_f32(vaddq_f32(s00, s10), s20);
|
||||
float32x4_t t01 = vaddq_f32(vaddq_f32(s01, s11), s21);
|
||||
float32x4_t t02 = vaddq_f32(vaddq_f32(s02, s12), s22);
|
||||
float32x4_t t03 = vaddq_f32(vaddq_f32(s03, s13), s23);
|
||||
|
||||
float32x4_t t10 = vsubq_f32(vsubq_f32(s10, s20), s30);
|
||||
float32x4_t t11 = vsubq_f32(vsubq_f32(s11, s21), s31);
|
||||
float32x4_t t12 = vsubq_f32(vsubq_f32(s12, s22), s32);
|
||||
float32x4_t t13 = vsubq_f32(vsubq_f32(s13, s23), s33);
|
||||
|
||||
float32x4_t d00 = vaddq_f32(vaddq_f32(vaddq_f32(t00, t01), t02), bias_ptr);
|
||||
float32x4_t d01 = vaddq_f32(vsubq_f32(vsubq_f32(t01, t02), t03), bias_ptr);
|
||||
float32x4_t d10 = vaddq_f32(vaddq_f32(vaddq_f32(t10, t11), t12), bias_ptr);
|
||||
float32x4_t d11 = vaddq_f32(vsubq_f32(vsubq_f32(t11, t12), t13), bias_ptr);
|
||||
|
||||
float32x4_t zeros = {0, 0, 0, 0};
|
||||
float32x4_t bounds = {6, 6, 6, 6};
|
||||
if (is_relu) {
|
||||
d00 = vmaxq_f32(d00, zeros);
|
||||
d01 = vmaxq_f32(d01, zeros);
|
||||
d10 = vmaxq_f32(d10, zeros);
|
||||
d11 = vmaxq_f32(d11, zeros);
|
||||
}
|
||||
if (is_relu6) {
|
||||
d00 = vminq_f32(vmaxq_f32(d00, zeros), bounds);
|
||||
d01 = vminq_f32(vmaxq_f32(d01, zeros), bounds);
|
||||
d10 = vminq_f32(vmaxq_f32(d10, zeros), bounds);
|
||||
d11 = vminq_f32(vmaxq_f32(d11, zeros), bounds);
|
||||
}
|
||||
|
||||
vst1q_f32(dst_output, d00);
|
||||
if (w_in_range) {
|
||||
vst1q_f32(dst_output + channel, d01);
|
||||
}
|
||||
if (h_in_range) {
|
||||
vst1q_f32(dst_output + output_w * channel, d10);
|
||||
if (w_in_range) {
|
||||
vst1q_f32(dst_output + output_w * channel + channel, d11);
|
||||
}
|
||||
}
|
||||
#else
|
||||
for (int i = 0; i < C4NUM; i++) {
|
||||
const float *local_ptr = src_buf + i;
|
||||
const float *bias_ptr = bias + i;
|
||||
|
||||
float s00 = local_ptr[0];
|
||||
float s01 = (local_ptr + 4)[0];
|
||||
float s02 = (local_ptr + 8)[0];
|
||||
float s03 = (local_ptr + 12)[0];
|
||||
|
||||
float s10 = (local_ptr + 16)[0];
|
||||
float s11 = (local_ptr + 20)[0];
|
||||
float s12 = (local_ptr + 24)[0];
|
||||
float s13 = (local_ptr + 28)[0];
|
||||
|
||||
float s20 = (local_ptr + 32)[0];
|
||||
float s21 = (local_ptr + 36)[0];
|
||||
float s22 = (local_ptr + 40)[0];
|
||||
float s23 = (local_ptr + 44)[0];
|
||||
|
||||
float s30 = (local_ptr + 48)[0];
|
||||
float s31 = (local_ptr + 52)[0];
|
||||
float s32 = (local_ptr + 56)[0];
|
||||
float s33 = (local_ptr + 60)[0];
|
||||
|
||||
float t00 = s00 + s10 + s20;
|
||||
float t01 = s01 + s11 + s21;
|
||||
float t02 = s02 + s12 + s22;
|
||||
float t03 = s03 + s13 + s23;
|
||||
|
||||
float t10 = s10 - s20 - s30;
|
||||
float t11 = s11 - s21 - s31;
|
||||
float t12 = s12 - s22 - s32;
|
||||
float t13 = s13 - s23 - s33;
|
||||
|
||||
float d00 = t00 + t01 + t02 + bias_ptr[0];
|
||||
float d01 = t01 - t02 - t03 + bias_ptr[0];
|
||||
float d10 = t10 + t11 + t12 + bias_ptr[0];
|
||||
float d11 = t11 - t12 - t13 + bias_ptr[0];
|
||||
|
||||
if (is_relu) {
|
||||
d00 = MSMAX(d00, 0);
|
||||
d01 = MSMAX(d01, 0);
|
||||
d10 = MSMAX(d10, 0);
|
||||
d11 = MSMAX(d11, 0);
|
||||
}
|
||||
if (is_relu6) {
|
||||
d00 = MSMIN(MSMAX(d00, 0), 6);
|
||||
d01 = MSMIN(MSMAX(d01, 0), 6);
|
||||
d10 = MSMIN(MSMAX(d10, 0), 6);
|
||||
d11 = MSMIN(MSMAX(d11, 0), 6);
|
||||
}
|
||||
|
||||
(dst_output + i)[0] = d00;
|
||||
if (w_in_range) {
|
||||
(dst_output + i + channel)[0] = d01;
|
||||
}
|
||||
if (h_in_range) {
|
||||
(dst_output + i + output_w * channel)[0] = d10;
|
||||
if (w_in_range) {
|
||||
(dst_output + i + output_w * channel + channel)[0] = d11;
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void ConvDw3x3Fp32OutputTrans(float *trans_buffer, float *output_data, const float *bias, int out_h_block,
|
||||
int out_w_block, const ConvParameter *conv_param) {
|
||||
int oc4 = UP_DIV(conv_param->output_channel_, C4NUM);
|
||||
bool h_in_range = true;
|
||||
for (int oh = 0; oh < out_h_block; oh++) {
|
||||
const int real_oh = 2 * oh;
|
||||
if ((oh + 1) * 2 > conv_param->output_h_) {
|
||||
h_in_range = false;
|
||||
}
|
||||
bool w_in_range = true;
|
||||
float *buf_oh = trans_buffer + oh * out_w_block * 16 * C4NUM;
|
||||
float *output_oh = output_data + real_oh * conv_param->output_w_ * oc4 * C4NUM;
|
||||
|
||||
for (int ow = 0; ow < out_w_block; ow++) {
|
||||
const int real_ow = 2 * ow;
|
||||
if ((ow + 1) * 2 > conv_param->output_w_) {
|
||||
w_in_range = false;
|
||||
}
|
||||
float *buf_ow = buf_oh + ow * 16 * C4NUM;
|
||||
float *output_ow = output_oh + real_ow * oc4 * C4NUM;
|
||||
|
||||
ConvDw3x3Fp32OutputUnit(buf_ow, output_ow, bias, oc4 * C4NUM, conv_param->output_w_, h_in_range, w_in_range,
|
||||
conv_param->is_relu_, conv_param->is_relu6_);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ConvDw3x3Fp32(float *output_data, const float *input_data, const float *weight_data, const float *bias_data,
|
||||
float *trans_buffer, float *block_buffer, const ConvParameter *conv_param, int task_id) {
|
||||
int thread_count = conv_param->thread_num_;
|
||||
int output_channel = conv_param->output_channel_;
|
||||
int oc4 = UP_DIV(output_channel, C4NUM);
|
||||
int out_h_block = UP_DIV(conv_param->output_h_, 2);
|
||||
int out_w_block = UP_DIV(conv_param->output_w_, 2);
|
||||
|
||||
int input_batch = conv_param->input_batch_;
|
||||
for (int batch = 0; batch < input_batch; batch++) {
|
||||
const float *input = input_data + batch * conv_param->input_h_ * conv_param->input_w_ *
|
||||
UP_DIV(conv_param->input_channel_, C4NUM) * C4NUM;
|
||||
float *output = output_data + batch * conv_param->output_h_ * conv_param->output_w_ *
|
||||
UP_DIV(conv_param->output_channel_, C4NUM) * C4NUM;
|
||||
for (int oc = task_id; oc < oc4; oc += thread_count) {
|
||||
const float *weight = weight_data + oc * 16 * C4NUM;
|
||||
const float *bias = bias_data + oc * C4NUM;
|
||||
|
||||
ConvDw3x3Fp32InputTrans(input + oc * C4NUM, trans_buffer, block_buffer, out_h_block, out_w_block, conv_param);
|
||||
|
||||
ConvDw3x3Fp32Winograd(trans_buffer, weight, out_h_block, out_w_block);
|
||||
|
||||
ConvDw3x3Fp32OutputTrans(trans_buffer, output + oc * C4NUM, bias, out_h_block, out_w_block, conv_param);
|
||||
}
|
||||
}
|
||||
}
|
||||
/*conv depthwise 3x3 fp32 end*/
|
||||
|
||||
/*deconv depthwise fp32 begin*/
|
||||
void DeconvDepthwiseBorderPixel(float *dst, const float *src, const float *weight, int height, int width,
|
||||
int in_kh_step, int in_kw_step, int kernel_w_step) {
|
||||
|
@ -727,14 +339,14 @@ void DeconvDepthwiseBorder(float *dst, const float *src, const float *weight, in
|
|||
const ConvParameter *conv_param, const SlidingWindowParam *sliding) {
|
||||
const float *src_h = src + top * sliding->out_h_step_;
|
||||
for (int ih = top; ih < bottom; ih++) {
|
||||
int oh = ih * conv_param->stride_h_ - conv_param->pad_h_;
|
||||
int oh = ih * conv_param->stride_h_ - conv_param->pad_u_;
|
||||
int start_kh = MSMAX(0, UP_DIV(-oh, conv_param->dilation_h_));
|
||||
int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->output_h_ - oh, conv_param->dilation_h_));
|
||||
float *dst_h = dst + oh * sliding->in_h_step_;
|
||||
|
||||
const float *src_kernel = src_h + left * sliding->block_channel_;
|
||||
for (int iw = left; iw < right; iw++) {
|
||||
int ow = iw * conv_param->stride_w_ - conv_param->pad_w_;
|
||||
int ow = iw * conv_param->stride_w_ - conv_param->pad_l_;
|
||||
int start_kw = MSMAX(0, UP_DIV(-ow, conv_param->dilation_w_));
|
||||
int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->output_w_ - ow, conv_param->dilation_w_));
|
||||
float *dst_w = dst_h + ow * sliding->block_channel_;
|
||||
|
@ -790,12 +402,14 @@ void DeconvDepthwiseCenter(float *dst, const float *src, const float *weight, in
|
|||
#endif
|
||||
|
||||
void DeconvDepthwisePostFunc(float *dst, const float *bias, int block_channel, const ConvParameter *conv_param) {
|
||||
bool relu = conv_param->act_type_ == ActType_Relu;
|
||||
bool relu6 = conv_param->act_type_ == ActType_Relu6;
|
||||
float *dst_k = dst;
|
||||
for (int k = 0; k < conv_param->output_h_ * conv_param->output_w_; k++) {
|
||||
for (int c = 0; c < C4NUM; c++) {
|
||||
dst_k[c] += bias[c];
|
||||
dst_k[c] = (conv_param->is_relu_) ? (MSMAX(0, dst_k[c])) : (dst_k[c]);
|
||||
dst_k[c] = (conv_param->is_relu6_) ? (MSMIN(6, MSMAX(0, dst_k[c]))) : (dst_k[c]);
|
||||
dst_k[c] = (relu) ? (MSMAX(0, dst_k[c])) : (dst_k[c]);
|
||||
dst_k[c] = (relu6) ? (MSMIN(6, MSMAX(0, dst_k[c]))) : (dst_k[c]);
|
||||
}
|
||||
dst_k += block_channel;
|
||||
}
|
||||
|
@ -821,8 +435,8 @@ void DeconvDwC4Fp32(float *output_data, const float *input_data, const float *we
|
|||
conv_param->input_w_, conv_param, sliding);
|
||||
|
||||
if (sliding->right_ > sliding->left_ && sliding->bottom_ > sliding->top_) {
|
||||
int oh_h_start = sliding->top_ * conv_param->stride_h_ - conv_param->pad_h_;
|
||||
int oh_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_w_;
|
||||
int oh_h_start = sliding->top_ * conv_param->stride_h_ - conv_param->pad_u_;
|
||||
int oh_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_l_;
|
||||
float *out_t = dst_data + oh_h_start * sliding->in_h_step_ + oh_w_start * sliding->block_channel_;
|
||||
const float *in_t = src_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_;
|
||||
|
||||
|
|
|
@ -48,11 +48,6 @@ void DepthwiseBorder(float *dst, const float *src, const float *weight, const fl
|
|||
void ConvDwC4Fp32(float *output_data, const float *input_data, const float *weight_data, const float *bias_data,
|
||||
const ConvParameter *conv_param, const SlidingWindowParam *sliding, int task_id);
|
||||
|
||||
void ConvDw3x3Fp32FilterTrans(float *trans_weight, float *weight, int oc4);
|
||||
|
||||
void ConvDw3x3Fp32(float *output_data, const float *input_data, const float *weight_data, const float *bias_data,
|
||||
float *trans_buffer, float *block_buffer, const ConvParameter *conv_param, int task_id);
|
||||
|
||||
void DeconvDwC4Fp32(float *output_data, const float *input_data, const float *weight_data, const float *bias_data,
|
||||
const ConvParameter *conv_param, const SlidingWindowParam *sliding, int task_id);
|
||||
|
||||
|
|
|
@ -33,18 +33,18 @@ void PackDeConvWeightFp32(const float *weight, float *dst, int input_channel, in
|
|||
return;
|
||||
}
|
||||
|
||||
int DeConvPostFp32C8x8(const float *src, float *tmp, const float *bias, float *dst, int output_channel,
|
||||
ConvParameter *conv_param) {
|
||||
/* row8x8-major(ih*iw x oc*kh*kw) -> row8-major(oh*ow x oc) */
|
||||
int DeConvPostFp32C12x8(const float *src, float *tmp, const float *bias, float *dst, int output_channel,
|
||||
ConvParameter *conv_param) {
|
||||
/* row12x8-major(ih*iw x oc*kh*kw) -> row8-major(oh*ow x oc) */
|
||||
size_t input_plane = conv_param->input_w_ * conv_param->input_h_;
|
||||
size_t kernel_plane = conv_param->kernel_w_ * conv_param->kernel_h_;
|
||||
size_t output_plane = conv_param->output_w_ * conv_param->output_h_;
|
||||
int oc8 = UP_ROUND(output_channel, C8NUM);
|
||||
int in_plane8 = UP_ROUND(input_plane, C8NUM);
|
||||
int in_plane12 = UP_ROUND(input_plane, C12NUM);
|
||||
int src_iw_stride = C8NUM;
|
||||
int src_ih_stride = conv_param->input_w_ * C8NUM;
|
||||
int src_kw_stride = in_plane8 * C8NUM;
|
||||
int src_kh_stride = in_plane8 * conv_param->kernel_w_ * C8NUM;
|
||||
int src_kw_stride = in_plane12 * C8NUM;
|
||||
int src_kh_stride = in_plane12 * conv_param->kernel_w_ * C8NUM;
|
||||
int dst_oh_stride = conv_param->output_w_ * C8NUM;
|
||||
int dst_ow_stride = C8NUM;
|
||||
int dst_kh_stride = conv_param->dilation_h_ * conv_param->output_w_ * C8NUM;
|
||||
|
@ -52,13 +52,13 @@ int DeConvPostFp32C8x8(const float *src, float *tmp, const float *bias, float *d
|
|||
|
||||
for (int c = 0; c < oc8; c += 8) {
|
||||
float *dst_ptr = tmp + c * output_plane;
|
||||
const float *src_ptr = src + c * in_plane8 * kernel_plane;
|
||||
const float *src_ptr = src + c * in_plane12 * kernel_plane;
|
||||
memset(dst_ptr, 0, output_plane * C8NUM * sizeof(float));
|
||||
|
||||
for (int ih = 0; ih < conv_param->input_h_; ih++) {
|
||||
for (int iw = 0; iw < conv_param->input_w_; iw++) {
|
||||
int oh = ih * conv_param->stride_h_ - conv_param->pad_h_;
|
||||
int ow = iw * conv_param->stride_w_ - conv_param->pad_w_;
|
||||
int oh = ih * conv_param->stride_h_ - conv_param->pad_u_;
|
||||
int ow = iw * conv_param->stride_w_ - conv_param->pad_l_;
|
||||
|
||||
int kh_start = MSMAX(0, UP_DIV(-oh, conv_param->dilation_h_));
|
||||
int kh_end = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->output_h_ - oh, conv_param->dilation_h_));
|
||||
|
@ -97,45 +97,7 @@ int DeConvPostFp32C8x8(const float *src, float *tmp, const float *bias, float *d
|
|||
} /*ih*/
|
||||
} /*oc8*/
|
||||
|
||||
PostConvFuncFp32C8(tmp, dst, bias, output_channel, output_plane, conv_param->output_channel_, conv_param->is_relu_,
|
||||
conv_param->is_relu6_);
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int DeConvPostFp32C4(const float *src, float *tmp_c4, float *dst, const float *bias, int output_channel,
|
||||
int input_plane, int kernel_plane, int output_plane, ConvParameter *conv_param) {
|
||||
int oc4 = UP_DIV(output_channel, C4NUM);
|
||||
for (int c = 0; c < oc4; c++) {
|
||||
float *dst_ptr = tmp_c4 + c * output_plane * C4NUM;
|
||||
const float *src_ptr = src + c * input_plane * kernel_plane * C4NUM;
|
||||
memset(dst_ptr, 0, output_plane * C4NUM * sizeof(float));
|
||||
|
||||
for (int ih = 0; ih < conv_param->input_h_; ih++) {
|
||||
for (int iw = 0; iw < conv_param->input_w_; iw++) {
|
||||
int oh = ih * conv_param->stride_h_ - conv_param->pad_h_;
|
||||
int ow = iw * conv_param->stride_w_ - conv_param->pad_w_;
|
||||
|
||||
int kh_start = MSMAX(0, UP_DIV(-oh, conv_param->dilation_h_));
|
||||
int kh_end = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->output_h_ - oh, conv_param->dilation_h_));
|
||||
int kw_start = MSMAX(0, UP_DIV(-ow, conv_param->dilation_w_));
|
||||
int kw_end = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->output_w_ - ow, conv_param->dilation_w_));
|
||||
for (int kh = kh_start; kh < kh_end; kh++) {
|
||||
for (int kw = kw_start; kw < kw_end; kw++) {
|
||||
int src_index = ih * conv_param->input_w_ * C4NUM + iw * C4NUM +
|
||||
kh * input_plane * conv_param->kernel_w_ * C4NUM + kw * input_plane * C4NUM;
|
||||
int dst_index = oh * conv_param->output_w_ * C4NUM + ow * C4NUM +
|
||||
kh * conv_param->dilation_h_ * conv_param->output_w_ * C4NUM +
|
||||
kw * conv_param->dilation_w_ * C4NUM;
|
||||
for (int i = 0; i < C4NUM; i++) {
|
||||
dst_ptr[dst_index + i] += src_ptr[src_index + i];
|
||||
}
|
||||
} /*kw*/
|
||||
} /*kh*/
|
||||
} /*iw*/
|
||||
} /*ih*/
|
||||
} /*oc4*/
|
||||
|
||||
PostConvFuncFp32C4(tmp_c4, dst, bias, output_channel, output_plane, conv_param->output_channel_, conv_param->is_relu_,
|
||||
conv_param->is_relu6_);
|
||||
PostConvFuncFp32C8(tmp, dst, bias, output_channel, output_plane, conv_param->output_channel_,
|
||||
conv_param->act_type_ == ActType_Relu, conv_param->act_type_ == ActType_Relu6);
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
|
|
@ -16,20 +16,19 @@
|
|||
#ifndef MINDSPORE_LITE_NNACL_FP32_DECONV_H_
|
||||
#define MINDSPORE_LITE_NNACL_FP32_DECONV_H_
|
||||
|
||||
#include <string.h>
|
||||
#include "nnacl/pack.h"
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/conv_parameter.h"
|
||||
#include "nnacl/fp32/strassen_matmul.h"
|
||||
#include "nnacl/errorcode.h"
|
||||
#include "nnacl/fp32/common_func.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
void PackDeConvWeightFp32(const float *weight, float *dst, int input_channel, int output_channel, int plane);
|
||||
|
||||
int DeConvPostFp32C4(const float *src, float *tmp_c4, float *dst, const float *bias, int output_channel,
|
||||
int input_plane, int kernel_plane, int output_plane, ConvParameter *conv_param);
|
||||
int DeConvPostFp32C8x8(const float *src, float *tmp_out, const float *bias, float *dst, int output_channel,
|
||||
ConvParameter *conv_param);
|
||||
int DeConvPostFp32C12x8(const float *src, float *tmp_out, const float *bias, float *dst, int output_channel,
|
||||
ConvParameter *conv_param);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -19,19 +19,13 @@
|
|||
|
||||
#include "nnacl/op_base.h"
|
||||
|
||||
typedef struct GatherParameter {
|
||||
OpParameter op_parameter_;
|
||||
int axis_;
|
||||
int batchDims_;
|
||||
} GatherParameter;
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
int Gather(float *input, int outer_size, int inner_size, int limit, int *indices, int indices_element_size,
|
||||
float *output);
|
||||
int GatherInt32(const int32_t *input, int outer_size, int inner_size, int limit, int *indices,
|
||||
int indices_element_size, int32_t *output);
|
||||
int GatherInt32(const int32_t *input, int outer_size, int inner_size, int limit, int *indices, int indices_element_size,
|
||||
int32_t *output);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -28,6 +28,129 @@ void RowMajor2Row8Major(float *src_ptr, float *dst_ptr, int row, int col) {
|
|||
return;
|
||||
}
|
||||
|
||||
void RowMajor2Row12Major(float *src_ptr, float *dst_ptr, int row, int col) {
|
||||
for (int r = 0; r < row; r++) {
|
||||
float *src = src_ptr + r * col;
|
||||
for (int c = 0; c < col; c++) {
|
||||
int cd8 = c / C12NUM;
|
||||
int cm8 = c % C12NUM;
|
||||
dst_ptr[cd8 * C12NUM * row + r * C12NUM + cm8] = src[c];
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void RowMajor2Col12Major(float *src_ptr, float *dst_ptr, size_t row, size_t col) {
|
||||
size_t row_up_12 = UP_ROUND(row, C12NUM);
|
||||
size_t row12 = row / C12NUM * C12NUM;
|
||||
size_t col4 = col / C4NUM * C4NUM;
|
||||
float *src_r = src_ptr;
|
||||
float *dst_r = dst_ptr;
|
||||
|
||||
size_t ri = 0;
|
||||
for (; ri < row12; ri += C12NUM) {
|
||||
size_t ci = 0;
|
||||
for (; ci < col4; ci += C4NUM) {
|
||||
float *src_c = src_r + ci;
|
||||
float *dst_c = dst_r + ci * C12NUM;
|
||||
|
||||
/* 12x4 row-major to col-major */
|
||||
#ifdef ENABLE_ARM64
|
||||
size_t stride = col * sizeof(float);
|
||||
asm volatile(
|
||||
"mov x10, %[src_c]\n"
|
||||
"mov x11, %[dst_c]\n"
|
||||
|
||||
"ld1 {v0.4s}, [x10], %[stride]\n"
|
||||
"ld1 {v1.4s}, [x10], %[stride]\n"
|
||||
"ld1 {v2.4s}, [x10], %[stride]\n"
|
||||
"ld1 {v3.4s}, [x10], %[stride]\n"
|
||||
|
||||
"ld1 {v4.4s}, [x10], %[stride]\n"
|
||||
"ld1 {v5.4s}, [x10], %[stride]\n"
|
||||
"ld1 {v6.4s}, [x10], %[stride]\n"
|
||||
"ld1 {v7.4s}, [x10], %[stride]\n"
|
||||
|
||||
"zip1 v12.4s, v0.4s, v1.4s\n"
|
||||
"zip2 v13.4s, v0.4s, v1.4s\n"
|
||||
"zip1 v14.4s, v2.4s, v3.4s\n"
|
||||
"zip2 v15.4s, v2.4s, v3.4s\n"
|
||||
|
||||
"ld1 {v8.4s}, [x10], %[stride]\n"
|
||||
"ld1 {v9.4s}, [x10], %[stride]\n"
|
||||
"ld1 {v10.4s}, [x10], %[stride]\n"
|
||||
"ld1 {v11.4s}, [x10], %[stride]\n"
|
||||
|
||||
"zip1 v16.4s, v4.4s, v5.4s\n"
|
||||
"zip2 v17.4s, v4.4s, v5.4s\n"
|
||||
"zip1 v18.4s, v6.4s, v7.4s\n"
|
||||
"zip2 v19.4s, v6.4s, v7.4s\n"
|
||||
|
||||
"trn1 v20.2d, v12.2d, v14.2d\n"
|
||||
"trn2 v23.2d, v12.2d, v14.2d\n"
|
||||
"trn1 v26.2d, v13.2d, v15.2d\n"
|
||||
"trn2 v29.2d, v13.2d, v15.2d\n"
|
||||
|
||||
"trn1 v21.2d, v16.2d, v18.2d\n"
|
||||
"trn2 v24.2d, v16.2d, v18.2d\n"
|
||||
"trn1 v27.2d, v17.2d, v19.2d\n"
|
||||
"trn2 v30.2d, v17.2d, v19.2d\n"
|
||||
|
||||
"zip1 v12.4s, v8.4s, v9.4s\n"
|
||||
"zip2 v13.4s, v8.4s, v9.4s\n"
|
||||
"zip1 v14.4s, v10.4s, v11.4s\n"
|
||||
"zip2 v15.4s, v10.4s, v11.4s\n"
|
||||
|
||||
"trn1 v22.2d, v12.2d, v14.2d\n"
|
||||
"trn2 v25.2d, v12.2d, v14.2d\n"
|
||||
"trn1 v28.2d, v13.2d, v15.2d\n"
|
||||
"trn2 v31.2d, v13.2d, v15.2d\n"
|
||||
|
||||
"st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x11], #64\n"
|
||||
"st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x11], #64\n"
|
||||
"st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x11], #64\n"
|
||||
|
||||
:
|
||||
: [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride)
|
||||
: "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14",
|
||||
"v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29",
|
||||
"v30", "v31");
|
||||
#else
|
||||
for (int tr = 0; tr < C12NUM; tr++) {
|
||||
for (int tc = 0; tc < C4NUM; tc++) {
|
||||
dst_c[tc * C12NUM + tr] = src_c[tr * col + tc];
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
for (; ci < col; ci++) {
|
||||
float *src_c = src_r + ci;
|
||||
float *dst_c = dst_r + ci * C12NUM;
|
||||
for (size_t i = 0; i < C12NUM; i++) {
|
||||
dst_c[i] = src_c[i * col];
|
||||
}
|
||||
}
|
||||
src_r += C12NUM * col;
|
||||
dst_r += C12NUM * col;
|
||||
}
|
||||
|
||||
for (; ri < row; ri++) {
|
||||
for (size_t i = 0; i < col; i++) {
|
||||
dst_r[i * C12NUM] = src_r[i];
|
||||
}
|
||||
src_r += col;
|
||||
dst_r += 1;
|
||||
}
|
||||
|
||||
for (; ri < row_up_12; ri++) {
|
||||
for (size_t i = 0; i < col; i++) {
|
||||
dst_r[i * C12NUM] = 0;
|
||||
}
|
||||
dst_r += 1;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void RowMajor2Col8Major(float *src_ptr, float *dst_ptr, size_t row, size_t col) {
|
||||
size_t row8 = row / C8NUM * C8NUM;
|
||||
size_t col4 = col / C4NUM * C4NUM;
|
||||
|
@ -221,18 +344,18 @@ void Row8x8Major2RowMajor(float *src_ptr, float *dst_ptr, size_t row, size_t col
|
|||
return;
|
||||
}
|
||||
|
||||
void MatMul8x8(const float *a, const float *b, float *dst, const float *bias, ActType act_type, int deep, int row,
|
||||
int col, int stride, bool write_nhwc) {
|
||||
if (write_nhwc) {
|
||||
void MatMul12x8(const float *a, const float *b, float *dst, const float *bias, ActType act_type, int deep, int row,
|
||||
int col, int stride, int out_type) {
|
||||
if (out_type == OutType_Nhwc) {
|
||||
/* col8-major * row8-major => col-major */
|
||||
for (int r = 0; r < row; r++) {
|
||||
for (int c = 0; c < col; c++) {
|
||||
int r8div = r / 8, r8mod = r % 8;
|
||||
int r12div = r / 12, r12mod = r % 12;
|
||||
int c8div = c / 8, c8mod = c % 8;
|
||||
size_t ci = r * stride + c;
|
||||
float value = 0;
|
||||
for (int d = 0; d < deep; d++) {
|
||||
size_t ai = r8div * deep * 8 + d * 8 + r8mod;
|
||||
size_t ai = r12div * deep * 12 + d * 12 + r12mod;
|
||||
size_t bi = c8div * deep * 8 + d * 8 + c8mod;
|
||||
value = value + a[ai] * b[bi];
|
||||
}
|
||||
|
@ -242,22 +365,41 @@ void MatMul8x8(const float *a, const float *b, float *dst, const float *bias, Ac
|
|||
dst[ci] = value;
|
||||
}
|
||||
}
|
||||
} else if (out_type == OutType_C8) {
|
||||
/* col8-major * row8-major => col12x8-major */
|
||||
int col_8 = UP_ROUND(col, C8NUM);
|
||||
int row_12 = UP_ROUND(row, C12NUM);
|
||||
for (int r = 0; r < row_12; r++) {
|
||||
for (int c = 0; c < col_8; c++) {
|
||||
int r12div = r / C12NUM, r12mod = r % C12NUM;
|
||||
int c8div = c / C8NUM, c8mod = c % C8NUM;
|
||||
size_t ci = (c8div * C8NUM * row_12 + r * C8NUM + c8mod);
|
||||
float value = 0;
|
||||
for (int d = 0; d < deep; d++) {
|
||||
size_t ai = r12div * deep * C12NUM + d * C12NUM + r12mod;
|
||||
size_t bi = c8div * deep * C8NUM + d * C8NUM + c8mod;
|
||||
value = value + a[ai] * b[bi];
|
||||
}
|
||||
if (bias != NULL) value += bias[c];
|
||||
if (act_type == ActType_Relu6) value = MSMIN(6.0f, value);
|
||||
if (act_type != ActType_No) value = MSMAX(0.0f, value);
|
||||
dst[ci] = value;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
/* col8-major * row8-major => col8x8-major */
|
||||
int col_8 = UP_ROUND(col, C8NUM);
|
||||
int row_8 = UP_ROUND(row, C8NUM);
|
||||
for (int r = 0; r < row_8; r++) {
|
||||
for (int c = 0; c < col_8; c++) {
|
||||
int r8div = r / 8, r8mod = r % 8;
|
||||
int c8div = c / 8, c8mod = c % 8;
|
||||
size_t ci = c8div * row_8 * 8 + r * 8 + c8mod;
|
||||
for (int i = 0; i < row; ++i) {
|
||||
int src_r_offset = i;
|
||||
int dst_r_offset = i * col * stride;
|
||||
for (int j = 0; j < col; ++j) {
|
||||
int c8div = j / 8, c8mod = j % 8;
|
||||
size_t ci = dst_r_offset + c8div * 8 * stride + c8mod;
|
||||
float value = 0;
|
||||
for (int d = 0; d < deep; d++) {
|
||||
size_t ai = r8div * deep * 8 + d * 8 + r8mod;
|
||||
for (int d = 0; d < deep; ++d) {
|
||||
size_t ai = src_r_offset + d * C12NUM;
|
||||
size_t bi = c8div * deep * 8 + d * 8 + c8mod;
|
||||
value = value + a[ai] * b[bi];
|
||||
}
|
||||
if (bias != NULL) value += bias[c];
|
||||
if (bias != NULL) value += bias[j];
|
||||
if (act_type == ActType_Relu6) value = MSMIN(6.0f, value);
|
||||
if (act_type != ActType_No) value = MSMAX(0.0f, value);
|
||||
dst[ci] = value;
|
||||
|
@ -267,11 +409,16 @@ void MatMul8x8(const float *a, const float *b, float *dst, const float *bias, Ac
|
|||
return;
|
||||
}
|
||||
|
||||
void MatMul(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row, int col,
|
||||
int stride, bool write_nhwc) {
|
||||
void MatMulOpt(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row,
|
||||
int col, size_t stride, int out_type) {
|
||||
#ifdef ENABLE_ARM64
|
||||
MatmulFloatNeon64(a, b, c, bias, (int)act_type, deep, row, col, stride, write_nhwc);
|
||||
if (out_type == 2 && row <= 8) {
|
||||
MatmulFloatNeon64OptRemain(a, b, c, deep, row, col, stride);
|
||||
} else {
|
||||
MatmulFloatNeon64Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type == OutType_Nhwc),
|
||||
(int)(out_type == OutType_TileC8));
|
||||
}
|
||||
#else
|
||||
MatMul8x8(a, b, c, bias, act_type, deep, row, col, stride, write_nhwc);
|
||||
MatMul12x8(a, b, c, bias, act_type, deep, row, col, stride, out_type);
|
||||
#endif
|
||||
}
|
||||
|
|
|
@ -26,14 +26,20 @@
|
|||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
void MatMul(const float *a, const float *b, float *c, const float *bias, ActType act_type, int depth, int row, int col,
|
||||
int stride, bool write_nhwc);
|
||||
void MatMulOpt(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row,
|
||||
int col, size_t stride, int out_type);
|
||||
|
||||
void RowMajor2Row8Major(float *src_ptr, float *dst_ptr, int row, int col);
|
||||
void RowMajor2Row12Major(float *src_ptr, float *dst_ptr, int row, int col);
|
||||
void RowMajor2Col8Major(float *src_ptr, float *dst_ptr, size_t row, size_t col);
|
||||
void RowMajor2Col12Major(float *src_ptr, float *dst_ptr, size_t row, size_t col);
|
||||
void Row8x8Major2RowMajor(float *src_ptr, float *dst_ptr, size_t row, size_t col, size_t stride);
|
||||
#ifdef ENABLE_ARM64
|
||||
void MatmulFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row,
|
||||
int col, size_t stride, bool write_nhwc);
|
||||
void MatmulFloatNeon64Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row,
|
||||
int col, size_t stride, size_t write_nhwc, size_t write_c4);
|
||||
void MatmulFloatNeon64OptRemain(const float *a, const float *b, float *c, int depth, int row, int col, size_t stride);
|
||||
#endif
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
|
|
@ -130,8 +130,103 @@ void MaxPooling(const float *input_ptr, float *output_ptr, PoolingParameter *poo
|
|||
int out_plane = output_w * output_h;
|
||||
int out_tile_count = UP_DIV(out_plane, TILE_NUM);
|
||||
int thread_num = pooling_param->thread_num_;
|
||||
int c4 = UP_DIV(channel, C4NUM); /* oc && ic */
|
||||
|
||||
for (int batch = 0; batch < output_batch; batch++) {
|
||||
const float *src_b_ptr = input_ptr + batch * in_h * in_w * channel;
|
||||
float *dst_b_ptr = output_ptr + batch * output_h * output_w * channel;
|
||||
for (int thread_id = task_id; thread_id < out_tile_count; thread_id += thread_num) {
|
||||
int cal_start_index = thread_id * TILE_NUM;
|
||||
int real_cal_num = (out_plane - cal_start_index) > TILE_NUM ? TILE_NUM : (out_plane - cal_start_index);
|
||||
for (int i = 0; i < real_cal_num; i++) {
|
||||
int index = cal_start_index + i;
|
||||
int out_w_index = index % output_w;
|
||||
int out_h_index = index / output_w;
|
||||
int in_w_index = out_w_index * stride_w - pad_w;
|
||||
int in_h_index = out_h_index * stride_h - pad_h;
|
||||
|
||||
const float *src_plane_ptr = src_b_ptr;
|
||||
float *dst_plane_ptr = dst_b_ptr + index * channel;
|
||||
|
||||
int real_win_h_start = MSMAX(0, -in_h_index);
|
||||
int real_win_h_end = MSMIN(win_h, in_h - in_h_index);
|
||||
int resl_win_w_start = MSMAX(0, -in_w_index);
|
||||
int real_win_w_end = MSMIN(win_w, in_w - in_w_index);
|
||||
|
||||
for (int ci = 0; ci < c4 - 1; ci++) {
|
||||
const float *src_c_ptr = src_plane_ptr + ci * C4NUM;
|
||||
float *dst_c_ptr = dst_plane_ptr + ci * C4NUM;
|
||||
#ifdef ENABLE_NEON
|
||||
float32x4_t tmp_max = vdupq_n_f32(-FLT_MAX);
|
||||
#else
|
||||
float tmp_max1 = -FLT_MAX;
|
||||
float tmp_max2 = -FLT_MAX;
|
||||
float tmp_max3 = -FLT_MAX;
|
||||
float tmp_max4 = -FLT_MAX;
|
||||
#endif
|
||||
|
||||
for (int kh = real_win_h_start; kh < real_win_h_end; kh++) {
|
||||
for (int kw = resl_win_w_start; kw < real_win_w_end; kw++) {
|
||||
const float *src_win_ptr = src_c_ptr + ((in_h_index + kh) * in_w + in_w_index + kw) * channel;
|
||||
#ifdef ENABLE_NEON
|
||||
tmp_max = vmaxq_f32(tmp_max, vld1q_f32(src_win_ptr));
|
||||
#else
|
||||
tmp_max1 = fmax(tmp_max1, src_win_ptr[0]);
|
||||
tmp_max2 = fmax(tmp_max2, src_win_ptr[1]);
|
||||
tmp_max3 = fmax(tmp_max3, src_win_ptr[2]);
|
||||
tmp_max4 = fmax(tmp_max4, src_win_ptr[3]);
|
||||
|
||||
#endif
|
||||
} // win_w loop
|
||||
} // win_h loop
|
||||
#ifdef ENABLE_NEON
|
||||
vst1q_f32(dst_c_ptr, tmp_max);
|
||||
#else
|
||||
dst_c_ptr[0] = tmp_max1;
|
||||
dst_c_ptr[1] = tmp_max2;
|
||||
dst_c_ptr[2] = tmp_max3;
|
||||
dst_c_ptr[3] = tmp_max4;
|
||||
#endif
|
||||
} // ic4-1 loop
|
||||
int channel_s = (c4 - 1) * C4NUM;
|
||||
for (int ci = channel_s; ci < channel; ci++) {
|
||||
float *dst_c_ptr = dst_plane_ptr + ci;
|
||||
const float *src_c_ptr = src_plane_ptr + ci;
|
||||
float tmp_max = -FLT_MAX;
|
||||
|
||||
for (int kh = real_win_h_start; kh < real_win_h_end; kh++) {
|
||||
for (int kw = resl_win_w_start; kw < real_win_w_end; kw++) {
|
||||
const float *src_win_ptr = src_c_ptr + ((in_h_index + kh) * in_w + in_w_index + kw) * channel;
|
||||
tmp_max = fmax(tmp_max, src_win_ptr[0]);
|
||||
} // win_w loop
|
||||
} // win_h loop
|
||||
dst_c_ptr[0] = tmp_max;
|
||||
} // channel_res loop
|
||||
} // real_cal_num loop
|
||||
} // out_plane loop
|
||||
} // out_batch loop
|
||||
}
|
||||
|
||||
void AvgPoolingRelu(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param, int task_id) {
|
||||
int stride_w = pooling_param->stride_w_;
|
||||
int stride_h = pooling_param->stride_h_;
|
||||
int pad_w = pooling_param->pad_l_;
|
||||
int pad_h = pooling_param->pad_u_;
|
||||
int win_w = pooling_param->window_w_;
|
||||
int win_h = pooling_param->window_h_;
|
||||
int channel = pooling_param->input_channel_;
|
||||
int c4 = UP_DIV(channel, C4NUM);
|
||||
// input channel is equal to output channel
|
||||
int in_w = pooling_param->input_w_;
|
||||
int in_h = pooling_param->input_h_;
|
||||
int output_w = pooling_param->output_w_;
|
||||
int output_h = pooling_param->output_h_;
|
||||
int output_batch = pooling_param->output_batch_;
|
||||
int out_plane = output_w * output_h;
|
||||
int out_tile_count = UP_DIV(out_plane, TILE_NUM);
|
||||
int thread_num = pooling_param->thread_num_;
|
||||
#ifdef ENABLE_NEON
|
||||
float32x4_t zeros = vdupq_n_f32(0);
|
||||
#endif
|
||||
|
||||
for (int batch = 0; batch < output_batch; batch++) {
|
||||
int in_batch_offset = batch * in_h * in_w * channel;
|
||||
|
@ -149,6 +244,121 @@ void MaxPooling(const float *input_ptr, float *output_ptr, PoolingParameter *poo
|
|||
for (int j = 0; j < c4 - 1; j++) {
|
||||
int in_channel_offset = in_batch_offset + j * C4NUM;
|
||||
int out_channel_offset = out_plane_offset + j * C4NUM;
|
||||
#ifdef ENABLE_NEON
|
||||
float32x4_t tmp_avg = vdupq_n_f32(0);
|
||||
#else
|
||||
float tmp_avg1 = 0;
|
||||
float tmp_avg2 = 0;
|
||||
float tmp_avg3 = 0;
|
||||
float tmp_avg4 = 0;
|
||||
#endif
|
||||
int real_count = 0;
|
||||
for (int h = 0; h < win_h; h++) {
|
||||
for (int w = 0; w < win_w; w++) {
|
||||
if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 ||
|
||||
(in_w_index + w) >= in_w) {
|
||||
continue;
|
||||
} else {
|
||||
int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel;
|
||||
#ifdef ENABLE_NEON
|
||||
tmp_avg = vaddq_f32(tmp_avg, vld1q_f32(input_ptr + in_offset));
|
||||
#else
|
||||
tmp_avg1 += *(input_ptr + in_offset);
|
||||
tmp_avg2 += *(input_ptr + in_offset + 1);
|
||||
tmp_avg3 += *(input_ptr + in_offset + 2);
|
||||
tmp_avg4 += *(input_ptr + in_offset + 3);
|
||||
#endif
|
||||
++real_count;
|
||||
}
|
||||
} // win_w loop
|
||||
} // win_h loop
|
||||
#ifdef ENABLE_NEON
|
||||
tmp_avg = vmaxq_f32(tmp_avg, zeros);
|
||||
vst1q_f32(output_ptr + out_channel_offset, tmp_avg / vdupq_n_f32(real_count));
|
||||
#else
|
||||
tmp_avg1 = fmax(tmp_avg1, 0);
|
||||
tmp_avg2 = fmax(tmp_avg2, 0);
|
||||
tmp_avg3 = fmax(tmp_avg3, 0);
|
||||
tmp_avg4 = fmax(tmp_avg4, 0);
|
||||
|
||||
*(output_ptr + out_channel_offset) = tmp_avg1 / (float)real_count;
|
||||
*(output_ptr + out_channel_offset + 1) = tmp_avg2 / (float)real_count;
|
||||
*(output_ptr + out_channel_offset + 2) = tmp_avg3 / (float)real_count;
|
||||
*(output_ptr + out_channel_offset + 3) = tmp_avg4 / (float)real_count;
|
||||
#endif
|
||||
} // ic4-1 loop
|
||||
int channel_s = (c4 - 1) * C4NUM;
|
||||
for (int k = channel_s; k < channel; k++) {
|
||||
int in_channel_offset = in_batch_offset + k;
|
||||
int out_channel_offset = out_plane_offset + k;
|
||||
float tmp_avg = 0;
|
||||
int real_count = 0;
|
||||
for (int h = 0; h < win_h; h++) {
|
||||
for (int w = 0; w < win_w; w++) {
|
||||
if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 ||
|
||||
(in_w_index + w) >= in_w) {
|
||||
continue;
|
||||
} else {
|
||||
int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel;
|
||||
tmp_avg += *(input_ptr + in_offset);
|
||||
++real_count;
|
||||
}
|
||||
} // win_w loop
|
||||
} // win_h loop
|
||||
tmp_avg = fmax(tmp_avg, 0);
|
||||
*(output_ptr + out_channel_offset) = tmp_avg / (float)real_count;
|
||||
} // channel_res loop
|
||||
} // real_cal_num loop
|
||||
} // out_plane loop
|
||||
} // out_batch loop
|
||||
}
|
||||
|
||||
void MaxPoolingRelu(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param, int task_id) {
|
||||
int stride_w = pooling_param->stride_w_;
|
||||
int stride_h = pooling_param->stride_h_;
|
||||
int pad_w = pooling_param->pad_l_;
|
||||
int pad_h = pooling_param->pad_u_;
|
||||
int win_w = pooling_param->window_w_;
|
||||
int win_h = pooling_param->window_h_;
|
||||
int channel = pooling_param->input_channel_;
|
||||
int in_w = pooling_param->input_w_;
|
||||
int in_h = pooling_param->input_h_;
|
||||
int output_w = pooling_param->output_w_;
|
||||
int output_h = pooling_param->output_h_;
|
||||
int output_batch = pooling_param->output_batch_;
|
||||
int out_plane = output_w * output_h;
|
||||
int out_tile_count = UP_DIV(out_plane, TILE_NUM);
|
||||
int thread_num = pooling_param->thread_num_;
|
||||
int c4 = UP_DIV(channel, C4NUM);
|
||||
|
||||
#ifdef ENABLE_NEON
|
||||
float32x4_t zeros = vdupq_n_f32(0);
|
||||
#endif
|
||||
|
||||
for (int batch = 0; batch < output_batch; batch++) {
|
||||
const float *src_b_ptr = input_ptr + batch * in_h * in_w * channel;
|
||||
float *dst_b_ptr = output_ptr + batch * output_h * output_w * channel;
|
||||
for (int thread_id = task_id; thread_id < out_tile_count; thread_id += thread_num) {
|
||||
int cal_start_index = thread_id * TILE_NUM;
|
||||
int real_cal_num = (out_plane - cal_start_index) > TILE_NUM ? TILE_NUM : (out_plane - cal_start_index);
|
||||
for (int i = 0; i < real_cal_num; i++) {
|
||||
int index = cal_start_index + i;
|
||||
int out_w_index = index % output_w;
|
||||
int out_h_index = index / output_w;
|
||||
int in_w_index = out_w_index * stride_w - pad_w;
|
||||
int in_h_index = out_h_index * stride_h - pad_h;
|
||||
|
||||
const float *src_plane_ptr = src_b_ptr;
|
||||
float *dst_plane_ptr = dst_b_ptr + index * channel;
|
||||
|
||||
int real_win_h_start = MSMAX(0, -in_h_index);
|
||||
int real_win_h_end = MSMIN(win_h, in_h - in_h_index);
|
||||
int resl_win_w_start = MSMAX(0, -in_w_index);
|
||||
int real_win_w_end = MSMIN(win_w, in_w - in_w_index);
|
||||
|
||||
for (int ci = 0; ci < c4 - 1; ci++) {
|
||||
const float *src_c_ptr = src_plane_ptr + ci * C4NUM;
|
||||
float *dst_c_ptr = dst_plane_ptr + ci * C4NUM;
|
||||
#ifdef ENABLE_NEON
|
||||
float32x4_t tmp_max = vdupq_n_f32(-FLT_MAX);
|
||||
#else
|
||||
|
@ -157,6 +367,105 @@ void MaxPooling(const float *input_ptr, float *output_ptr, PoolingParameter *poo
|
|||
float tmp_max3 = -FLT_MAX;
|
||||
float tmp_max4 = -FLT_MAX;
|
||||
#endif
|
||||
|
||||
for (int kh = real_win_h_start; kh < real_win_h_end; kh++) {
|
||||
for (int kw = resl_win_w_start; kw < real_win_w_end; kw++) {
|
||||
const float *src_win_ptr = src_c_ptr + ((in_h_index + kh) * in_w + in_w_index + kw) * channel;
|
||||
#ifdef ENABLE_NEON
|
||||
tmp_max = vmaxq_f32(tmp_max, vld1q_f32(src_win_ptr));
|
||||
#else
|
||||
tmp_max1 = fmax(tmp_max1, src_win_ptr[0]);
|
||||
tmp_max2 = fmax(tmp_max2, src_win_ptr[1]);
|
||||
tmp_max3 = fmax(tmp_max3, src_win_ptr[2]);
|
||||
tmp_max4 = fmax(tmp_max4, src_win_ptr[3]);
|
||||
|
||||
#endif
|
||||
} // win_w loop
|
||||
} // win_h loop
|
||||
#ifdef ENABLE_NEON
|
||||
tmp_max = vmaxq_f32(tmp_max, zeros);
|
||||
vst1q_f32(dst_c_ptr, tmp_max);
|
||||
#else
|
||||
// relu:
|
||||
tmp_max1 = fmax(tmp_max1, 0);
|
||||
tmp_max2 = fmax(tmp_max2, 0);
|
||||
tmp_max3 = fmax(tmp_max3, 0);
|
||||
tmp_max4 = fmax(tmp_max4, 0);
|
||||
|
||||
dst_c_ptr[0] = tmp_max1;
|
||||
dst_c_ptr[1] = tmp_max2;
|
||||
dst_c_ptr[2] = tmp_max3;
|
||||
dst_c_ptr[3] = tmp_max4;
|
||||
#endif
|
||||
} // ic4-1 loop
|
||||
int channel_s = (c4 - 1) * C4NUM;
|
||||
for (int ci = channel_s; ci < channel; ci++) {
|
||||
float *dst_c_ptr = dst_plane_ptr + ci;
|
||||
const float *src_c_ptr = src_plane_ptr + ci;
|
||||
float tmp_max = -FLT_MAX;
|
||||
|
||||
for (int kh = real_win_h_start; kh < real_win_h_end; kh++) {
|
||||
for (int kw = resl_win_w_start; kw < real_win_w_end; kw++) {
|
||||
const float *src_win_ptr = src_c_ptr + ((in_h_index + kh) * in_w + in_w_index + kw) * channel;
|
||||
tmp_max = fmax(tmp_max, src_win_ptr[0]);
|
||||
} // win_w loop
|
||||
} // win_h loop
|
||||
dst_c_ptr[0] = tmp_max;
|
||||
} // channel_res loop
|
||||
} // real_cal_num loop
|
||||
} // out_plane loop
|
||||
} // out_batch loop
|
||||
}
|
||||
|
||||
void AvgPoolingRelu6(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param, int task_id) {
|
||||
int stride_w = pooling_param->stride_w_;
|
||||
int stride_h = pooling_param->stride_h_;
|
||||
int pad_w = pooling_param->pad_l_;
|
||||
int pad_h = pooling_param->pad_u_;
|
||||
int win_w = pooling_param->window_w_;
|
||||
int win_h = pooling_param->window_h_;
|
||||
int channel = pooling_param->input_channel_;
|
||||
int c4 = UP_DIV(channel, C4NUM);
|
||||
int in_w = pooling_param->input_w_;
|
||||
int in_h = pooling_param->input_h_;
|
||||
int output_w = pooling_param->output_w_;
|
||||
int output_h = pooling_param->output_h_;
|
||||
int output_batch = pooling_param->output_batch_;
|
||||
int out_plane = output_w * output_h;
|
||||
int out_tile_count = UP_DIV(out_plane, TILE_NUM);
|
||||
int thread_num = pooling_param->thread_num_;
|
||||
// input channel is equal to output channel
|
||||
|
||||
#ifdef ENABLE_NEON
|
||||
float32x4_t zeros = vdupq_n_f32(0);
|
||||
float32x4_t bounds = vdupq_n_f32(6);
|
||||
#endif
|
||||
|
||||
for (int batch = 0; batch < output_batch; batch++) {
|
||||
int in_batch_offset = batch * in_h * in_w * channel;
|
||||
int out_batch_offset = batch * output_h * output_w * channel;
|
||||
for (int thread_id = task_id; thread_id < out_tile_count; thread_id += thread_num) {
|
||||
int cal_start_index = thread_id * TILE_NUM;
|
||||
int real_cal_num = (out_plane - cal_start_index) > TILE_NUM ? TILE_NUM : (out_plane - cal_start_index);
|
||||
for (int i = 0; i < real_cal_num; i++) {
|
||||
int index = cal_start_index + i;
|
||||
int out_w_index = index % output_w;
|
||||
int out_h_index = index / output_w;
|
||||
int in_w_index = out_w_index * stride_w - pad_w;
|
||||
int in_h_index = out_h_index * stride_h - pad_h;
|
||||
int out_plane_offset = out_batch_offset + index * channel;
|
||||
for (int j = 0; j < c4 - 1; j++) {
|
||||
int in_channel_offset = in_batch_offset + j * C4NUM;
|
||||
int out_channel_offset = out_plane_offset + j * C4NUM;
|
||||
#ifdef ENABLE_NEON
|
||||
float32x4_t tmp_avg = vdupq_n_f32(0);
|
||||
#else
|
||||
float tmp_avg1 = 0;
|
||||
float tmp_avg2 = 0;
|
||||
float tmp_avg3 = 0;
|
||||
float tmp_avg4 = 0;
|
||||
#endif
|
||||
int real_count = 0;
|
||||
for (int h = 0; h < win_h; h++) {
|
||||
for (int w = 0; w < win_w; w++) {
|
||||
if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 ||
|
||||
|
@ -165,30 +474,48 @@ void MaxPooling(const float *input_ptr, float *output_ptr, PoolingParameter *poo
|
|||
} else {
|
||||
int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel;
|
||||
#ifdef ENABLE_NEON
|
||||
tmp_max = vmaxq_f32(tmp_max, vld1q_f32(input_ptr + in_offset));
|
||||
tmp_avg = vaddq_f32(tmp_avg, vld1q_f32(input_ptr + in_offset));
|
||||
#else
|
||||
tmp_max1 = fmax(tmp_max1, *(input_ptr + in_offset));
|
||||
tmp_max2 = fmax(tmp_max2, *(input_ptr + in_offset + 1));
|
||||
tmp_max3 = fmax(tmp_max3, *(input_ptr + in_offset + 2));
|
||||
tmp_max4 = fmax(tmp_max4, *(input_ptr + in_offset + 3));
|
||||
tmp_avg1 += *(input_ptr + in_offset);
|
||||
tmp_avg2 += *(input_ptr + in_offset + 1);
|
||||
tmp_avg3 += *(input_ptr + in_offset + 2);
|
||||
tmp_avg4 += *(input_ptr + in_offset + 3);
|
||||
#endif
|
||||
++real_count;
|
||||
}
|
||||
} // win_w loop
|
||||
} // win_h loop
|
||||
#ifdef ENABLE_NEON
|
||||
vst1q_f32(output_ptr + out_channel_offset, tmp_max);
|
||||
tmp_avg = tmp_avg / vdupq_n_f32(real_count);
|
||||
tmp_avg = vmaxq_f32(tmp_avg, zeros);
|
||||
tmp_avg = vminq_f32(tmp_avg, bounds);
|
||||
vst1q_f32(output_ptr + out_channel_offset, tmp_avg);
|
||||
#else
|
||||
*(output_ptr + out_channel_offset) = tmp_max1;
|
||||
*(output_ptr + out_channel_offset + 1) = tmp_max2;
|
||||
*(output_ptr + out_channel_offset + 2) = tmp_max3;
|
||||
*(output_ptr + out_channel_offset + 3) = tmp_max4;
|
||||
tmp_avg1 /= (float)real_count;
|
||||
tmp_avg2 /= (float)real_count;
|
||||
tmp_avg3 /= (float)real_count;
|
||||
tmp_avg4 /= (float)real_count;
|
||||
tmp_avg1 = fmax(tmp_avg1, 0);
|
||||
tmp_avg2 = fmax(tmp_avg2, 0);
|
||||
tmp_avg3 = fmax(tmp_avg3, 0);
|
||||
tmp_avg4 = fmax(tmp_avg4, 0);
|
||||
tmp_avg1 = fmin(tmp_avg1, 6);
|
||||
tmp_avg2 = fmin(tmp_avg2, 6);
|
||||
tmp_avg3 = fmin(tmp_avg3, 6);
|
||||
tmp_avg4 = fmin(tmp_avg4, 6);
|
||||
|
||||
*(output_ptr + out_channel_offset) = tmp_avg1;
|
||||
*(output_ptr + out_channel_offset + 1) = tmp_avg2;
|
||||
*(output_ptr + out_channel_offset + 2) = tmp_avg3;
|
||||
*(output_ptr + out_channel_offset + 3) = tmp_avg4;
|
||||
#endif
|
||||
} // ic4-1 loop
|
||||
int channel_s = (c4 - 1) * C4NUM;
|
||||
for (int k = channel_s; k < channel; k++) {
|
||||
int in_channel_offset = in_batch_offset + k;
|
||||
int out_channel_offset = out_plane_offset + k;
|
||||
float tmp_max = -FLT_MAX;
|
||||
float tmp_avg = 0;
|
||||
int real_count = 0;
|
||||
for (int h = 0; h < win_h; h++) {
|
||||
for (int w = 0; w < win_w; w++) {
|
||||
if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 ||
|
||||
|
@ -196,11 +523,125 @@ void MaxPooling(const float *input_ptr, float *output_ptr, PoolingParameter *poo
|
|||
continue;
|
||||
} else {
|
||||
int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel;
|
||||
tmp_max = fmax(tmp_max, *(input_ptr + in_offset));
|
||||
tmp_avg += *(input_ptr + in_offset);
|
||||
++real_count;
|
||||
}
|
||||
} // win_w loop
|
||||
} // win_h loop
|
||||
*(output_ptr + out_channel_offset) = tmp_max;
|
||||
tmp_avg /= (float)real_count;
|
||||
tmp_avg = fmax(tmp_avg, 0);
|
||||
tmp_avg = fmin(tmp_avg, 6);
|
||||
*(output_ptr + out_channel_offset) = tmp_avg;
|
||||
} // channel_res loop
|
||||
} // real_cal_num loop
|
||||
} // out_plane loop
|
||||
} // out_batch loop
|
||||
}
|
||||
|
||||
void MaxPoolingRelu6(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param, int task_id) {
|
||||
int stride_w = pooling_param->stride_w_;
|
||||
int stride_h = pooling_param->stride_h_;
|
||||
int pad_w = pooling_param->pad_l_;
|
||||
int pad_h = pooling_param->pad_u_;
|
||||
int win_w = pooling_param->window_w_;
|
||||
int win_h = pooling_param->window_h_;
|
||||
int channel = pooling_param->input_channel_;
|
||||
int in_w = pooling_param->input_w_;
|
||||
int in_h = pooling_param->input_h_;
|
||||
int output_w = pooling_param->output_w_;
|
||||
int output_h = pooling_param->output_h_;
|
||||
int output_batch = pooling_param->output_batch_;
|
||||
int out_plane = output_w * output_h;
|
||||
int out_tile_count = UP_DIV(out_plane, TILE_NUM);
|
||||
int thread_num = pooling_param->thread_num_;
|
||||
int c4 = UP_DIV(channel, C4NUM);
|
||||
|
||||
#ifdef ENABLE_NEON
|
||||
float32x4_t zeros = vdupq_n_f32(0);
|
||||
float32x4_t bounds = vdupq_n_f32(6);
|
||||
#endif
|
||||
|
||||
for (int batch = 0; batch < output_batch; batch++) {
|
||||
const float *src_b_ptr = input_ptr + batch * in_h * in_w * channel;
|
||||
float *dst_b_ptr = output_ptr + batch * output_h * output_w * channel;
|
||||
for (int thread_id = task_id; thread_id < out_tile_count; thread_id += thread_num) {
|
||||
int cal_start_index = thread_id * TILE_NUM;
|
||||
int real_cal_num = (out_plane - cal_start_index) > TILE_NUM ? TILE_NUM : (out_plane - cal_start_index);
|
||||
for (int i = 0; i < real_cal_num; i++) {
|
||||
int index = cal_start_index + i;
|
||||
int out_w_index = index % output_w;
|
||||
int out_h_index = index / output_w;
|
||||
int in_w_index = out_w_index * stride_w - pad_w;
|
||||
int in_h_index = out_h_index * stride_h - pad_h;
|
||||
|
||||
const float *src_plane_ptr = src_b_ptr;
|
||||
float *dst_plane_ptr = dst_b_ptr + index * channel;
|
||||
|
||||
int real_win_h_start = MSMAX(0, -in_h_index);
|
||||
int real_win_h_end = MSMIN(win_h, in_h - in_h_index);
|
||||
int resl_win_w_start = MSMAX(0, -in_w_index);
|
||||
int real_win_w_end = MSMIN(win_w, in_w - in_w_index);
|
||||
|
||||
for (int ci = 0; ci < c4 - 1; ci++) {
|
||||
const float *src_c_ptr = src_plane_ptr + ci * C4NUM;
|
||||
float *dst_c_ptr = dst_plane_ptr + ci * C4NUM;
|
||||
#ifdef ENABLE_NEON
|
||||
float32x4_t tmp_max = vdupq_n_f32(-FLT_MAX);
|
||||
#else
|
||||
float tmp_max1 = -FLT_MAX;
|
||||
float tmp_max2 = -FLT_MAX;
|
||||
float tmp_max3 = -FLT_MAX;
|
||||
float tmp_max4 = -FLT_MAX;
|
||||
#endif
|
||||
|
||||
for (int kh = real_win_h_start; kh < real_win_h_end; kh++) {
|
||||
for (int kw = resl_win_w_start; kw < real_win_w_end; kw++) {
|
||||
const float *src_win_ptr = src_c_ptr + ((in_h_index + kh) * in_w + in_w_index + kw) * channel;
|
||||
#ifdef ENABLE_NEON
|
||||
tmp_max = vmaxq_f32(tmp_max, vld1q_f32(src_win_ptr));
|
||||
#else
|
||||
tmp_max1 = fmax(tmp_max1, src_win_ptr[0]);
|
||||
tmp_max2 = fmax(tmp_max2, src_win_ptr[1]);
|
||||
tmp_max3 = fmax(tmp_max3, src_win_ptr[2]);
|
||||
tmp_max4 = fmax(tmp_max4, src_win_ptr[3]);
|
||||
|
||||
#endif
|
||||
} // win_w loop
|
||||
} // win_h loop
|
||||
#ifdef ENABLE_NEON
|
||||
tmp_max = vmaxq_f32(tmp_max, zeros);
|
||||
tmp_max = vminq_f32(tmp_max, bounds);
|
||||
vst1q_f32(dst_c_ptr, tmp_max);
|
||||
#else
|
||||
// relu:
|
||||
tmp_max1 = fmax(tmp_max1, 0);
|
||||
tmp_max2 = fmax(tmp_max2, 0);
|
||||
tmp_max3 = fmax(tmp_max3, 0);
|
||||
tmp_max4 = fmax(tmp_max4, 0);
|
||||
tmp_max1 = fmin(tmp_max1, 6);
|
||||
tmp_max2 = fmin(tmp_max2, 6);
|
||||
tmp_max3 = fmin(tmp_max3, 6);
|
||||
tmp_max4 = fmin(tmp_max4, 6);
|
||||
|
||||
dst_c_ptr[0] = tmp_max1;
|
||||
dst_c_ptr[1] = tmp_max2;
|
||||
dst_c_ptr[2] = tmp_max3;
|
||||
dst_c_ptr[3] = tmp_max4;
|
||||
#endif
|
||||
} // ic4-1 loop
|
||||
int channel_s = (c4 - 1) * C4NUM;
|
||||
for (int ci = channel_s; ci < channel; ci++) {
|
||||
float *dst_c_ptr = dst_plane_ptr + ci;
|
||||
const float *src_c_ptr = src_plane_ptr + ci;
|
||||
float tmp_max = -FLT_MAX;
|
||||
|
||||
for (int kh = real_win_h_start; kh < real_win_h_end; kh++) {
|
||||
for (int kw = resl_win_w_start; kw < real_win_w_end; kw++) {
|
||||
const float *src_win_ptr = src_c_ptr + ((in_h_index + kh) * in_w + in_w_index + kw) * channel;
|
||||
tmp_max = fmax(tmp_max, src_win_ptr[0]);
|
||||
} // win_w loop
|
||||
} // win_h loop
|
||||
dst_c_ptr[0] = tmp_max;
|
||||
} // channel_res loop
|
||||
} // real_cal_num loop
|
||||
} // out_plane loop
|
||||
|
|
|
@ -30,6 +30,14 @@ extern "C" {
|
|||
void AvgPooling(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param, int task_id);
|
||||
|
||||
void MaxPooling(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param, int task_id);
|
||||
|
||||
void AvgPoolingRelu(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param, int task_id);
|
||||
|
||||
void MaxPoolingRelu(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param, int task_id);
|
||||
|
||||
void AvgPoolingRelu6(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param, int task_id);
|
||||
|
||||
void MaxPoolingRelu6(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param, int task_id);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -17,17 +17,15 @@
|
|||
#include "nnacl/fp32/resize.h"
|
||||
#include "nnacl/common_func.h"
|
||||
#include "nnacl/errorcode.h"
|
||||
|
||||
int ResizeBilinear(const float *input_data, float *output_data, const int *input_shape, const int *output_shape,
|
||||
bool align_corners, int tid, int thread_num) {
|
||||
if (input_data == NULL || output_data == NULL || input_shape == NULL || output_shape == NULL) {
|
||||
int PrepareResizeBilinear(const int *input_shape, const int *output_shape, bool align_corners, int *y_bottoms,
|
||||
int *y_tops, int *x_lefts, int *x_rights, float *y_bottom_weights, float *x_left_weights) {
|
||||
if (input_shape == NULL || output_shape == NULL || y_bottoms == NULL || y_tops == NULL || x_lefts == NULL ||
|
||||
x_rights == NULL || y_bottom_weights == NULL || x_left_weights == NULL) {
|
||||
return NNACL_NULL_PTR;
|
||||
}
|
||||
|
||||
int in_n = input_shape[0];
|
||||
int in_h = input_shape[1];
|
||||
int in_w = input_shape[2];
|
||||
int in_c = input_shape[3];
|
||||
|
||||
int new_height = output_shape[1];
|
||||
int new_width = output_shape[2];
|
||||
|
@ -40,65 +38,119 @@ int ResizeBilinear(const float *input_data, float *output_data, const int *input
|
|||
width_scale = (float)(in_w - 1) / (new_width - 1);
|
||||
}
|
||||
|
||||
int n, h, w, c;
|
||||
for (n = 0; n < in_n; n++) {
|
||||
for (h = tid; h < new_height; h += thread_num) {
|
||||
float actual_y = (float)h * height_scale;
|
||||
int y_bottom = (int)(floor(actual_y));
|
||||
int y_top = y_bottom + 1 < in_h ? (y_bottom + 1) : (in_h - 1);
|
||||
float y_top_weight = actual_y - (float)(y_bottom);
|
||||
const float y_bottom_weight = 1.0f - y_top_weight;
|
||||
for (w = 0; w < new_width; w++) {
|
||||
float actual_x = (float)(w)*width_scale;
|
||||
int x_left = (int)(floor(actual_x));
|
||||
int x_right = x_left + 1 < in_w ? (x_left + 1) : (in_w - 1);
|
||||
float x_right_weight = actual_x - (float)(x_left);
|
||||
const float x_left_weight = 1.0f - x_right_weight;
|
||||
c = 0;
|
||||
int h, w;
|
||||
for (h = 0; h < new_height; h++) {
|
||||
float actual_y = (float)h * height_scale;
|
||||
int y_bottom = (int)(floor(actual_y));
|
||||
int y_top = y_bottom + 1 < in_h ? (y_bottom + 1) : (in_h - 1);
|
||||
float y_top_weight = actual_y - (float)(y_bottom);
|
||||
const float y_bottom_weight = 1.0f - y_top_weight;
|
||||
|
||||
y_bottoms[h] = y_bottom;
|
||||
y_tops[h] = y_top;
|
||||
y_bottom_weights[h] = y_bottom_weight;
|
||||
}
|
||||
for (w = 0; w < new_width; w++) {
|
||||
float actual_x = (float)(w)*width_scale;
|
||||
int x_left = (int)(floor(actual_x));
|
||||
int x_right = x_left + 1 < in_w ? (x_left + 1) : (in_w - 1);
|
||||
float x_right_weight = actual_x - (float)(x_left);
|
||||
const float x_left_weight = 1.0f - x_right_weight;
|
||||
|
||||
x_lefts[w] = x_left;
|
||||
x_rights[w] = x_right;
|
||||
x_left_weights[w] = x_left_weight;
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int ResizeBilinear(const float *input_data, float *output_data, const int *input_shape, const int *output_shape,
|
||||
int *y_bottoms, int *y_tops, int *x_lefts, int *x_rights, float *y_bottom_weights,
|
||||
float *x_left_weights, int n_h_begin, int n_h_end) {
|
||||
if (input_data == NULL || output_data == NULL || input_shape == NULL || output_shape == NULL || y_bottoms == NULL ||
|
||||
y_tops == NULL || x_lefts == NULL || x_rights == NULL || y_bottom_weights == NULL || x_left_weights == NULL) {
|
||||
return NNACL_NULL_PTR;
|
||||
}
|
||||
|
||||
int in_w = input_shape[2];
|
||||
int in_c = input_shape[3];
|
||||
|
||||
int new_height = output_shape[1];
|
||||
int new_width = output_shape[2];
|
||||
|
||||
int n_h, n, h, w, c;
|
||||
n = n_h_begin / new_height;
|
||||
h = n_h_begin % new_height;
|
||||
int n_h_stride = new_width * in_c;
|
||||
int out_offset = n_h_begin * n_h_stride;
|
||||
for (n_h = n_h_begin; n_h < n_h_end; n_h++, h++) {
|
||||
if (h == new_height) {
|
||||
h = 0;
|
||||
n++;
|
||||
}
|
||||
int y_bottom = y_bottoms[h];
|
||||
int y_top = y_tops[h];
|
||||
float y_bottom_weight = y_bottom_weights[h];
|
||||
float y_top_weight = 1.0f - y_bottom_weight;
|
||||
|
||||
for (w = 0; w < new_width; w++) {
|
||||
int x_left = x_lefts[w];
|
||||
int x_right = x_rights[w];
|
||||
float x_left_weight = x_left_weights[w];
|
||||
float x_right_weight = 1.0f - x_left_weight;
|
||||
float top_left_weight = y_top_weight * x_left_weight;
|
||||
float top_right_weight = y_top_weight * x_right_weight;
|
||||
float bottom_left_weight = y_bottom_weight * x_left_weight;
|
||||
float bottom_right_weight = y_bottom_weight * x_right_weight;
|
||||
|
||||
c = 0;
|
||||
int in_bottom_left_offset = offset(input_shape, n, y_bottom, x_left, c);
|
||||
int in_bottom_right_offset = in_bottom_left_offset + (x_right - x_left) * in_c;
|
||||
int in_top_left_offset = in_bottom_left_offset + (y_top - y_bottom) * in_w * in_c;
|
||||
int in_top_right_offset = in_bottom_right_offset + (y_top - y_bottom) * in_w * in_c;
|
||||
|
||||
#ifdef ENABLE_NEON
|
||||
for (; c <= in_c - 4; c += 4) {
|
||||
float32x4_t bottom_left = vld1q_f32(input_data + offset(input_shape, n, y_bottom, x_left, c));
|
||||
float32x4_t bottom_right = vld1q_f32(input_data + offset(input_shape, n, y_bottom, x_right, c));
|
||||
float32x4_t top_left = vld1q_f32(input_data + offset(input_shape, n, y_top, x_left, c));
|
||||
float32x4_t top_right = vld1q_f32(input_data + offset(input_shape, n, y_top, x_right, c));
|
||||
float32x4_t top_left_w = vdupq_n_f32(top_left_weight);
|
||||
float32x4_t top_right_w = vdupq_n_f32(top_right_weight);
|
||||
float32x4_t bottom_left_w = vdupq_n_f32(bottom_left_weight);
|
||||
float32x4_t bottom_right_w = vdupq_n_f32(bottom_right_weight);
|
||||
|
||||
float32x4_t y_top_w = vdupq_n_f32(y_top_weight);
|
||||
float32x4_t y_bottom_w = vdupq_n_f32(y_bottom_weight);
|
||||
float32x4_t x_left_w = vdupq_n_f32(x_left_weight);
|
||||
float32x4_t x_right_w = vdupq_n_f32(x_right_weight);
|
||||
for (; c <= in_c - 4; c += 4) {
|
||||
float32x4_t bottom_left = vld1q_f32(input_data + in_bottom_left_offset + c);
|
||||
float32x4_t bottom_right = vld1q_f32(input_data + in_bottom_right_offset + c);
|
||||
float32x4_t top_left = vld1q_f32(input_data + in_top_left_offset + c);
|
||||
float32x4_t top_right = vld1q_f32(input_data + in_top_right_offset + c);
|
||||
|
||||
float32x4_t interp_value = vdupq_n_f32(0.0);
|
||||
float32x4_t tmp = vmulq_f32(bottom_left, y_bottom_w);
|
||||
tmp = vmulq_f32(tmp, x_left_w);
|
||||
interp_value = vaddq_f32(interp_value, tmp);
|
||||
float32x4_t interp_value = vdupq_n_f32(0.0);
|
||||
|
||||
tmp = vmulq_f32(bottom_right, y_bottom_w);
|
||||
tmp = vmulq_f32(tmp, x_right_w);
|
||||
interp_value = vaddq_f32(interp_value, tmp);
|
||||
float32x4_t tmp = vmulq_f32(bottom_left, bottom_left_w);
|
||||
interp_value = vaddq_f32(interp_value, tmp);
|
||||
|
||||
tmp = vmulq_f32(top_left, y_top_w);
|
||||
tmp = vmulq_f32(tmp, x_left_w);
|
||||
interp_value = vaddq_f32(interp_value, tmp);
|
||||
tmp = vmulq_f32(bottom_right, bottom_right_w);
|
||||
interp_value = vaddq_f32(interp_value, tmp);
|
||||
|
||||
tmp = vmulq_f32(top_right, y_top_w);
|
||||
tmp = vmulq_f32(tmp, x_right_w);
|
||||
interp_value = vaddq_f32(interp_value, tmp);
|
||||
vst1q_f32(output_data + offset(output_shape, n, h, w, c), interp_value);
|
||||
}
|
||||
tmp = vmulq_f32(top_left, top_left_w);
|
||||
interp_value = vaddq_f32(interp_value, tmp);
|
||||
|
||||
tmp = vmulq_f32(top_right, top_right_w);
|
||||
interp_value = vaddq_f32(interp_value, tmp);
|
||||
vst1q_f32(output_data + out_offset, interp_value);
|
||||
out_offset += 4;
|
||||
}
|
||||
#endif
|
||||
for (; c < in_c; c++) {
|
||||
float bottom_left = input_data[offset(input_shape, n, y_bottom, x_left, c)];
|
||||
float bottom_right = input_data[offset(input_shape, n, y_bottom, x_right, c)];
|
||||
float top_left = input_data[offset(input_shape, n, y_top, x_left, c)];
|
||||
float top_right = input_data[offset(input_shape, n, y_top, x_right, c)];
|
||||
float interp_value = bottom_left * y_bottom_weight * x_left_weight +
|
||||
bottom_right * y_bottom_weight * x_right_weight +
|
||||
top_left * y_top_weight * x_left_weight + top_right * y_top_weight * x_right_weight;
|
||||
output_data[offset(output_shape, n, h, w, c)] = interp_value;
|
||||
}
|
||||
for (; c < in_c; c++) {
|
||||
float bottom_left = input_data[in_bottom_left_offset + c];
|
||||
float bottom_right = input_data[in_bottom_right_offset + c];
|
||||
float top_left = input_data[in_top_left_offset + c];
|
||||
float top_right = input_data[in_top_right_offset + c];
|
||||
float interp_value = bottom_left * bottom_left_weight + bottom_right * bottom_right_weight +
|
||||
top_left * top_left_weight + top_right * top_right_weight;
|
||||
output_data[out_offset] = interp_value;
|
||||
out_offset++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -25,9 +25,12 @@
|
|||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
int ResizeBilinear(const float *input_data, float *output_data, const int *input_shape, const int *output_shape,
|
||||
bool align_corners, int tid, int thread_num);
|
||||
|
||||
int PrepareResizeBilinear(const int *input_shape, const int *output_shape, bool align_corners, int *y_bottoms,
|
||||
int *y_tops, int *x_lefts, int *x_rights, float *y_bottom_weights, float *x_left_weights);
|
||||
int ResizeBilinear(const float *input_data, float *output_data, const int *input_shape, const int *output_shape,
|
||||
int *y_bottoms, int *y_tops, int *x_lefts, int *x_rights, float *y_bottom_weights,
|
||||
float *x_left_weights, int n_h_begin, int n_h_end);
|
||||
int ResizeNearestNeighbor(const float *input_data, float *output_data, const int *input_shape, const int *output_shape,
|
||||
int tid, int thread_num);
|
||||
#ifdef __cplusplus
|
||||
|
|
|
@ -0,0 +1,79 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "nnacl/fp32/scale.h"
|
||||
#ifdef ENABLE_ARM
|
||||
#include <arm_neon.h>
|
||||
#endif
|
||||
void ScaleInner(float *in_data, float *out_data, float *scale, float *offset, int outer_start, int outer_end,
|
||||
int axis_size, int inner_size) {
|
||||
for (int out = outer_start; out < outer_end; out++) {
|
||||
int out_offset = out * axis_size * inner_size;
|
||||
for (int i = 0; i < axis_size; i++) {
|
||||
int axis_offset = out_offset + i * inner_size;
|
||||
int in_index = 0;
|
||||
#ifdef ENABLE_ARM64
|
||||
for (; in_index < inner_size - 4; in_index += 4) {
|
||||
int in_offset = axis_offset + in_index;
|
||||
float32x4_t data = vld1q_f32(in_data + in_offset);
|
||||
float32x4_t scale_4 = vdupq_n_f32(scale[i]);
|
||||
float32x4_t offset_4 = vdupq_n_f32(offset[i]);
|
||||
float32x4_t reslut = vfmaq_f32(offset_4, data, scale_4);
|
||||
vst1q_f32(out_data + in_offset, reslut);
|
||||
}
|
||||
#endif
|
||||
for (; in_index < inner_size; in_index++) {
|
||||
int in_offset = axis_offset + in_index;
|
||||
out_data[in_offset] = in_data[in_offset] * scale[i] + offset[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ScaleAxis(float *in_data, float *out_data, float *scale, float *offset, int outer_start, int outer_end,
|
||||
int axis_size) {
|
||||
for (int out = outer_start; out < outer_end; out++) {
|
||||
int out_offset = out * axis_size;
|
||||
int index = 0;
|
||||
#ifdef ENABLE_ARM64
|
||||
for (; index < axis_size - 4; index += 4) {
|
||||
int in_offset = out_offset + index;
|
||||
float32x4_t data = vld1q_f32(in_data + in_offset);
|
||||
float32x4_t scale_4 = vld1q_f32(scale + index);
|
||||
float32x4_t offset_4 = vld1q_f32(offset + index);
|
||||
float32x4_t reslut = vfmaq_f32(offset_4, data, scale_4);
|
||||
vst1q_f32(out_data + in_offset, reslut);
|
||||
}
|
||||
#endif
|
||||
for (; index < axis_size; index++) {
|
||||
int in_offset = out_offset + index;
|
||||
out_data[in_offset] = in_data[in_offset] * scale[index] + offset[index];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void DoScale(float *in_data, float *out_data, float *scale, float *offset, int task_id, ScaleParameter *scale_param) {
|
||||
int outer_step = UP_DIV(scale_param->outer_size_, scale_param->op_parameter_.thread_num_);
|
||||
int outer_start = task_id * outer_step;
|
||||
int outer_end = MSMIN(outer_start + outer_step, scale_param->outer_size_);
|
||||
|
||||
if (scale_param->inner_size_ == 1) {
|
||||
ScaleAxis(in_data, out_data, scale, offset, outer_start, outer_end, scale_param->axis_size_);
|
||||
} else {
|
||||
ScaleInner(in_data, out_data, scale, offset, outer_start, outer_end, scale_param->axis_size_,
|
||||
scale_param->inner_size_);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,30 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_NNACL_SCALE_FP32_H_
|
||||
#define MINDSPORE_LITE_NNACL_SCALE_FP32_H_
|
||||
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/scale.h"
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
void DoScale(float *in_data, float *out_data, float *scale, float *offset, int task_id, ScaleParameter *scale_param);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // MINDSPORE_LITE_NNACL_SCALE_FP32_H_
|
|
@ -16,132 +16,79 @@
|
|||
#include "nnacl/fp32/space_to_batch.h"
|
||||
#include "nnacl/arithmetic_common.h"
|
||||
#include "nnacl/errorcode.h"
|
||||
#include "nnacl/fp32/concat.h"
|
||||
#include "nnacl/op_base.h"
|
||||
|
||||
int EnumElement(int *shape, int n_dims) {
|
||||
int total = 1;
|
||||
for (int i = 0; i < n_dims; i++) {
|
||||
total *= shape[i];
|
||||
}
|
||||
return total;
|
||||
}
|
||||
|
||||
void TransposeForNHWC(const float *in_data, float *out_data, int *strides, int *out_strides, int *perm,
|
||||
int *output_shape, int h_start, int h_end) {
|
||||
const int stride0 = strides[perm[0]];
|
||||
const int stride1 = strides[perm[1]];
|
||||
const int stride2 = strides[perm[2]];
|
||||
const int stride3 = strides[perm[3]];
|
||||
const int stride4 = strides[perm[4]];
|
||||
const int out_stride0 = out_strides[0];
|
||||
const int out_stride1 = out_strides[1];
|
||||
const int out_stride2 = out_strides[2];
|
||||
const int out_stride3 = out_strides[3];
|
||||
const int out_stride4 = out_strides[4];
|
||||
const int output0 = output_shape[0];
|
||||
const int output2 = output_shape[2];
|
||||
const int output3 = output_shape[3];
|
||||
const int output4 = output_shape[4];
|
||||
|
||||
for (int i = 0; i < output0; ++i) {
|
||||
int out_stride0_i = i * out_stride0;
|
||||
int stride0_i = i * stride0;
|
||||
for (int j = h_start; j < h_end; ++j) {
|
||||
int out_stride1_j = j * out_stride1;
|
||||
int stride1_j = j * stride1;
|
||||
for (int k = 0; k < output2; ++k) {
|
||||
int out_stride2_k = k * out_stride2;
|
||||
int stride2_k = k * stride2;
|
||||
for (int m = 0; m < output3; ++m) {
|
||||
int out_stride3_m = m * out_stride3;
|
||||
int stride3_m = m * stride3;
|
||||
for (int n = 0; n < output4; ++n) {
|
||||
int out_stride4_n = n * out_stride4;
|
||||
int stride4_n = n * stride4;
|
||||
memcpy(out_data + out_stride0_i + out_stride1_j + out_stride2_k + out_stride3_m + out_stride4_n,
|
||||
in_data + stride0_i + stride1_j + stride2_k + stride3_m + stride4_n, stride4 * sizeof(float));
|
||||
}
|
||||
}
|
||||
void DoSpaceToBatchNHWC(const float *input, float *output, SpaceToBatchParameter *param, int *in_shape,
|
||||
int *out_shape) {
|
||||
int out_dim0 = out_shape[0];
|
||||
int out_dim1 = out_shape[1];
|
||||
int out_dim2 = out_shape[2];
|
||||
int copy_num = out_shape[3];
|
||||
int block_w = param->block_sizes_[1];
|
||||
int block_h = param->block_sizes_[0];
|
||||
int in_strides[4];
|
||||
ComputeStrides(in_shape, in_strides, 4);
|
||||
int out_strides[4];
|
||||
ComputeStrides(out_shape, out_strides, 4);
|
||||
size_t copy_size = copy_num * sizeof(float);
|
||||
size_t out_offset = 0;
|
||||
for (int n = 0; n < out_dim0; ++n) {
|
||||
int in_n = n % in_shape[0];
|
||||
int32_t stride_w = (n / in_shape[0]) % block_w;
|
||||
int32_t stride_h = (n / in_shape[0]) / block_w;
|
||||
size_t in_offset0 = in_n * in_strides[0];
|
||||
for (int h = 0; h < out_dim1; ++h) {
|
||||
size_t in_offset1 = in_offset0 + (h * block_h + stride_h) * in_strides[1];
|
||||
for (int w = 0; w < out_dim2; ++w) {
|
||||
size_t in_offset2 = in_offset1 + (w * block_w + stride_w) * in_strides[2];
|
||||
memcpy(output + out_offset, input + in_offset2, copy_size);
|
||||
out_offset += copy_num;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int SpaceToBatchForNHWC(const float *input, float *output, int *in_shape, int shape_size, int *block_sizes, int h_start,
|
||||
int h_end) {
|
||||
int trans_in_shape[6] = {in_shape[0], in_shape[1] / block_sizes[0],
|
||||
block_sizes[0], in_shape[2] / block_sizes[1],
|
||||
block_sizes[1], in_shape[3]};
|
||||
int trans_out_shape[6] = {
|
||||
in_shape[0], block_sizes[0], block_sizes[1], in_shape[1] / block_sizes[0], in_shape[2] / block_sizes[1],
|
||||
in_shape[3]};
|
||||
int in_strides[C4NUM + 2];
|
||||
ComputeStrides(trans_in_shape, in_strides, shape_size + 2);
|
||||
int out_strides[C4NUM + 2];
|
||||
ComputeStrides(trans_out_shape, out_strides, shape_size + 2);
|
||||
|
||||
int perm[6] = {0, 2, 4, 1, 3, 5};
|
||||
TransposeForNHWC(input, output, in_strides, out_strides, perm, trans_out_shape, h_start, h_end);
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
void DoPadding(const float *input, float *padded_input, SpaceToBatchParameter param, float *tmp_space[]) {
|
||||
float *tmp = padded_input;
|
||||
(void)memcpy(tmp, input, param.num_elements_ * sizeof(float));
|
||||
float *target = tmp_space[0];
|
||||
float *tmp_zeros = tmp_space[1];
|
||||
float *tmp2 = NULL;
|
||||
int cur_shape[param.n_dims_], cur_start_shape[param.n_dims_], cur_end_shape[param.n_dims_],
|
||||
cur_target_shape[param.n_dims_];
|
||||
float *concat_inputs[3];
|
||||
int *concat_shapes[4];
|
||||
|
||||
for (int i = 0; i < param.n_dims_; i++) {
|
||||
cur_shape[i] = param.in_shape_[i];
|
||||
cur_start_shape[i] = param.in_shape_[i];
|
||||
cur_end_shape[i] = param.in_shape_[i];
|
||||
cur_target_shape[i] = param.in_shape_[i];
|
||||
}
|
||||
for (int i = 0; i < param.n_space_dims_; ++i) {
|
||||
if (param.padded_in_shape_[i + 1] > param.in_shape_[i + 1]) {
|
||||
int concat_idx = 0;
|
||||
cur_target_shape[i + 1] = 0;
|
||||
if (param.paddings_[2 * i] != 0) {
|
||||
cur_start_shape[i + 1] = param.paddings_[2 * i];
|
||||
concat_inputs[concat_idx] = tmp_zeros;
|
||||
concat_shapes[concat_idx++] = cur_start_shape;
|
||||
cur_target_shape[i + 1] += cur_start_shape[i + 1];
|
||||
void DoSpaceToBatchPaddingNHWC(const float *input, float *output, int *in_shape, int *padding, int *out_shape,
|
||||
const float *pedding_h_data, const float *pedding_w_data) {
|
||||
int in_h = in_shape[1];
|
||||
int in_w = in_shape[2];
|
||||
int in_c = in_shape[3];
|
||||
int out_w = out_shape[2];
|
||||
int out_c = out_shape[3];
|
||||
size_t ped_h_num = out_w * out_c;
|
||||
size_t ped_h_size = ped_h_num * sizeof(float);
|
||||
size_t ped_w_size = out_c * sizeof(float);
|
||||
size_t out_offset = 0;
|
||||
int in_strides[4];
|
||||
ComputeStrides(in_shape, in_strides, 4);
|
||||
int out_strides[4];
|
||||
ComputeStrides(out_shape, out_strides, 4);
|
||||
size_t copy_size = in_c * sizeof(float);
|
||||
for (int i = 0; i < in_shape[0]; ++i) {
|
||||
size_t in_offset0 = i * in_strides[0];
|
||||
for (int pad_h_top = 0; pad_h_top < padding[0]; ++pad_h_top) {
|
||||
memcpy(output + out_offset, pedding_h_data, ped_h_size);
|
||||
out_offset += ped_h_num;
|
||||
}
|
||||
for (int j = 0; j < in_h; ++j) {
|
||||
size_t in_offset1 = in_offset0 + j * in_strides[1];
|
||||
for (int pad_w_left = 0; pad_w_left < padding[2]; ++pad_w_left) {
|
||||
memcpy(output + out_offset, pedding_w_data, ped_w_size);
|
||||
out_offset += out_c;
|
||||
}
|
||||
|
||||
concat_inputs[concat_idx] = tmp;
|
||||
concat_shapes[concat_idx++] = cur_shape;
|
||||
cur_target_shape[i + 1] += cur_shape[i + 1];
|
||||
if (param.paddings_[2 * i + 1] != 0) {
|
||||
cur_end_shape[i + 1] = param.paddings_[2 * i + 1];
|
||||
concat_inputs[concat_idx] = tmp_zeros;
|
||||
concat_shapes[concat_idx++] = cur_end_shape;
|
||||
cur_target_shape[i + 1] += cur_end_shape[i + 1];
|
||||
for (int k = 0; k < in_w; ++k) {
|
||||
size_t in_offset2 = in_offset1 + k * in_strides[2];
|
||||
memcpy(output + out_offset, input + in_offset2, copy_size);
|
||||
out_offset += in_c;
|
||||
}
|
||||
concat_shapes[concat_idx] = cur_target_shape;
|
||||
Concat((void **)concat_inputs, concat_idx, i + 1, concat_shapes, param.n_dims_, target);
|
||||
|
||||
tmp2 = tmp;
|
||||
tmp = target;
|
||||
target = tmp2;
|
||||
cur_start_shape[i + 1] = cur_end_shape[i + 1] = cur_shape[i + 1] = concat_shapes[concat_idx][i + 1];
|
||||
for (int pad_w_right = 0; pad_w_right < padding[3]; ++pad_w_right) {
|
||||
memcpy(output + out_offset, pedding_w_data, ped_w_size);
|
||||
out_offset += out_c;
|
||||
}
|
||||
}
|
||||
for (int pad_h_bottom = 0; pad_h_bottom < padding[1]; ++pad_h_bottom) {
|
||||
memcpy(output + out_offset, pedding_h_data, ped_h_size);
|
||||
out_offset += ped_h_num;
|
||||
}
|
||||
}
|
||||
if (padded_input != tmp) {
|
||||
memcpy(padded_input, tmp, param.num_elements_padded_ * sizeof(float));
|
||||
}
|
||||
}
|
||||
|
||||
int SpaceToBatch(const float *input, float *output, SpaceToBatchParameter param, int h_start, int h_end) {
|
||||
if (input == NULL || output == NULL) {
|
||||
return NNACL_NULL_PTR;
|
||||
}
|
||||
int ret =
|
||||
SpaceToBatchForNHWC(input, output, param.padded_in_shape_, param.n_dims_, param.block_sizes_, h_start, h_end);
|
||||
return ret;
|
||||
}
|
||||
|
|
|
@ -22,26 +22,17 @@
|
|||
|
||||
typedef struct SpaceToBatchParameter {
|
||||
OpParameter op_parameter_;
|
||||
int block_sizes_[8];
|
||||
int paddings_[8];
|
||||
int n_dims_;
|
||||
int num_elements_;
|
||||
int num_elements_padded_;
|
||||
int n_space_dims_;
|
||||
int in_shape_[8];
|
||||
int padded_in_shape_[8];
|
||||
bool need_paddings_;
|
||||
int block_sizes_[4];
|
||||
int paddings_[4];
|
||||
} SpaceToBatchParameter;
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
int SpaceToBatch(const float *input, float *output, SpaceToBatchParameter param, int h_start, int h_end);
|
||||
int SpaceToBatchForNHWC(const float *input, float *output, int *in_shape, int shape_size, int *block_size, int h_start,
|
||||
int h_end);
|
||||
void TransposeForNHWC(const float *in_data, float *out_data, int *strides, int *out_strides, int *perm,
|
||||
int *output_shape, int h_start, int h_end);
|
||||
void DoPadding(const float *input, float *padded_input, SpaceToBatchParameter param, float *tmp_space[]);
|
||||
int EnumElement(int *shape, int n_dims);
|
||||
void DoSpaceToBatchNHWC(const float *input, float *output, SpaceToBatchParameter *param, int *in_shape,
|
||||
int *out_shape);
|
||||
void DoSpaceToBatchPaddingNHWC(const float *input, float *output, int *in_shape, int *padding, int *out_shape,
|
||||
const float *pedding_h_data, const float *pedding_w_data);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -1,204 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "nnacl/fp32/strassen_matmul.h"
|
||||
|
||||
bool CheckRecursion(int row, int col, int deep, int max_recursion, int cur_recursion) {
|
||||
if (cur_recursion >= max_recursion) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (row % 2 != 0 || col % 2 != 0 || deep % 2 != 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
int row2 = row / 2;
|
||||
int col2 = col / 2;
|
||||
int deep2 = deep / 2;
|
||||
|
||||
float save_cost = row * col * 4 * deep * 4 * 2 + row * col * 4 -
|
||||
7 * (row2 * col2 * 4 * deep2 * 4 * 2 - row2 * col2 * 4) - 4 * (row2 * deep2 * 4 * 3) -
|
||||
4 * (deep2 * 4 * col2 * 4 * 3) - 7 * (row2 * col2 * 4 * 3);
|
||||
|
||||
return (save_cost > 0.f);
|
||||
}
|
||||
|
||||
void GemmMatMulComm(const float *a_ptr, const float *b_ptr, float *dst_ptr, int row, int col, int deep, int b_stride,
|
||||
int c_stride) {
|
||||
int row4mod = row % 4;
|
||||
int row4div = row / 4;
|
||||
for (int r = 0; r < row; r++) {
|
||||
int r4mod = r % 4;
|
||||
int r4div = r / 4;
|
||||
for (int c = 0; c < col * 4; c++) {
|
||||
float value = 0;
|
||||
int ic = c / 4 * c_stride + r * 4 + c % 4;
|
||||
for (int d = 0; d < deep * 4; d++) {
|
||||
int d4mod = d % 4;
|
||||
int d4div = d / 4;
|
||||
int a_stride = (r < (row4div * 4)) ? 4 : row4mod;
|
||||
int ai = r4div * 4 * deep * 4 + d4div * a_stride * 4 + r4mod * 4 + d4mod;
|
||||
int bi = c / 4 * b_stride + d * 4 + c % 4;
|
||||
value = value + a_ptr[ai] * b_ptr[bi];
|
||||
}
|
||||
dst_ptr[ic] = value;
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void GemmMatMul(const float *a_ptr, const float *b_ptr, float *dst_ptr, int row, int col, int deep, int b_stride,
|
||||
int c_stride) {
|
||||
int row4mod = row % 4;
|
||||
int row4div = row / 4;
|
||||
|
||||
if (row4div > 0) {
|
||||
GemmMatMulComm(a_ptr, b_ptr, dst_ptr, row4div * 4, col, deep, b_stride, c_stride);
|
||||
}
|
||||
|
||||
if (row4mod != 0) {
|
||||
GemmMatMulComm(a_ptr + row4div * deep * 4 * 4, b_ptr, dst_ptr + row4div * 4 * 4, row4mod, col, deep, b_stride,
|
||||
c_stride);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
int RecursionMatmul(const float *a_ptr, const float *b_ptr, float *c_ptr, StrassenMatMulParameter *matmul_param,
|
||||
int max_recursion, int cur_recursion, float *tmp_a_ptr) {
|
||||
size_t row2 = matmul_param->row_ / 2;
|
||||
size_t deep2 = matmul_param->deep_ / 2;
|
||||
size_t col2 = matmul_param->col_ / 2;
|
||||
size_t a_stride = matmul_param->a_stride_;
|
||||
size_t b_stride = matmul_param->b_stride_;
|
||||
size_t c_stride = matmul_param->c_stride_;
|
||||
|
||||
StrassenMatMulParameter rec_matmul;
|
||||
rec_matmul.row_ = row2;
|
||||
rec_matmul.deep_ = deep2;
|
||||
rec_matmul.col_ = col2;
|
||||
|
||||
float *x_ptr = (float *)(malloc(row2 * MSMAX(deep2, col2) * FP32_STRASSEN_UINT * sizeof(float)));
|
||||
if (x_ptr == NULL) {
|
||||
return NNACL_ERRCODE_STRASSEN_RECURSION_MALLOC;
|
||||
}
|
||||
float *y_ptr = (float *)(malloc(col2 * deep2 * FP32_STRASSEN_WEIGHT_UINT * sizeof(float)));
|
||||
if (y_ptr == NULL) {
|
||||
free(x_ptr);
|
||||
return NNACL_ERRCODE_STRASSEN_RECURSION_MALLOC;
|
||||
}
|
||||
size_t x_stride = row2 * FP32_STRASSEN_UINT;
|
||||
size_t y_stride = deep2 * FP32_STRASSEN_WEIGHT_UINT;
|
||||
|
||||
const float *a11 = a_ptr;
|
||||
const float *a12 = a_ptr + deep2 * a_stride;
|
||||
const float *a21 = a_ptr + row2 * FP32_STRASSEN_UINT;
|
||||
const float *a22 = a_ptr + deep2 * a_stride + row2 * FP32_STRASSEN_UINT;
|
||||
const float *b11 = b_ptr;
|
||||
const float *b12 = b_ptr + col2 * b_stride;
|
||||
const float *b21 = b_ptr + deep2 * FP32_STRASSEN_WEIGHT_UINT;
|
||||
const float *b22 = b_ptr + col2 * b_stride + deep2 * FP32_STRASSEN_WEIGHT_UINT;
|
||||
float *c11 = c_ptr;
|
||||
float *c12 = c_ptr + col2 * c_stride;
|
||||
float *c21 = c_ptr + row2 * FP32_STRASSEN_UINT;
|
||||
float *c22 = c_ptr + col2 * c_stride + row2 * FP32_STRASSEN_UINT;
|
||||
|
||||
/* S3 = A11 - A21 */
|
||||
MatrixSub(a11, a21, x_ptr, a_stride, a_stride, x_stride, row2, deep2);
|
||||
|
||||
/* T3 = B22 - B12 */
|
||||
MatrixSub(b22, b12, y_ptr, b_stride, b_stride, y_stride, deep2 * 4, col2);
|
||||
|
||||
/* P7 = S3T3 */
|
||||
rec_matmul.a_stride_ = x_stride;
|
||||
rec_matmul.b_stride_ = y_stride;
|
||||
rec_matmul.c_stride_ = c_stride;
|
||||
StrassenMatmul(x_ptr, y_ptr, c21, &rec_matmul, max_recursion, cur_recursion + 1, tmp_a_ptr);
|
||||
|
||||
/* S1 = A21 + A22 */
|
||||
MatrixAdd(a21, a22, x_ptr, a_stride, a_stride, x_stride, row2, deep2);
|
||||
|
||||
/* T1 = B12 - B11 */
|
||||
MatrixSub(b12, b11, y_ptr, b_stride, b_stride, y_stride, deep2 * 4, col2);
|
||||
|
||||
/* P5 = S1T1 */
|
||||
StrassenMatmul(x_ptr, y_ptr, c22, &rec_matmul, max_recursion, cur_recursion + 1, tmp_a_ptr);
|
||||
|
||||
/* S2 = S1 - A11 */
|
||||
MatrixSub(x_ptr, a11, x_ptr, x_stride, a_stride, x_stride, row2, deep2);
|
||||
|
||||
/* T2 = B22 - T1 */
|
||||
MatrixSub(b22, y_ptr, y_ptr, b_stride, y_stride, y_stride, deep2 * 4, col2);
|
||||
|
||||
/* P6 = S2T2 */
|
||||
StrassenMatmul(x_ptr, y_ptr, c12, &rec_matmul, max_recursion, cur_recursion + 1, tmp_a_ptr);
|
||||
|
||||
/* S4 = A12 - S2 */
|
||||
MatrixSub(a12, x_ptr, x_ptr, a_stride, x_stride, x_stride, row2, deep2);
|
||||
|
||||
/* P3 = S4B22 */
|
||||
rec_matmul.b_stride_ = b_stride;
|
||||
StrassenMatmul(x_ptr, b22, c11, &rec_matmul, max_recursion, cur_recursion + 1, tmp_a_ptr);
|
||||
|
||||
/* P1 = A11B11 */
|
||||
rec_matmul.a_stride_ = a_stride;
|
||||
rec_matmul.c_stride_ = row2 * FP32_STRASSEN_UINT;
|
||||
StrassenMatmul(a11, b11, x_ptr, &rec_matmul, max_recursion, cur_recursion + 1, tmp_a_ptr);
|
||||
|
||||
/* U2 = P1 + P6
|
||||
U3 = U2 + P7
|
||||
U4 = U2 + P5
|
||||
U7 = U3 + P5
|
||||
U5 = U4 + P3 */
|
||||
MatrixMultiAdd(c11, c12, c21, c22, x_ptr, row2, col2, c_stride, x_stride);
|
||||
|
||||
/* T4 = T2 - B21 */
|
||||
MatrixSub(y_ptr, b21, y_ptr, y_stride, b_stride, y_stride, deep2 * 4, col2);
|
||||
|
||||
/* P4 = A22T4 */
|
||||
rec_matmul.b_stride_ = y_stride;
|
||||
rec_matmul.c_stride_ = c_stride;
|
||||
StrassenMatmul(a22, y_ptr, c11, &rec_matmul, max_recursion, cur_recursion + 1, tmp_a_ptr);
|
||||
|
||||
/* U6 = U3 - P4 */
|
||||
MatrixSub(c21, c11, c21, c_stride, c_stride, c_stride, row2, col2);
|
||||
|
||||
/* P2 = A12B21 */
|
||||
rec_matmul.b_stride_ = b_stride;
|
||||
StrassenMatmul(a12, b21, c11, &rec_matmul, max_recursion, cur_recursion + 1, tmp_a_ptr);
|
||||
|
||||
/* U1 = P1 + P2 */
|
||||
MatrixAdd(x_ptr, c11, c11, x_stride, c_stride, c_stride, row2, col2);
|
||||
|
||||
free(x_ptr);
|
||||
free(y_ptr);
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int CommonMatMul(const float *a_ptr, const float *b_ptr, float *c_ptr, StrassenMatMulParameter *matmul_param,
|
||||
float *tmp_a_ptr) {
|
||||
MatrixPack(a_ptr, tmp_a_ptr, matmul_param->row_, matmul_param->deep_, matmul_param->a_stride_);
|
||||
GemmMatMul(tmp_a_ptr, b_ptr, c_ptr, matmul_param->row_, matmul_param->col_, matmul_param->deep_,
|
||||
matmul_param->b_stride_, matmul_param->c_stride_);
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int StrassenMatmul(const float *a_ptr, const float *b_ptr, float *c_ptr, StrassenMatMulParameter *matmul_param,
|
||||
int max_recursion, int cur_recursion, float *tmp_a_ptr) {
|
||||
if (CheckRecursion(matmul_param->row_, matmul_param->col_, matmul_param->deep_, cur_recursion, max_recursion)) {
|
||||
return RecursionMatmul(a_ptr, b_ptr, c_ptr, matmul_param, max_recursion, cur_recursion, tmp_a_ptr);
|
||||
}
|
||||
return CommonMatMul(a_ptr, b_ptr, c_ptr, matmul_param, tmp_a_ptr);
|
||||
}
|
|
@ -1,45 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_NNACL_FP32_STRASSEN_MATMUL_H_
|
||||
#define MINDSPORE_LITE_NNACL_FP32_STRASSEN_MATMUL_H_
|
||||
|
||||
#include <memory.h>
|
||||
#include "nnacl/pack.h"
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/errorcode.h"
|
||||
#include "nnacl/strassen_matmul.h"
|
||||
#include "nnacl/fp32/common_func.h"
|
||||
|
||||
#define FP32_STRASSEN_UINT C4NUM
|
||||
#define FP32_STRASSEN_WEIGHT_UINT (C4NUM * C4NUM)
|
||||
#define FP32_STRASSEN_MAX_RECURSION 5
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
int RecursionMatmul(const float *a_ptr, const float *b_ptr, float *c_ptr, StrassenMatMulParameter *matmul_param,
|
||||
int max_recursion, int, float *tmp_a_ptr);
|
||||
int CommonMatMul(const float *a_ptr, const float *b_ptr, float *c_ptr, StrassenMatMulParameter *Matmul_param,
|
||||
float *tmp_a_ptr);
|
||||
|
||||
int StrassenMatmul(const float *a_ptr, const float *b_ptr, float *c_ptr, StrassenMatMulParameter *matmul_param,
|
||||
int max_recursion, int cur_recursion, float *tmp_a_ptr);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // MINDSPORE_LITE_NNACL_FP32_STRASSEN_MATMUL_H_
|
|
@ -16,9 +16,27 @@
|
|||
|
||||
#include "nnacl/fp32/topk.h"
|
||||
|
||||
int DescendCmp(const void *a, const void *b) { return ((const TopkNode *)b)->element - ((const TopkNode *)a)->element; }
|
||||
int DescendCmp(const void *a, const void *b) {
|
||||
float sub = ((const TopkNode *)b)->element - ((const TopkNode *)a)->element;
|
||||
if (sub > 0) {
|
||||
return 1;
|
||||
} else if (sub < 0) {
|
||||
return -1;
|
||||
} else {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
int AscendCmp(const void *a, const void *b) { return ((const TopkNode *)a)->element - ((const TopkNode *)b)->element; }
|
||||
int AscendCmp(const void *a, const void *b) {
|
||||
float sub = ((const TopkNode *)a)->element - ((const TopkNode *)b)->element;
|
||||
if (sub > 0) {
|
||||
return 1;
|
||||
} else if (sub < 0) {
|
||||
return -1;
|
||||
} else {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
void Topk(float *input_data, float *output_data, int32_t *output_index, TopkParameter *parameter) {
|
||||
int last_dim_size = parameter->last_dim_size_;
|
||||
|
|
|
@ -20,9 +20,9 @@
|
|||
static int is_a_ge_zero_and_a_lt_b(int a, int b) { return (unsigned)(a) < (unsigned)(b); }
|
||||
|
||||
void im2col_hwc(const float *in_data, float *data_col, ConvParameter *conv_param) {
|
||||
const int pad_left = /*conv_param->pad_l_*/ conv_param->pad_w_;
|
||||
const int pad_left = /*conv_param->pad_l_*/ conv_param->pad_l_;
|
||||
// const int pad_right = /*conv_param->pad_r_*/conv_param->pad_w_;
|
||||
const int pad_up = /*conv_param->pad_u_*/ conv_param->pad_h_;
|
||||
const int pad_up = /*conv_param->pad_u_*/ conv_param->pad_u_;
|
||||
// const int pad_down = /*conv_param->pad_d/*/conv_param->pad_h_;
|
||||
|
||||
const int stride_h = conv_param->stride_h_;
|
||||
|
@ -72,9 +72,9 @@ void im2col_hwc(const float *in_data, float *data_col, ConvParameter *conv_param
|
|||
|
||||
// output matrix is (kernel_h*kernel_w*channels)X(output_h*output_w)
|
||||
void im2row_hwc(const float *in_data, float *data_row, ConvParameter *conv_param) {
|
||||
const int pad_left = /*conv_param->pad_l_*/ conv_param->pad_w_;
|
||||
const int pad_left = /*conv_param->pad_l_*/ conv_param->pad_l_;
|
||||
// const int pad_right = /*conv_param->pad_r_*/conv_param->pad_w_;
|
||||
const int pad_up = /*conv_param->pad_u_*/ conv_param->pad_h_;
|
||||
const int pad_up = /*conv_param->pad_u_*/ conv_param->pad_u_;
|
||||
// const int pad_down = /*conv_param->pad_d/*/conv_param->pad_h_;
|
||||
|
||||
const int stride_h = conv_param->stride_h_;
|
||||
|
|
|
@ -14,20 +14,15 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_NNACL_STRASSEN_MATMUL_H_
|
||||
#define MINDSPORE_LITE_NNACL_STRASSEN_MATMUL_H_
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_GATHER_PARAMETER_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_GATHER_PARAMETER_H_
|
||||
|
||||
#include "nnacl/op_base.h"
|
||||
|
||||
/* hw*inc4 X inc4*oc4 */
|
||||
typedef struct StrassenMatMulParameter {
|
||||
OpParameter op_parameter;
|
||||
int row_; /* h * w */
|
||||
int col_; /* oc4 / 4 */
|
||||
int deep_; /* inc4 / 4 */
|
||||
int a_stride_; /* h * w * 4 */
|
||||
int b_stride_; /* inc4 * 4 */
|
||||
int c_stride_; /* h * w * 4 */
|
||||
} StrassenMatMulParameter;
|
||||
typedef struct GatherParameter {
|
||||
OpParameter op_parameter_;
|
||||
int axis_;
|
||||
int batchDims_;
|
||||
} GatherParameter;
|
||||
|
||||
#endif // MINDSPORE_LITE_NNACL_STRASSEN_MATMUL_H_
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_GATHER_PARAMETER_H_
|
|
@ -49,6 +49,10 @@ void ConvDwInt8Center(int8_t *dst, const int16_t *src, const int16_t *weight, co
|
|||
size_t width, size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel,
|
||||
size_t in_sh_step, size_t in_sw_step, size_t in_kh_step, size_t in_kw_step, int out_multiplier,
|
||||
int left_shift, int right_shift, int32_t out_zp, int32_t acc_min, int32_t acc_max);
|
||||
void ConvDwInt8Row(int32_t *output_ptr, const int8_t *input_ptr, const int16_t *weight_ptr, int num_pixels,
|
||||
int output_channel, int input_step, int8_t input_zp);
|
||||
void ConvDwInt8PostAlign4(int8_t *dst, int32_t *buffer, int num_pixels, int32_t output_zp, int32_t out_multiplier,
|
||||
int32_t left_shift, int32_t right_shift, int32_t acc_min, int32_t acc_max);
|
||||
#endif
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
|
|
@ -20,6 +20,99 @@
|
|||
#include "nnacl/int8/common_func.h"
|
||||
|
||||
/*conv depthwise int8 begin*/
|
||||
// only support perlayer
|
||||
#ifndef ENABLE_ARM64
|
||||
void ConvDwInt8Row(int32_t *output_ptr, const int8_t *input_ptr, const int16_t *weight_ptr, int num_pixels,
|
||||
int output_channel, int input_step, int8_t input_zp) {
|
||||
for (int i = 0; i < num_pixels; i++) {
|
||||
for (int c = 0; c < output_channel; c++) {
|
||||
const int16_t input = input_ptr[c] - input_zp;
|
||||
*output_ptr++ += input * weight_ptr[c];
|
||||
}
|
||||
input_ptr += input_step;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
void ConvDwInt8Post(int8_t *dst, int32_t *buffer, int num_pixels, int32_t output_zp, int32_t out_multiplier,
|
||||
int32_t left_shift, int32_t right_shift, int32_t acc_min, int32_t acc_max) {
|
||||
int align_num = 0;
|
||||
#ifdef ENABLE_ARM64
|
||||
align_num = num_pixels / 4 * 4;
|
||||
ConvDwInt8PostAlign4(dst, buffer, align_num, output_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max);
|
||||
#endif
|
||||
for (int i = align_num; i < num_pixels; i++) {
|
||||
buffer[i] = RoundingDivideByPOT(
|
||||
SaturatingRoundingDoublingHighMul(buffer[i] * (1 << (unsigned int)left_shift), out_multiplier), -right_shift);
|
||||
buffer[i] += output_zp;
|
||||
buffer[i] = MSMAX(buffer[i], acc_min);
|
||||
buffer[i] = MSMIN(buffer[i], acc_max);
|
||||
dst[i] = (buffer[i]);
|
||||
}
|
||||
}
|
||||
|
||||
void ConvDwInt8(int8_t *output_data, int32_t *row_buffer, const int8_t *input_data, const int16_t *weight_data,
|
||||
const int32_t *bias_data, const ConvParameter *conv_param, int task_id) {
|
||||
int h_step = UP_DIV(conv_param->output_h_, conv_param->thread_num_);
|
||||
int h_start = h_step * task_id;
|
||||
int h_end = MSMIN(h_start + h_step, conv_param->output_h_);
|
||||
|
||||
int out_multiplier = conv_param->conv_quant_arg_.quant_multiplier_[0];
|
||||
int left_shift = conv_param->conv_quant_arg_.left_shift_[0];
|
||||
int right_shift = conv_param->conv_quant_arg_.right_shift_[0];
|
||||
|
||||
int intput_zp = conv_param->conv_quant_arg_.input_quant_args_[0].zp_;
|
||||
int output_zp = conv_param->conv_quant_arg_.output_quant_args_[0].zp_;
|
||||
int acc_min = conv_param->conv_quant_arg_.out_act_min_[0];
|
||||
int acc_max = conv_param->conv_quant_arg_.out_act_max_[0];
|
||||
|
||||
for (int b = 0; b < conv_param->output_batch_; b++) {
|
||||
const int8_t *src = input_data + b * conv_param->input_h_ * conv_param->input_w_ * conv_param->input_channel_;
|
||||
int8_t *dst = output_data + b * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_channel_;
|
||||
for (int oh = h_start; oh < h_end; oh++) {
|
||||
int8_t *dst_data = dst + oh * conv_param->output_w_ * conv_param->output_channel_;
|
||||
|
||||
int ih_origin = oh * conv_param->stride_h_ - conv_param->pad_u_;
|
||||
int start_kh = MSMAX(0, UP_DIV(-ih_origin, conv_param->dilation_h_));
|
||||
int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih_origin, conv_param->dilation_h_));
|
||||
|
||||
// init acc
|
||||
for (int ow = 0; ow < conv_param->output_w_; ow++) {
|
||||
memcpy(row_buffer + ow * conv_param->output_channel_, bias_data, conv_param->output_channel_ * sizeof(int32_t));
|
||||
}
|
||||
for (int kh = start_kh; kh < end_kh; kh++) {
|
||||
int ih = ih_origin + conv_param->dilation_w_ * kh;
|
||||
|
||||
const int8_t *src_kh = src + ih * conv_param->input_w_ * conv_param->input_channel_;
|
||||
const int16_t *weight_kh = weight_data + kh * conv_param->kernel_w_ * conv_param->output_channel_;
|
||||
|
||||
int in_sw_step = conv_param->stride_w_ * conv_param->input_channel_;
|
||||
for (int kw = 0; kw < conv_param->kernel_w_; kw++) {
|
||||
int out_w_start = MSMAX(
|
||||
0, (conv_param->pad_l_ - conv_param->dilation_w_ * kw + conv_param->stride_w_ - 1) / conv_param->stride_w_);
|
||||
int out_w_end = MSMIN(conv_param->output_w_, (conv_param->input_w_ + conv_param->pad_l_ -
|
||||
conv_param->dilation_w_ * kw + conv_param->stride_w_ - 1) /
|
||||
conv_param->stride_w_);
|
||||
|
||||
int32_t *acc_w = row_buffer + out_w_start * conv_param->output_channel_;
|
||||
int iw_origin = (out_w_start * conv_param->stride_w_) - conv_param->pad_l_ + conv_param->dilation_w_ * kw;
|
||||
|
||||
const int8_t *src_kw = src_kh + iw_origin * conv_param->input_channel_;
|
||||
int num_pixels = out_w_end - out_w_start;
|
||||
|
||||
ConvDwInt8Row(acc_w, src_kw, weight_kh, num_pixels, conv_param->output_channel_, in_sw_step, intput_zp);
|
||||
weight_kh += conv_param->output_channel_;
|
||||
}
|
||||
}
|
||||
// post func, acc int32 -> dst int8
|
||||
ConvDwInt8Post(dst_data, row_buffer, conv_param->output_w_ * conv_param->output_channel_, output_zp,
|
||||
out_multiplier, left_shift, right_shift, acc_min, acc_max);
|
||||
}
|
||||
}
|
||||
}
|
||||
/*conv depthwise int8 end*/
|
||||
|
||||
/*conv depthwise sliding window int8 begin*/
|
||||
void DepthwiseBorderPixelInt8(int8_t *dst, const int16_t *src, const int16_t *weight, const int32_t *bias, int height,
|
||||
int width, int in_kh_step, int in_kw_step, int kernel_w, int *out_multiplier,
|
||||
int *left_shift, int *right_shift, int32_t out_zp, int32_t acc_min, int32_t acc_max,
|
||||
|
@ -68,14 +161,14 @@ void DepthwiseBorderInt8(int8_t *dst, const int16_t *src, const int16_t *weight,
|
|||
bool per_channel) {
|
||||
int8_t *dst_h = dst + top * sliding->out_h_step_;
|
||||
for (int oh = top; oh < bottom; oh++) {
|
||||
int ih = oh * conv_param->stride_h_ - conv_param->pad_h_;
|
||||
int ih = oh * conv_param->stride_h_ - conv_param->pad_u_;
|
||||
int start_kh = MSMAX(0, UP_DIV(-ih, conv_param->dilation_h_));
|
||||
int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih, conv_param->dilation_h_));
|
||||
const int16_t *src_h = src + ih * sliding->in_h_step_;
|
||||
|
||||
int8_t *dst_kernel = dst_h + left * sliding->block_channel_;
|
||||
for (int ow = left; ow < right; ow++) {
|
||||
int iw = ow * conv_param->stride_w_ - conv_param->pad_w_;
|
||||
int iw = ow * conv_param->stride_w_ - conv_param->pad_l_;
|
||||
int start_kw = MSMAX(0, UP_DIV(-iw, conv_param->dilation_w_));
|
||||
int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->input_w_ - iw, conv_param->dilation_w_));
|
||||
const int16_t *src_w = src_h + iw * sliding->block_channel_;
|
||||
|
@ -153,8 +246,8 @@ void DepthwiseCenterInt8(int8_t *dst, const int16_t *src, const int16_t *weight,
|
|||
}
|
||||
#endif
|
||||
|
||||
void ConvDwInt8(int8_t *output_data, const int16_t *input_data, const int16_t *weight_data, const int32_t *bias_data,
|
||||
const ConvParameter *conv_param, const SlidingWindowParam *sliding, int task_id) {
|
||||
void ConvDwSWInt8(int8_t *output_data, const int16_t *input_data, const int16_t *weight_data, const int32_t *bias_data,
|
||||
const ConvParameter *conv_param, const SlidingWindowParam *sliding, int task_id) {
|
||||
const int16_t *src = input_data;
|
||||
int8_t *dst = output_data;
|
||||
bool per_channel = conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL;
|
||||
|
@ -186,8 +279,8 @@ void ConvDwInt8(int8_t *output_data, const int16_t *input_data, const int16_t *w
|
|||
per_channel);
|
||||
|
||||
if (sliding->right_ > sliding->left_ && sliding->bottom_ > sliding->top_) {
|
||||
int in_h_start = sliding->top_ * conv_param->stride_h_ - conv_param->pad_h_;
|
||||
int in_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_w_;
|
||||
int in_h_start = sliding->top_ * conv_param->stride_h_ - conv_param->pad_u_;
|
||||
int in_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_l_;
|
||||
const int16_t *in_t = src_data + in_h_start * sliding->in_h_step_ + in_w_start * sliding->block_channel_;
|
||||
int8_t *out_t = dst_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_;
|
||||
#ifdef ENABLE_ARM64
|
||||
|
@ -215,7 +308,7 @@ void ConvDwInt8(int8_t *output_data, const int16_t *input_data, const int16_t *w
|
|||
} // batch loop
|
||||
// output nhwc4
|
||||
}
|
||||
/*conv depthwise int8 end*/
|
||||
/*conv depthwise sliding window int8 end*/
|
||||
|
||||
/*deconv depthwise int8 begin*/
|
||||
void DeconvDepthwiseBorderPixelInt8(int32_t *dst, const int16_t *src, const int16_t *weight, int height, int width,
|
||||
|
@ -241,14 +334,14 @@ void DeconvDepthwiseBorderInt8(int32_t *dst, const int16_t *src, const int16_t *
|
|||
int right, const ConvParameter *conv_param, const SlidingWindowParam *sliding) {
|
||||
const int16_t *src_h = src + top * sliding->out_h_step_;
|
||||
for (int ih = top; ih < bottom; ih++) {
|
||||
int oh = ih * conv_param->stride_h_ - conv_param->pad_h_;
|
||||
int oh = ih * conv_param->stride_h_ - conv_param->pad_u_;
|
||||
int start_kh = MSMAX(0, UP_DIV(-oh, conv_param->dilation_h_));
|
||||
int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->output_h_ - oh, conv_param->dilation_h_));
|
||||
int32_t *dst_h = dst + oh * sliding->in_h_step_;
|
||||
|
||||
const int16_t *src_kernel = src_h + left * sliding->block_channel_;
|
||||
for (int iw = left; iw < right; iw++) {
|
||||
int ow = iw * conv_param->stride_w_ - conv_param->pad_w_;
|
||||
int ow = iw * conv_param->stride_w_ - conv_param->pad_l_;
|
||||
int start_kw = MSMAX(0, UP_DIV(-ow, conv_param->dilation_w_));
|
||||
int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->output_w_ - ow, conv_param->dilation_w_));
|
||||
int32_t *dst_w = dst_h + ow * C4NUM;
|
||||
|
@ -341,8 +434,8 @@ void DeconvDwInt8(int8_t *output_data, int32_t *output_buffer, const int16_t *in
|
|||
conv_param->input_w_, conv_param, sliding);
|
||||
|
||||
if (sliding->right_ > sliding->left_ && sliding->bottom_ > sliding->top_) {
|
||||
int oh_h_start = sliding->top_ * conv_param->stride_h_ - conv_param->pad_h_;
|
||||
int oh_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_w_;
|
||||
int oh_h_start = sliding->top_ * conv_param->stride_h_ - conv_param->pad_u_;
|
||||
int oh_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_l_;
|
||||
int32_t *out_t = output_buffer + oh_h_start * sliding->in_h_step_ + oh_w_start * sliding->block_channel_;
|
||||
const int16_t *in_t =
|
||||
src_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_;
|
||||
|
|
|
@ -23,8 +23,12 @@
|
|||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
void ConvDwInt8(int8_t *output_data, const int16_t *input_data, const int16_t *weight_data, const int32_t *bias_data,
|
||||
const ConvParameter *conv_param, const SlidingWindowParam *sliding, int task_id);
|
||||
|
||||
void ConvDwInt8(int8_t *output_data, int32_t *output_row, const int8_t *input_data, const int16_t *weight_data,
|
||||
const int32_t *bias_data, const ConvParameter *conv_param, int task_id);
|
||||
|
||||
void ConvDwSWInt8(int8_t *output_data, const int16_t *input_data, const int16_t *weight_data, const int32_t *bias_data,
|
||||
const ConvParameter *conv_param, const SlidingWindowParam *sliding, int task_id);
|
||||
|
||||
void DeconvDwInt8(int8_t *output_data, int32_t *output_buffer, const int16_t *input_data, const int16_t *weight_data,
|
||||
const int32_t *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding,
|
||||
|
|
|
@ -28,7 +28,7 @@ void IndirectGemmInt8(int8_t *dst, int32_t *tmp_dst, const int8_t *src, const in
|
|||
int32_t out_zp = conv_param->conv_quant_arg_.output_quant_args_[0].zp_;
|
||||
int32_t act_min = conv_param->conv_quant_arg_.out_act_min_[0];
|
||||
int32_t act_max = conv_param->conv_quant_arg_.out_act_max_[0];
|
||||
int oc4 = UP_DIV(output_channel, C4NUM);
|
||||
|
||||
#ifdef ENABLE_ARM64
|
||||
size_t asymmetric = conv_param->conv_quant_arg_.asymmetric_ & FILTER_ASYMMETRIC;
|
||||
size_t per_channel = conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL;
|
||||
|
@ -36,6 +36,7 @@ void IndirectGemmInt8(int8_t *dst, int32_t *tmp_dst, const int8_t *src, const in
|
|||
output_channel * sizeof(int8_t), input_sum, act_min, act_max, out_zp, out_multiplier,
|
||||
shift_before, shift_after, asymmetric, per_channel);
|
||||
#else
|
||||
int oc4 = UP_DIV(output_channel, C4NUM);
|
||||
int tile_num = conv_param->tile_num_;
|
||||
int plane_c4 = UP_DIV(kernel_plane, C4NUM);
|
||||
for (int oc = 0; oc < output_channel; oc++) {
|
||||
|
@ -198,7 +199,7 @@ void IndirectGemmInt8Opt(int8_t *dst, int32_t *tmp_dst, const int8_t *src, const
|
|||
}
|
||||
}
|
||||
|
||||
void Conv3x3Uint8Gemm(int32_t *dst, const int16_t *src, const int16_t *weight, int oc, int ic8, size_t real_cal_num) {
|
||||
void Conv3x3Int8Gemm(int32_t *dst, const int16_t *src, const int16_t *weight, int oc, int ic8, size_t real_cal_num) {
|
||||
int oc4 = UP_DIV(oc, C4NUM);
|
||||
#ifdef ENABLE_ARM64
|
||||
IndirectGemmInt16to32_8x4(dst, src, weight, 16, ic8, oc4, oc4 * 4 * 16 * sizeof(int32_t));
|
||||
|
@ -263,7 +264,8 @@ void ConvInt8(int8_t *input_data, int8_t *packed_input, int8_t *packed_weight, c
|
|||
int output_tile_count = UP_DIV(output_count, tile_n);
|
||||
int ic4 = UP_DIV(in_channel, C4NUM);
|
||||
int kernel_plane = kernel_h * kernel_w;
|
||||
int unit_size = kernel_plane * ic4 * C4NUM;
|
||||
int plane_block = UP_DIV(kernel_plane, C4NUM);
|
||||
int unit_size = plane_block * C4NUM * ic4 * C4NUM;
|
||||
int packed_input_size = output_tile_count * tile_n * unit_size;
|
||||
int input_sum_offset;
|
||||
if (conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL) {
|
||||
|
@ -297,9 +299,10 @@ void ConvInt8(int8_t *input_data, int8_t *packed_input, int8_t *packed_weight, c
|
|||
out_channel, tmp_input_sum, conv_param);
|
||||
} else {
|
||||
// res part
|
||||
IndirectGemmInt8(tmp_out, tmp_dst + tmp_dst_offset, gemm_input, packed_weight, bias_data, ic4, kernel_plane,
|
||||
int8_t *tmp_out_ptr = tmp_out + task_id * tile_n * out_channel;
|
||||
IndirectGemmInt8(tmp_out_ptr, tmp_dst + tmp_dst_offset, gemm_input, packed_weight, bias_data, ic4, kernel_plane,
|
||||
out_channel, tmp_input_sum, conv_param);
|
||||
memcpy(output_data + out_offset, tmp_out, real_cal_num * out_channel);
|
||||
memcpy(output_data + out_offset, tmp_out_ptr, real_cal_num * out_channel);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -359,14 +362,274 @@ void ConvInt8Opt(int8_t *input_data, int8_t *packed_input, int8_t *packed_weight
|
|||
kernel_plane, out_channel, tmp_input_sum, conv_param, gemm_func);
|
||||
} else {
|
||||
// res part
|
||||
IndirectGemmInt8Opt(tmp_out, tmp_dst + tmp_dst_offset, gemm_input, packed_weight, bias_data, ic4, kernel_plane,
|
||||
out_channel, tmp_input_sum, conv_param, gemm_func);
|
||||
memcpy(output_data + out_offset, tmp_out, real_cal_num * out_channel);
|
||||
int8_t *tmp_out_ptr = tmp_out + task_id * tile_n * out_channel;
|
||||
IndirectGemmInt8Opt(tmp_out_ptr, tmp_dst + tmp_dst_offset, gemm_input, packed_weight, bias_data, ic4,
|
||||
kernel_plane, out_channel, tmp_input_sum, conv_param, gemm_func);
|
||||
memcpy(output_data + out_offset, tmp_out_ptr, real_cal_num * out_channel);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Conv1x1PreOpt(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum, size_t input_channel,
|
||||
size_t output_channel, size_t plane_size, ConvParameter *conv_param) {
|
||||
int ic4 = UP_ROUND(input_channel, C4NUM);
|
||||
size_t hw_8div = plane_size / C8NUM * C8NUM;
|
||||
size_t hw_8res = plane_size - hw_8div;
|
||||
size_t ic_4div = input_channel / C4NUM * C4NUM;
|
||||
int32_t filter_zp = conv_param->conv_quant_arg_.filter_quant_args_[0].zp_;
|
||||
|
||||
if (conv_param->conv_quant_arg_.filter_arg_num_ == 1) {
|
||||
const int8_t *src_r = src_input;
|
||||
int8_t *pack_r = packed_input;
|
||||
/* per layer */
|
||||
for (int hwi = 0; hwi < hw_8div; hwi += C8NUM) {
|
||||
const int8_t *src_ic = src_r;
|
||||
int8_t *pack_ic = pack_r;
|
||||
int32_t *input_sum_r = input_sum + hwi;
|
||||
#ifdef ENABLE_ARM64
|
||||
size_t src_stride = input_channel;
|
||||
size_t ic_4res = input_channel - ic_4div;
|
||||
asm volatile(
|
||||
"dup v10.4s, wzr \n"
|
||||
"dup v11.4s, wzr \n"
|
||||
"mov x20, %[input_sum_r] \n"
|
||||
"dup v20.4s, %w[filter_zp] \n"
|
||||
|
||||
"mov x10, %[src_ic] \n"
|
||||
"mov x11, %[pack_ic] \n"
|
||||
|
||||
"mov x0, #0 \n"
|
||||
"1: \n"
|
||||
"cmp x0, %[ic_4div] \n"
|
||||
"add x0, x0, #4\n"
|
||||
"mov x12, x10 \n"
|
||||
"add x10, x10, #4\n"
|
||||
"blt 2f \n"
|
||||
"cmp %[ic_4res], #0\n"
|
||||
"beq 6f \n"
|
||||
"cmp %[ic_4res], #1\n"
|
||||
"beq 3f \n"
|
||||
"cmp %[ic_4res], #2\n"
|
||||
"beq 4f \n"
|
||||
"cmp %[ic_4res], #3\n"
|
||||
"beq 5f \n"
|
||||
|
||||
"2: \n"
|
||||
"ld1 {v0.s}[0], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.s}[1], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.s}[2], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.s}[3], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.s}[0], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.s}[1], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.s}[2], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.s}[3], [x12], %[src_stride]\n"
|
||||
|
||||
"st1 {v0.16b}, [x11], #16\n"
|
||||
"st1 {v1.16b}, [x11], #16\n"
|
||||
|
||||
"saddlp v4.8h, v0.16b \n"
|
||||
"saddlp v5.8h, v1.16b \n"
|
||||
|
||||
"saddlp v0.4s, v4.8h \n"
|
||||
"saddlp v1.4s, v5.8h \n"
|
||||
|
||||
"add v10.4s, v10.4s, v0.4s \n"
|
||||
"add v11.4s, v11.4s, v1.4s \n"
|
||||
"b 1b \n"
|
||||
|
||||
"3: \n" /* col res 1 */
|
||||
"dup v0.4s, wzr \n"
|
||||
"dup v1.4s, wzr \n"
|
||||
|
||||
"ld1 {v0.b}[0], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.b}[4], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.b}[8], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.b}[12], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.b}[0], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.b}[4], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.b}[8], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.b}[12], [x12], %[src_stride]\n"
|
||||
|
||||
"st1 {v0.16b}, [x11], #16\n"
|
||||
"st1 {v1.16b}, [x11], #16\n"
|
||||
"saddlp v4.8h, v0.16b \n"
|
||||
"saddlp v5.8h, v1.16b \n"
|
||||
"saddlp v0.4s, v4.8h \n"
|
||||
"saddlp v1.4s, v5.8h \n"
|
||||
"add v10.4s, v10.4s, v0.4s \n"
|
||||
"add v11.4s, v11.4s, v1.4s \n"
|
||||
"b 6f \n"
|
||||
|
||||
"4: \n" /* col res 2 */
|
||||
"dup v0.4s, wzr \n"
|
||||
"dup v1.4s, wzr \n"
|
||||
|
||||
"ld1 {v0.h}[0], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.h}[2], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.h}[4], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.h}[6], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.h}[0], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.h}[2], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.h}[4], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.h}[6], [x12], %[src_stride]\n"
|
||||
|
||||
"st1 {v0.16b}, [x11], #16\n"
|
||||
"st1 {v1.16b}, [x11], #16\n"
|
||||
"saddlp v4.8h, v0.16b \n"
|
||||
"saddlp v5.8h, v1.16b \n"
|
||||
"saddlp v0.4s, v4.8h \n"
|
||||
"saddlp v1.4s, v5.8h \n"
|
||||
"add v10.4s, v10.4s, v0.4s \n"
|
||||
"add v11.4s, v11.4s, v1.4s \n"
|
||||
"b 6f \n"
|
||||
|
||||
"5: \n" /* col res 3 */
|
||||
"dup v0.4s, wzr \n"
|
||||
"dup v1.4s, wzr \n"
|
||||
"add x13, x12, #2 \n"
|
||||
|
||||
"ld1 {v0.h}[0], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.b}[2], [x13], %[src_stride]\n"
|
||||
"ld1 {v0.h}[2], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.b}[6], [x13], %[src_stride]\n"
|
||||
"ld1 {v0.h}[4], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.b}[10], [x13], %[src_stride]\n"
|
||||
"ld1 {v0.h}[6], [x12], %[src_stride]\n"
|
||||
"ld1 {v0.b}[14], [x13], %[src_stride]\n"
|
||||
"ld1 {v1.h}[0], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.b}[2], [x13], %[src_stride]\n"
|
||||
"ld1 {v1.h}[2], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.b}[6], [x13], %[src_stride]\n"
|
||||
"ld1 {v1.h}[4], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.b}[10], [x13], %[src_stride]\n"
|
||||
"ld1 {v1.h}[6], [x12], %[src_stride]\n"
|
||||
"ld1 {v1.b}[14], [x13], %[src_stride]\n"
|
||||
|
||||
"st1 {v0.16b}, [x11], #16\n"
|
||||
"st1 {v1.16b}, [x11], #16\n"
|
||||
"saddlp v4.8h, v0.16b \n"
|
||||
"saddlp v5.8h, v1.16b \n"
|
||||
"saddlp v0.4s, v4.8h \n"
|
||||
"saddlp v1.4s, v5.8h \n"
|
||||
"add v10.4s, v10.4s, v0.4s \n"
|
||||
"add v11.4s, v11.4s, v1.4s \n"
|
||||
"b 6f \n"
|
||||
|
||||
"6: \n"
|
||||
"mul v10.4s, v10.4s, v20.4s \n"
|
||||
"mul v11.4s, v11.4s, v20.4s \n"
|
||||
|
||||
"st1 {v10.4s}, [x20], #16 \n"
|
||||
"st1 {v11.4s}, [x20], #16 \n"
|
||||
|
||||
:
|
||||
: [ src_ic ] "r"(src_ic), [ pack_ic ] "r"(pack_ic), [ input_sum_r ] "r"(input_sum_r),
|
||||
[ src_stride ] "r"(src_stride), [ ic_4div ] "r"(ic_4div), [ ic_4res ] "r"(ic_4res),
|
||||
[ filter_zp ] "r"(filter_zp)
|
||||
: "x0", "x1", "x10", "x11", "x12", "x13", "x20", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v10", "v11",
|
||||
"v20");
|
||||
#else
|
||||
int32_t tmp_sum_value[8] = {0};
|
||||
for (int ici = 0; ici < ic_4div; ici += C4NUM) {
|
||||
for (int i = 0; i < C8NUM; i++) {
|
||||
tmp_sum_value[i] += src_ic[0 + i * input_channel];
|
||||
tmp_sum_value[i] += src_ic[1 + i * input_channel];
|
||||
tmp_sum_value[i] += src_ic[2 + i * input_channel];
|
||||
tmp_sum_value[i] += src_ic[3 + i * input_channel];
|
||||
pack_ic[0 + i * C4NUM] = src_ic[0 + i * input_channel];
|
||||
pack_ic[1 + i * C4NUM] = src_ic[1 + i * input_channel];
|
||||
pack_ic[2 + i * C4NUM] = src_ic[2 + i * input_channel];
|
||||
pack_ic[3 + i * C4NUM] = src_ic[3 + i * input_channel];
|
||||
}
|
||||
src_ic += C4NUM;
|
||||
pack_ic += C4NUM * C8NUM;
|
||||
}
|
||||
for (int ici = ic_4div; ici < input_channel; ici += 1) {
|
||||
for (int i = 0; i < C8NUM; i++) {
|
||||
tmp_sum_value[i] += src_ic[i * input_channel];
|
||||
pack_ic[i * C4NUM] = src_ic[i * input_channel];
|
||||
}
|
||||
src_ic += 1;
|
||||
pack_ic += 1;
|
||||
}
|
||||
|
||||
for (int i = 0; i < C8NUM; i++) {
|
||||
input_sum_r[i] = tmp_sum_value[i] * filter_zp;
|
||||
}
|
||||
#endif
|
||||
src_r += input_channel * C8NUM;
|
||||
pack_r += ic4 * C8NUM;
|
||||
}
|
||||
|
||||
if (hw_8div != plane_size) {
|
||||
memset(pack_r, 0, C8NUM * ic4);
|
||||
for (int hwi = hw_8div; hwi < plane_size; hwi += 1) {
|
||||
int32_t tmp_sum_value = 0;
|
||||
const int8_t *src_ic = src_r;
|
||||
int8_t *pack_ic = pack_r;
|
||||
for (int ici = 0; ici < ic_4div; ici += C4NUM) {
|
||||
tmp_sum_value += src_ic[0];
|
||||
tmp_sum_value += src_ic[1];
|
||||
tmp_sum_value += src_ic[2];
|
||||
tmp_sum_value += src_ic[3];
|
||||
pack_ic[0] = src_ic[0];
|
||||
pack_ic[1] = src_ic[1];
|
||||
pack_ic[2] = src_ic[2];
|
||||
pack_ic[3] = src_ic[3];
|
||||
src_ic += C4NUM;
|
||||
pack_ic += C4NUM * C8NUM;
|
||||
}
|
||||
for (int ici = ic_4div; ici < input_channel; ici += 1) {
|
||||
tmp_sum_value += src_ic[0];
|
||||
pack_ic[0] = src_ic[0];
|
||||
src_ic += 1;
|
||||
pack_ic += 1;
|
||||
}
|
||||
input_sum[hwi] = tmp_sum_value * filter_zp;
|
||||
src_r += input_channel;
|
||||
pack_r += C4NUM;
|
||||
}
|
||||
for (int hwi = plane_size; hwi < plane_size + hw_8res; hwi++) {
|
||||
input_sum[hwi] = 0;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
/* per channel */
|
||||
RowMajor2Row4x8MajorInt8(src_input, packed_input, plane_size, input_channel);
|
||||
PackInputSum8x4Int8(packed_input, input_sum, input_channel, output_channel, plane_size, conv_param);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void Conv1x1Int8Opt(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum,
|
||||
const int32_t *bias, int row, int col, int deep4, ConvParameter *conv_param,
|
||||
MATMUL_OPT_R_FUNC matmul_func) {
|
||||
matmul_func(packed_input, packed_weight, dst, row, col, deep4, conv_param->output_channel_, input_sum, bias,
|
||||
conv_param->conv_quant_arg_.left_shift_, conv_param->conv_quant_arg_.right_shift_,
|
||||
conv_param->conv_quant_arg_.quant_multiplier_, conv_param->conv_quant_arg_.output_quant_args_[0].zp_,
|
||||
conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], false);
|
||||
return;
|
||||
}
|
||||
|
||||
void Conv1x1Int8(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum,
|
||||
const int32_t *bias, int row, int col, int deep16, ConvParameter *conv_param) {
|
||||
#ifdef ENABLE_ARM64
|
||||
MatmulInt8Neon64(packed_input, packed_weight, dst, UP_ROUND(row, C4NUM), UP_ROUND(col, C4NUM), deep16, input_sum,
|
||||
bias, conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0],
|
||||
conv_param->conv_quant_arg_.output_quant_args_[0].zp_,
|
||||
conv_param->conv_quant_arg_.quant_multiplier_[0], conv_param->conv_quant_arg_.left_shift_[0],
|
||||
conv_param->conv_quant_arg_.right_shift_[0], row, col, conv_param->output_channel_);
|
||||
#else
|
||||
MatMulInt8_16x4_r(packed_input, packed_weight, dst, row, col, deep16, conv_param->output_channel_, input_sum, bias,
|
||||
conv_param->conv_quant_arg_.left_shift_, conv_param->conv_quant_arg_.right_shift_,
|
||||
conv_param->conv_quant_arg_.quant_multiplier_,
|
||||
conv_param->conv_quant_arg_.output_quant_args_[0].zp_, conv_param->conv_quant_arg_.out_act_min_[0],
|
||||
conv_param->conv_quant_arg_.out_act_max_[0], false);
|
||||
#endif
|
||||
return;
|
||||
}
|
||||
|
||||
// int8 convolution 3x3
|
||||
void Conv3x3Int8(int16_t *input_data, int16_t *transed_weight, const int32_t *bias_data, int8_t *output_data,
|
||||
int16_t *tile_buffer, int16_t *block_unit_buffer, int32_t *tmp_dst_buffer, int8_t *tmp_out,
|
||||
|
@ -391,15 +654,15 @@ void Conv3x3Int8(int16_t *input_data, int16_t *transed_weight, const int32_t *bi
|
|||
int start_index = thread_id * TILE_NUM;
|
||||
int real_cal_num = (output_count - start_index) < TILE_NUM ? (output_count - start_index) : TILE_NUM;
|
||||
|
||||
Conv3x3Uint8InputTransform(input_data + in_batch_offset, tile_buffer + task_id * tile_buffer_offset,
|
||||
block_unit_buffer + task_id * block_unit_buffer_offset, start_index, real_cal_num,
|
||||
out_w_block, conv_param);
|
||||
Conv3x3Int8InputTransform(input_data + in_batch_offset, tile_buffer + task_id * tile_buffer_offset,
|
||||
block_unit_buffer + task_id * block_unit_buffer_offset, start_index, real_cal_num,
|
||||
out_w_block, conv_param);
|
||||
|
||||
Conv3x3Uint8Gemm(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, tile_buffer + task_id * tile_buffer_offset,
|
||||
transed_weight, output_channel, ic8, real_cal_num);
|
||||
Conv3x3Int8Gemm(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, tile_buffer + task_id * tile_buffer_offset,
|
||||
transed_weight, output_channel, ic8, real_cal_num);
|
||||
|
||||
Conv3x3Uint8OutputTransform(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, tmp_out + tmp_out_batch_offset,
|
||||
bias_data, start_index, real_cal_num, out_w_block, conv_param);
|
||||
Conv3x3Int8OutputTransform(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, tmp_out + tmp_out_batch_offset,
|
||||
bias_data, start_index, real_cal_num, out_w_block, conv_param);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -25,6 +25,8 @@
|
|||
#include "nnacl/conv_parameter.h"
|
||||
#include "nnacl/winograd_utils.h"
|
||||
#include "nnacl/quantization/quantize.h"
|
||||
#include "nnacl/matmul_parameter.h"
|
||||
#include "nnacl/int8/matmul_int8.h"
|
||||
|
||||
typedef void (*GEMM_FUNC)(int8_t *dst, const int8_t *src, const int8_t *weight, const int32_t *bias, size_t ksize,
|
||||
size_t ic4, size_t output_channel, size_t offset, const int32_t *input_sum, size_t act_min,
|
||||
|
@ -51,6 +53,15 @@ void ConvInt8Opt(int8_t *input_data, int8_t *packed_input, int8_t *packed_weight
|
|||
int32_t *tmp_dst, int8_t *tmp_out, int8_t *output_data, int32_t *input_sum, int task_id,
|
||||
ConvParameter *conv_param, GEMM_FUNC gemm_func);
|
||||
|
||||
// int8 convolution 1x1
|
||||
void Conv1x1PreOpt(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum, size_t input_channel,
|
||||
size_t output_channel, size_t plane_size, ConvParameter *conv_param);
|
||||
void Conv1x1Int8(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum,
|
||||
const int32_t *bias, int row, int col, int deep16, ConvParameter *conv_param);
|
||||
void Conv1x1Int8Opt(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum,
|
||||
const int32_t *bias, int row, int col, int deep4, ConvParameter *conv_param,
|
||||
MATMUL_OPT_R_FUNC matmul_func);
|
||||
|
||||
// int8 convolution 3x3
|
||||
void Conv3x3Int8(int16_t *input_data, int16_t *transed_weight, const int32_t *bias_data, int8_t *output_data,
|
||||
int16_t *tile_buffer, int16_t *block_unit_buffer, int32_t *tmp_dst_buffer, int8_t *tmp_out,
|
||||
|
|
|
@ -33,8 +33,8 @@ int DeConvPostInt8C8(const int32_t *src, const int32_t *bias, int32_t *tmp, int8
|
|||
|
||||
for (int ih = 0; ih < conv_param->input_h_; ih++) {
|
||||
for (int iw = 0; iw < conv_param->input_w_; iw++) {
|
||||
int oh = ih * conv_param->stride_h_ - conv_param->pad_h_;
|
||||
int ow = iw * conv_param->stride_w_ - conv_param->pad_w_;
|
||||
int oh = ih * conv_param->stride_h_ - conv_param->pad_u_;
|
||||
int ow = iw * conv_param->stride_w_ - conv_param->pad_l_;
|
||||
|
||||
int kh_start = MSMAX(0, UP_DIV(-oh, conv_param->dilation_h_));
|
||||
int kh_end = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->output_h_ - oh, conv_param->dilation_h_));
|
||||
|
@ -88,8 +88,8 @@ int DeConvPostInt8C4(const int32_t *src, const int32_t *bias, int32_t *tmp, int8
|
|||
|
||||
for (int ih = 0; ih < conv_param->input_h_; ih++) {
|
||||
for (int iw = 0; iw < conv_param->input_w_; iw++) {
|
||||
int oh = ih * conv_param->stride_h_ - conv_param->pad_h_;
|
||||
int ow = iw * conv_param->stride_w_ - conv_param->pad_w_;
|
||||
int oh = ih * conv_param->stride_h_ - conv_param->pad_u_;
|
||||
int ow = iw * conv_param->stride_w_ - conv_param->pad_l_;
|
||||
|
||||
int kh_start = MSMAX(0, UP_DIV(-oh, conv_param->dilation_h_));
|
||||
int kh_end = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->output_h_ - oh, conv_param->dilation_h_));
|
||||
|
@ -172,73 +172,7 @@ void DeConvPackWeightSum(int8_t *weight, int32_t *weight_sum, int32_t input_zp,
|
|||
void DeConvPackInputSum(const int8_t *src, int32_t *dst, int32_t filter_zp, size_t row4, size_t col16,
|
||||
bool suppport_opt) {
|
||||
/* optimize normal -> same layout */
|
||||
#ifdef ENABLE_ARM64
|
||||
asm volatile(
|
||||
"mov x10, %[src] \n"
|
||||
"mov x11, %[dst] \n"
|
||||
"dup v15.4s, %w[filter_zp] \n"
|
||||
|
||||
"mov x0, #0 \n"
|
||||
"1: \n"
|
||||
"cmp x0, %[row4] \n"
|
||||
"beq 4f \n"
|
||||
"add x0, x0, #4\n"
|
||||
"dup v10.4s, wzr \n"
|
||||
"mov x2, #0 \n"
|
||||
|
||||
"2: \n"
|
||||
"cmp x2, %[col16] \n"
|
||||
"beq 3f \n"
|
||||
"add x2, x2, #16\n"
|
||||
|
||||
"ld1 {v0.16b}, [x10], #16\n"
|
||||
"ld1 {v1.16b}, [x10], #16\n"
|
||||
"ld1 {v2.16b}, [x10], #16\n"
|
||||
"ld1 {v3.16b}, [x10], #16\n"
|
||||
|
||||
"saddlp v4.8h, v0.16b \n"
|
||||
"saddlp v5.8h, v1.16b \n"
|
||||
"saddlp v6.8h, v2.16b \n"
|
||||
"saddlp v7.8h, v3.16b \n"
|
||||
|
||||
"saddlp v0.4S, v4.8h \n"
|
||||
"saddlp v1.4S, v5.8h \n"
|
||||
"saddlp v2.4S, v6.8h \n"
|
||||
"saddlp v3.4S, v7.8h \n"
|
||||
|
||||
"addv s4, v0.4S \n"
|
||||
"addv s5, v1.4S \n"
|
||||
"addv s6, v2.4S \n"
|
||||
"addv s7, v3.4S \n"
|
||||
|
||||
"mov v0.s[0], v4.s[0] \n"
|
||||
"mov v0.s[1], v5.s[0] \n"
|
||||
"mov v0.s[2], v6.s[0] \n"
|
||||
"mov v0.s[3], v7.s[0] \n"
|
||||
|
||||
"add v10.4s, v10.4s, v0.4s \n"
|
||||
"b 2b\n"
|
||||
|
||||
"3: \n"
|
||||
"mul v10.4s, v10.4s, v15.4s \n"
|
||||
"st1 {v10.4s}, [x11], #16 \n"
|
||||
"beq 1b \n"
|
||||
|
||||
"4: \n"
|
||||
|
||||
:
|
||||
: [ dst ] "r"(dst), [ src ] "r"(src), [ row4 ] "r"(row4), [ col16 ] "r"(col16), [ filter_zp ] "r"(filter_zp)
|
||||
: "x0", "x1", "x2", "x3", "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v10", "v15");
|
||||
#else
|
||||
for (int r = 0; r < row4; r++) {
|
||||
int32_t tmp_value = 0;
|
||||
for (int c = 0; c < col16; c++) {
|
||||
int r4div = r / C4NUM, r4mod = r % C4NUM, c16div = c / C16NUM, c16mod = c % C16NUM;
|
||||
int src_index = r4div * C4NUM * col16 + c16div * C16NUM * C4NUM + r4mod * C16NUM + c16mod;
|
||||
tmp_value += src[src_index];
|
||||
}
|
||||
}
|
||||
#endif
|
||||
PackInputSum16x4PerLayer(src, dst, filter_zp, row4, col16);
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
|
@ -29,8 +29,8 @@ int DivInt8(int8_t *input0_data, int8_t *input1_data, int8_t *output_data, int64
|
|||
}
|
||||
|
||||
int recip_shift;
|
||||
const int32_t input1_inv = (input1_val > 0) ? ComputerReciproal(input1_val, 31, &recip_shift)
|
||||
: -ComputerReciproal(-input1_val, 31, &recip_shift);
|
||||
const int32_t input1_inv = (input1_val > 0) ? ComputerReciprocal(input1_val, 31, &recip_shift)
|
||||
: -ComputerReciprocal(-input1_val, 31, &recip_shift);
|
||||
const int leading_bits = CountLeadingSignBits(input0_val);
|
||||
const int32_t raw_data =
|
||||
SaturatingRoundingDoublingHighMul(input0_val * (1 << (unsigned int)leading_bits), input1_inv);
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "nnacl/int8/gatherNd_int8.h"
|
||||
#include <string.h>
|
||||
#include "nnacl/errorcode.h"
|
||||
|
||||
int GatherNdInt8(int8_t *input, int8_t *output, int *in_offset, int area, int count, GatherQuantArg param) {
|
||||
double alpha = param.alpha_;
|
||||
int z1 = param.zp_in_;
|
||||
int z2 = param.zp_out_;
|
||||
for (int i = 0; i < count; ++i) {
|
||||
for (int j = 0; j < area; ++j) {
|
||||
int32_t tmp = round(alpha * (input[in_offset[i] + j] - z1)) + z2;
|
||||
tmp = tmp > 127 ? 127 : tmp;
|
||||
tmp = tmp < -128 ? -128 : tmp;
|
||||
output[area * i + j] = (int8_t)tmp;
|
||||
}
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
|
@ -0,0 +1,31 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_GATHERND_INT8_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_GATHERND_INT8_H_
|
||||
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/quantization/quantize.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
int GatherNdInt8(int8_t *in_data, int8_t *out_data, int *in_offset, int area, int count, GatherQuantArg param);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_GATHERND_INT8_H_
|
|
@ -0,0 +1,44 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
*/
|
||||
#include "nnacl/int8/gather_int8.h"
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/quantization/quantize.h"
|
||||
#include "nnacl/errorcode.h"
|
||||
|
||||
int GatherInt8(int8_t *in_data, int8_t *out_data, int outer_size, int inner_size, int limit, int *indices,
|
||||
int indices_element_size, GatherQuantArg para) {
|
||||
double alpha = para.alpha_;
|
||||
int z1 = para.zp_in_;
|
||||
int z2 = para.zp_out_;
|
||||
int i, m, j;
|
||||
for (m = 0; m < outer_size; ++m) {
|
||||
const int8_t *inputm = in_data + inner_size * m * limit;
|
||||
int8_t *outputm = out_data + inner_size * m * indices_element_size;
|
||||
for (i = 0; i < indices_element_size; ++i) {
|
||||
if (indices[i] < 0 || indices[i] > limit) {
|
||||
return NNACL_ERR;
|
||||
}
|
||||
for (j = 0; j < inner_size; ++j) {
|
||||
int32_t tmp = round(alpha * (inputm[indices[i] * inner_size + j] - z1)) + z2;
|
||||
tmp = tmp > 127 ? 127 : tmp;
|
||||
tmp = tmp < -128 ? -128 : tmp;
|
||||
outputm[i * inner_size + j] = (int8_t)tmp;
|
||||
}
|
||||
}
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
|
@ -0,0 +1,32 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_GATHER_INT8_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_GATHER_INT8_H_
|
||||
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/quantization/quantize.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
int GatherInt8(int8_t *in_data, int8_t *out_data, int outer_size, int inner_size, int limit, int *indices,
|
||||
int indices_element_size, GatherQuantArg para);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_GATHER_INT8_H_
|
|
@ -28,6 +28,36 @@ void RowMajor2Row8MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col)
|
|||
}
|
||||
}
|
||||
|
||||
void RowMajor2Row4x16MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col) {
|
||||
int col16 = UP_ROUND(col, C16NUM);
|
||||
for (int r = 0; r < row; r++) {
|
||||
int rd4 = r / C4NUM;
|
||||
int rm4 = r % C4NUM;
|
||||
for (int c = 0; c < col; c++) {
|
||||
int cd16 = c / C16NUM;
|
||||
int cm16 = c % C16NUM;
|
||||
int dst_index = rd4 * col16 * C4NUM + cd16 * C4NUM * C16NUM + rm4 * C16NUM + cm16;
|
||||
int src_index = r * col + c;
|
||||
dst_ptr[dst_index] = src_ptr[src_index];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void RowMajor2Row8x4MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col) {
|
||||
int col4 = UP_ROUND(col, C4NUM);
|
||||
for (int r = 0; r < row; r++) {
|
||||
int rd8 = r / C8NUM;
|
||||
int rm8 = r % C8NUM;
|
||||
for (int c = 0; c < col; c++) {
|
||||
int cd4 = c / C4NUM;
|
||||
int cm4 = c % C4NUM;
|
||||
int dst_index = rd8 * col4 * C8NUM + cd4 * C8NUM * C4NUM + rm8 * C4NUM + cm4;
|
||||
int src_index = r * col + c;
|
||||
dst_ptr[dst_index] = src_ptr[src_index];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void MatrixPack4x16UnitInt8(int8_t *src, int8_t *dst, int row, int col, int stride) {
|
||||
for (int r = 0; r < row; r++) {
|
||||
int8_t *src_r = src + r * stride;
|
||||
|
@ -37,6 +67,29 @@ void MatrixPack4x16UnitInt8(int8_t *src, int8_t *dst, int row, int col, int stri
|
|||
return;
|
||||
}
|
||||
|
||||
void MatrixEmptyInt8(int8_t *dst, int row, int col) {
|
||||
for (int r = 0; r < row; r++) {
|
||||
int8_t *dst_r = dst + r * C16NUM;
|
||||
memset(dst_r, 0, col * sizeof(int8_t));
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void RowMajor2Row4x8MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col) {
|
||||
/* Row-major to row16x4-major (block row-major) */
|
||||
int col4 = UP_ROUND(col, C4NUM);
|
||||
for (int r = 0; r < row; r++) {
|
||||
int rd8 = r / C8NUM, rm8 = r % C8NUM;
|
||||
for (int c = 0; c < col; c++) {
|
||||
int cd4 = c / C4NUM, cm4 = c % C4NUM;
|
||||
int src_index = r * col + c;
|
||||
int dst_index = rd8 * col4 * C8NUM + cd4 * C4NUM * C8NUM + rm8 * C4NUM + cm4;
|
||||
dst_ptr[dst_index] = src_ptr[src_index];
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void RowMajor2Row16x4MajorInt8(void *src_ptr, void *dst_ptr, int row, int col) {
|
||||
/* Row-major to row16x4-major (block row-major) */
|
||||
int col16 = UP_ROUND(col, C16NUM);
|
||||
|
@ -50,16 +103,17 @@ void RowMajor2Row16x4MajorInt8(void *src_ptr, void *dst_ptr, int row, int col) {
|
|||
for (int ri = 0; ri < row_4div; ri += C4NUM) {
|
||||
for (int ci = 0; ci < col_16div; ci += C16NUM) {
|
||||
#ifdef ENABLE_ARM64
|
||||
size_t col_offset = col;
|
||||
int8_t *src_c = src_r + ci;
|
||||
int8_t *dst_c = dst_r + ci * C4NUM;
|
||||
asm volatile(
|
||||
"mov x10, %[src_c] \n"
|
||||
"mov x11, %[dst_c] \n"
|
||||
|
||||
"ld1 {v0.16b}, [x10], %[col]\n"
|
||||
"ld1 {v1.16b}, [x10], %[col]\n"
|
||||
"ld1 {v2.16b}, [x10], %[col]\n"
|
||||
"ld1 {v3.16b}, [x10], %[col]\n"
|
||||
"ld1 {v0.16b}, [x10], %[col_offset]\n"
|
||||
"ld1 {v1.16b}, [x10], %[col_offset]\n"
|
||||
"ld1 {v2.16b}, [x10], %[col_offset]\n"
|
||||
"ld1 {v3.16b}, [x10], %[col_offset]\n"
|
||||
|
||||
"st1 {v0.16b}, [x11], #16\n"
|
||||
"st1 {v1.16b}, [x11], #16\n"
|
||||
|
@ -67,7 +121,7 @@ void RowMajor2Row16x4MajorInt8(void *src_ptr, void *dst_ptr, int row, int col) {
|
|||
"st1 {v3.16b}, [x11], #16\n"
|
||||
|
||||
:
|
||||
: [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ col ] "r"(col)
|
||||
: [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ col_offset ] "r"(col_offset)
|
||||
: "x10", "x11", "v0", "v1", "v2", "v3");
|
||||
#else
|
||||
MatrixPack4x16UnitInt8(src_r + ci, dst_r + ci * C4NUM, C4NUM, C16NUM, col);
|
||||
|
@ -76,12 +130,15 @@ void RowMajor2Row16x4MajorInt8(void *src_ptr, void *dst_ptr, int row, int col) {
|
|||
|
||||
if (col != col_16div) {
|
||||
MatrixPack4x16UnitInt8(src_r + col_16div, dst_r + col_16div * C4NUM, C4NUM, col_16res, col);
|
||||
MatrixEmptyInt8(dst_r + col_16div * C4NUM + col_16res, C4NUM, C16NUM - col_16res);
|
||||
}
|
||||
src_r += C4NUM * col;
|
||||
dst_r += C4NUM * col16;
|
||||
}
|
||||
|
||||
if (row != row_4div) {
|
||||
memset(dst_r, 0, C4NUM * col16);
|
||||
|
||||
for (int ci = 0; ci < col_16div; ci += C16NUM) {
|
||||
MatrixPack4x16UnitInt8(src_r + ci, dst_r + ci * C4NUM, row_4res, C16NUM, col);
|
||||
}
|
||||
|
@ -103,25 +160,6 @@ void RowMajor2Col8MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col)
|
|||
}
|
||||
}
|
||||
|
||||
void MatMulInt8(const int8_t *a, const int8_t *b, int32_t *c, const int row8, const int col8, const int deep,
|
||||
const int32_t a_zp, const int32_t b_zp) {
|
||||
/* col8-major * row8-major => row8x8-major */
|
||||
for (int row = 0; row < row8; row++) {
|
||||
for (int col = 0; col < col8; col++) {
|
||||
int r8div = row / 8, r8mod = row % 8;
|
||||
int c8div = col / 8, c8mod = col % 8;
|
||||
size_t ci = c8div * row8 * 8 + row * 8 + c8mod;
|
||||
int32_t value = 0;
|
||||
for (int d = 0; d < deep; d++) {
|
||||
size_t ai = r8div * deep * 8 + d * 8 + r8mod;
|
||||
size_t bi = c8div * deep * 8 + d * 8 + c8mod;
|
||||
value = value + ((int32_t)a[ai] - a_zp) * ((int32_t)b[bi] - b_zp);
|
||||
}
|
||||
c[ci] = value;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void MatMulInt8_16x4(const int8_t *a, const int8_t *b, int *dst, int row_4, int col_4, int deep_16,
|
||||
const int *input_sum, const int *bias) {
|
||||
/* row4x16-major * row16x4-major => row4x4-major */
|
||||
|
@ -145,7 +183,100 @@ void MatMulInt8_16x4(const int8_t *a, const int8_t *b, int *dst, int row_4, int
|
|||
return;
|
||||
}
|
||||
|
||||
#ifdef ENABLE_ARM64
|
||||
void MatMulInt8_16x4_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_16,
|
||||
size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift,
|
||||
int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini, int32_t maxi,
|
||||
bool per_channel) {
|
||||
/* row4x16-major * row16x4-major => (int8)row-major : per-channel */
|
||||
for (int r = 0; r < row; r++) {
|
||||
for (int c = 0; c < col; c++) {
|
||||
int r4div = r / C4NUM, r4mod = r % C4NUM;
|
||||
int c4div = c / C4NUM, c4mod = c % C4NUM;
|
||||
size_t ci = r * stride + c;
|
||||
int32_t value = 0;
|
||||
for (int d = 0; d < deep_16; d++) {
|
||||
int d16div = d / C16NUM, d16mod = d % C16NUM;
|
||||
size_t ai = r4div * deep_16 * C4NUM + d16div * C4NUM * C16NUM + r4mod * C16NUM + d16mod;
|
||||
size_t bi = c4div * deep_16 * C4NUM + d16div * C4NUM * C16NUM + c4mod * C16NUM + d16mod;
|
||||
value = value + a[ai] * b[bi];
|
||||
}
|
||||
int32_t cur_input_sum = per_channel ? input_sum[c4div * UP_ROUND(row, C4NUM) + r * C4NUM + c4mod] : input_sum[r];
|
||||
value -= cur_input_sum;
|
||||
value += bias[c];
|
||||
int32_t cur_left_shift = per_channel ? left_shift[c] : left_shift[0];
|
||||
int32_t cur_right_shift = per_channel ? right_shift[c] : right_shift[0];
|
||||
int32_t cur_multiplier = per_channel ? multiplier[c] : multiplier[0];
|
||||
value = MultiplyByQuantizedMultiplier(value, cur_multiplier, cur_left_shift, cur_right_shift) + output_zp;
|
||||
value = MSMIN(maxi, value);
|
||||
value = MSMAX(mini, value);
|
||||
dst[ci] = (int8_t)value;
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void MatMulInt8_8x8_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4,
|
||||
size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift,
|
||||
int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini, int32_t maxi,
|
||||
bool per_channel) {
|
||||
/* row8x4-major * row4x8-major => (int8)row-major */
|
||||
for (int r = 0; r < row; r++) {
|
||||
for (int c = 0; c < col; c++) {
|
||||
int r8div = r / C8NUM, r8mod = r % C8NUM;
|
||||
int c8div = c / C8NUM, c8mod = c % C8NUM;
|
||||
size_t ci = r * stride + c;
|
||||
int32_t value = 0;
|
||||
for (int d = 0; d < deep_4; d++) {
|
||||
int d4div = d / C4NUM, d4mod = d % C4NUM;
|
||||
size_t ai = r8div * deep_4 * C8NUM + d4div * C8NUM * C4NUM + r8mod * C4NUM + d4mod;
|
||||
size_t bi = c8div * deep_4 * C8NUM + d4div * C8NUM * C4NUM + c8mod * C4NUM + d4mod;
|
||||
value = value + a[ai] * b[bi];
|
||||
}
|
||||
int32_t cur_input_sum = per_channel ? input_sum[c8div * UP_ROUND(row, C8NUM) + r * C8NUM + c8mod] : input_sum[r];
|
||||
value -= cur_input_sum;
|
||||
value += bias[c];
|
||||
int32_t cur_left_shift = per_channel ? left_shift[c] : left_shift[0];
|
||||
int32_t cur_right_shift = per_channel ? right_shift[c] : right_shift[0];
|
||||
int32_t cur_multiplier = per_channel ? multiplier[c] : multiplier[0];
|
||||
value = MultiplyByQuantizedMultiplier(value, cur_multiplier, cur_left_shift, cur_right_shift) + output_zp;
|
||||
value = MSMIN(maxi, value);
|
||||
value = MSMAX(mini, value);
|
||||
dst[ci] = (int8_t)value;
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
/* row4x16-major * col16x4-major => row4x4-major */
|
||||
void MatmulInt8(const int8_t *a, const int8_t *b, int8_t *dst, const int *a_sums, const int *bias, int act_min,
|
||||
int act_max, int out_zp, int multiplier, int left_shift, int right_shift, int row, int col, int deep16,
|
||||
int stride) {
|
||||
int8_t *output = dst;
|
||||
for (int r = 0; r < row; r++) {
|
||||
for (int c = 0; c < col; c++) {
|
||||
int r4div = r / C4NUM;
|
||||
int r4mod = r % C4NUM;
|
||||
int c4div = c / C4NUM;
|
||||
int c4mod = c % C4NUM;
|
||||
int value = 0;
|
||||
for (int d = 0; d < deep16; d++) {
|
||||
int d16div = d / C16NUM;
|
||||
int d16mod = d % C16NUM;
|
||||
size_t ai = r4div * deep16 * C4NUM + d16div * C4NUM * C16NUM + r4mod * C16NUM + d16mod;
|
||||
size_t bi = c4div * deep16 * C4NUM + d16div * C4NUM * C16NUM + c4mod * C16NUM + d16mod;
|
||||
value += a[ai] * b[bi];
|
||||
}
|
||||
value -= a_sums[r];
|
||||
value += bias[c];
|
||||
value = MultiplyByQuantizedMultiplier(value, multiplier, left_shift, right_shift) + out_zp;
|
||||
value = MSMIN(INT8_MAX, value);
|
||||
value = MSMAX(INT8_MIN, value);
|
||||
output[c] = (int8_t)value;
|
||||
}
|
||||
output += stride;
|
||||
}
|
||||
}
|
||||
|
||||
void RowMajor2Row4x16Major(int8_t *src, int row, int col, int8_t *dst, int col_16) {
|
||||
int stride = sizeof(int8_t) * 16 * 4;
|
||||
for (int r = 0; r < row; ++r) {
|
||||
|
@ -168,23 +299,35 @@ void RowMajor2Col16x4Major(int8_t *src, int row, int col, int8_t *dst, int row_1
|
|||
}
|
||||
}
|
||||
|
||||
void RowMajor2Asums(int8_t *a, int row, int col, int b_zp, int *dst) {
|
||||
// dst: weight_zp * input_row_sums
|
||||
void CalcInputSums(int8_t *input, int row, int col, int weight_zp, int *dst, DataOrder order) {
|
||||
for (int r = 0; r < row; ++r) {
|
||||
int sum = 0;
|
||||
for (int c = 0; c < col; ++c) {
|
||||
int src_idx = r * col + c;
|
||||
dst[r] += a[src_idx];
|
||||
if (order == RowMajor) {
|
||||
sum += input[r * col + c];
|
||||
} else {
|
||||
sum += input[c * row + r];
|
||||
}
|
||||
}
|
||||
dst[r] *= b_zp;
|
||||
sum *= weight_zp;
|
||||
dst[r] = sum;
|
||||
}
|
||||
}
|
||||
|
||||
void RowMajor2Bbias(int8_t *b, int row, int col, int a_zp, int b_zp, int *bias, int *dst) {
|
||||
// dst: bias + depth*input_zp*weight_zp - input_zp*weight_col_sums
|
||||
void CalcWeightBiasSums(int8_t *weight, int row, int col, int input_zp, int weight_zp, int *bias, int *dst,
|
||||
DataOrder order) {
|
||||
for (int c = 0; c < col; ++c) {
|
||||
int sum = 0;
|
||||
for (int r = 0; r < row; ++r) {
|
||||
int src_idx = r * col + c;
|
||||
dst[c] += b[src_idx];
|
||||
if (order == RowMajor) {
|
||||
sum += weight[r * col + c];
|
||||
} else {
|
||||
sum += weight[c * row + r];
|
||||
}
|
||||
}
|
||||
dst[c] = row * a_zp * b_zp - a_zp * dst[c];
|
||||
dst[c] = row * input_zp * weight_zp - input_zp * sum;
|
||||
if (bias) {
|
||||
dst[c] += bias[c];
|
||||
}
|
||||
|
@ -201,4 +344,3 @@ void Row4x4Major2RowMajor(int8_t *src, int row4, int8_t *dst, int row, int cow)
|
|||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -24,25 +24,37 @@
|
|||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
void MatMulInt8(const int8_t *a, const int8_t *b, int *c, const int row8, const int col8, const int deep,
|
||||
const int a_zp, const int b_zp);
|
||||
void MatMulInt8_16x4(const int8_t *a, const int8_t *b, int *dst, int row_4, int col_4, int deep_16,
|
||||
const int *input_sum, const int *bias);
|
||||
void MatMulInt8_16x4_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_16,
|
||||
size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift,
|
||||
int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini, int32_t maxi,
|
||||
bool per_channel);
|
||||
void RowMajor2Row8MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col);
|
||||
void RowMajor2Row4x16MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col);
|
||||
void RowMajor2Col8MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col);
|
||||
void RowMajor2Row16x4MajorInt8(void *src_ptr, void *dst_ptr, int row, int col);
|
||||
|
||||
#ifdef ENABLE_ARM64
|
||||
void MatMulInt8_8x8_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4,
|
||||
size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift,
|
||||
int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini, int32_t maxi,
|
||||
bool per_channel);
|
||||
void RowMajor2Row8x4MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col);
|
||||
void RowMajor2Row4x8MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col);
|
||||
|
||||
void RowMajor2Row4x16Major(int8_t *src, int row, int col, int8_t *dst, int col_16);
|
||||
void RowMajor2Col16x4Major(int8_t *src, int row, int col, int8_t *dst, int row_16);
|
||||
void RowMajor2Asums(int8_t *a, int row, int col, int b_zp, int *dst);
|
||||
void RowMajor2Bbias(int8_t *b, int row, int col, int a_zp, int b_zp, int *bias, int *dst);
|
||||
void Row4x4Major2RowMajor(int8_t *src, int row4, int8_t *dst, int row, int cow);
|
||||
void CalcInputSums(int8_t *input, int row, int col, int weight_zp, int *dst, DataOrder order);
|
||||
void CalcWeightBiasSums(int8_t *weight, int row, int col, int input_zp, int weight_zp, int *bias, int *dst,
|
||||
DataOrder order);
|
||||
void MatmulInt8(const int8_t *a, const int8_t *b, int8_t *dst, const int *a_sums, const int *bias, int act_min,
|
||||
int act_max, int out_zp, int multiplier, int left_shift, int right_shift, int row, int col, int deep16,
|
||||
int stride);
|
||||
|
||||
// bias = bias + depth * a_zp * b_zp - a_zp * b_sums
|
||||
#ifdef ENABLE_ARM64
|
||||
void MatmulInt8Neon64(const int8_t *a, const int8_t *b, int8_t *dst, int row4, int col4, int deep16, const int *a_sums,
|
||||
const int *bias, int act_min, int act_max, int out_zp, int multiplier, int left_shift,
|
||||
int right_shift);
|
||||
int right_shift, int row, int col, int stride);
|
||||
|
||||
void MatMulR4Int8Neon64(const int8_t *a, const int8_t *b, int32_t *dst, int row4, int col4, int deep16,
|
||||
const int *input_sum, const int *bias);
|
||||
|
|
|
@ -89,8 +89,13 @@ void AvgPoolingOptInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParam
|
|||
int output_batch = pooling_param->output_batch_;
|
||||
int out_plane = output_w * output_h;
|
||||
int out_tile_count = UP_DIV(out_plane, TILE_NUM);
|
||||
int thread_num = pooling_param->thread_num_;
|
||||
int c8 = UP_DIV(channel, C8NUM);
|
||||
int thread_num = out_tile_count < pooling_param->thread_num_ ? out_tile_count : pooling_param->thread_num_;
|
||||
float input_scale = pooling_param->quant_args_[0][0].scale_;
|
||||
int input_zp = pooling_param->quant_args_[0][0].zp_;
|
||||
float output_scale = pooling_param->quant_args_[1][0].scale_;
|
||||
int output_zp = pooling_param->quant_args_[1][0].zp_;
|
||||
double real_multiplier = input_scale / output_scale;
|
||||
int c16 = channel / C16NUM;
|
||||
const int8_t out_min = INT8_MIN;
|
||||
const int8_t out_max = INT8_MAX;
|
||||
|
||||
|
@ -107,89 +112,159 @@ void AvgPoolingOptInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParam
|
|||
int in_w_index = out_w_index * stride_w - pad_w;
|
||||
int in_h_index = out_h_index * stride_h - pad_h;
|
||||
int out_plane_offset = out_batch_offset + index * channel;
|
||||
for (int j = 0; j < c8 - 1; j++) {
|
||||
int in_channel_offset = in_batch_offset + j * C8NUM;
|
||||
int out_channel_offset = out_plane_offset + j * C8NUM;
|
||||
int16_t tmp_avg1 = 0;
|
||||
int16_t tmp_avg2 = 0;
|
||||
int16_t tmp_avg3 = 0;
|
||||
int16_t tmp_avg4 = 0;
|
||||
int16_t tmp_avg5 = 0;
|
||||
int16_t tmp_avg6 = 0;
|
||||
int16_t tmp_avg7 = 0;
|
||||
int16_t tmp_avg8 = 0;
|
||||
int real_count = 0;
|
||||
for (int h = 0; h < win_h; h++) {
|
||||
for (int w = 0; w < win_w; w++) {
|
||||
if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 ||
|
||||
(in_w_index + w) >= in_w) {
|
||||
continue;
|
||||
} else {
|
||||
int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel;
|
||||
tmp_avg1 += *(input_ptr + in_offset);
|
||||
tmp_avg2 += *(input_ptr + in_offset + 1);
|
||||
tmp_avg3 += *(input_ptr + in_offset + 2);
|
||||
tmp_avg4 += *(input_ptr + in_offset + 3);
|
||||
tmp_avg5 += *(input_ptr + in_offset + 4);
|
||||
tmp_avg6 += *(input_ptr + in_offset + 5);
|
||||
tmp_avg7 += *(input_ptr + in_offset + 6);
|
||||
tmp_avg8 += *(input_ptr + in_offset + 7);
|
||||
++real_count;
|
||||
int input_stride = (in_h_index * in_w + in_w_index) * channel;
|
||||
int kw_s = MSMAX(0, -in_w_index);
|
||||
int kw_e = MSMIN(win_w, in_w - in_w_index);
|
||||
int kh_s = MSMAX(0, -in_h_index);
|
||||
int kh_e = MSMIN(win_h, in_h - in_h_index);
|
||||
int real_count = (kw_e - kw_s) * (kh_e - kh_s);
|
||||
|
||||
// 16 channels
|
||||
for (int j = 0; j < c16; j++) {
|
||||
#ifdef ENABLE_NEON
|
||||
int16x8_t tmp_avg[2];
|
||||
tmp_avg[0] = vmovq_n_s16(0);
|
||||
tmp_avg[1] = vmovq_n_s16(0);
|
||||
#else
|
||||
int16_t tmp_avg[16];
|
||||
int16_t real_out[16];
|
||||
for (int m = 0; m < C16NUM; ++m) {
|
||||
tmp_avg[m] = 0;
|
||||
}
|
||||
#endif
|
||||
int in_channel_offset = in_batch_offset + j * C16NUM;
|
||||
int out_channel_offset = out_plane_offset + j * C16NUM;
|
||||
|
||||
for (int h = kh_s; h < kh_e; h++) {
|
||||
for (int w = kw_s; w < kw_e; w++) {
|
||||
int in_offset = in_channel_offset + input_stride + (h * in_w + w) * channel;
|
||||
#ifdef ENABLE_NEON
|
||||
int8x16_t in_ptr = vld1q_s8(input_ptr + in_offset);
|
||||
int8x8_t in_data1 = vget_low_s8(in_ptr);
|
||||
int8x8_t in_data2 = vget_high_s8(in_ptr);
|
||||
int16x8_t data1 = vmovl_s8(in_data1);
|
||||
int16x8_t data2 = vmovl_s8(in_data2);
|
||||
tmp_avg[0] = vaddq_s16(tmp_avg[0], data1);
|
||||
tmp_avg[1] = vaddq_s16(tmp_avg[1], data2);
|
||||
#else
|
||||
for (int k = 0; k < C16NUM; ++k) {
|
||||
tmp_avg[k] += input_ptr[in_offset + k];
|
||||
}
|
||||
#endif
|
||||
} // win_w loop
|
||||
} // win_h loop
|
||||
int16_t tmp_out1 = round((float)tmp_avg1 / (float)real_count);
|
||||
int16_t tmp_out2 = round((float)tmp_avg2 / (float)real_count);
|
||||
int16_t tmp_out3 = round((float)tmp_avg3 / (float)real_count);
|
||||
int16_t tmp_out4 = round((float)tmp_avg4 / (float)real_count);
|
||||
int16_t tmp_out5 = round((float)tmp_avg5 / (float)real_count);
|
||||
int16_t tmp_out6 = round((float)tmp_avg6 / (float)real_count);
|
||||
int16_t tmp_out7 = round((float)tmp_avg7 / (float)real_count);
|
||||
int16_t tmp_out8 = round((float)tmp_avg8 / (float)real_count);
|
||||
int16_t real_out1 = tmp_out1 < out_min ? out_min : tmp_out1;
|
||||
int16_t real_out2 = tmp_out2 < out_min ? out_min : tmp_out2;
|
||||
int16_t real_out3 = tmp_out3 < out_min ? out_min : tmp_out3;
|
||||
int16_t real_out4 = tmp_out4 < out_min ? out_min : tmp_out4;
|
||||
int16_t real_out5 = tmp_out5 < out_min ? out_min : tmp_out5;
|
||||
int16_t real_out6 = tmp_out6 < out_min ? out_min : tmp_out6;
|
||||
int16_t real_out7 = tmp_out7 < out_min ? out_min : tmp_out7;
|
||||
int16_t real_out8 = tmp_out8 < out_min ? out_min : tmp_out8;
|
||||
real_out1 = real_out1 > out_max ? out_max : real_out1;
|
||||
real_out2 = real_out2 > out_max ? out_max : real_out2;
|
||||
real_out3 = real_out3 > out_max ? out_max : real_out3;
|
||||
real_out4 = real_out4 > out_max ? out_max : real_out4;
|
||||
real_out5 = real_out5 > out_max ? out_max : real_out5;
|
||||
real_out6 = real_out6 > out_max ? out_max : real_out6;
|
||||
real_out7 = real_out7 > out_max ? out_max : real_out7;
|
||||
real_out8 = real_out8 > out_max ? out_max : real_out8;
|
||||
*(output_ptr + out_channel_offset) = (int8_t)real_out1;
|
||||
*(output_ptr + out_channel_offset + 1) = (int8_t)real_out2;
|
||||
*(output_ptr + out_channel_offset + 2) = (int8_t)real_out3;
|
||||
*(output_ptr + out_channel_offset + 3) = (int8_t)real_out4;
|
||||
*(output_ptr + out_channel_offset + 4) = (int8_t)real_out5;
|
||||
*(output_ptr + out_channel_offset + 5) = (int8_t)real_out6;
|
||||
*(output_ptr + out_channel_offset + 6) = (int8_t)real_out7;
|
||||
*(output_ptr + out_channel_offset + 7) = (int8_t)real_out8;
|
||||
} // in_channel loop
|
||||
int channel_s = (c8 - 1) * C8NUM;
|
||||
for (int k = channel_s; k < channel; k++) {
|
||||
int in_channel_offset = in_batch_offset + k;
|
||||
int out_channel_offset = out_plane_offset + k;
|
||||
#ifdef ENABLE_NEON
|
||||
int16_t tmp_data[8];
|
||||
int16_t tmp_out[8];
|
||||
int16_t tmp_data1[8];
|
||||
int16_t tmp_out1[8];
|
||||
for (int l = 0; l < C8NUM; l++) {
|
||||
tmp_data[l] = tmp_avg[0][l] + 128 * real_count;
|
||||
tmp_out[l] = (tmp_data[l] + real_count / 2) / real_count;
|
||||
tmp_out[l] -= 128;
|
||||
tmp_out[l] = round((tmp_out[l] - input_zp) * real_multiplier) + output_zp;
|
||||
}
|
||||
for (int l = 0; l < C8NUM; l++) {
|
||||
tmp_data1[l] = tmp_avg[1][l] + 128 * real_count;
|
||||
tmp_out1[l] = (tmp_data1[l] + real_count / 2) / real_count;
|
||||
tmp_out1[l] -= 128;
|
||||
tmp_out1[l] = round((tmp_out1[l] - input_zp) * real_multiplier) + output_zp;
|
||||
}
|
||||
int8x8_t real_out[2];
|
||||
int8x8_t output_min = vdup_n_s8(out_min);
|
||||
int8x8_t output_max = vdup_n_s8(out_max);
|
||||
real_out[0] = vqmovn_s16(vld1q_s16(tmp_out));
|
||||
real_out[0] = vmin_s8(real_out[0], output_max);
|
||||
real_out[0] = vmax_s8(real_out[0], output_min);
|
||||
vst1_s8(output_ptr + out_channel_offset, real_out[0]);
|
||||
real_out[1] = vqmovn_s16(vld1q_s16(tmp_out1));
|
||||
real_out[1] = vmin_s8(real_out[1], output_max);
|
||||
real_out[1] = vmax_s8(real_out[1], output_min);
|
||||
vst1_s8(output_ptr + out_channel_offset + 8, real_out[1]);
|
||||
#else
|
||||
for (int l = 0; l < C16NUM; ++l) {
|
||||
int16_t tmp_data = tmp_avg[l] + 128 * real_count;
|
||||
real_out[l] = (tmp_data + real_count / 2) / real_count - 128;
|
||||
real_out[l] = (int8_t)(round((real_out[l] - input_zp) * real_multiplier) + output_zp);
|
||||
real_out[l] = real_out[l] < out_min ? out_min : real_out[l];
|
||||
real_out[l] = real_out[l] > out_max ? out_max : real_out[l];
|
||||
*(output_ptr + out_channel_offset + l) = (int8_t)real_out[l];
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
// 8 channels
|
||||
int channel_16_res = channel - c16 * C16NUM;
|
||||
int c8 = channel_16_res / C8NUM;
|
||||
int in_c16_offset = in_batch_offset + c16 * C16NUM;
|
||||
int out_c16_offset = out_plane_offset + c16 * C16NUM;
|
||||
for (int j = 0; j < c8; j++) {
|
||||
#ifdef ENABLE_NEON
|
||||
int16x8_t tmp_avg = vmovq_n_s16(0);
|
||||
#else
|
||||
int16_t tmp_avg[8] = {0, 0, 0, 0, 0, 0, 0, 0};
|
||||
int16_t real_out[8];
|
||||
#endif
|
||||
int in_channel_offset = in_c16_offset + j * C8NUM;
|
||||
int out_channel_offset = out_c16_offset + j * C8NUM;
|
||||
for (int h = kh_s; h < kh_e; h++) {
|
||||
for (int w = kw_s; w < kw_e; w++) {
|
||||
int in_offset = in_channel_offset + input_stride + (h * in_w + w) * channel;
|
||||
#ifdef ENABLE_NEON
|
||||
int8x8_t in_ptr = vld1_s8(input_ptr + in_offset);
|
||||
int16x8_t data = vmovl_s8(in_ptr);
|
||||
tmp_avg = vaddq_s16(tmp_avg, data);
|
||||
#else
|
||||
for (int k = 0; k < C8NUM; ++k) {
|
||||
tmp_avg[k] += input_ptr[in_offset + k];
|
||||
}
|
||||
#endif
|
||||
} // win_w loop
|
||||
} // win_h loop
|
||||
#ifdef ENABLE_NEON
|
||||
int16_t tmp_data[8];
|
||||
int16_t tmp_out[8];
|
||||
for (int l = 0; l < C8NUM; l++) {
|
||||
tmp_data[l] = tmp_avg[l] + 128 * real_count;
|
||||
tmp_out[l] = (tmp_data[l] + real_count / 2) / real_count;
|
||||
tmp_out[l] -= 128;
|
||||
tmp_out[l] = round((tmp_out[l] - input_zp) * real_multiplier) + output_zp;
|
||||
}
|
||||
int8x8_t real_out;
|
||||
int8x8_t output_min = vdup_n_s8(out_min);
|
||||
int8x8_t output_max = vdup_n_s8(out_max);
|
||||
real_out = vqmovn_s16(vld1q_s16(tmp_out));
|
||||
real_out = vmin_s8(real_out, output_max);
|
||||
real_out = vmax_s8(real_out, output_min);
|
||||
vst1_s8(output_ptr + out_channel_offset, real_out);
|
||||
#else
|
||||
for (int l = 0; l < C8NUM; ++l) {
|
||||
int16_t tmp_data = tmp_avg[l] + 128 * real_count;
|
||||
real_out[l] = (tmp_data + real_count / 2) / real_count - 128;
|
||||
real_out[l] = (int8_t)(round((real_out[l] - input_zp) * real_multiplier) + output_zp);
|
||||
real_out[l] = real_out[l] < out_min ? out_min : real_out[l];
|
||||
real_out[l] = real_out[l] > out_max ? out_max : real_out[l];
|
||||
*(output_ptr + out_channel_offset + l) = (int8_t)real_out[l];
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
// less than 8 channel
|
||||
int channel_8_res = channel_16_res - c8 * C8NUM;
|
||||
int in_c8_offset = in_c16_offset + c8 * C8NUM;
|
||||
int out_c8_offset = out_c16_offset + c8 * C8NUM;
|
||||
for (int k = 0; k < channel_8_res; k++) {
|
||||
int in_channel_offset = in_c8_offset + k;
|
||||
int out_channel_offset = out_c8_offset + k;
|
||||
int16_t tmp_avg = 0;
|
||||
int real_count = 0;
|
||||
for (int h = 0; h < win_h; h++) {
|
||||
for (int w = 0; w < win_w; w++) {
|
||||
if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 ||
|
||||
(in_w_index + w) >= in_w) {
|
||||
continue;
|
||||
} else {
|
||||
int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel;
|
||||
tmp_avg += *(input_ptr + in_offset);
|
||||
++real_count;
|
||||
}
|
||||
for (int h = kh_s; h < kh_e; h++) {
|
||||
for (int w = kw_s; w < kw_e; w++) {
|
||||
int in_offset = in_channel_offset + input_stride + (h * in_w + w) * channel;
|
||||
tmp_avg += input_ptr[in_offset];
|
||||
} // win_w loop
|
||||
} // win_h loop
|
||||
int16_t tmp_out = round((float)tmp_avg / (float)real_count);
|
||||
int16_t tmp_out = round((float)tmp_avg / (float)real_count + 128) - 128;
|
||||
tmp_out = (int8_t)(round((tmp_out - input_zp) * real_multiplier) + output_zp);
|
||||
int16_t real_out = tmp_out < out_min ? out_min : tmp_out;
|
||||
real_out = real_out > out_max ? out_max : real_out;
|
||||
*(output_ptr + out_channel_offset) = (int8_t)real_out;
|
||||
|
@ -249,6 +324,109 @@ void MaxPoolingInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParamete
|
|||
} // out_batch loop
|
||||
}
|
||||
|
||||
void MaxPoolingWithQuantInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParameter *pooling_param,
|
||||
int task_id) {
|
||||
int stride_w = pooling_param->stride_w_;
|
||||
int stride_h = pooling_param->stride_h_;
|
||||
int pad_w = pooling_param->pad_l_;
|
||||
int pad_h = pooling_param->pad_u_;
|
||||
int win_w = pooling_param->window_w_;
|
||||
int win_h = pooling_param->window_h_;
|
||||
int channel = pooling_param->input_channel_;
|
||||
int in_w = pooling_param->input_w_;
|
||||
int in_h = pooling_param->input_h_;
|
||||
int output_w = pooling_param->output_w_;
|
||||
int output_h = pooling_param->output_h_;
|
||||
int output_batch = pooling_param->output_batch_;
|
||||
int out_plane = output_w * output_h;
|
||||
int out_tile_count = UP_DIV(out_plane, TILE_NUM);
|
||||
int thread_num = out_tile_count < pooling_param->thread_num_ ? out_tile_count : pooling_param->thread_num_;
|
||||
int c16 = UP_DIV(channel, 16);
|
||||
// input channel is equal to output channel
|
||||
float input_scale = pooling_param->quant_args_[0][0].scale_;
|
||||
int input_zp = pooling_param->quant_args_[0][0].zp_;
|
||||
float output_scale = pooling_param->quant_args_[1][0].scale_;
|
||||
int output_zp = pooling_param->quant_args_[1][0].zp_;
|
||||
double real_multiplier = input_scale / output_scale;
|
||||
|
||||
for (int batch = 0; batch < output_batch; batch++) {
|
||||
int in_batch_offset = batch * in_h * in_w * channel;
|
||||
int out_batch_offset = batch * output_h * output_w * channel;
|
||||
for (int thread_id = task_id; thread_id < out_tile_count; thread_id += thread_num) {
|
||||
int cal_start_index = thread_id * TILE_NUM;
|
||||
int real_cal_num = (out_plane - cal_start_index) > TILE_NUM ? TILE_NUM : (out_plane - cal_start_index);
|
||||
for (int i = 0; i < real_cal_num; i++) {
|
||||
int index = cal_start_index + i;
|
||||
int out_w_index = index % output_w;
|
||||
int out_h_index = index / output_w;
|
||||
int in_w_index = out_w_index * stride_w - pad_w;
|
||||
int in_h_index = out_h_index * stride_h - pad_h;
|
||||
int out_plane_offset = out_batch_offset + index * channel;
|
||||
for (int j = 0; j < c16 - 1; j++) {
|
||||
int in_channel_offset = in_batch_offset + j * 16;
|
||||
int out_channel_offset = out_plane_offset + j * 16;
|
||||
#ifdef ENABLE_NEON
|
||||
int8x16_t tmp_max = vdupq_n_s8(INT8_MIN);
|
||||
#else
|
||||
int8_t tmp_max[16];
|
||||
for (int m = 0; m < C16NUM; ++m) {
|
||||
tmp_max[m] = INT8_MIN;
|
||||
}
|
||||
#endif
|
||||
for (int h = 0; h < win_h; h++) {
|
||||
for (int w = 0; w < win_w; w++) {
|
||||
if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 ||
|
||||
(in_w_index + w) >= in_w) {
|
||||
continue;
|
||||
} else {
|
||||
int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel;
|
||||
#ifdef ENABLE_NEON
|
||||
tmp_max = vmaxq_s8(tmp_max, vld1q_s8(input_ptr + in_offset));
|
||||
#else
|
||||
for (int k = 0; k < C16NUM; ++k) {
|
||||
tmp_max[k] = MaxInt8(tmp_max[k], *(input_ptr + in_offset + k));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
} // win_w loop
|
||||
} // win_h loop
|
||||
#ifdef ENABLE_NEON
|
||||
for (int l = 0; l < C16NUM; ++l) {
|
||||
tmp_max[l] = (int8_t)(round((tmp_max[l] - input_zp) * real_multiplier) + output_zp);
|
||||
}
|
||||
vst1q_s8(output_ptr + out_channel_offset, tmp_max);
|
||||
#else
|
||||
for (int l = 0; l < C16NUM; ++l) {
|
||||
*(output_ptr + out_channel_offset + l) =
|
||||
(int8_t)(round((tmp_max[l] - input_zp) * real_multiplier) + output_zp);
|
||||
}
|
||||
#endif
|
||||
} // in_channel loop
|
||||
|
||||
// res channel
|
||||
int channel_s = (c16 - 1) * 16;
|
||||
for (int k = channel_s; k < channel; k++) {
|
||||
int in_channel_offset = in_batch_offset + k;
|
||||
int out_channel_offset = out_plane_offset + k;
|
||||
int8_t tmp_max = INT8_MIN;
|
||||
for (int h = 0; h < win_h; h++) {
|
||||
for (int w = 0; w < win_w; w++) {
|
||||
if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 ||
|
||||
(in_w_index + w) >= in_w) {
|
||||
continue;
|
||||
} else {
|
||||
int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel;
|
||||
tmp_max = MaxInt8(tmp_max, *(input_ptr + in_offset));
|
||||
}
|
||||
} // win_w loop
|
||||
} // win_h loop
|
||||
*(output_ptr + out_channel_offset) = (int8_t)(round((tmp_max - input_zp) * real_multiplier) + output_zp);
|
||||
} // channel_res loop
|
||||
} // out_plane loop
|
||||
} // out_batch loop
|
||||
}
|
||||
}
|
||||
|
||||
void MaxPoolingOptInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParameter *pooling_param, int task_id) {
|
||||
int stride_w = pooling_param->stride_w_;
|
||||
int stride_h = pooling_param->stride_h_;
|
||||
|
@ -264,7 +442,7 @@ void MaxPoolingOptInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParam
|
|||
int output_batch = pooling_param->output_batch_;
|
||||
int out_plane = output_w * output_h;
|
||||
int out_tile_count = UP_DIV(out_plane, TILE_NUM);
|
||||
int thread_num = pooling_param->thread_num_;
|
||||
int thread_num = out_tile_count < pooling_param->thread_num_ ? out_tile_count : pooling_param->thread_num_;
|
||||
int c16 = UP_DIV(channel, 16);
|
||||
|
||||
for (int batch = 0; batch < output_batch; batch++) {
|
||||
|
@ -286,22 +464,10 @@ void MaxPoolingOptInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParam
|
|||
#ifdef ENABLE_NEON
|
||||
int8x16_t tmp_max = vdupq_n_s8(INT8_MIN);
|
||||
#else
|
||||
int8_t tmp_max1 = INT8_MIN;
|
||||
int8_t tmp_max2 = INT8_MIN;
|
||||
int8_t tmp_max3 = INT8_MIN;
|
||||
int8_t tmp_max4 = INT8_MIN;
|
||||
int8_t tmp_max5 = INT8_MIN;
|
||||
int8_t tmp_max6 = INT8_MIN;
|
||||
int8_t tmp_max7 = INT8_MIN;
|
||||
int8_t tmp_max8 = INT8_MIN;
|
||||
int8_t tmp_max9 = INT8_MIN;
|
||||
int8_t tmp_max10 = INT8_MIN;
|
||||
int8_t tmp_max11 = INT8_MIN;
|
||||
int8_t tmp_max12 = INT8_MIN;
|
||||
int8_t tmp_max13 = INT8_MIN;
|
||||
int8_t tmp_max14 = INT8_MIN;
|
||||
int8_t tmp_max15 = INT8_MIN;
|
||||
int8_t tmp_max16 = INT8_MIN;
|
||||
int8_t tmp_max[16];
|
||||
for (int m = 0; m < C16NUM; ++m) {
|
||||
tmp_max[m] = INT8_MIN;
|
||||
}
|
||||
#endif
|
||||
for (int h = 0; h < win_h; h++) {
|
||||
for (int w = 0; w < win_w; w++) {
|
||||
|
@ -313,22 +479,9 @@ void MaxPoolingOptInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParam
|
|||
#ifdef ENABLE_NEON
|
||||
tmp_max = vmaxq_s8(tmp_max, vld1q_s8(input_ptr + in_offset));
|
||||
#else
|
||||
tmp_max1 = MaxInt8(tmp_max1, *(input_ptr + in_offset));
|
||||
tmp_max2 = MaxInt8(tmp_max2, *(input_ptr + in_offset + 1));
|
||||
tmp_max3 = MaxInt8(tmp_max3, *(input_ptr + in_offset + 2));
|
||||
tmp_max4 = MaxInt8(tmp_max4, *(input_ptr + in_offset + 3));
|
||||
tmp_max5 = MaxInt8(tmp_max5, *(input_ptr + in_offset + 4));
|
||||
tmp_max6 = MaxInt8(tmp_max6, *(input_ptr + in_offset + 5));
|
||||
tmp_max7 = MaxInt8(tmp_max7, *(input_ptr + in_offset + 6));
|
||||
tmp_max8 = MaxInt8(tmp_max8, *(input_ptr + in_offset + 7));
|
||||
tmp_max9 = MaxInt8(tmp_max9, *(input_ptr + in_offset + 8));
|
||||
tmp_max10 = MaxInt8(tmp_max10, *(input_ptr + in_offset + 9));
|
||||
tmp_max11 = MaxInt8(tmp_max11, *(input_ptr + in_offset + 10));
|
||||
tmp_max12 = MaxInt8(tmp_max12, *(input_ptr + in_offset + 11));
|
||||
tmp_max13 = MaxInt8(tmp_max13, *(input_ptr + in_offset + 12));
|
||||
tmp_max14 = MaxInt8(tmp_max14, *(input_ptr + in_offset + 13));
|
||||
tmp_max15 = MaxInt8(tmp_max15, *(input_ptr + in_offset + 14));
|
||||
tmp_max16 = MaxInt8(tmp_max16, *(input_ptr + in_offset + 15));
|
||||
for (int k = 0; k < C16NUM; ++k) {
|
||||
tmp_max[k] = MaxInt8(tmp_max[k], *(input_ptr + in_offset + k));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
} // win_w loop
|
||||
|
@ -336,24 +489,13 @@ void MaxPoolingOptInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParam
|
|||
#ifdef ENABLE_NEON
|
||||
vst1q_s8(output_ptr + out_channel_offset, tmp_max);
|
||||
#else
|
||||
*(output_ptr + out_channel_offset) = tmp_max1;
|
||||
*(output_ptr + out_channel_offset + 1) = tmp_max2;
|
||||
*(output_ptr + out_channel_offset + 2) = tmp_max3;
|
||||
*(output_ptr + out_channel_offset + 3) = tmp_max4;
|
||||
*(output_ptr + out_channel_offset + 4) = tmp_max5;
|
||||
*(output_ptr + out_channel_offset + 5) = tmp_max6;
|
||||
*(output_ptr + out_channel_offset + 6) = tmp_max7;
|
||||
*(output_ptr + out_channel_offset + 7) = tmp_max8;
|
||||
*(output_ptr + out_channel_offset + 8) = tmp_max9;
|
||||
*(output_ptr + out_channel_offset + 9) = tmp_max10;
|
||||
*(output_ptr + out_channel_offset + 10) = tmp_max11;
|
||||
*(output_ptr + out_channel_offset + 11) = tmp_max12;
|
||||
*(output_ptr + out_channel_offset + 12) = tmp_max13;
|
||||
*(output_ptr + out_channel_offset + 13) = tmp_max14;
|
||||
*(output_ptr + out_channel_offset + 14) = tmp_max15;
|
||||
*(output_ptr + out_channel_offset + 15) = tmp_max16;
|
||||
for (int l = 0; l < C16NUM; ++l) {
|
||||
*(output_ptr + out_channel_offset + l) = tmp_max[l];
|
||||
}
|
||||
#endif
|
||||
} // in_channel loop
|
||||
|
||||
// res channel
|
||||
int channel_s = (c16 - 1) * 16;
|
||||
for (int k = channel_s; k < channel; k++) {
|
||||
int in_channel_offset = in_batch_offset + k;
|
||||
|
|
|
@ -32,6 +32,8 @@ void AvgPoolingOptInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParam
|
|||
|
||||
void MaxPoolingInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParameter *pooling_param, int task_id);
|
||||
|
||||
void MaxPoolingWithQuantInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParameter *pooling_param, int task_id);
|
||||
|
||||
void MaxPoolingOptInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParameter *pooling_param, int task_id);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
|
|
@ -86,6 +86,62 @@ int ResizeBilinearInt8(const int8_t *input_data, int8_t *output_data, const int
|
|||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int ResizeBilinearInt8WithFloatWeight(const int8_t *input_data, int8_t *output_data, const int *input_shape,
|
||||
const int *output_shape, const bool align_corners, QuantArg *quant_in,
|
||||
QuantArg *quant_out, const QuantMulArg *mul_arg, int tid, int thread_num) {
|
||||
if (input_data == NULL || output_data == NULL || input_shape == NULL || output_shape == NULL) {
|
||||
return NNACL_NULL_PTR;
|
||||
}
|
||||
|
||||
int32_t in_n = input_shape[0];
|
||||
int32_t in_h = input_shape[1];
|
||||
int32_t in_w = input_shape[2];
|
||||
int32_t in_c = input_shape[3];
|
||||
|
||||
int32_t new_height = output_shape[1];
|
||||
int32_t new_width = output_shape[2];
|
||||
float height_scale, width_scale;
|
||||
ComputeScaleFloat(in_h, new_height, align_corners, &height_scale);
|
||||
ComputeScaleFloat(in_w, new_width, align_corners, &width_scale);
|
||||
|
||||
int n, h, w, c;
|
||||
for (n = 0; n < in_n; n++) {
|
||||
for (h = tid; h < new_height; h += thread_num) {
|
||||
float actual_y;
|
||||
int bottom, top;
|
||||
float bottom_weight, top_weight;
|
||||
ComputeInterpolationArgsFloatWeight(h, height_scale, in_h, &actual_y, &bottom, &bottom_weight, &top, &top_weight);
|
||||
for (w = 0; w < new_width; w++) {
|
||||
float actual_x;
|
||||
int left, right;
|
||||
float left_weight, right_weight;
|
||||
ComputeInterpolationArgsFloatWeight(w, width_scale, in_w, &actual_x, &left, &left_weight, &right,
|
||||
&right_weight);
|
||||
for (c = 0; c < in_c; c++) {
|
||||
float bottom_left_value = ((int32_t)input_data[offset(input_shape, n, bottom, left, c)] - quant_in->zp_) *
|
||||
bottom_weight * left_weight;
|
||||
float bottom_right_value = ((int32_t)input_data[offset(input_shape, n, bottom, right, c)] - quant_in->zp_) *
|
||||
bottom_weight * right_weight;
|
||||
float top_left_value =
|
||||
((int32_t)input_data[offset(input_shape, n, top, left, c)] - quant_in->zp_) * top_weight * left_weight;
|
||||
float top_right_value =
|
||||
((int32_t)input_data[offset(input_shape, n, top, right, c)] - quant_in->zp_) * top_weight * right_weight;
|
||||
float interp_value = bottom_left_value + bottom_right_value + top_left_value + top_right_value;
|
||||
|
||||
const int out_interp_value = MultiplyByQuantizedMultiplier((int32_t)interp_value, mul_arg->multiplier_,
|
||||
mul_arg->left_shift_, mul_arg->right_shift_) +
|
||||
quant_out->zp_;
|
||||
int8_t out_value;
|
||||
out_value = out_interp_value > INT8_MAX ? INT8_MAX : out_interp_value;
|
||||
out_value = out_value < INT8_MIN ? INT8_MIN : out_value;
|
||||
output_data[offset(output_shape, n, h, w, c)] = out_value;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int ResizeNearestNeighborInt8Simple(const int8_t *input_data, int8_t *output_data, const int *input_shape,
|
||||
const int *output_shape, const bool align_corners, int tid, int thread_num) {
|
||||
int batch, y, x, c;
|
||||
|
@ -133,6 +189,22 @@ void ComputeInterpolationArgs(const int32_t pos, const int32_t scale, const int3
|
|||
*scaled_high_weight = *scaled_pos - (1 << 10) * (*low);
|
||||
}
|
||||
|
||||
void ComputeScaleFloat(const int32_t in_value, const int32_t out_value, const bool align_corners, float *scale) {
|
||||
*scale = (float)in_value / out_value;
|
||||
if (align_corners && out_value > 1) {
|
||||
*scale = (float)(in_value - 1) / (out_value - 1);
|
||||
}
|
||||
}
|
||||
|
||||
void ComputeInterpolationArgsFloatWeight(const int32_t pos, const float scale, const int32_t size, float *actual_pos,
|
||||
int32_t *low, float *low_weight, int32_t *high, float *high_weight) {
|
||||
*actual_pos = pos * scale;
|
||||
*low = *actual_pos > 0 ? floor(*actual_pos) : 0;
|
||||
*low_weight = 1.0 - (*actual_pos - *low);
|
||||
*high = *low + 1 < size ? *low + 1 : size - 1;
|
||||
*high_weight = *actual_pos - (*low);
|
||||
}
|
||||
|
||||
void ComputeNearestNeighborInt(const int32_t pos, const int in_size, const int32_t new_size, const bool align_corners,
|
||||
int32_t *nearest) {
|
||||
if (new_size == 0) {
|
||||
|
|
|
@ -31,6 +31,20 @@ int ResizeBilinearInt8(const int8_t *input_data, int8_t *output_data, const int
|
|||
const bool align_corners, QuantArg *quant_in, QuantArg *quant_out, const QuantMulArg *mul_arg,
|
||||
int tid, int thread_num);
|
||||
|
||||
int ResizeBilinearInt8WithFloatWeight(const int8_t *input_data, int8_t *output_data, const int *input_shape,
|
||||
const int *output_shape, const bool align_corners, QuantArg *quant_in,
|
||||
QuantArg *quant_out, const QuantMulArg *mul_arg, int tid, int thread_num);
|
||||
|
||||
void ComputeScale(const int32_t in_value, const int32_t out_value, const bool align_corners, int32_t *scale);
|
||||
|
||||
void ComputeInterpolationArgs(const int32_t pos, const int32_t scale, const int32_t size, int32_t *scaled_pos,
|
||||
int32_t *low, int32_t *scaled_low_weight, int32_t *high, int32_t *scaled_high_weight);
|
||||
|
||||
void ComputeScaleFloat(const int32_t in_value, const int32_t out_value, const bool align_corners, float *scale);
|
||||
|
||||
void ComputeInterpolationArgsFloatWeight(const int32_t pos, const float scale, const int32_t size, float *actual_pos,
|
||||
int32_t *low, float *low_weight, int32_t *high, float *high_weight);
|
||||
|
||||
int ResizeNearestNeighborInt8Simple(const int8_t *input_data, int8_t *output_data, const int *input_shape,
|
||||
const int *output_shape, const bool align_corners, int tid, int thread_num);
|
||||
|
||||
|
@ -38,11 +52,6 @@ int ResizeNearestNeighborInt8(const int8_t *input_data, int8_t *output_data, con
|
|||
const int *output_shape, const bool align_corners, const QuantMulArg *multiplier,
|
||||
QuantArg *quant_in, QuantArg *quant_out, int tid, int thread_num);
|
||||
|
||||
void ComputeScale(const int32_t in_value, const int32_t out_value, const bool align_corners, int32_t *scale);
|
||||
|
||||
void ComputeInterpolationArgs(const int32_t pos, const int32_t scale, const int32_t size, int32_t *scaled_pos,
|
||||
int32_t *low, int32_t *scaled_low_weight, int32_t *high, int32_t *scaled_high_weight);
|
||||
|
||||
void ComputeNearestNeighborInt(const int32_t pos, const int in_size, const int32_t new_size, const bool align_corners,
|
||||
int32_t *nearest);
|
||||
#ifdef __cplusplus
|
||||
|
|
|
@ -58,7 +58,7 @@ int SoftmaxInt8(const int8_t *input_ptr, int8_t *output_ptr, int count, int *exp
|
|||
int axis_offset = outter_offset + i * inner_size;
|
||||
for (int c = 0; c < inner_size; ++c) {
|
||||
int num_bits_over_unit;
|
||||
int shifted_scale = ComputerReciproal(sum_data[c], 12, &num_bits_over_unit);
|
||||
int shifted_scale = ComputerReciprocal(sum_data[c], 12, &num_bits_over_unit);
|
||||
int unsat_output = RoundingDivideByPOT(
|
||||
SaturatingRoundingDoublingHighMul(shifted_scale, exp_data[axis_offset + c]), num_bits_over_unit + 31 - 8);
|
||||
|
||||
|
|
|
@ -22,18 +22,28 @@
|
|||
typedef void (*MATMUL_OPT_R4_FUNC)(const int8_t *a, const int8_t *b, int *dst, int row_4, int col_4, int deep_16,
|
||||
const int *input_sum, const int *bias);
|
||||
|
||||
typedef void (*MATMUL_OPT_R_FUNC)(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4,
|
||||
size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift,
|
||||
int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini,
|
||||
int32_t maxi, bool per_channel);
|
||||
|
||||
typedef void (*MAT_TRANS_FUNC)(void *dst, void *a, int row, int col);
|
||||
|
||||
typedef enum ActType { ActType_No, ActType_Relu, ActType_Relu6 } ActType;
|
||||
typedef enum OutType { OutType_C8 = 0, OutType_Nhwc = 1, OutType_TileC8 = 2 } OutType;
|
||||
|
||||
typedef struct MatMulParameter {
|
||||
OpParameter op_parameter_;
|
||||
int row_;
|
||||
int col_;
|
||||
int row_4_;
|
||||
int row_8_;
|
||||
int row_12_;
|
||||
int row_16_;
|
||||
int col_4_;
|
||||
int col_8_;
|
||||
int deep_;
|
||||
int deep_4_;
|
||||
int deep_16_;
|
||||
bool has_bias_;
|
||||
int batch;
|
||||
bool a_transpose_; /* false : row-major */
|
||||
|
|
|
@ -23,8 +23,8 @@
|
|||
|
||||
#define C4NUM 4
|
||||
#define C8NUM 8
|
||||
#define C12NUM 12
|
||||
#define C16NUM 16
|
||||
#define BLOCK 4
|
||||
#define TILE_NUM 8
|
||||
|
||||
#define MSMIN(x, y) ((x) < (y) ? (x) : (y))
|
||||
|
@ -55,10 +55,17 @@ typedef enum LiteDataType {
|
|||
kDataTypeInt8,
|
||||
} LiteDataType;
|
||||
|
||||
typedef enum DataOrder {
|
||||
RowMajor,
|
||||
ColMajor,
|
||||
} DataOrder;
|
||||
|
||||
typedef struct OpParameter {
|
||||
char name_[100];
|
||||
int type_;
|
||||
int thread_num_;
|
||||
} OpParameter;
|
||||
|
||||
typedef enum ActType { ActType_No, ActType_Relu, ActType_Relu6, ActType_Prelu } ActType;
|
||||
|
||||
#endif // MINDSPORE_LITE_NNACL_OP_BASE_H_
|
||||
|
|
|
@ -15,6 +15,8 @@
|
|||
*/
|
||||
|
||||
#include <stdlib.h>
|
||||
#include <stdbool.h>
|
||||
#include "nnacl/op_base.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
|
@ -27,6 +29,10 @@ extern void IndirectGemmInt8_24x4_dp(int8_t *dst, const int8_t *src, const int8_
|
|||
|
||||
extern void MatMulOptR4Int8Neon64(const int8_t *a, const int8_t *b, int *dst, int row4, int col4, int deep16,
|
||||
const int *input_sum, const int *bias);
|
||||
extern void MatmulInt8DpNeon64(const int8_t *a, const int8_t *b, int8_t *dst, int row8, int col8, int deep4,
|
||||
const int *a_sums, const int *bias, int act_min, int act_max, int out_zp, int multiplier,
|
||||
int left_shift, int right_shift, int row, int col, int stride);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
@ -36,7 +42,7 @@ void IndirectGemmInt8_optimize_handler(int8_t *dst, const int8_t *src, const int
|
|||
size_t ksize, size_t ic4, size_t output_channel, size_t offset,
|
||||
const int32_t *input_sum, size_t act_min, size_t act_max, size_t out_zp,
|
||||
int32_t *out_multiplier, int32_t *shift_before, int32_t *shift_after,
|
||||
size_t asymmetric, size_t per_channel) {
|
||||
size_t asymmetric, size_t per_channel) {
|
||||
return IndirectGemmInt8_24x4_dp(dst, src, weight, bias, ksize, ic4, output_channel, offset, input_sum, act_min,
|
||||
act_max, out_zp, out_multiplier, shift_before, shift_after, asymmetric, per_channel);
|
||||
}
|
||||
|
@ -45,4 +51,12 @@ void MatMulR4Int8_optimize_handler(const int8_t *a, const int8_t *b, int *dst, i
|
|||
const int *input_sum, const int *bias) {
|
||||
return MatMulOptR4Int8Neon64(a, b, dst, row4, col4, deep16, input_sum, bias);
|
||||
}
|
||||
|
||||
void MatMulRInt8_optimize_handler(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4,
|
||||
size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift,
|
||||
int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini,
|
||||
int32_t maxi, bool per_channel) {
|
||||
return MatmulInt8DpNeon64(a, b, dst, UP_ROUND(row, 8), UP_ROUND(col, 8), deep_4, input_sum, bias, mini, maxi,
|
||||
output_zp, multiplier[0], left_shift[0], right_shift[0], row, col, col);
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -62,6 +62,10 @@ void PackWeightFp32(float *weight_data, ConvParameter *conv_param, float *packed
|
|||
} // kernel plane loop
|
||||
}
|
||||
|
||||
void PackWeightKHWToHWKFp32(const void *src, void *dst, int plane, int channel) {
|
||||
return PackNCHWToNHWCFp32(src, dst, 1, plane, channel);
|
||||
}
|
||||
|
||||
void PackWeightInt8(int8_t *weight_data, ConvParameter *conv_param, int8_t *packed_weight, int32_t *weight_sum) {
|
||||
// original weight format : ohwi
|
||||
int kernel_h = conv_param->kernel_h_;
|
||||
|
@ -153,22 +157,24 @@ void PackWeightInt8Opt(int8_t *weight_data, ConvParameter *conv_param, int8_t *p
|
|||
} // kernel plane loop
|
||||
}
|
||||
|
||||
void Conv1x1InputPackFp32(const float *src, float *dst, ConvParameter *conv_param) {
|
||||
void Conv1x1InputPack(const void *src_ptr, void *dst_ptr, ConvParameter *conv_param, int data_size) {
|
||||
/* support nhwc */
|
||||
char *src = (char *)src_ptr;
|
||||
char *dst = (char *)dst_ptr;
|
||||
for (int dst_h = 0; dst_h < conv_param->output_h_; dst_h++) {
|
||||
int src_h = dst_h * conv_param->stride_h_ - conv_param->pad_h_;
|
||||
int src_h = dst_h * conv_param->stride_h_ - conv_param->pad_u_;
|
||||
if (src_h < 0 || src_h >= conv_param->input_h_) {
|
||||
continue;
|
||||
}
|
||||
const float *src_h_ptr = src + src_h * conv_param->input_w_ * conv_param->input_channel_;
|
||||
float *dst_h_ptr = dst + dst_h * conv_param->output_w_ * conv_param->input_channel_;
|
||||
const char *src_h_ptr = src + src_h * conv_param->input_w_ * conv_param->input_channel_ * data_size;
|
||||
char *dst_h_ptr = dst + dst_h * conv_param->output_w_ * conv_param->input_channel_ * data_size;
|
||||
for (int dst_w = 0; dst_w < conv_param->output_w_; dst_w++) {
|
||||
int src_w = dst_w * conv_param->stride_w_ - conv_param->pad_w_;
|
||||
int src_w = dst_w * conv_param->stride_w_ - conv_param->pad_l_;
|
||||
if (src_w < 0 || src_w >= conv_param->input_w_) {
|
||||
continue;
|
||||
}
|
||||
memcpy(dst_h_ptr + dst_w * conv_param->input_channel_, src_h_ptr + src_w * conv_param->input_channel_,
|
||||
conv_param->input_channel_ * sizeof(float));
|
||||
memcpy(dst_h_ptr + dst_w * conv_param->input_channel_ * data_size,
|
||||
src_h_ptr + src_w * conv_param->input_channel_ * data_size, conv_param->input_channel_ * data_size);
|
||||
}
|
||||
}
|
||||
return;
|
||||
|
@ -188,6 +194,139 @@ void Pack1x1WeightFp32(const float *weight_data, float *packed_weight, ConvParam
|
|||
return;
|
||||
}
|
||||
|
||||
void PackInputSum16x4PerLayer(const int8_t *src, int32_t *dst, int32_t filter_zp, size_t row4, size_t col16) {
|
||||
/* optimize normal -> same layout */
|
||||
#ifdef ENABLE_ARM64
|
||||
asm volatile(
|
||||
"mov x10, %[src] \n"
|
||||
"mov x11, %[dst] \n"
|
||||
"dup v15.4s, %w[filter_zp] \n"
|
||||
|
||||
"mov x0, #0 \n"
|
||||
"1: \n"
|
||||
"cmp x0, %[row4] \n"
|
||||
"beq 4f \n"
|
||||
"add x0, x0, #4\n"
|
||||
"dup v10.4s, wzr \n"
|
||||
"mov x2, #0 \n"
|
||||
|
||||
"2: \n"
|
||||
"cmp x2, %[col16] \n"
|
||||
"beq 3f \n"
|
||||
"add x2, x2, #16\n"
|
||||
|
||||
"ld1 {v0.16b}, [x10], #16\n"
|
||||
"ld1 {v1.16b}, [x10], #16\n"
|
||||
"ld1 {v2.16b}, [x10], #16\n"
|
||||
"ld1 {v3.16b}, [x10], #16\n"
|
||||
|
||||
"saddlp v4.8h, v0.16b \n"
|
||||
"saddlp v5.8h, v1.16b \n"
|
||||
"saddlp v6.8h, v2.16b \n"
|
||||
"saddlp v7.8h, v3.16b \n"
|
||||
|
||||
"saddlp v0.4S, v4.8h \n"
|
||||
"saddlp v1.4S, v5.8h \n"
|
||||
"saddlp v2.4S, v6.8h \n"
|
||||
"saddlp v3.4S, v7.8h \n"
|
||||
|
||||
"addv s4, v0.4S \n"
|
||||
"addv s5, v1.4S \n"
|
||||
"addv s6, v2.4S \n"
|
||||
"addv s7, v3.4S \n"
|
||||
|
||||
"mov v0.s[0], v4.s[0] \n"
|
||||
"mov v0.s[1], v5.s[0] \n"
|
||||
"mov v0.s[2], v6.s[0] \n"
|
||||
"mov v0.s[3], v7.s[0] \n"
|
||||
|
||||
"add v10.4s, v10.4s, v0.4s \n"
|
||||
"b 2b\n"
|
||||
|
||||
"3: \n"
|
||||
"mul v10.4s, v10.4s, v15.4s \n"
|
||||
"st1 {v10.4s}, [x11], #16 \n"
|
||||
"beq 1b \n"
|
||||
|
||||
"4: \n"
|
||||
|
||||
:
|
||||
: [ dst ] "r"(dst), [ src ] "r"(src), [ row4 ] "r"(row4), [ col16 ] "r"(col16), [ filter_zp ] "r"(filter_zp)
|
||||
: "x0", "x1", "x2", "x3", "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v10", "v15");
|
||||
#else
|
||||
for (int r = 0; r < row4; r++) {
|
||||
int32_t tmp_value = 0;
|
||||
for (int c = 0; c < col16; c++) {
|
||||
int r4div = r / C4NUM, r4mod = r % C4NUM, c16div = c / C16NUM, c16mod = c % C16NUM;
|
||||
int src_index = r4div * C4NUM * col16 + c16div * C16NUM * C4NUM + r4mod * C16NUM + c16mod;
|
||||
tmp_value += src[src_index];
|
||||
}
|
||||
dst[r] = tmp_value * filter_zp;
|
||||
}
|
||||
#endif
|
||||
return;
|
||||
}
|
||||
|
||||
void PackInputSum16x4Int8(const int8_t *input_value, int32_t *input_sum, size_t input_channel, size_t output_channel,
|
||||
size_t plane_size, ConvParameter *conv_param) {
|
||||
size_t hw4 = UP_ROUND(plane_size, C4NUM);
|
||||
size_t ic16 = UP_ROUND(input_channel, C16NUM);
|
||||
if (conv_param->conv_quant_arg_.filter_arg_num_ == 1) {
|
||||
PackInputSum16x4PerLayer(input_value, input_sum, conv_param->conv_quant_arg_.filter_quant_args_[0].zp_, hw4, ic16);
|
||||
} else {
|
||||
for (int ri = 0; ri < plane_size; ri++) {
|
||||
int ri4div = ri / C4NUM, ri4mod = ri % C4NUM;
|
||||
for (int ci = 0; ci < output_channel; ci++) {
|
||||
int32_t tmp_sum_value = 0;
|
||||
int ci4div = ci / C4NUM, ci4mod = ci % C4NUM;
|
||||
int32_t filter_zp = conv_param->conv_quant_arg_.filter_quant_args_[ci].zp_;
|
||||
for (int di = 0; di < input_channel; di++) {
|
||||
size_t di16div = di / C16NUM, di16mod = di % C16NUM;
|
||||
int src_index = ri4div * C4NUM * ic16 + di16div * C16NUM * C4NUM + ri4mod * C16NUM + di16mod;
|
||||
tmp_sum_value += input_value[src_index];
|
||||
}
|
||||
int dst_index = ci4div * C4NUM * hw4 + ri * C4NUM + ci4mod;
|
||||
input_sum[dst_index] = tmp_sum_value * filter_zp;
|
||||
}
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void PackInputSum8x4Int8(const int8_t *input_value, int32_t *input_sum, size_t input_channel, size_t output_channel,
|
||||
size_t plane_size, ConvParameter *conv_param) {
|
||||
size_t hw8 = UP_ROUND(plane_size, C8NUM);
|
||||
size_t ic4 = UP_ROUND(input_channel, C4NUM);
|
||||
if (conv_param->conv_quant_arg_.filter_arg_num_ == 1) {
|
||||
for (int r = 0; r < hw8; r++) {
|
||||
int32_t tmp_value = 0;
|
||||
for (int c = 0; c < ic4; c++) {
|
||||
int r8div = r / C8NUM, r8mod = r % C8NUM, c4div = c / C4NUM, c4mod = c % C4NUM;
|
||||
int src_index = r8div * C8NUM * ic4 + c4div * C8NUM * C4NUM + r8mod * C4NUM + c4mod;
|
||||
tmp_value += input_value[src_index];
|
||||
}
|
||||
input_sum[r] = tmp_value * conv_param->conv_quant_arg_.filter_quant_args_[0].zp_;
|
||||
}
|
||||
} else {
|
||||
for (int ri = 0; ri < plane_size; ri++) {
|
||||
int ri8div = ri / C8NUM, ri8mod = ri % C8NUM;
|
||||
for (int ci = 0; ci < output_channel; ci++) {
|
||||
int32_t tmp_sum_value = 0;
|
||||
int ci8div = ci / C8NUM, ci8mod = ci % C8NUM;
|
||||
int32_t filter_zp = conv_param->conv_quant_arg_.filter_quant_args_[ci].zp_;
|
||||
for (int di = 0; di < input_channel; di++) {
|
||||
size_t di4div = di / C4NUM, di4mod = di % C4NUM;
|
||||
int src_index = ri8div * C8NUM * ic4 + di4div * C8NUM * C4NUM + ri8mod * C4NUM + di4mod;
|
||||
tmp_sum_value += input_value[src_index];
|
||||
}
|
||||
int dst_index = ci8div * C8NUM * hw8 + ri * C8NUM + ci8mod;
|
||||
input_sum[dst_index] = tmp_sum_value * filter_zp;
|
||||
}
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void Im2ColPackUnitFp32(const float *input_data, ConvParameter *conv_param, float *packed_input, int real_cal_num,
|
||||
int block_index) {
|
||||
// input format : nhwc
|
||||
|
@ -195,8 +334,8 @@ void Im2ColPackUnitFp32(const float *input_data, ConvParameter *conv_param, floa
|
|||
int kernel_w = conv_param->kernel_w_;
|
||||
int stride_h = conv_param->stride_h_;
|
||||
int stride_w = conv_param->stride_w_;
|
||||
int pad_h = conv_param->pad_h_;
|
||||
int pad_w = conv_param->pad_w_;
|
||||
int pad_h = conv_param->pad_u_;
|
||||
int pad_w = conv_param->pad_l_;
|
||||
int dilation_h = conv_param->dilation_h_;
|
||||
int dilation_w = conv_param->dilation_w_;
|
||||
int in_channel = conv_param->input_channel_;
|
||||
|
@ -204,23 +343,21 @@ void Im2ColPackUnitFp32(const float *input_data, ConvParameter *conv_param, floa
|
|||
int in_w = conv_param->input_w_;
|
||||
int out_w = conv_param->output_w_;
|
||||
int ic4 = UP_DIV(in_channel, C4NUM);
|
||||
memset(packed_input, 0, kernel_h * kernel_w * ic4 * C4NUM * TILE_NUM * sizeof(float));
|
||||
|
||||
for (int i = 0; i < real_cal_num; i++) {
|
||||
int block_start = block_index + i;
|
||||
int input_h = block_start / out_w * stride_h - pad_h;
|
||||
int input_w = block_start % out_w * stride_w - pad_w;
|
||||
for (int j = 0; j < kernel_h; j++) {
|
||||
int input_y = input_h + j * dilation_h;
|
||||
if (input_y < 0 || input_y >= in_h) {
|
||||
continue;
|
||||
}
|
||||
int input_y_stride = input_y * in_w * ic4 * C4NUM;
|
||||
for (int n = 0; n < kernel_w; n++) {
|
||||
int input_x = input_w + n * dilation_w;
|
||||
if (input_x < 0 || input_x >= in_w) {
|
||||
continue;
|
||||
}
|
||||
int input_x_stride = input_y_stride + input_x * ic4 * C4NUM;
|
||||
int input_stride = input_h * in_w * ic4 * C4NUM + input_w * ic4 * C4NUM;
|
||||
int kh_s = MSMAX(0, UP_DIV(-input_h, dilation_h));
|
||||
int kh_e = MSMIN(kernel_h, UP_DIV(in_h - input_h, dilation_h));
|
||||
int kw_s = MSMAX(0, UP_DIV(-input_w, dilation_w));
|
||||
int kw_e = MSMIN(kernel_w, UP_DIV(in_w - input_w, dilation_w));
|
||||
for (int j = kh_s; j < kh_e; j++) {
|
||||
int input_y_stride = j * dilation_h * in_w * ic4 * C4NUM + input_stride;
|
||||
for (int n = kw_s; n < kw_e; n++) {
|
||||
int input_x_stride = input_y_stride + n * dilation_w * ic4 * C4NUM;
|
||||
int input_plane_offset = (j * kernel_w + n) * C8NUM * C4NUM * ic4 + i * C4NUM;
|
||||
for (int m = 0; m < ic4; m++) {
|
||||
int channel_block_stride = input_x_stride + m * C4NUM;
|
||||
|
@ -247,8 +384,8 @@ void Im2ColPackUnitInt8(const int8_t *input_data, int8_t *packed_input, int real
|
|||
int kernel_w = conv_param->kernel_w_;
|
||||
int stride_h = conv_param->stride_h_;
|
||||
int stride_w = conv_param->stride_w_;
|
||||
int pad_h = conv_param->pad_h_;
|
||||
int pad_w = conv_param->pad_w_;
|
||||
int pad_h = conv_param->pad_u_;
|
||||
int pad_w = conv_param->pad_l_;
|
||||
int dilation_h = conv_param->dilation_h_;
|
||||
int dilation_w = conv_param->dilation_w_;
|
||||
int in_channel = conv_param->input_channel_;
|
||||
|
@ -318,8 +455,8 @@ void Im2ColPackUnitInt8Opt(const int8_t *input_data, int8_t *packed_input, int r
|
|||
int kernel_w = conv_param->kernel_w_;
|
||||
int stride_h = conv_param->stride_h_;
|
||||
int stride_w = conv_param->stride_w_;
|
||||
int pad_h = conv_param->pad_h_;
|
||||
int pad_w = conv_param->pad_w_;
|
||||
int pad_h = conv_param->pad_u_;
|
||||
int pad_w = conv_param->pad_l_;
|
||||
int dilation_h = conv_param->dilation_h_;
|
||||
int dilation_w = conv_param->dilation_w_;
|
||||
int in_channel = conv_param->input_channel_;
|
||||
|
@ -350,9 +487,7 @@ void Im2ColPackUnitInt8Opt(const int8_t *input_data, int8_t *packed_input, int r
|
|||
for (int m = 0; m < ic4; m++) {
|
||||
int channel_block_stride = input_x_stride + m * C4NUM;
|
||||
int channel_block_offset = input_plane_offset + m * tile_num * C4NUM;
|
||||
for (int k = 0; k < C4NUM; k++) {
|
||||
(packed_input + channel_block_offset)[k] = (input_data + channel_block_stride)[k];
|
||||
}
|
||||
memcpy(packed_input + channel_block_offset, input_data + channel_block_stride, 4);
|
||||
} // channel_block loop
|
||||
} // kernel_w loop
|
||||
} // kernel_h loop
|
||||
|
@ -660,6 +795,8 @@ void PackNC4HW4ToNHWCRelu6Fp32(const void *src, void *dst, int batch, int plane,
|
|||
}
|
||||
}
|
||||
|
||||
void PackNC4HW4ToNHWCPreluFp32(const void *src, void *dst, const void *slope, int batch, int plane, int channel) {}
|
||||
|
||||
void PackNC4HW4ToNCHWFp32(const void *src, void *dst, int batch, int plane, int channel) {
|
||||
int c4 = UP_DIV(channel, C4NUM);
|
||||
for (int b = 0; b < batch; b++) {
|
||||
|
|
|
@ -35,10 +35,18 @@ void Im2ColPackUnitInt8(const int8_t *input_data, int8_t *packed_input, int real
|
|||
void Im2ColPackUnitInt8Opt(const int8_t *input_data, int8_t *packed_input, int real_cal_num, int block_index,
|
||||
int32_t *input_sum, ConvParameter *conv_param);
|
||||
|
||||
void Conv1x1InputPackFp32(const float *src, float *dst, ConvParameter *conv_param);
|
||||
void PackInputSum16x4PerLayer(const int8_t *src, int32_t *dst, int32_t filter_zp, size_t row4, size_t col16);
|
||||
|
||||
void Conv1x1InputPack(const void *src_ptr, void *dst_ptr, ConvParameter *conv_param, int data_size);
|
||||
|
||||
void Pack1x1WeightFp32(const float *weight_data, float *packed_weight, ConvParameter *conv_param);
|
||||
|
||||
void PackInputSum16x4Int8(const int8_t *input_value, int32_t *input_sum, size_t input_channel, size_t output_channel,
|
||||
size_t plane_size, ConvParameter *conv_param);
|
||||
|
||||
void PackInputSum8x4Int8(const int8_t *input_value, int32_t *input_sum, size_t input_channel, size_t output_channel,
|
||||
size_t plane_size, ConvParameter *conv_param);
|
||||
|
||||
void MatrixPack(const float *src, float *dst, int row, int ic4, int stride);
|
||||
|
||||
void PackInputToC8Int8(const int8_t *input_data, int16_t *packed_input, ConvParameter *conv_param);
|
||||
|
@ -46,6 +54,8 @@ void PackInputToC8Int8(const int8_t *input_data, int16_t *packed_input, ConvPara
|
|||
void PackWeightFp32(float *weight_data, ConvParameter *conv_param, float *packed_weight, int oc_block,
|
||||
int oc_block_num);
|
||||
|
||||
void PackWeightKHWToHWKFp32(const void *src, void *dst, int plane, int channel);
|
||||
|
||||
void PackWeightInt8(int8_t *weight_data, ConvParameter *conv_param, int8_t *packed_weight, int32_t *weight_sum);
|
||||
|
||||
void PackWeightInt8Opt(int8_t *weight_data, ConvParameter *conv_param, int8_t *packed_weight, int32_t *weight_sum);
|
||||
|
@ -76,6 +86,8 @@ void PackNC4HW4ToNHWCReluFp32(const void *src, void *dst, int batch, int plane,
|
|||
|
||||
void PackNC4HW4ToNHWCRelu6Fp32(const void *src, void *dst, int batch, int plane, int channel);
|
||||
|
||||
void PackNC4HW4ToNHWCPreluFp32(const void *src, void *dst, const void *slope, int batch, int plane, int channel);
|
||||
|
||||
void PackNC4HW4ToNCHWFp32(const void *src, void *dst, int batch, int plane, int channel);
|
||||
|
||||
void PackNHWCToC8HWN8Fp32(const void *src, void *dst, int batch, int plane, int channel);
|
||||
|
|
|
@ -19,14 +19,16 @@
|
|||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/quantization/quantize.h"
|
||||
|
||||
typedef enum PoolMode { PoolMode_No, PoolMode_MaxPool, PoolMode_AvgPool } PoolMode;
|
||||
|
||||
typedef enum RoundMode { RoundMode_No, RoundMode_Ceil, RoundMode_Floor } RoundMode;
|
||||
|
||||
typedef struct PoolingParameter {
|
||||
OpParameter op_parameter_;
|
||||
PoolMode pool_mode_;
|
||||
RoundMode round_mode_;
|
||||
ActType act_type_;
|
||||
QuantArg **quant_args_;
|
||||
bool global_;
|
||||
bool max_pooling_;
|
||||
bool avg_pooling_;
|
||||
bool round_ceil_;
|
||||
bool round_floor_;
|
||||
int window_w_;
|
||||
int window_h_;
|
||||
int input_w_;
|
||||
|
@ -44,6 +46,8 @@ typedef struct PoolingParameter {
|
|||
int stride_w_;
|
||||
int stride_h_;
|
||||
int thread_num_;
|
||||
bool global_;
|
||||
bool quantize_;
|
||||
} PoolingParameter;
|
||||
|
||||
#endif // MINDSPORE_LITE_NNACL_POOLING_PARAMETER_H_
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue