sync lite to r0.7

This commit is contained in:
xuanyue 2020-08-28 22:11:31 +08:00
parent b5393e6628
commit 0ce8708dee
716 changed files with 21079 additions and 10992 deletions

View File

@ -16,20 +16,20 @@
@title mindspore_build
SET BASEPATH=%CD%
IF NOT EXIST %BASEPATH%/build (
IF NOT EXIST "%BASEPATH%/build" (
md "build"
)
cd %BASEPATH%/build
cd "%BASEPATH%/build"
set BUILD_PATH=%CD%
IF NOT EXIST %BUILD_PATH%/mindspore (
IF NOT EXIST "%BUILD_PATH%/mindspore" (
md "mindspore"
)
cd %CD%/mindspore
cd "%CD%/mindspore"
IF "%2%" == "lite" (
IF "%1%" == "lite" (
call :gene_gtest
call :run_cmake
IF errorlevel 1 (
@ -47,14 +47,17 @@ IF "%2%" == "lite" (
)
cd %BUILD_PATH%/mindspore
IF "%1%" == "" (
cmake --build . -- -j6
IF "%2%" == "" (
cmake --build . --target package -- -j6
) ELSE (
cmake --build . -- -j%1%
cmake --build . --target package -- -j%2%
)
IF errorlevel 1 (
echo "build fail."
goto run_fail
) ELSE (
cd "%BASEPATH%/output"
rd /s /q _CPack_Packages
)
) ELSE (
cmake -DCMAKE_BUILD_TYPE=Release -DENABLE_CPU=ON -DENABLE_MINDDATA=ON -DUSE_GLOG=ON ^
@ -75,40 +78,40 @@ IF "%2%" == "lite" (
)
)
cd %BASEPATH%
cd "%BASEPATH%"
goto run_eof
:run_cmake
cd %BUILD_PATH%/mindspore
cd "%BUILD_PATH%/mindspore"
cmake -DBUILD_DEVICE=on -DBUILD_CONVERTER=on -DPLATFORM_ARM64=off -DSUPPORT_TRAIN=off ^
-DCMAKE_BUILD_TYPE=Release -DSUPPORT_GPU=off -DBUILD_MINDDATA=off -DOFFLINE_COMPILE=off ^
-G "CodeBlocks - MinGW Makefiles" %BASEPATH%/mindspore/lite
-G "CodeBlocks - MinGW Makefiles" "%BASEPATH%/mindspore/lite"
GOTO:EOF
:gene_gtest
cd %BASEPATH%/third_party
cd "%BASEPATH%/third_party"
IF EXIST googletest rd /s /q googletest
git submodule update --init --recursive googletest
cd %BUILD_PATH%/mindspore
cd "%BUILD_PATH%/mindspore"
GOTO:EOF
:gene_protobuf
SET PROTOC=%BASEPATH%/build/mindspore/_deps/protobuf-src/_build/protoc
SET PROTOC="%BASEPATH%/build/mindspore/_deps/protobuf-src/_build/protoc"
SET PROTO_SRC_DIR=%BASEPATH%/mindspore/lite/tools/converter/parser/caffe
SET PROTO_SRC_DIR="%BASEPATH%/mindspore/lite/tools/converter/parser/caffe"
cd %PROTO_SRC_DIR%
%PROTOC% *.proto --proto_path=%PROTO_SRC_DIR% --cpp_out=%PROTO_SRC_DIR%
SET PROTO_SRC_DIR=%BASEPATH%/mindspore/lite/tools/converter/parser/onnx
SET PROTO_SRC_DIR="%BASEPATH%/mindspore/lite/tools/converter/parser/onnx"
cd %PROTO_SRC_DIR%
%PROTOC% *.proto --proto_path=%PROTO_SRC_DIR% --cpp_out=%PROTO_SRC_DIR%
cd %BUILD_PATH%/mindspore
GOTO:EOF
:gene_flatbuffer
SET FLATC=%BASEPATH%/build/mindspore/_deps/flatbuffers-src/_build/flatc
SET FLAT_DIR=%BASEPATH%/mindspore/lite/schema
SET FLATC="%BASEPATH%/build/mindspore/_deps/flatbuffers-src/_build/flatc"
SET FLAT_DIR="%BASEPATH%/mindspore/lite/schema"
cd %FLAT_DIR%
IF EXIST inner rd /s /q inner
md inner
@ -116,14 +119,14 @@ GOTO:EOF
%FLATC% -c -b *.fbs
%FLATC% -c -b --reflect-types --gen-mutable --reflect-names --gen-object-api -o %FLAT_DIR%/inner *.fbs
SET FLAT_DIR=%BASEPATH%/mindspore/lite/tools/converter/parser/tflite
SET FLAT_DIR="%BASEPATH%/mindspore/lite/tools/converter/parser/tflite"
cd %FLAT_DIR%
%FLATC% -c -b --reflect-types --gen-mutable --reflect-names --gen-object-api -o %FLAT_DIR% *.fbs
cd %BUILD_PATH%/mindspore
cd "%BUILD_PATH%/mindspore"
GOTO:EOF
:run_fail
cd %BASEPATH%
cd "%BASEPATH%"
set errorlevel=1
:run_eof

View File

@ -393,7 +393,7 @@ build_mindspore()
CMAKE_VERBOSE="--verbose"
fi
cmake --build . --target package ${CMAKE_VERBOSE} -j$THREAD_NUM
echo "success to build mindspore project!"
echo "success building mindspore project!"
}
checkndk() {
@ -618,10 +618,12 @@ build_lite()
if [[ "${COMPILE_RET}" -ne 0 ]]; then
echo "---------------- mindspore lite: build failed ----------------"
exit 1
else
mv ${BASEPATH}/output/tmp/*.tar.gz* ${BASEPATH}/output/
rm -rf ${BASEPATH}/output/tmp/
echo "---------------- mindspore lite: build success ----------------"
exit 0
fi
}

View File

@ -1,12 +1,18 @@
include(CMakePackageConfigHelpers)
set(LIB_DIR ${MAIN_DIR}/lib)
set(INC_DIR ${MAIN_DIR}/include)
set(TURBO_DIR ${MAIN_DIR}/third_party/libjpeg-turbo)
set(OPENCV_DIR ${MAIN_DIR}/third_party/opencv)
set(PROTOBF_DIR ${MAIN_DIR}/third_party/protobuf)
set(FLATBF_DIR ${MAIN_DIR}/third_party/flatbuffers)
set(LIB_DIR ${MAIN_DIR}-${COMPONENT_NAME}/lib)
set(INC_DIR ${MAIN_DIR}-${COMPONENT_NAME}/include)
set(TURBO_DIR ${MAIN_DIR}-${COMPONENT_NAME}/third_party/libjpeg-turbo)
set(OPENCV_DIR ${MAIN_DIR}-${COMPONENT_NAME}/third_party/opencv)
set(PROTOBF_DIR ${MAIN_DIR}-${COMPONENT_NAME}/third_party/protobuf)
set(FLATBF_DIR ${MAIN_DIR}-${COMPONENT_NAME}/third_party/flatbuffers)
set(LIB_DIR_RUN_X86 ${MAIN_DIR}-${RUN_X86_COMPONENT_NAME}/lib)
set(INC_DIR_RUN_X86 ${MAIN_DIR}-${RUN_X86_COMPONENT_NAME}/include)
set(TURBO_DIR_RUN_X86 ${MAIN_DIR}-${RUN_X86_COMPONENT_NAME}/third_party/libjpeg-turbo)
set(OPENCV_DIR_RUN_X86 ${MAIN_DIR}-${RUN_X86_COMPONENT_NAME}/third_party/opencv)
set(PROTOBF_DIR_RUN_X86 ${MAIN_DIR}-${RUN_X86_COMPONENT_NAME}/third_party/protobuf)
set(FLATBF_DIR_RUN_X86 ${MAIN_DIR}-${RUN_X86_COMPONENT_NAME}/third_party/flatbuffers)
if (BUILD_MINDDATA)
install(DIRECTORY ${TOP_DIR}/mindspore/ccsrc/minddata/dataset/include/ DESTINATION ${INC_DIR} COMPONENT ${COMPONENT_NAME} FILES_MATCHING PATTERN "*.h")
install(FILES ${TOP_DIR}/mindspore/lite/build/minddata/libminddata-lite.so DESTINATION ${LIB_DIR} COMPONENT ${COMPONENT_NAME})
@ -41,19 +47,40 @@ elseif (PLATFORM_ARM32)
install(DIRECTORY ${TOP_DIR}/mindspore/lite/include/ DESTINATION ${INC_DIR} COMPONENT ${COMPONENT_NAME} FILES_MATCHING PATTERN "*.h")
install(DIRECTORY ${TOP_DIR}/mindspore/lite/schema/ DESTINATION ${INC_DIR}/schema COMPONENT ${COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "inner" EXCLUDE)
install(DIRECTORY ${TOP_DIR}/third_party/flatbuffers/include DESTINATION ${FLATBF_DIR} COMPONENT ${COMPONENT_NAME})
elseif (CMAKE_SYSTEM_NAME MATCHES "Windows")
get_filename_component(CXX_DIR ${CMAKE_CXX_COMPILER} PATH)
file(GLOB LIB_LIST ${CXX_DIR}/libstdc++-6.dll ${CXX_DIR}/libwinpthread-1.dll ${CXX_DIR}/libssp-0.dll ${CXX_DIR}/libgcc_s_seh-1.dll)
install(FILES ${TOP_DIR}/build/mindspore/tools/converter/converter_lite.exe DESTINATION ${TOP_DIR}/build/mindspore/package COMPONENT ${COMPONENT_NAME})
install(FILES ${LIB_LIST} DESTINATION ${TOP_DIR}/build/mindspore/package COMPONENT ${COMPONENT_NAME})
install(FILES ${TOP_DIR}/build/mindspore/tools/converter/libconverter_parser.a DESTINATION ${TOP_DIR}/build/mindspore/package COMPONENT ${PARSER_NAME})
else ()
install(FILES ${TOP_DIR}/mindspore/lite/build/src/libmindspore-lite.so DESTINATION ${LIB_DIR} COMPONENT ${RUN_X86_COMPONENT_NAME})
install(DIRECTORY ${TOP_DIR}/mindspore/lite/include/ DESTINATION ${INC_DIR_RUN_X86} COMPONENT ${RUN_X86_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h")
install(DIRECTORY ${TOP_DIR}/mindspore/lite/schema/ DESTINATION ${INC_DIR_RUN_X86}/schema COMPONENT ${RUN_X86_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "inner" EXCLUDE)
install(FILES ${TOP_DIR}/mindspore/core/ir/dtype/type_id.h DESTINATION ${INC_DIR_RUN_X86}/ir/dtype COMPONENT ${RUN_X86_COMPONENT_NAME})
install(DIRECTORY ${TOP_DIR}/third_party/flatbuffers/include DESTINATION ${FLATBF_DIR_RUN_X86} COMPONENT ${RUN_X86_COMPONENT_NAME})
install(FILES ${TOP_DIR}/mindspore/lite/build/src/libmindspore-lite.so DESTINATION ${LIB_DIR_RUN_X86} COMPONENT ${RUN_X86_COMPONENT_NAME})
install(FILES ${TOP_DIR}/third_party/protobuf/build/lib/libprotobuf.so.19.0.0 DESTINATION ${PROTOBF_DIR}/lib RENAME libprotobuf.so.19 COMPONENT ${COMPONENT_NAME})
endif ()
set(CPACK_GENERATOR TGZ)
if (CMAKE_SYSTEM_NAME MATCHES "Windows")
set(CPACK_GENERATOR ZIP)
else ()
set(CPACK_GENERATOR TGZ)
endif ()
set(CPACK_ARCHIVE_COMPONENT_INSTALL ON)
if (PLATFORM_ARM64 OR PLATFORM_ARM32)
set(CPACK_COMPONENTS_ALL ${COMPONENT_NAME})
elseif (WIN32)
set(CPACK_COMPONENTS_ALL ${COMPONENT_NAME})
else ()
set(CPACK_COMPONENTS_ALL ${COMPONENT_NAME} ${RUN_X86_COMPONENT_NAME})
endif ()
set(CPACK_PACKAGE_FILE_NAME ${MAIN_DIR})
set(CPACK_PACKAGE_DIRECTORY ${TOP_DIR}/output/tmp)
if (WIN32)
set(CPACK_PACKAGE_DIRECTORY ${TOP_DIR}/output)
else ()
set(CPACK_PACKAGE_DIRECTORY ${TOP_DIR}/output/tmp)
endif()
set(CPACK_PACKAGE_CHECKSUM SHA256)
include(CPack)

View File

@ -5,15 +5,15 @@ if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_
message(FATAL_ERROR "GCC vesion ${CMAKE_CXX_COMPILER_VERSION} must not be less than 7.3.0")
endif ()
set(MS_VERSION_MAJOY 0)
set(MS_VERSION_MAJOR 0)
set(MS_VERSION_MINOR 7)
set(MS_VERSION_REVISION 0)
set(DIR_PREFIX mindspore-lite)
set(MS_VERSION ${MS_VERSION_MAJOY}.${MS_VERSION_MINOR}.${MS_VERSION_REVISION})
set(MS_VERSION ${MS_VERSION_MAJOR}.${MS_VERSION_MINOR}.${MS_VERSION_REVISION})
set(MAIN_DIR ${DIR_PREFIX}-${MS_VERSION})
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DMS_VERSION_MAJOY=${MS_VERSION_MAJOY} -DMS_VERSION_MINOR=${MS_VERSION_MINOR} -DMS_VERSION_REVISION=${MS_VERSION_REVISION}")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DMS_VERSION_MAJOY=${MS_VERSION_MAJOY} -DMS_VERSION_MINOR=${MS_VERSION_MINOR} -DMS_VERSION_REVISION=${MS_VERSION_REVISION}")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DMS_VERSION_MAJOR=${MS_VERSION_MAJOR} -DMS_VERSION_MINOR=${MS_VERSION_MINOR} -DMS_VERSION_REVISION=${MS_VERSION_REVISION}")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DMS_VERSION_MAJOR=${MS_VERSION_MAJOR} -DMS_VERSION_MINOR=${MS_VERSION_MINOR} -DMS_VERSION_REVISION=${MS_VERSION_REVISION}")
if (SUPPORT_GPU)
set(PROCESS_UNIT gpu)
@ -25,13 +25,16 @@ if (PLATFORM_ARM64)
set(COMPONENT_NAME runtime-arm64-${PROCESS_UNIT})
elseif (PLATFORM_ARM32)
set(COMPONENT_NAME runtime-arm32-${PROCESS_UNIT})
elseif (WIN32)
set(PARSER_NAME libconverter-parser-win-${PROCESS_UNIT})
set(COMPONENT_NAME converter-win-${PROCESS_UNIT})
else ()
set(COMPONENT_NAME convert-ubuntu)
endif()
set(RUN_X86_COMPONENT_NAME runtime-x86-${PROCESS_UNIT})
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17")
set(TOP_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../..)
string(REPLACE "/mindspore/lite" "" TOP_DIR ${CMAKE_CURRENT_SOURCE_DIR})
set(CORE_DIR ${TOP_DIR}/mindspore/core)
set(CCSRC_DIR ${TOP_DIR}/mindspore/ccsrc)
include_directories(${TOP_DIR})
@ -65,20 +68,20 @@ set(CMAKE_VERBOSE_MAKEFILE on)
add_compile_definitions(USE_ANDROID_LOG)
add_compile_definitions(NO_DLIB)
add_compile_options(-fPIC)
if (NOT PLATFORM_ARM64 AND NOT PLATFORM_ARM32)
if ("${CMAKE_BUILD_TYPE}" STREQUAL "Debug")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DDebug -g")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DDebug -g")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fvisibility=default")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fvisibility=default")
else ()
## enable for binscope for release
set(CMAKE_C_FLAGS "-fPIC -fPIE -D_FORTIFY_SOURCE=2 -O2 -Wall -Werror -fstack-protector-strong -Wno-attributes -Wno-deprecated-declarations ${CMAKE_C_FLAGS}")
set(CMAKE_CXX_FLAGS "-fPIC -fPIE -D_FORTIFY_SOURCE=2 -O2 -Wall -Werror -fstack-protector-strong -Wno-attributes -Wno-deprecated-declarations ${CMAKE_CXX_FLAGS}")
if ("${CMAKE_BUILD_TYPE}" STREQUAL "Debug")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DDebug -g")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DDebug -g")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fvisibility=default")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fvisibility=default")
else ()
## enable for binscope for release
set(CMAKE_C_FLAGS "-fPIC -fPIE -D_FORTIFY_SOURCE=2 -O2 -Wall -Werror -fstack-protector-strong -Wno-attributes -Wno-deprecated-declarations -Wno-missing-braces ${CMAKE_C_FLAGS}")
set(CMAKE_CXX_FLAGS "-fPIC -fPIE -D_FORTIFY_SOURCE=2 -O2 -Wall -Werror -fstack-protector-strong -Wno-attributes -Wno-deprecated-declarations -Wno-missing-braces -Wno-overloaded-virtual ${CMAKE_CXX_FLAGS}")
if (NOT WIN32)
set(CMAKE_SHARED_LINKER_FLAGS "-Wl,-z,relro,-z,now -Wl,-z,noexecstack ${CMAKE_SHARED_LINKER_FLAGS}")
set(CMAKE_EXE_LINKER_FLAGS "-Wl,-z,relro,-z,now -Wl,-z,noexecstack ${CMAKE_EXE_LINKER_FLAGS}")
string(REPLACE " -g " " " CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
endif ()
endif()
string(REPLACE " -g " " " CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
endif ()
if (BUILD_DEVICE)
@ -110,12 +113,11 @@ if (WIN32)
add_compile_definitions(BUILDING_DLL)
endif ()
set(ANF_SRC
${CMAKE_CURRENT_SOURCE_DIR}/../core/ir/meta_tensor.cc
set(CORE_SRC
${CORE_DIR}/ir/meta_tensor.cc
${CORE_DIR}/gvar/logging_level.cc
${CORE_DIR}/gvar/typeid_manager.cc
${CMAKE_CURRENT_SOURCE_DIR}/../core/base/base.cc
${CMAKE_CURRENT_SOURCE_DIR}/src/common/log_adapter.cc
${CORE_DIR}/base/base.cc
)
if (BUILD_CONVERTER)
if (PLATFORM_ARM64 OR PLATFORM_ARM32)
@ -163,7 +165,6 @@ if (BUILD_DEVICE)
add_compile_definitions(ENABLE_ARM32)
endif ()
if (PLATFORM_ARM64)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=armv8.2-a+dotprod+fp16")
add_compile_definitions(ENABLE_ARM64)
if (ENABLE_FP16)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=armv8.2-a+dotprod+fp16")
@ -207,4 +208,4 @@ if (BUILD_DEVICE)
endif ()
endif ()
include(${TOP_DIR}/cmake/package_lite.cmake)
include(${TOP_DIR}/cmake/package_lite.cmake)

56
mindspore/lite/README.md Normal file
View File

@ -0,0 +1,56 @@
[查看中文](./README_CN.md)
## What Is MindSpore Lite
MindSpore lite is a high-performance, lightweight open source reasoning framework that can be used to meet the needs of AI applications on mobile devices. MindSpore Lite focuses on how to deploy AI technology more effectively on devices. It has been integrated into HMS (Huawei Mobile Services) to provide inferences for applications such as image classification, object detection and OCR. MindSpore Lite will promote the development and enrichment of the AI software/hardware application ecosystem.
<img src="../../docs/MindSpore-Lite-architecture.png" alt="MindSpore Lite Architecture" width="600"/>
For more details please check out our [MindSpore Lite Architecture Guide](https://www.mindspore.cn/lite/docs/en/master/architecture.html).
### MindSpore Lite features
1. Cooperative work with MindSpore training
- Provides training, optimization, and deployment.
- The unified IR realizes the device-cloud AI application integration.
2. Lightweight
- Provides model compress, which could help to improve performance as well.
- Provides the ultra-lightweight reasoning solution MindSpore Micro to meet the deployment requirements in extreme environments such as smart watches and headphones.
3. High-performance
- The built-in high-performance kernel computing library NNACL supports multiple convolution optimization algorithms such as Slide window, im2col+gemm, winograde, etc.
- Assembly code to improve performance of kernel operators. Supports CPU, GPU, and NPU.
4. Versatility
- Supports IOS, Android.
- Supports Lite OS.
- Supports mobile device, smart screen, pad, and IOT devices.
- Supports third party models such as TFLite, CAFFE and ONNX.
## MindSpore Lite AI deployment procedure
1. Model selection and personalized training
Select a new model or use an existing model for incremental training using labeled data. When designing a model for mobile device, it is necessary to consider the model size, accuracy and calculation amount.
The MindSpore team provides a series of pre-training models used for image classification, object detection. You can use these pre-trained models in your application.
The pre-trained models provided by MindSpore include: [Image Classification](https://download.mindspore.cn/model_zoo/official/lite/) and [Object Detection](https://download.mindspore.cn/model_zoo/official/lite/). More models will be provided in the feature.
MindSpore allows you to retrain pre-trained models to perform other tasks. For example: using a pre-trained image classification model, it can be retrained to recognize new image types. See [Retraining](https://www.mindspore.cn/lite/tutorial/zh-CN/master/advanced_use/retraining_of_quantized_network.html).
2. Model converter and optimization
If you use MindSpore or a third-party model, you need to use [MindSpore Lite Model Converter Tool](https://www.mindspore.cn/lite/tutorial/zh-CN/master/use/converter_tool.html) to convert the model into MindSpore Lite model. The MindSpore Lite model converter tool provides the converter of TensorFlow Lite, Caffe, ONNX to MindSpore Lite model, fusion and quantization could be introduced during convert procedure.
MindSpore also provides a tool to convert models running on IoT devices .
3. Model deployment
This stage mainly realizes model deployment, including model management, deployment, operation and maintenance monitoring, etc.
4. Inference
Load the model and perform inference. [Inference](https://www.mindspore.cn/lite/tutorial/zh-CN/master/use/runtime.html) is the process of running input data through the model to get output.
MindSpore provides a series of pre-trained models that can be deployed on mobile device [example](#TODO).

View File

@ -0,0 +1,66 @@

[View English](./README.md)
## MindSpore Lite介绍
MindSpore Lite是MindSpore推出的端云协同的、轻量化、高性能AI推理框架用于满足越来越多的端测AI应用需求。MindSpore Lite聚焦AI技术在端侧设备上的部署和运行已经在华为HMS和智能终端的图像分类、目标识别、人脸识别、文字识别等应用中广泛使用未来MindSpore Lite将与MindSpore AI社区一起致力于丰富AI软硬件应用生态。
<img src="../../docs/MindSpore-Lite-architecture.png" alt="MindSpore Lite Architecture" width="600"/>
欲了解更多详情,请查看我们的[MindSpore Lite 总体架构](https://www.mindspore.cn/lite/docs/zh-CN/master/architecture.html)。
## MindSpore Lite技术特点
1. 端云协同提供一站式训练和推理
- 提供模型训练、模型转换优化、部署和推理端到端流程。
- 统一的IR实现端云AI应用一体化。
2. 超轻量
- 支持模型量化压缩,模型更小跑得更快。
- 提供超轻量的推理解决方案MindSpore Micro满足智能手表、耳机等极限环境下的部署要求。
3. 高性能
- 自带的高性能内核计算库NNACL支持Sliding Windows、Im2Col+GEMM、Winograd等多种卷积优化算法。
- 汇编级优化支持CPU、GPU、NPU异构调度最大化发挥硬件算力最小化推理时延和功耗。
4. 广覆盖
- 支持iOS、Android等手机操作系统。
- 支持LiteOS嵌入式操作系统。
- 支持手机、大屏、平板、IoT等各种智能设备上的AI应用。
- 支持MindSpore/TensorFlow Lite/Caffe/ONNX模型方便用户快速部署。
## MindSpore Lite AI部署流程
1. 模型选择和个性化训练
包括选择新模型或对已有模型,利用标注数据进行增量训练。面向端侧设计模型时,需要考虑模型大小、精度和计算量。
MindSpore团队提供了一系列预训练模型用于解决图像分类、目标检测等场景的学习问题。可以在您的应用程序中使用这些预训练模型对应的终端模型。
MindSpore提供的预训练模型包括[图像分类Image Classification](https://download.mindspore.cn/model_zoo/official/lite/)和[目标检测Object Detection](https://download.mindspore.cn/model_zoo/official/lite/)。后续MindSpore团队会增加更多的预置模型。
MindSpore允许您重新训练预训练模型以执行其他任务。比如使用预训练的图像分类模型可以重新训练来识别新的图像类型。参见[重训练](https://www.mindspore.cn/lite/tutorial/zh-CN/master/advanced_use/retraining_of_quantized_network.html)。
2. 模型转换/优化
如果您使用MindSpore或第三方训练的模型需要使用[MindSpore Lite模型转换工具](https://www.mindspore.cn/lite/tutorial/zh-CN/master/use/converter_tool.html)转换成MindSpore Lite模型格式。MindSpore Lite模型转换工具不仅提供了将TensorFlow Lite、Caffe、ONNX等模型格式转换为MindSpore Lite模型格式还提供了算子融合、量化等功能。
MindSpore还提供了将IoT设备上运行的模型转换成.C代码的生成工具。
经过上述两个部署,您已经得到端侧可以部署的模型。
3. 模型部署
这个阶段主要实现模型部署,包括模型管理、部署和运维监控等。
4. 模型推理
主要完成模型推理工作,即加载模型,完成模型相关的所有计算。[推理](https://www.mindspore.cn/lite/tutorial/zh-CN/master/use/runtime.html)是通过模型运行输入数据,获取预测的过程。
MindSpore提供了一系列预训练模型部署在智能终端的[样例](#TODO)。

View File

@ -28,17 +28,17 @@ namespace mindspore::lite {
class Allocator;
/// \brief CpuBindMode defined for holding bind cpu strategy argument.
enum CpuBindMode {
typedef enum {
MID_CPU = -1, /**< bind middle cpu first */
HIGHER_CPU = 1, /**< bind higher cpu first */
NO_BIND = 0 /**< no bind */
};
} CpuBindMode;
/// \brief DeviceType defined for holding user's preferred backend.
typedef enum {
DT_CPU, /**< CPU device type */
DT_GPU, /**< GPU device type */
DT_NPU /**< NPU device type */
DT_NPU /**< NPU device type, not supported yet */
} DeviceType;
/// \brief DeviceContext defined for holding DeviceType.

View File

@ -86,17 +86,34 @@ class MS_API LiteSession {
/// \return STATUS as an error code of running graph, STATUS is defined in errorcode.h.
virtual int RunGraph(const KernelCallBack &before = nullptr, const KernelCallBack &after = nullptr) = 0;
/// \brief Get output MindSpore Lite MSTensors of model.
/// \brief Get output MindSpore Lite MSTensors of model mapped by node name.
///
/// \return The map of output node name and MindSpore Lite MSTensor.
virtual std::unordered_map<std::string, std::vector<mindspore::tensor::MSTensor *>> GetOutputs() const = 0;
virtual std::unordered_map<std::string, std::vector<mindspore::tensor::MSTensor *>> GetOutputMapByNode() const = 0;
/// \brief Get output MindSpore Lite MSTensors of model by node name.
///
/// \param[in] node_name Define node name.
///
/// \return The vector of MindSpore Lite MSTensor.
virtual std::vector<tensor::MSTensor *> GetOutputsByName(const std::string &node_name) const = 0;
virtual std::vector<tensor::MSTensor *> GetOutputsByNodeName(const std::string &node_name) const = 0;
/// \brief Get output MindSpore Lite MSTensors of model mapped by tensor name.
///
/// \return The map of output tensor name and MindSpore Lite MSTensor.
virtual std::unordered_map<std::string, mindspore::tensor::MSTensor *> GetOutputMapByTensor() const = 0;
/// \brief Get name of output tensors of model compiled by this session.
///
/// \return The vector of string as output tensor names in order.
virtual std::vector<std::string> GetOutputTensorNames() const = 0;
/// \brief Get output MindSpore Lite MSTensors of model by tensor name.
///
/// \param[in] tensor_name Define tensor name.
///
/// \return Pointer of MindSpore Lite MSTensor.
virtual mindspore::tensor::MSTensor *GetOutputByTensorName(const std::string &tensor_name) const = 0;
/// \brief Resize inputs shape.
///

View File

@ -24,8 +24,17 @@ namespace lite {
/// \brief Global method to get a version string.
///
/// \return The version string of MindSpore Lite.
#ifndef MS_VERSION_MAJOR
#define MS_VERSION_MAJOR 0
#endif
#ifndef MS_VERSION_MINOR
#define MS_VERSION_MINOR 7
#endif
#ifndef MS_VERSION_REVISION
#define MS_VERSION_REVISION 0
#endif
std::string Version() {
return "MindSpore Lite " + std::to_string(MS_VERSION_MAJOY) + "." + std::to_string(MS_VERSION_MINOR) + "." +
return "MindSpore Lite " + std::to_string(MS_VERSION_MAJOR) + "." + std::to_string(MS_VERSION_MINOR) + "." +
std::to_string(MS_VERSION_REVISION);
}
} // namespace lite

View File

@ -15,10 +15,10 @@ fi
# copy arm64 so
cd ${TOP_PATH}/output/
rm -rf mindspore-lite-0.6.0
tar -zxvf mindspore-lite-0.6.0-runtime-arm64-cpu.tar.gz
rm -rf mindspore-lite-0.7.0
tar -zxvf mindspore-lite-0.7.0-runtime-arm64-cpu.tar.gz
mkdir -p ${BASE_PATH}/lib/
cp ${TOP_PATH}/output/mindspore-lite-0.6.0/lib/libmindspore-lite.so ${BASE_PATH}/lib/
cp ${TOP_PATH}/output/mindspore-lite-0.7.0-runtime-arm64-cpu/lib/libmindspore-lite.so ${BASE_PATH}/lib/
cp ${ANDROID_NDK}/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/lib/aarch64-linux-android/libc++_shared.so ${BASE_PATH}/lib/
# build jni so

View File

@ -76,8 +76,8 @@ public class LiteSession {
return tensors;
}
public Map<String, List<MSTensor>> getOutputs() {
Map<String, List<Long>> ret = this.getOutputs(this.sessionPtr);
public Map<String, List<MSTensor>> getOutputMapByNode() {
Map<String, List<Long>> ret = this.getOutputMapByNode(this.sessionPtr);
Map<String, List<MSTensor>> tensorMap = new HashMap<>();
Set<Map.Entry<String, List<Long>>> entrySet = ret.entrySet();
for (Map.Entry<String, List<Long>> entry : entrySet) {
@ -93,8 +93,8 @@ public class LiteSession {
return tensorMap;
}
public List<MSTensor> getOutputsByName(String nodeName) {
List<Long> ret = this.getOutputsByName(this.sessionPtr, nodeName);
public List<MSTensor> getOutputsByNodeName(String nodeName) {
List<Long> ret = this.getOutputsByNodeName(this.sessionPtr, nodeName);
ArrayList<MSTensor> tensors = new ArrayList<>();
for (Long msTensorAddr : ret) {
MSTensor msTensor = new MSTensor(msTensorAddr);
@ -103,6 +103,27 @@ public class LiteSession {
return tensors;
}
public Map<String, MSTensor> getOutputMapByTensor() {
Map<String, Long> ret = this.getOutputMapByTensor(this.sessionPtr);
Map<String, MSTensor> tensorMap = new HashMap<>();
Set<Map.Entry<String, Long>> entrySet = ret.entrySet();
for (Map.Entry<String, Long> entry : entrySet) {
String name = entry.getKey();
Long msTensorAddr = entry.getValue();
tensorMap.put(name, new MSTensor(msTensorAddr));
}
return tensorMap;
}
public List<String> getOutputTensorNames() {
return getOutputTensorNames(this.sessionPtr);
}
public MSTensor getOutputByTensorName(String tensorName) {
Long tensor_addr = getOutputByTensorName(this.sessionPtr, tensorName);
return new MSTensor(tensor_addr);
}
public void free() {
this.free(this.sessionPtr);
this.sessionPtr = 0;
@ -120,9 +141,15 @@ public class LiteSession {
private native List<Long> getInputsByName(long sessionPtr, String nodeName);
private native Map<String, List<Long>> getOutputs(long sessionPtr);
private native Map<String, List<Long>> getOutputMapByNode(long sessionPtr);
private native List<Long> getOutputsByName(long sessionPtr, String nodeName);
private native List<Long> getOutputsByNodeName(long sessionPtr, String nodeName);
private native Map<String, Long> getOutputMapByTensor(long sessionPtr);
private native List<String> getOutputTensorNames(long sessionPtr);
private native Long getOutputByTensorName(long sessionPtr, String tensorName);
private native void free(long sessionPtr);
}

View File

@ -16,6 +16,10 @@
package com.mindspore.lite;
import android.util.Log;
import java.nio.ByteBuffer;
public class MSTensor {
private long tensorPtr;
@ -27,7 +31,7 @@ public class MSTensor {
this.tensorPtr = tensorPtr;
}
public boolean init (int dataType, int[] shape) {
public boolean init(int dataType, int[] shape) {
this.tensorPtr = createMSTensor(dataType, shape, shape.length);
return this.tensorPtr != 0;
}
@ -48,14 +52,30 @@ public class MSTensor {
this.setDataType(this.tensorPtr, dataType);
}
public byte[] getData() {
return this.getData(this.tensorPtr);
public byte[] getByteData() {
return this.getByteData(this.tensorPtr);
}
public float[] getFloatData() {
return this.getFloatData(this.tensorPtr);
}
public int[] getIntData() {
return this.getIntData(this.tensorPtr);
}
public long[] getLongData() {
return this.getLongData(this.tensorPtr);
}
public void setData(byte[] data) {
this.setData(this.tensorPtr, data, data.length);
}
public void setData(ByteBuffer data) {
this.setByteBufferData(this.tensorPtr, data);
}
public long size() {
return this.size(this.tensorPtr);
}
@ -69,6 +89,24 @@ public class MSTensor {
this.tensorPtr = 0;
}
private float[] decodeBytes(byte[] bytes) {
if (bytes.length % 4 != 0) {
Log.e("MS_LITE", "Length of bytes should be multi of 4 ");
return null;
}
int size = bytes.length / 4;
float[] ret = new float[size];
for (int i = 0; i < size; i = i + 4) {
int accNum = 0;
accNum = accNum | (bytes[i] & 0xff) << 0;
accNum = accNum | (bytes[i + 1] & 0xff) << 8;
accNum = accNum | (bytes[i + 2] & 0xff) << 16;
accNum = accNum | (bytes[i + 3] & 0xff) << 24;
ret[i / 4] = Float.intBitsToFloat(accNum);
}
return ret;
}
private native long createMSTensor(int dataType, int[] shape, int shapeLen);
private native int[] getShape(long tensorPtr);
@ -79,10 +117,18 @@ public class MSTensor {
private native boolean setDataType(long tensorPtr, int dataType);
private native byte[] getData(long tensorPtr);
private native byte[] getByteData(long tensorPtr);
private native long[] getLongData(long tensorPtr);
private native int[] getIntData(long tensorPtr);
private native float[] getFloatData(long tensorPtr);
private native boolean setData(long tensorPtr, byte[] data, long dataLen);
private native boolean setByteBufferData(long tensorPtr, ByteBuffer buffer);
private native long size(long tensorPtr);
private native int elementsNum(long tensorPtr);

View File

@ -80,6 +80,11 @@ public class Model {
return ret;
}
public boolean loadModel(String modelPath) {
this.modelPtr = loadModelByPath(modelPath);
return this.modelPtr != 0;
}
public void free() {
this.free(this.modelPtr);
this.modelPtr = 0;
@ -87,5 +92,7 @@ public class Model {
private native long loadModel(MappedByteBuffer buffer);
private native long loadModelByPath(String modelPath);
private native void free(long modelPtr);
}

View File

@ -1,11 +1,11 @@
cmake_minimum_required(VERSION 3.14)
project (Lite-java)
set(MS_VERSION_MAJOY 0)
set(MS_VERSION_MAJOR 0)
set(MS_VERSION_MINOR 7)
set(MS_VERSION_REVISION 0)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DMS_VERSION_MAJOY=${MS_VERSION_MAJOY} -DMS_VERSION_MINOR=${MS_VERSION_MINOR} -DMS_VERSION_REVISION=${MS_VERSION_REVISION}")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DMS_VERSION_MAJOY=${MS_VERSION_MAJOY} -DMS_VERSION_MINOR=${MS_VERSION_MINOR} -DMS_VERSION_REVISION=${MS_VERSION_REVISION}")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DMS_VERSION_MAJOR=${MS_VERSION_MAJOR} -DMS_VERSION_MINOR=${MS_VERSION_MINOR} -DMS_VERSION_REVISION=${MS_VERSION_REVISION}")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DMS_VERSION_MAJOR=${MS_VERSION_MAJOR} -DMS_VERSION_MINOR=${MS_VERSION_MINOR} -DMS_VERSION_REVISION=${MS_VERSION_REVISION}")
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../../../lite/)

View File

@ -14,12 +14,11 @@
* limitations under the License.
*/
#include "common/jni_utils.h"
#include <cstring>
char *JstringToChar(JNIEnv *env, jstring jstr) {
char *rtn = NULL;
char *rtn = nullptr;
jclass clsstring = env->FindClass("java/lang/String");
jstring strencode = env->NewStringUTF("GB2312");
jmethodID mid = env->GetMethodID(clsstring, "getBytes", "(Ljava/lang/String;)[B");

View File

@ -18,6 +18,7 @@
#include <jni.h>
#include "common/ms_log.h"
#include "include/context.h"
#include "include/thread_pool_config.h"
extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_context_Context_createContext(JNIEnv *env, jobject thiz,
jint device_type,
@ -44,13 +45,13 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_context_Context_creat
}
switch (cpu_bind_mode) {
case -1:
context->cpu_bind_mode_ = mindspore::lite::MID_CPU;
context->cpu_bind_mode_ = MID_CPU;
break;
case 0:
context->cpu_bind_mode_ = mindspore::lite::NO_BIND;
context->cpu_bind_mode_ = NO_BIND;
break;
case 1:
context->cpu_bind_mode_ = mindspore::lite::HIGHER_CPU;
context->cpu_bind_mode_ = HIGHER_CPU;
break;
default:
MS_LOGE("Invalid cpu_bind_mode : %d", cpu_bind_mode);

View File

@ -14,7 +14,6 @@
* limitations under the License.
*/
#include <jni.h>
#include "common/ms_log.h"
#include "common/jni_utils.h"
@ -22,7 +21,7 @@
#include "include/errorcode.h"
extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_LiteSession_createSession(JNIEnv *env, jobject thiz,
jlong context_ptr) {
jlong context_ptr) {
auto *pointer = reinterpret_cast<void *>(context_ptr);
if (pointer == nullptr) {
MS_LOGE("Context pointer from java is nullptr");
@ -38,8 +37,8 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_LiteSession_createSes
}
extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_LiteSession_compileGraph(JNIEnv *env, jobject thiz,
jlong session_ptr,
jlong model_ptr) {
jlong session_ptr,
jlong model_ptr) {
auto *session_pointer = reinterpret_cast<void *>(session_ptr);
if (session_pointer == nullptr) {
MS_LOGE("Session pointer from java is nullptr");
@ -58,7 +57,7 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_LiteSession_compil
}
extern "C" JNIEXPORT void JNICALL Java_com_mindspore_lite_LiteSession_bindThread(JNIEnv *env, jobject thiz,
jlong session_ptr, jboolean if_bind) {
jlong session_ptr, jboolean if_bind) {
auto *pointer = reinterpret_cast<void *>(session_ptr);
if (pointer == nullptr) {
MS_LOGE("Session pointer from java is nullptr");
@ -69,7 +68,7 @@ extern "C" JNIEXPORT void JNICALL Java_com_mindspore_lite_LiteSession_bindThread
}
extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_LiteSession_runGraph(JNIEnv *env, jobject thiz,
jlong session_ptr) {
jlong session_ptr) {
auto *pointer = reinterpret_cast<void *>(session_ptr);
if (pointer == nullptr) {
MS_LOGE("Session pointer from java is nullptr");
@ -81,7 +80,7 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_LiteSession_runGra
}
extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getInputs(JNIEnv *env, jobject thiz,
jlong session_ptr) {
jlong session_ptr) {
jclass array_list = env->FindClass("java/util/ArrayList");
jmethodID array_list_construct = env->GetMethodID(array_list, "<init>", "()V");
jobject ret = env->NewObject(array_list, array_list_construct);
@ -104,8 +103,8 @@ extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getInpu
}
extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getInputsByName(JNIEnv *env, jobject thiz,
jlong session_ptr,
jstring node_name) {
jlong session_ptr,
jstring node_name) {
jclass array_list = env->FindClass("java/util/ArrayList");
jmethodID array_list_construct = env->GetMethodID(array_list, "<init>", "()V");
jobject ret = env->NewObject(array_list, array_list_construct);
@ -127,8 +126,8 @@ extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getInpu
return ret;
}
extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getOutputs(JNIEnv *env, jobject thiz,
jlong session_ptr) {
extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getOutputMapByNode(JNIEnv *env, jobject thiz,
jlong session_ptr) {
jclass hash_map_clazz = env->FindClass("java/util/HashMap");
jmethodID hash_map_construct = env->GetMethodID(hash_map_clazz, "<init>", "()V");
jobject hash_map = env->NewObject(hash_map_clazz, hash_map_construct);
@ -140,7 +139,7 @@ extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getOutp
return hash_map;
}
auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(pointer);
auto outputs = lite_session_ptr->GetOutputs();
auto outputs = lite_session_ptr->GetOutputMapByNode();
jclass long_object = env->FindClass("java/lang/Long");
jmethodID long_object_construct = env->GetMethodID(long_object, "<init>", "(J)V");
jclass array_list = env->FindClass("java/util/ArrayList");
@ -159,9 +158,9 @@ extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getOutp
return hash_map;
}
extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getOutputsByName(JNIEnv *env, jobject thiz,
jlong session_ptr,
jstring node_name) {
extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getOutputsByNodeName(JNIEnv *env, jobject thiz,
jlong session_ptr,
jstring node_name) {
jclass array_list = env->FindClass("java/util/ArrayList");
jmethodID array_list_construct = env->GetMethodID(array_list, "<init>", "()V");
jobject ret = env->NewObject(array_list, array_list_construct);
@ -175,7 +174,7 @@ extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getOutp
return ret;
}
auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(pointer);
auto inputs = lite_session_ptr->GetOutputsByName(JstringToChar(env, node_name));
auto inputs = lite_session_ptr->GetOutputsByNodeName(JstringToChar(env, node_name));
for (auto input : inputs) {
jobject tensor_addr = env->NewObject(long_object, long_object_construct, jlong(input));
env->CallBooleanMethod(ret, array_list_add, tensor_addr);
@ -183,8 +182,66 @@ extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getOutp
return ret;
}
extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getOutputMapByTensor(JNIEnv *env, jobject thiz,
jlong session_ptr) {
jclass hash_map_clazz = env->FindClass("java/util/HashMap");
jmethodID hash_map_construct = env->GetMethodID(hash_map_clazz, "<init>", "()V");
jobject hash_map = env->NewObject(hash_map_clazz, hash_map_construct);
jmethodID hash_map_put =
env->GetMethodID(hash_map_clazz, "put", "(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;");
auto *pointer = reinterpret_cast<void *>(session_ptr);
if (pointer == nullptr) {
MS_LOGE("Session pointer from java is nullptr");
return hash_map;
}
auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(pointer);
auto outputs = lite_session_ptr->GetOutputMapByTensor();
jclass long_object = env->FindClass("java/lang/Long");
jmethodID long_object_construct = env->GetMethodID(long_object, "<init>", "(J)V");
for (auto output_iter : outputs) {
auto node_name = output_iter.first;
auto ms_tensor = output_iter.second;
jobject tensor_addr = env->NewObject(long_object, long_object_construct, jlong(ms_tensor));
env->CallObjectMethod(hash_map, hash_map_put, env->NewStringUTF(node_name.c_str()), tensor_addr);
}
return hash_map;
}
extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_LiteSession_getOutputTensorNames(JNIEnv *env, jobject thiz,
jlong session_ptr) {
jclass array_list = env->FindClass("java/util/ArrayList");
jmethodID array_list_construct = env->GetMethodID(array_list, "<init>", "()V");
jobject ret = env->NewObject(array_list, array_list_construct);
jmethodID array_list_add = env->GetMethodID(array_list, "add", "(Ljava/lang/Object;)Z");
auto *pointer = reinterpret_cast<void *>(session_ptr);
if (pointer == nullptr) {
MS_LOGE("Session pointer from java is nullptr");
return ret;
}
auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(pointer);
auto output_names = lite_session_ptr->GetOutputTensorNames();
for (auto output_name : output_names) {
env->CallBooleanMethod(ret, array_list_add, env->NewStringUTF(output_name.c_str()));
}
return ret;
}
extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_LiteSession_getOutputByTensorName(JNIEnv *env, jobject thiz,
jlong session_ptr,
jstring tensor_name) {
auto *pointer = reinterpret_cast<void *>(session_ptr);
if (pointer == nullptr) {
MS_LOGE("Session pointer from java is nullptr");
return jlong(nullptr);
}
auto *lite_session_ptr = static_cast<mindspore::session::LiteSession *>(pointer);
auto output = lite_session_ptr->GetOutputByTensorName(JstringToChar(env, tensor_name));
return jlong(output);
}
extern "C" JNIEXPORT void JNICALL Java_com_mindspore_lite_LiteSession_free(JNIEnv *env, jobject thiz,
jlong session_ptr) {
jlong session_ptr) {
auto *pointer = reinterpret_cast<void *>(session_ptr);
if (pointer == nullptr) {
MS_LOGE("Session pointer from java is nullptr");

View File

@ -14,9 +14,10 @@
* limitations under the License.
*/
#include <jni.h>
#include <fstream>
#include "common/ms_log.h"
#include "common/jni_utils.h"
#include "include/model.h"
extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_Model_loadModel(JNIEnv *env, jobject thiz, jobject buffer) {
@ -38,6 +39,46 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_Model_loadModel(JNIEn
return reinterpret_cast<jlong>(model);
}
extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_Model_loadModelByPath(JNIEnv *env, jobject thiz,
jstring model_path) {
auto model_path_char = JstringToChar(env, model_path);
if (nullptr == model_path_char) {
MS_LOGE("model_path_char is nullptr");
return reinterpret_cast<jlong>(nullptr);
}
std::ifstream ifs(model_path_char);
if (!ifs.good()) {
MS_LOGE("file: %s is not exist", model_path_char);
return reinterpret_cast<jlong>(nullptr);
}
if (!ifs.is_open()) {
MS_LOGE("file: %s open failed", model_path_char);
return reinterpret_cast<jlong>(nullptr);
}
ifs.seekg(0, std::ios::end);
auto size = ifs.tellg();
std::unique_ptr<char[]> buf(new (std::nothrow) char[size]);
if (buf == nullptr) {
MS_LOGE("malloc buf failed, file: %s", model_path_char);
ifs.close();
return reinterpret_cast<jlong>(nullptr);
}
ifs.seekg(0, std::ios::beg);
ifs.read(buf.get(), size);
ifs.close();
delete[](model_path_char);
MS_LOGD("Start Loading model");
auto model = mindspore::lite::Model::Import(buf.get(), size);
if (model == nullptr) {
MS_LOGE("Import model failed");
return reinterpret_cast<jlong>(nullptr);
}
return reinterpret_cast<jlong>(model);
}
extern "C" JNIEXPORT void JNICALL Java_com_mindspore_lite_Model_free(JNIEnv *env, jobject thiz, jlong model_ptr) {
auto *pointer = reinterpret_cast<void *>(model_ptr);
if (pointer == nullptr) {

View File

@ -14,15 +14,14 @@
* limitations under the License.
*/
#include <jni.h>
#include "common/ms_log.h"
#include "include/ms_tensor.h"
#include "ir/dtype/type_id.h"
extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_MSTensor_createMSTensor(JNIEnv *env, jobject thiz,
jint data_type, jintArray shape,
jint shape_len) {
jint data_type, jintArray shape,
jint shape_len) {
jboolean is_copy = false;
jint *local_shape_arr = env->GetIntArrayElements(shape, &is_copy);
std::vector<int> local_shape(shape_len);
@ -39,7 +38,7 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_MSTensor_createMSTens
}
extern "C" JNIEXPORT jintArray JNICALL Java_com_mindspore_lite_MSTensor_getShape(JNIEnv *env, jobject thiz,
jlong tensor_ptr) {
jlong tensor_ptr) {
auto *pointer = reinterpret_cast<void *>(tensor_ptr);
if (pointer == nullptr) {
MS_LOGE("Tensor pointer from java is nullptr");
@ -59,8 +58,8 @@ extern "C" JNIEXPORT jintArray JNICALL Java_com_mindspore_lite_MSTensor_getShape
}
extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_MSTensor_setShape(JNIEnv *env, jobject thiz,
jlong tensor_ptr, jintArray shape,
jint shape_len) {
jlong tensor_ptr, jintArray shape,
jint shape_len) {
jboolean is_copy = false;
jint *local_shape_arr = env->GetIntArrayElements(shape, &is_copy);
auto *pointer = reinterpret_cast<void *>(tensor_ptr);
@ -78,7 +77,7 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_MSTensor_setShape(
}
extern "C" JNIEXPORT jint JNICALL Java_com_mindspore_lite_MSTensor_getDataType(JNIEnv *env, jobject thiz,
jlong tensor_ptr) {
jlong tensor_ptr) {
auto *pointer = reinterpret_cast<void *>(tensor_ptr);
if (pointer == nullptr) {
MS_LOGE("Tensor pointer from java is nullptr");
@ -89,7 +88,7 @@ extern "C" JNIEXPORT jint JNICALL Java_com_mindspore_lite_MSTensor_getDataType(J
}
extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_MSTensor_setDataType(JNIEnv *env, jobject thiz,
jlong tensor_ptr, jint data_type) {
jlong tensor_ptr, jint data_type) {
auto *pointer = reinterpret_cast<void *>(tensor_ptr);
if (pointer == nullptr) {
MS_LOGE("Tensor pointer from java is nullptr");
@ -100,8 +99,8 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_MSTensor_setDataTy
return ret == data_type;
}
extern "C" JNIEXPORT jbyteArray JNICALL Java_com_mindspore_lite_MSTensor_getData(JNIEnv *env, jobject thiz,
jlong tensor_ptr) {
extern "C" JNIEXPORT jbyteArray JNICALL Java_com_mindspore_lite_MSTensor_getByteData(JNIEnv *env, jobject thiz,
jlong tensor_ptr) {
auto *pointer = reinterpret_cast<void *>(tensor_ptr);
if (pointer == nullptr) {
MS_LOGE("Tensor pointer from java is nullptr");
@ -113,15 +112,99 @@ extern "C" JNIEXPORT jbyteArray JNICALL Java_com_mindspore_lite_MSTensor_getData
MS_LOGD("Tensor has no data");
return env->NewByteArray(0);
}
auto local_data_size = ms_tensor_ptr->Size();
auto ret = env->NewByteArray(local_data_size);
env->SetByteArrayRegion(ret, 0, local_data_size, local_data);
if (ms_tensor_ptr->data_type() != mindspore::kNumberTypeUInt8) {
MS_LOGE("data type is error : %d", ms_tensor_ptr->data_type());
return env->NewByteArray(0);
}
auto local_element_num = ms_tensor_ptr->ElementsNum();
auto ret = env->NewByteArray(local_element_num);
env->SetByteArrayRegion(ret, 0, local_element_num, local_data);
return ret;
}
extern "C" JNIEXPORT jlongArray JNICALL Java_com_mindspore_lite_MSTensor_getLongData(JNIEnv *env, jobject thiz,
jlong tensor_ptr) {
auto *pointer = reinterpret_cast<void *>(tensor_ptr);
if (pointer == nullptr) {
MS_LOGE("Tensor pointer from java is nullptr");
return env->NewLongArray(0);
}
auto *ms_tensor_ptr = static_cast<mindspore::tensor::MSTensor *>(pointer);
auto *local_data = static_cast<jlong *>(ms_tensor_ptr->MutableData());
if (local_data == nullptr) {
MS_LOGD("Tensor has no data");
return env->NewLongArray(0);
}
if (ms_tensor_ptr->data_type() != mindspore::kNumberTypeInt64) {
MS_LOGE("data type is error : %d", ms_tensor_ptr->data_type());
return env->NewLongArray(0);
}
auto local_element_num = ms_tensor_ptr->ElementsNum();
auto ret = env->NewLongArray(local_element_num);
env->SetLongArrayRegion(ret, 0, local_element_num, local_data);
return ret;
}
extern "C" JNIEXPORT jintArray JNICALL Java_com_mindspore_lite_MSTensor_getIntData(JNIEnv *env, jobject thiz,
jlong tensor_ptr) {
auto *pointer = reinterpret_cast<void *>(tensor_ptr);
if (pointer == nullptr) {
MS_LOGE("Tensor pointer from java is nullptr");
return env->NewIntArray(0);
}
auto *ms_tensor_ptr = static_cast<mindspore::tensor::MSTensor *>(pointer);
auto *local_data = static_cast<jint *>(ms_tensor_ptr->MutableData());
if (local_data == nullptr) {
MS_LOGD("Tensor has no data");
return env->NewIntArray(0);
}
if (ms_tensor_ptr->data_type() != mindspore::kNumberTypeInt32) {
MS_LOGE("data type is error : %d", ms_tensor_ptr->data_type());
return env->NewIntArray(0);
}
auto local_element_num = ms_tensor_ptr->ElementsNum();
auto ret = env->NewIntArray(local_element_num);
env->SetIntArrayRegion(ret, 0, local_element_num, local_data);
return ret;
}
extern "C" JNIEXPORT jfloatArray JNICALL Java_com_mindspore_lite_MSTensor_getFloatData(JNIEnv *env, jobject thiz,
jlong tensor_ptr) {
auto *pointer = reinterpret_cast<void *>(tensor_ptr);
if (pointer == nullptr) {
MS_LOGE("Tensor pointer from java is nullptr");
return env->NewFloatArray(0);
}
auto *ms_tensor_ptr = static_cast<mindspore::tensor::MSTensor *>(pointer);
auto *local_data = static_cast<jfloat *>(ms_tensor_ptr->MutableData());
if (local_data == nullptr) {
MS_LOGD("Tensor has no data");
return env->NewFloatArray(0);
}
if (ms_tensor_ptr->data_type() != mindspore::kNumberTypeFloat32) {
MS_LOGE("data type is error : %d", ms_tensor_ptr->data_type());
return env->NewFloatArray(0);
}
auto local_element_num = ms_tensor_ptr->ElementsNum();
auto ret = env->NewFloatArray(local_element_num);
env->SetFloatArrayRegion(ret, 0, local_element_num, local_data);
return ret;
}
extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_MSTensor_setData(JNIEnv *env, jobject thiz,
jlong tensor_ptr, jbyteArray data,
jlong data_len) {
jlong tensor_ptr, jbyteArray data,
jlong data_len) {
auto *pointer = reinterpret_cast<void *>(tensor_ptr);
if (pointer == nullptr) {
MS_LOGE("Tensor pointer from java is nullptr");
@ -139,6 +222,36 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_MSTensor_setData(J
return static_cast<jboolean>(true);
}
extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_MSTensor_setByteBufferData(JNIEnv *env, jobject thiz,
jlong tensor_ptr,
jobject buffer) {
jbyte *p_data = reinterpret_cast<jbyte *>(env->GetDirectBufferAddress(buffer)); // get buffer poiter
jlong data_len = env->GetDirectBufferCapacity(buffer); // get buffer capacity
if (!p_data) {
MS_LOGE("GetDirectBufferAddress return null");
return NULL;
}
jbyteArray data = env->NewByteArray(data_len); // create byte[]
env->SetByteArrayRegion(data, 0, data_len, p_data); // copy data to byte[]
auto *pointer = reinterpret_cast<void *>(tensor_ptr);
if (pointer == nullptr) {
MS_LOGE("Tensor pointer from java is nullptr");
return static_cast<jboolean>(false);
}
auto *ms_tensor_ptr = static_cast<mindspore::tensor::MSTensor *>(pointer);
if (data_len != ms_tensor_ptr->Size()) {
MS_LOGE("data_len(%ld) not equal to Size of ms_tensor(%zu)", data_len, ms_tensor_ptr->Size());
return static_cast<jboolean>(false);
}
jboolean is_copy = false;
auto *data_arr = env->GetByteArrayElements(data, &is_copy);
auto *local_data = ms_tensor_ptr->MutableData();
memcpy(local_data, data_arr, data_len);
return static_cast<jboolean>(true);
}
extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_MSTensor_size(JNIEnv *env, jobject thiz, jlong tensor_ptr) {
auto *pointer = reinterpret_cast<void *>(tensor_ptr);
if (pointer == nullptr) {
@ -150,7 +263,7 @@ extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_MSTensor_size(JNIEnv
}
extern "C" JNIEXPORT jint JNICALL Java_com_mindspore_lite_MSTensor_elementsNum(JNIEnv *env, jobject thiz,
jlong tensor_ptr) {
jlong tensor_ptr) {
auto *pointer = reinterpret_cast<void *>(tensor_ptr);
if (pointer == nullptr) {
MS_LOGE("Tensor pointer from java is nullptr");

View File

@ -32,9 +32,11 @@ if (PLATFORM_ARM64)
)
set_target_properties(optimize PROPERTIES CLEAN_DIRECT_OUTPUT 1)
add_custom_command(TARGET optimize POST_BUILD
if ("${CMAKE_BUILD_TYPE}" STREQUAL "Release")
add_custom_command(TARGET optimize POST_BUILD
COMMAND ${ANDROID_NDK}/toolchains/aarch64-linux-android-4.9/prebuilt/linux-x86_64/aarch64-linux-android/bin/strip
${TOP_DIR}/build/nnacl/liboptimize.so)
endif ()
add_custom_command(TARGET optimize POST_BUILD
COMMAND rm -rf ${TOP_DIR}/output/lib/liboptimize.so

View File

@ -51,6 +51,8 @@ void TileOneDimension(float *inData, float *outData, int dim, size_t ndim, int *
int *outStrides, int *multiple);
void ComputeStrides(int *shape, int *strides, int ndim);
void CalcMultiplesAndStrides(ArithmeticParameter *param);
void TileDimensions(float *data0, float *data1, float *tile_data0, float *tile_data1, ArithmeticParameter *param);
void TileDimensionsUint8(uint8_t *data0, uint8_t *data1, uint8_t *tile_data0, uint8_t *tile_data1,
ArithmeticParameter *param);

View File

@ -29,7 +29,7 @@ mov x6, x1
mov x7, x2
mov x8, x4
LoopInputDepth16In:
LoopDepth16In:
cmp x8, #16
blt L4
sub x8, x8, #16
@ -39,8 +39,8 @@ mov x8, x4
ld1 {v16.4s, v17.4s}, [x0], #32
cmp x8, #16
blt LoopInputDepth16Out
LoopInputDepth16:
blt LoopDepth16Out
LoopDepth16:
fmla v16.4s, v0.4s, v2.4s
fmla v17.4s, v1.4s, v3.4s
@ -61,9 +61,9 @@ mov x8, x4
sub x8, x8, #16
cmp x8, #16
bge LoopInputDepth16
bge LoopDepth16
LoopInputDepth16Out:
LoopDepth16Out:
fmla v16.4s, v0.4s, v2.4s
fmla v17.4s, v1.4s, v3.4s
st1 {v16.4s, v17.4s}, [x9], #32
@ -81,7 +81,7 @@ mov x8, x4
cmp x8, #4
blt L0
LoopInputDepth4:
LoopDepth4:
ld1 {v0.4s}, [x6], #16
ld1 {v2.4s}, [x7], #16
ld1 {v16.4s}, [x0], #16
@ -89,13 +89,13 @@ mov x8, x4
st1 {v16.4s}, [x9], #16
sub x8, x8, #4
cmp x8, #4
bge LoopInputDepth4
bge LoopDepth4
L0:
cmp x8, #0
beq Loop16LineEnd
LoopInputDepth0:
LoopDepth0:
ldr s0, [x6], #4
ldr s1, [x7], #4
ldr s2, [x0], #4
@ -103,7 +103,7 @@ mov x8, x4
fadd s2, s2, s0
str s2, [x9], #4
subs x8, x8, #1
bne LoopInputDepth0
bne LoopDepth0
Loop16LineEnd:

View File

@ -90,36 +90,36 @@ ConvDwInt8Center:
LoopKw16:
mov x22, x21
ld1 {v25.4h}, [x17], #8
ld1 {v16.4h}, [x22], x13
ld1 {v17.4h}, [x22], x13
ld1 {v16.4h}, [x22], x11
ld1 {v17.4h}, [x22], x11
smlal v0.4s, v16.4h, v25.4h
smlal v1.4s, v17.4h, v25.4h
ld1 {v18.4h}, [x22], x13
ld1 {v19.4h}, [x22], x13
ld1 {v18.4h}, [x22], x11
ld1 {v19.4h}, [x22], x11
smlal v2.4s, v18.4h, v25.4h
smlal v3.4s, v19.4h, v25.4h
ld1 {v20.4h}, [x22], x13
ld1 {v21.4h}, [x22], x13
ld1 {v20.4h}, [x22], x11
ld1 {v21.4h}, [x22], x11
smlal v4.4s, v20.4h, v25.4h
smlal v5.4s, v21.4h, v25.4h
ld1 {v22.4h}, [x22], x13
ld1 {v23.4h}, [x22], x13
ld1 {v22.4h}, [x22], x11
ld1 {v23.4h}, [x22], x11
smlal v6.4s, v22.4h, v25.4h
smlal v7.4s, v23.4h, v25.4h
ld1 {v16.4h}, [x22], x13
ld1 {v17.4h}, [x22], x13
ld1 {v16.4h}, [x22], x11
ld1 {v17.4h}, [x22], x11
smlal v8.4s, v16.4h, v25.4h
smlal v9.4s, v17.4h, v25.4h
ld1 {v18.4h}, [x22], x13
ld1 {v19.4h}, [x22], x13
ld1 {v18.4h}, [x22], x11
ld1 {v19.4h}, [x22], x11
smlal v10.4s, v18.4h, v25.4h
smlal v11.4s, v19.4h, v25.4h
ld1 {v20.4h}, [x22], x13
ld1 {v21.4h}, [x22], x13
ld1 {v20.4h}, [x22], x11
ld1 {v21.4h}, [x22], x11
smlal v12.4s, v20.4h, v25.4h
smlal v13.4s, v21.4h, v25.4h
ld1 {v22.4h}, [x22], x13
ld1 {v23.4h}, [x22], x13
ld1 {v22.4h}, [x22], x11
ld1 {v23.4h}, [x22], x11
smlal v14.4s, v22.4h, v25.4h
smlal v15.4s, v23.4h, v25.4h
subs x18, x18, #1
@ -420,20 +420,20 @@ ConvDwInt8Center:
LoopKw8:
mov x22, x21
ld1 {v25.4h}, [x17], #8
ld1 {v16.4h}, [x22], x13
ld1 {v17.4h}, [x22], x13
ld1 {v16.4h}, [x22], x11
ld1 {v17.4h}, [x22], x11
smlal v0.4s, v16.4h, v25.4h
smlal v1.4s, v17.4h, v25.4h
ld1 {v18.4h}, [x22], x13
ld1 {v19.4h}, [x22], x13
ld1 {v18.4h}, [x22], x11
ld1 {v19.4h}, [x22], x11
smlal v2.4s, v18.4h, v25.4h
smlal v3.4s, v19.4h, v25.4h
ld1 {v20.4h}, [x22], x13
ld1 {v21.4h}, [x22], x13
ld1 {v20.4h}, [x22], x11
ld1 {v21.4h}, [x22], x11
smlal v4.4s, v20.4h, v25.4h
smlal v5.4s, v21.4h, v25.4h
ld1 {v22.4h}, [x22], x13
ld1 {v23.4h}, [x22], x13
ld1 {v22.4h}, [x22], x11
ld1 {v23.4h}, [x22], x11
smlal v6.4s, v22.4h, v25.4h
smlal v7.4s, v23.4h, v25.4h
subs x18, x18, #1

View File

@ -0,0 +1,169 @@
#ifdef __aarch64__
.text
.align 5
.global ConvDwInt8PostAlign4
#ifndef __APPLE__
.type ConvDwInt8PostAlign4, %function
#endif
// void ConvDwInt8PostAlign4(int8_t *dst, int32_t *buffer, int num_pixels, int32_t output_zp, int32_t out_multiplier,
// int32_t left_shift, int32_t right_shift, int32_t acc_min, int32_t acc_max);
// x0: dst, x1: buffer, x2: num_pixels, x3: output_zp, x4: out_multiplier,
// x5: left_shift, x6: right_shift, x7: acc_min, x8: acc_max
ConvDwInt8PostAlign4:
// registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to
// https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers
// x19 ~ x29 should be also preserved
// whereas our coding style do not permit such amount of parameters
ldr x8, [sp]
dup v26.4s, w5
dup v27.4s, w4
dup v28.4s, w6
dup v29.4s, w3
dup v30.4s, w7
dup v31.4s, w8
cmp x2, 16
blt LoopDepth8
LoopDepth16:
ld1 {v0.4s}, [x1], #16
ld1 {v1.4s}, [x1], #16
ld1 {v2.4s}, [x1], #16
ld1 {v3.4s}, [x1], #16
sqshl v0.4s, v0.4s, v26.4s
sqshl v1.4s, v1.4s, v26.4s
sqshl v2.4s, v2.4s, v26.4s
sqshl v3.4s, v3.4s, v26.4s
sqrdmulh v0.4s, v0.4s, v27.4s
sqrdmulh v1.4s, v1.4s, v27.4s
sqrdmulh v2.4s, v2.4s, v27.4s
sqrdmulh v3.4s, v3.4s, v27.4s
and v16.16b, v28.16b, v0.16b
sshr v16.4s, v16.4s, #31
sqadd v0.4s, v0.4s, v16.4s
srshl v0.4s, v0.4s, v28.4s
and v17.16b, v28.16b, v1.16b
sshr v17.4s, v17.4s, #31
sqadd v1.4s, v1.4s, v17.4s
srshl v1.4s, v1.4s, v28.4s
and v18.16b, v28.16b, v2.16b
sshr v18.4s, v18.4s, #31
sqadd v2.4s, v2.4s, v18.4s
srshl v2.4s, v2.4s, v28.4s
and v19.16b, v28.16b, v3.16b
sshr v19.4s, v19.4s, #31
sqadd v3.4s, v3.4s, v19.4s
srshl v3.4s, v3.4s, v28.4s
add v0.4s, v0.4s, v29.4s
add v1.4s, v1.4s, v29.4s
add v2.4s, v2.4s, v29.4s
add v3.4s, v3.4s, v29.4s
smax v0.4s, v0.4s, v30.4s
smax v1.4s, v1.4s, v30.4s
smax v2.4s, v2.4s, v30.4s
smax v3.4s, v3.4s, v30.4s
smin v0.4s, v0.4s, v31.4s
smin v1.4s, v1.4s, v31.4s
smin v2.4s, v2.4s, v31.4s
smin v3.4s, v3.4s, v31.4s
sqxtn v0.4h, v0.4s
sqxtn v1.4h, v1.4s
sqxtn v2.4h, v2.4s
sqxtn v3.4h, v3.4s
sqxtn v0.8b, v0.8h
sqxtn v1.8b, v1.8h
sqxtn v2.8b, v2.8h
sqxtn v3.8b, v3.8h
st1 {v0.s}[0], [x0], #4
st1 {v1.s}[0], [x0], #4
st1 {v2.s}[0], [x0], #4
st1 {v3.s}[0], [x0], #4
sub x2, x2, #16
cmp x2, #16
bge LoopDepth16
LoopDepth8:
cmp x2, #8
blt LoopDepth4
ld1 {v0.4s}, [x1], #16
ld1 {v1.4s}, [x1], #16
sqshl v0.4s, v0.4s, v26.4s
sqshl v1.4s, v1.4s, v26.4s
sqrdmulh v0.4s, v0.4s, v27.4s
sqrdmulh v1.4s, v1.4s, v27.4s
and v16.16b, v28.16b, v0.16b
sshr v16.4s, v16.4s, #31
sqadd v0.4s, v0.4s, v16.4s
srshl v0.4s, v0.4s, v28.4s
and v17.16b, v28.16b, v1.16b
sshr v17.4s, v17.4s, #31
sqadd v1.4s, v1.4s, v17.4s
srshl v1.4s, v1.4s, v28.4s
add v0.4s, v0.4s, v29.4s
add v1.4s, v1.4s, v29.4s
smax v0.4s, v0.4s, v30.4s
smax v1.4s, v1.4s, v30.4s
smin v0.4s, v0.4s, v31.4s
smin v1.4s, v1.4s, v31.4s
sqxtn v0.4h, v0.4s
sqxtn v1.4h, v1.4s
sqxtn v0.8b, v0.8h
sqxtn v1.8b, v1.8h
st1 {v0.s}[0], [x0], #4
st1 {v1.s}[0], [x0], #4
sub x2, x2, #8
cmp x2, #8
bge LoopDepth8
LoopDepth4:
cmp x2, #4
blt End
ld1 {v0.4s}, [x1], #16
sqshl v0.4s, v0.4s, v26.4s
sqrdmulh v0.4s, v0.4s, v27.4s
and v16.16b, v28.16b, v0.16b
sshr v16.4s, v16.4s, #31
sqadd v0.4s, v0.4s, v16.4s
srshl v0.4s, v0.4s, v28.4s
add v0.4s, v0.4s, v29.4s
smax v0.4s, v0.4s, v30.4s
smin v0.4s, v0.4s, v31.4s
sqxtn v0.4h, v0.4s
sqxtn v0.8b, v0.8h
st1 {v0.s}[0], [x0], #4
sub x2, x2, #4
bge LoopDepth4
End:
ret
#endif

View File

@ -0,0 +1,122 @@
#ifdef __aarch64__
.text
.align 5
.global ConvDwInt8Row
#ifndef __APPLE__
.type ConvDwInt8Row, %function
#endif
// void ConvDwInt8Row(int32_t *output_ptr, const int8_t *input_ptr, const int16_t *weight_ptr, int num_pixels,
// int output_channel, int input_step, int8_t input_zp)
// x0: output_ptr, x1: input_ptr, x2: weight_ptr, x3: num_pixels,
// x4: output_channel, x5: input_step, x6: input_zp
//
ConvDwInt8Row:
// registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to
// https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers
// x19 ~ x29 should be also preserved
// whereas our coding style do not permit such amount of parameters
cmp x3, #0
beq End
mov x10, x0
dup v31.8b, w6
LoopOutPixel:
mov x7, x1
mov x8, x2
mov x9, x4
LoopDepth16In:
cmp x9, #16
blt L8
sub x9, x9, #16
ld1 {v0.8b, v1.8b}, [x7], #16
ld1 {v2.8h, v3.8h}, [x8], #32
ld1 {v16.4s, v17.4s}, [x0], #32
ssubl v20.8h, v0.8b, v31.8b
smlal v16.4s, v20.4h, v2.4h
smlal2 v17.4s, v20.8h, v2.8h
cmp x9, #16
blt LoopDepth16Out
LoopDepth16:
st1 {v16.4s, v17.4s}, [x10], #32
ld1 {v18.4s, v19.4s}, [x0], #32
ssubl v21.8h, v1.8b, v31.8b
smlal v18.4s, v21.4h, v3.4h
smlal2 v19.4s, v21.8h, v3.8h
st1 {v18.4s, v19.4s}, [x10], #32
ld1 {v0.8b, v1.8b}, [x7], #16
ld1 {v2.8h, v3.8h}, [x8], #32
ld1 {v16.4s, v17.4s}, [x0], #32
ssubl v20.8h, v0.8b, v31.8b
smlal v16.4s, v20.4h, v2.4h
smlal2 v17.4s, v20.8h, v2.8h
sub x9, x9, #16
cmp x9, #16
bge LoopDepth16
LoopDepth16Out:
st1 {v16.4s, v17.4s}, [x10], #32
ld1 {v18.4s, v19.4s}, [x0], #32
ssubl v21.8h, v1.8b, v31.8b
smlal v18.4s, v21.4h, v3.4h
smlal2 v19.4s, v21.8h, v3.8h
st1 {v18.4s, v19.4s}, [x10], #32
L8:
cmp x9, #8
blt L0
LoopDepth8:
ld1 {v0.8b}, [x7], #8
ld1 {v2.8h}, [x8], #16
ld1 {v16.4s, v17.4s}, [x0], #32
ssubl v20.8h, v0.8b, v31.8b
smlal v16.4s, v20.4h, v2.4h
smlal2 v17.4s, v20.8h, v2.8h
st1 {v16.4s, v17.4s}, [x10], #32
sub x9, x9, #8
cmp x9, #8
bge LoopDepth8
L0:
cmp x9, #0
beq Loop16LineEnd
LoopDepth0:
ldrsb w14, [x7], #1
ldrsh w15, [x8], #2
ldr w16, [x0], #4
add w14, w14, w6
sxth w14, w14
madd w14, w14, w15, w16
str w14, [x10], #4
subs x9, x9, #1
bne LoopDepth0
Loop16LineEnd:
subs x3, x3, #1
add x1, x1, x5
bne LoopOutPixel
End:
ret
#endif

View File

@ -0,0 +1,812 @@
#ifdef __aarch64__
.text
.align 5
.global MatmulFloatNeon64Opt
#ifndef __APPLE__
.type MatmulFloatNeon64Opt, %function
#endif
// A: LM [row_8 * depth] col_8_major
// B: RM [depth * col_8] row_8_major
// C: A*B [row_8 * col_8] col_8x8_major
// A * B -> [8 * depth] * [depth * 8] -> [8 * 4] * [4 * 8] or [8 * 1] * [1 * 8]
///////////////////////////////////////////////////////////////////////////////
//CommLoopMul RM 1x8 block
// /-----------------------------------------\
// |v2.s[0] ... v2.s[3] v3.s[0] ... v3.s[3]|
// \-----------------------------------------/
// LM 8x1 block
// /---------------------\ /-----------------------------------------\
// | v0.s[0] | |v16.s[0]...v16.s[3] v17.s[0]...v17.s[3]|
// | ... | | ... ... |
// | v0.s[3] | |v22.s[0]...v22.s[3] v23.s[0]...v23.s[3]|
// | v1.s[0] | |v24.s[0]...v24.s[3] v25.s[0]...v25.s[3]|
// | ... | | ... ... |
// | v1.s[3] | |v30.s[0]...v30.s[3] v31.s[0]...v31.s[3]|
// \---------------------/ \-----------------------------------------/
// accumulators 8x8 block
//
///////////////////////////////////////////////////////////////////////////////
//OptLoopMul4 RM 4x8 block
// /--------------------------------------------\
// |v8.s[0] ... v8.s[3] v9.s[0] ... v9.s[3] |
// |v10.s[0] ... v10.s[3] v11.s[0] ... v11.s[3]|
// |v12.s[0] ... v12.s[3] v13.s[0] ... v13.s[3]|
// |v14.s[0] ... v14.s[3] v15.s[0] ... v15.s[3]|
// \--------------------------------------------/
// LM 8x4 block
// /---------------------------------\ /--------------------------------------------\
// | v0.s[0] v2.s[0] v4.s[0] v6.s[0] | |v16.s[0]...v16.s[3] v17.s[0]...v17.s[3] |
// | ... ... ... ... | | ... ... |
// | v0.s[3] v2.s[3] v4.s[3] v6.s[3] | |v22.s[0]...v22.s[3] v23.s[0]...v23.s[3] |
// | v1.s[0] v3.s[0] v5.s[0] v7.s[0] | |v24.s[0]...v24.s[3] v25.s[0]...v25.s[3] |
// | ... ... ... ... | | ... ... |
// | v1.s[3] v3.s[3] v5.s[3] v7.s[3] | |v30.s[0]...v30.s[3] v31.s[0]...v31.s[3] |
// \---------------------------------/ \--------------------------------------------/
// accumulators 8x8 block
/////////////////////////////////////////////////////////////////////////////////
//
// void MatmulFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth
// int row, int col, size_t stride, size_t writeNhwc, size_t WriteWino)
// x0: a
// x1: b
// x2: c
// x3: bias
// w4: act_type
// w5: depth
// w6: row
// w7: col
// w17: stride
// w13: c8_nhwc_c4
MatmulFloatNeon64Opt:
sub sp, sp, #128
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
ldr x9, [sp, #8]
ldr x14, [sp, #16]
mov w18, #32 // sizeof(float) * 8
mul w15, w5, w18 // block stride of lhs/rhs: sizeof(float) * 8 * depth
mov x18, #4
ldr x17, [sp]
cbz x14, NoWinoSteps
mul x8, x7, x17
mov x11, #8
mul x11, x11, x17
mul x8, x8, x18
mul x11, x11, x18
NoWinoSteps:
mul x17, x17, x18
L1:
mov w10, w6 // reload lhs row
mov x12, x0 // reload lhs ptr
mov x18, x2 // reload dst ptr
L2:
mov x16, x1 // reload rhs ptr
mov w13, w5 // reload depth
dup v8.4s, wzr
dup v9.4s, wzr
dup v10.4s, wzr
dup v11.4s, wzr
dup v12.4s, wzr
dup v13.4s, wzr
dup v14.4s, wzr
dup v15.4s, wzr
dup v16.4s, wzr
dup v17.4s, wzr
dup v18.4s, wzr
dup v19.4s, wzr
dup v20.4s, wzr
dup v21.4s, wzr
dup v22.4s, wzr
dup v23.4s, wzr
dup v24.4s, wzr
dup v25.4s, wzr
dup v26.4s, wzr
dup v27.4s, wzr
dup v28.4s, wzr
dup v29.4s, wzr
dup v30.4s, wzr
dup v31.4s, wzr
LoopStart:
ld1 {v0.4s, v1.4s, v2.4s}, [x12], #48
ld1 {v3.4s, v4.4s}, [x16], #32
fmla v8.4s, v3.4s, v0.s[0]
fmla v10.4s, v3.4s, v0.s[1]
fmla v12.4s, v3.4s, v0.s[2]
fmla v14.4s, v3.4s, v0.s[3]
fmla v9.4s, v4.4s, v0.s[0]
fmla v11.4s, v4.4s, v0.s[1]
fmla v13.4s, v4.4s, v0.s[2]
fmla v15.4s, v4.4s, v0.s[3]
subs w13, w13, #1
beq LoopEnd
Loop:
ld1 {v0.4s}, [x12], #16
fmla v16.4s, v3.4s, v1.s[0]
fmla v18.4s, v3.4s, v1.s[1]
fmla v20.4s, v3.4s, v1.s[2]
fmla v22.4s, v3.4s, v1.s[3]
fmla v17.4s, v4.4s, v1.s[0]
fmla v19.4s, v4.4s, v1.s[1]
fmla v21.4s, v4.4s, v1.s[2]
fmla v23.4s, v4.4s, v1.s[3]
ld1 {v1.4s}, [x12], #16
fmla v24.4s, v3.4s, v2.s[0]
fmla v26.4s, v3.4s, v2.s[1]
fmla v28.4s, v3.4s, v2.s[2]
fmla v30.4s, v3.4s, v2.s[3]
ld1 {v3.4s}, [x16], #16
fmla v25.4s, v4.4s, v2.s[0]
fmla v27.4s, v4.4s, v2.s[1]
fmla v29.4s, v4.4s, v2.s[2]
fmla v31.4s, v4.4s, v2.s[3]
ld1 {v4.4s}, [x16], #16
fmla v8.4s, v3.4s, v0.s[0]
fmla v10.4s, v3.4s, v0.s[1]
fmla v12.4s, v3.4s, v0.s[2]
fmla v14.4s, v3.4s, v0.s[3]
ld1 {v2.4s}, [x12], #16
fmla v9.4s, v4.4s, v0.s[0]
fmla v11.4s, v4.4s, v0.s[1]
fmla v13.4s, v4.4s, v0.s[2]
fmla v15.4s, v4.4s, v0.s[3]
subs w13, w13, #1
bgt Loop
LoopEnd:
fmla v16.4s, v3.4s, v1.s[0]
fmla v18.4s, v3.4s, v1.s[1]
fmla v20.4s, v3.4s, v1.s[2]
fmla v22.4s, v3.4s, v1.s[3]
fmla v17.4s, v4.4s, v1.s[0]
fmla v19.4s, v4.4s, v1.s[1]
fmla v21.4s, v4.4s, v1.s[2]
fmla v23.4s, v4.4s, v1.s[3]
fmla v24.4s, v3.4s, v2.s[0]
fmla v26.4s, v3.4s, v2.s[1]
fmla v28.4s, v3.4s, v2.s[2]
fmla v30.4s, v3.4s, v2.s[3]
fmla v25.4s, v4.4s, v2.s[0]
fmla v27.4s, v4.4s, v2.s[1]
fmla v29.4s, v4.4s, v2.s[2]
fmla v31.4s, v4.4s, v2.s[3]
Bias:
cbz x3, Activation
ld1 {v0.4s}, [x3], #16
ld1 {v1.4s}, [x3]
sub x3, x3, #16
fadd v8.4s, v8.4s, v0.4s
fadd v9.4s, v9.4s, v1.4s
fadd v10.4s, v10.4s, v0.4s
fadd v11.4s, v11.4s, v1.4s
fadd v12.4s, v12.4s, v0.4s
fadd v13.4s, v13.4s, v1.4s
fadd v14.4s, v14.4s, v0.4s
fadd v15.4s, v15.4s, v1.4s
fadd v16.4s, v16.4s, v0.4s
fadd v17.4s, v17.4s, v1.4s
fadd v18.4s, v18.4s, v0.4s
fadd v19.4s, v19.4s, v1.4s
fadd v20.4s, v20.4s, v0.4s
fadd v21.4s, v21.4s, v1.4s
fadd v22.4s, v22.4s, v0.4s
fadd v23.4s, v23.4s, v1.4s
fadd v24.4s, v24.4s, v0.4s
fadd v25.4s, v25.4s, v1.4s
fadd v26.4s, v26.4s, v0.4s
fadd v27.4s, v27.4s, v1.4s
fadd v28.4s, v28.4s, v0.4s
fadd v29.4s, v29.4s, v1.4s
fadd v30.4s, v30.4s, v0.4s
fadd v31.4s, v31.4s, v1.4s
Activation:
cmp w4, #2
beq Relu6
cmp w4, #1
beq Relu
b Write
Relu6:
mov w13, #6
dup v2.4s, w13
scvtf v2.4s, v2.4s
fmin v8.4s, v8.4s, v2.4s
fmin v9.4s, v9.4s, v2.4s
fmin v10.4s, v10.4s, v2.4s
fmin v11.4s, v11.4s, v2.4s
fmin v12.4s, v12.4s, v2.4s
fmin v13.4s, v13.4s, v2.4s
fmin v14.4s, v14.4s, v2.4s
fmin v15.4s, v15.4s, v2.4s
fmin v16.4s, v16.4s, v2.4s
fmin v17.4s, v17.4s, v2.4s
fmin v18.4s, v18.4s, v2.4s
fmin v19.4s, v19.4s, v2.4s
fmin v20.4s, v20.4s, v2.4s
fmin v21.4s, v21.4s, v2.4s
fmin v22.4s, v22.4s, v2.4s
fmin v23.4s, v23.4s, v2.4s
fmin v24.4s, v24.4s, v2.4s
fmin v25.4s, v25.4s, v2.4s
fmin v26.4s, v26.4s, v2.4s
fmin v27.4s, v27.4s, v2.4s
fmin v28.4s, v28.4s, v2.4s
fmin v29.4s, v29.4s, v2.4s
fmin v30.4s, v30.4s, v2.4s
fmin v31.4s, v31.4s, v2.4s
Relu:
dup v3.4s, wzr
fmax v8.4s, v8.4s, v3.4s
fmax v9.4s, v9.4s, v3.4s
fmax v10.4s, v10.4s, v3.4s
fmax v11.4s, v11.4s, v3.4s
fmax v12.4s, v12.4s, v3.4s
fmax v13.4s, v13.4s, v3.4s
fmax v14.4s, v14.4s, v3.4s
fmax v15.4s, v15.4s, v3.4s
fmax v16.4s, v16.4s, v3.4s
fmax v17.4s, v17.4s, v3.4s
fmax v18.4s, v18.4s, v3.4s
fmax v19.4s, v19.4s, v3.4s
fmax v20.4s, v20.4s, v3.4s
fmax v21.4s, v21.4s, v3.4s
fmax v22.4s, v22.4s, v3.4s
fmax v23.4s, v23.4s, v3.4s
fmax v24.4s, v24.4s, v3.4s
fmax v25.4s, v25.4s, v3.4s
fmax v26.4s, v26.4s, v3.4s
fmax v27.4s, v27.4s, v3.4s
fmax v28.4s, v28.4s, v3.4s
fmax v29.4s, v29.4s, v3.4s
fmax v30.4s, v30.4s, v3.4s
fmax v31.4s, v31.4s, v3.4s
Write:
cbnz x14, WriteWino
cbz x9, WriteC8
cmp w7, #1
beq Write1
cmp w7, #2
beq Write2
cmp w7, #3
beq Write3
cmp w7, #4
beq Write4
cmp w7, #5
beq Write5
cmp w7, #6
beq Write6
cmp w7, #7
beq Write7
b Write8
Write1:
str s8, [x18]
cmp w10, #1
beq WriteEnd
add x18, x18, x17
str s10, [x18]
cmp w10, #2
beq WriteEnd
add x18, x18, x17
str s12, [x18]
cmp w10, #3
beq WriteEnd
add x18, x18, x17
str s14, [x18]
cmp w10, #4
beq WriteEnd
add x18, x18, x17
str s16, [x18]
cmp w10, #5
beq WriteEnd
add x18, x18, x17
str s18, [x18]
cmp w10, #6
beq WriteEnd
add x18, x18, x17
str s20, [x18]
cmp w10, #7
beq WriteEnd
add x18, x18, x17
str s22, [x18]
cmp w10, #8
beq WriteEnd
add x18, x18, x17
str s24, [x18]
cmp w10, #9
beq WriteEnd
add x18, x18, x17
str s26, [x18]
cmp w10, #10
beq WriteEnd
add x18, x18, x17
str s28, [x18]
cmp w10, #11
beq WriteEnd
add x18, x18, x17
str s30, [x18]
add x18, x18, x17
b WriteEnd
Write2:
dup s9, v8.s[1]
stp s8, s9, [x18]
cmp w10, #1
beq WriteEnd
add x18, x18, x17
dup s11, v10.s[1]
stp s10, s11, [x18]
cmp w10, #2
beq WriteEnd
add x18, x18, x17
dup s13, v12.s[1]
stp s12, s13, [x18]
cmp w10, #3
beq WriteEnd
add x18, x18, x17
dup s15, v14.s[1]
stp s14, s15, [x18]
cmp w10, #4
beq WriteEnd
add x18, x18, x17
dup s17, v16.s[1]
stp s16, s17, [x18]
cmp w10, #5
beq WriteEnd
add x18, x18, x17
dup s19, v18.s[1]
stp s18, s19, [x18]
cmp w10, #6
beq WriteEnd
add x18, x18, x17
dup s21, v20.s[1]
stp s20, s21, [x18]
cmp w10, #7
beq WriteEnd
add x18, x18, x17
dup s23, v22.s[1]
stp s22, s23, [x18]
cmp w10, #8
beq WriteEnd
add x18, x18, x17
dup s25, v24.s[1]
stp s24, s25, [x18]
cmp w10, #9
beq WriteEnd
add x18, x18, x17
dup s27, v26.s[1]
stp s26, s27, [x18]
cmp w10, #10
beq WriteEnd
add x18, x18, x17
dup s29, v28.s[1]
stp s28, s29, [x18]
cmp w10, #11
beq WriteEnd
add x18, x18, x17
dup s31, v30.s[1]
stp s30, s31, [x18]
add x18, x18, x17
b WriteEnd
Write3:
add x13, x18, #8
dup s9, v8.s[1]
stp s8, s9, [x18]
add x18, x18, x17
st1 {v8.s}[2], [x13], x17
cmp w10, #1
beq WriteEnd
dup s11, v10.s[1]
stp s10, s11, [x18]
add x18, x18, x17
st1 {v10.s}[2], [x13], x17
cmp w10, #2
beq WriteEnd
dup s13, v12.s[1]
stp s12, s13, [x18]
add x18, x18, x17
st1 {v12.s}[2], [x13], x17
cmp w10, #3
beq WriteEnd
dup s15, v14.s[1]
stp s14, s15, [x18]
add x18, x18, x17
st1 {v14.s}[2], [x13], x17
cmp w10, #4
beq WriteEnd
dup s17, v16.s[1]
stp s16, s17, [x18]
add x18, x18, x17
st1 {v16.s}[2], [x13], x17
cmp w10, #5
beq WriteEnd
dup s19, v18.s[1]
stp s18, s19, [x18]
add x18, x18, x17
st1 {v18.s}[2], [x13], x17
cmp w10, #6
beq WriteEnd
dup s21, v20.s[1]
stp s20, s21, [x18]
add x18, x18, x17
st1 {v20.s}[2], [x13], x17
cmp w10, #7
beq WriteEnd
dup s23, v22.s[1]
stp s22, s23, [x18]
add x18, x18, x17
st1 {v22.s}[2], [x13], x17
cmp w10, #8
beq WriteEnd
dup s25, v24.s[1]
stp s24, s25, [x18]
add x18, x18, x17
st1 {v24.s}[2], [x13], x17
cmp w10, #9
beq WriteEnd
dup s27, v26.s[1]
stp s26, s27, [x18]
add x18, x18, x17
st1 {v26.s}[2], [x13], x17
cmp w10, #10
beq WriteEnd
dup s29, v28.s[1]
stp s28, s29, [x18]
add x18, x18, x17
st1 {v28.s}[2], [x13], x17
cmp w10, #11
beq WriteEnd
dup s31, v30.s[1]
stp s30, s31, [x18]
add x18, x18, x17
st1 {v30.s}[2], [x13]
b WriteEnd
Write4:
st1 {v8.4s}, [x18], x17
cmp w10, #1
beq WriteEnd
st1 {v10.4s}, [x18], x17
cmp w10, #2
beq WriteEnd
st1 {v12.4s}, [x18], x17
cmp w10, #3
beq WriteEnd
st1 {v14.4s}, [x18], x17
cmp w10, #4
beq WriteEnd
st1 {v16.4s}, [x18], x17
cmp w10, #5
beq WriteEnd
st1 {v18.4s}, [x18], x17
cmp w10, #6
beq WriteEnd
st1 {v20.4s}, [x18], x17
cmp w10, #7
beq WriteEnd
st1 {v22.4s}, [x18], x17
cmp w10, #8
beq WriteEnd
st1 {v24.4s}, [x18], x17
cmp w10, #9
beq WriteEnd
st1 {v26.4s}, [x18], x17
cmp w10, #10
beq WriteEnd
st1 {v28.4s}, [x18], x17
cmp w10, #11
beq WriteEnd
st1 {v30.4s}, [x18], x17
b WriteEnd
Write5:
add x13, x18, #16
st1 {v8.4s}, [x18], x17
str s9, [x13]
cmp w10, #1
beq WriteEnd
add x13, x13, x17
st1 {v10.4s}, [x18], x17
str s11, [x13]
cmp w10, #2
beq WriteEnd
add x13, x13, x17
st1 {v12.4s}, [x18], x17
str s13, [x13]
cmp w10, #3
beq WriteEnd
add x13, x13, x17
st1 {v14.4s}, [x18], x17
str s15, [x13]
cmp w10, #4
beq WriteEnd
add x13, x13, x17
st1 {v16.4s}, [x18], x17
str s17, [x13]
cmp w10, #5
beq WriteEnd
add x13, x13, x17
st1 {v18.4s}, [x18], x17
str s19, [x13]
cmp w10, #6
beq WriteEnd
add x13, x13, x17
st1 {v20.4s}, [x18], x17
str s21, [x13]
cmp w10, #7
beq WriteEnd
add x13, x13, x17
st1 {v22.4s}, [x18], x17
str s23, [x13]
cmp w10, #8
beq WriteEnd
add x13, x13, x17
st1 {v24.4s}, [x18], x17
str s25, [x13]
cmp w10, #9
beq WriteEnd
add x13, x13, x17
st1 {v26.4s}, [x18], x17
str s27, [x13]
cmp w10, #10
beq WriteEnd
add x13, x13, x17
st1 {v28.4s}, [x18], x17
str s29, [x13]
cmp w10, #11
beq WriteEnd
add x13, x13, x17
st1 {v30.4s}, [x18], x17
str s31, [x13]
b WriteEnd
Write6:
add x13, x18, #16
st1 {v8.4s}, [x18], x17
dup s8, v9.s[1]
stp s9, s8, [x13]
cmp w10, #1
beq WriteEnd
add x13, x13, x17
st1 {v10.4s}, [x18], x17
dup s10, v11.s[1]
stp s11, s10, [x13]
cmp w10, #2
beq WriteEnd
add x13, x13, x17
st1 {v12.4s}, [x18], x17
dup s12, v13.s[1]
stp s13, s12, [x13]
cmp w10, #3
beq WriteEnd
add x13, x13, x17
st1 {v14.4s}, [x18], x17
dup s14, v15.s[1]
stp s15, s14, [x13]
cmp w10, #4
beq WriteEnd
add x13, x13, x17
st1 {v16.4s}, [x18], x17
dup s16, v17.s[1]
stp s17, s16, [x13]
cmp w10, #5
beq WriteEnd
add x13, x13, x17
st1 {v18.4s}, [x18], x17
dup s18, v19.s[1]
stp s19, s18, [x13]
cmp w10, #6
beq WriteEnd
add x13, x13, x17
st1 {v20.4s}, [x18], x17
dup s20, v21.s[1]
stp s21, s20, [x13]
cmp w10, #7
beq WriteEnd
add x13, x13, x17
st1 {v22.4s}, [x18], x17
dup s22, v23.s[1]
stp s23, s22, [x13]
cmp w10, #8
beq WriteEnd
add x13, x13, x17
st1 {v24.4s}, [x18], x17
dup s24, v25.s[1]
stp s25, s24, [x13]
cmp w10, #9
beq WriteEnd
add x13, x13, x17
st1 {v26.4s}, [x18], x17
dup s26, v27.s[1]
stp s27, s26, [x13]
cmp w10, #10
beq WriteEnd
add x13, x13, x17
st1 {v28.4s}, [x18], x17
dup s28, v29.s[1]
stp s29, s28, [x13]
cmp w10, #11
beq WriteEnd
add x13, x13, x17
st1 {v30.4s}, [x18], x17
dup s30, v31.s[1]
stp s31, s30, [x13]
b WriteEnd
Write7:
add x13, x18, #16
add x16, x18, #24
st1 {v8.4s}, [x18], x17
dup s8, v9.s[1]
stp s9, s8, [x13]
add x13, x13, x17
st1 {v9.s}[2], [x16], x17
cmp w10, #1
beq WriteEnd
st1 {v10.4s}, [x18], x17
dup s10, v11.s[1]
stp s11, s10, [x13]
add x13, x13, x17
st1 {v11.s}[2], [x16], x17
cmp w10, #2
beq WriteEnd
st1 {v12.4s}, [x18], x17
dup s12, v13.s[1]
stp s13, s12, [x13]
add x13, x13, x17
st1 {v13.s}[2], [x16], x17
cmp w10, #3
beq WriteEnd
st1 {v14.4s}, [x18], x17
dup s14, v15.s[1]
stp s15, s14, [x13]
add x13, x13, x17
st1 {v15.s}[2], [x16], x17
cmp w10, #4
beq WriteEnd
st1 {v16.4s}, [x18], x17
dup s16, v17.s[1]
stp s17, s16, [x13]
add x13, x13, x17
st1 {v17.s}[2], [x16], x17
cmp w10, #5
beq WriteEnd
st1 {v18.4s}, [x18], x17
dup s18, v19.s[1]
stp s19, s18, [x13]
add x13, x13, x17
st1 {v19.s}[2], [x16], x17
cmp w10, #6
beq WriteEnd
st1 {v20.4s}, [x18], x17
dup s20, v21.s[1]
stp s21, s20, [x13]
add x13, x13, x17
st1 {v21.s}[2], [x16], x17
cmp w10, #7
beq WriteEnd
st1 {v22.4s}, [x18], x17
dup s22, v23.s[1]
stp s23, s22, [x13]
add x13, x13, x17
st1 {v23.s}[2], [x16], x17
cmp w10, #8
beq WriteEnd
st1 {v24.4s}, [x18], x17
dup s24, v25.s[1]
stp s25, s24, [x13]
add x13, x13, x17
st1 {v25.s}[2], [x16], x17
cmp w10, #9
beq WriteEnd
st1 {v26.4s}, [x18], x17
dup s26, v27.s[1]
stp s27, s26, [x13]
add x13, x13, x17
st1 {v27.s}[2], [x16], x17
cmp w10, #10
beq WriteEnd
st1 {v28.4s}, [x18], x17
dup s28, v29.s[1]
stp s29, s28, [x13]
add x13, x13, x17
st1 {v29.s}[2], [x16], x17
cmp w10, #11
beq WriteEnd
st1 {v30.4s}, [x18], x17
dup s30, v31.s[1]
stp s31, s30, [x13]
add x13, x13, x17
st1 {v31.s}[2], [x16], x17
b WriteEnd
WriteC8:
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x2], #64
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x2], #64
st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64
st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x2], #64
st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x2], #64
st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64
b WriteEnd
WriteWino:
st1 {v8.4s, v9.4s}, [x18], x8
st1 {v10.4s, v11.4s}, [x18], x8
st1 {v12.4s, v13.4s}, [x18], x8
st1 {v14.4s, v15.4s}, [x18], x8
st1 {v16.4s, v17.4s}, [x18], x8
st1 {v18.4s, v19.4s}, [x18], x8
st1 {v20.4s, v21.4s}, [x18], x8
st1 {v22.4s, v23.4s}, [x18], x8
st1 {v24.4s, v25.4s}, [x18], x8
st1 {v26.4s, v27.4s}, [x18], x8
st1 {v28.4s, v29.4s}, [x18], x8
st1 {v30.4s, v31.4s}, [x18], x8
b WriteEnd
Write8:
st1 {v8.4s, v9.4s}, [x18], x17
cmp w10, #1
beq WriteEnd
st1 {v10.4s, v11.4s}, [x18], x17
cmp w10, #2
beq WriteEnd
st1 {v12.4s, v13.4s}, [x18], x17
cmp w10, #3
beq WriteEnd
st1 {v14.4s, v15.4s}, [x18], x17
cmp w10, #4
beq WriteEnd
st1 {v16.4s, v17.4s}, [x18], x17
cmp w10, #5
beq WriteEnd
st1 {v18.4s, v19.4s}, [x18], x17
cmp w10, #6
beq WriteEnd
st1 {v20.4s, v21.4s}, [x18], x17
cmp w10, #7
beq WriteEnd
st1 {v22.4s, v23.4s}, [x18], x17
cmp w10, #8
beq WriteEnd
st1 {v24.4s, v25.4s}, [x18], x17
cmp w10, #9
beq WriteEnd
st1 {v26.4s, v27.4s}, [x18], x17
cmp w10, #10
beq WriteEnd
st1 {v28.4s, v29.4s}, [x18], x17
cmp w10, #11
beq WriteEnd
st1 {v30.4s, v31.4s}, [x18], x17
WriteEnd:
subs w10, w10, #12 // lhs row - 12
bgt L2
End2:
subs w7, w7, #8 // rhs col - 8
add x1, x1, x15 // rhs ptr + stride
cbz x3, NoBiasStep
add x3, x3, #32 // bias ptr + stride
NoBiasStep:
cbnz x14, WinoDstStep
cbz x9, NoDstStep
add x2, x2, #32 // dst ptr + stride
b NoDstStep
WinoDstStep:
add x2, x2, x11
NoDstStep:
bgt L1
End1:
sub sp, sp, #128
ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
ret
#endif

View File

@ -0,0 +1,144 @@
#ifdef __aarch64__
.text
.align 5
.global MatmulFloatNeon64OptRemain
#ifndef __APPLE__
.type MatmulFloatNeon64OptRemain, %function
#endif
// void MatmulFloatNeon64(const float *a, const float *b, float *c, int depth
// int row, int col, size_t stride)
// x0: a
// x1: b
// x2: c
// x3: depth
// x4: row
// x5: col
// x6: stride
// only for winograd
MatmulFloatNeon64OptRemain:
mov x18, #32 // sizeof(float) * 8
mul x9, x3, x18 // block stride of lhs/rhs: sizeof(float) * 8 * depth
mov x18, #4
mul x8, x5, x6
mov x11, #8
mul x11, x11, x6
mul x8, x8, x18
mul x11, x11, x18
cmp x4, #4
ble LoopH4
LoopH8:
mov x10, x4 // reload lhs row
mov x12, x0 // reload lhs ptr
mov x18, x2 // reload dst ptr
LoopW8:
mov x16, x1 // reload rhs ptr
mov x13, x3 // reload depth
dup v16.4s, wzr
dup v17.4s, wzr
dup v18.4s, wzr
dup v19.4s, wzr
dup v20.4s, wzr
dup v21.4s, wzr
dup v22.4s, wzr
dup v23.4s, wzr
dup v24.4s, wzr
dup v25.4s, wzr
dup v26.4s, wzr
dup v27.4s, wzr
dup v28.4s, wzr
dup v29.4s, wzr
dup v30.4s, wzr
dup v31.4s, wzr
LoopD8:
ld1 {v0.4s, v1.4s, v2.4s}, [x12], #48
ld1 {v3.4s, v4.4s}, [x16], #32
fmla v16.4s, v3.4s, v0.s[0]
fmla v18.4s, v3.4s, v0.s[1]
fmla v20.4s, v3.4s, v0.s[2]
fmla v22.4s, v3.4s, v0.s[3]
fmla v17.4s, v4.4s, v0.s[0]
fmla v19.4s, v4.4s, v0.s[1]
fmla v21.4s, v4.4s, v0.s[2]
fmla v23.4s, v4.4s, v0.s[3]
fmla v24.4s, v3.4s, v1.s[0]
fmla v26.4s, v3.4s, v1.s[1]
fmla v28.4s, v3.4s, v1.s[2]
fmla v30.4s, v3.4s, v1.s[3]
fmla v25.4s, v4.4s, v1.s[0]
fmla v27.4s, v4.4s, v1.s[1]
fmla v29.4s, v4.4s, v1.s[2]
fmla v31.4s, v4.4s, v1.s[3]
subs w13, w13, #1
bgt LoopD8
st1 {v16.4s, v17.4s}, [x18], x8
st1 {v18.4s, v19.4s}, [x18], x8
st1 {v20.4s, v21.4s}, [x18], x8
st1 {v22.4s, v23.4s}, [x18], x8
st1 {v24.4s, v25.4s}, [x18], x8
st1 {v26.4s, v27.4s}, [x18], x8
st1 {v28.4s, v29.4s}, [x18], x8
st1 {v30.4s, v31.4s}, [x18], x8
subs x10, x10, #8 // lhs row - 8
bgt LoopW8
subs x5, x5, #8 // rhs col - 8
add x1, x1, x9 // rhs ptr + stride
add x2, x2, x11
bgt LoopH8
ret
LoopH4:
mov x10, x4 // reload lhs row
mov x12, x0 // reload lhs ptr
mov x18, x2 // reload dst ptr
LoopW4:
mov x16, x1 // reload rhs ptr
mov x13, x3 // reload depth
dup v16.4s, wzr
dup v17.4s, wzr
dup v18.4s, wzr
dup v19.4s, wzr
dup v20.4s, wzr
dup v21.4s, wzr
dup v22.4s, wzr
dup v23.4s, wzr
LoopD4:
ld1 {v0.4s, v1.4s, v2.4s}, [x12], #48
ld1 {v3.4s, v4.4s}, [x16], #32
fmla v16.4s, v3.4s, v0.s[0]
fmla v18.4s, v3.4s, v0.s[1]
fmla v20.4s, v3.4s, v0.s[2]
fmla v22.4s, v3.4s, v0.s[3]
fmla v17.4s, v4.4s, v0.s[0]
fmla v19.4s, v4.4s, v0.s[1]
fmla v21.4s, v4.4s, v0.s[2]
fmla v23.4s, v4.4s, v0.s[3]
subs x13, x13, #1
bgt LoopD4
st1 {v16.4s, v17.4s}, [x18], x8
st1 {v18.4s, v19.4s}, [x18], x8
st1 {v20.4s, v21.4s}, [x18], x8
st1 {v22.4s, v23.4s}, [x18], x8
subs x10, x10, #4 // lhs row - 4
bgt LoopW4
subs x5, x5, #8 // rhs col - 8
add x1, x1, x9 // rhs ptr + stride
add x2, x2, x11
bgt LoopH4
ret
#endif

View File

@ -24,7 +24,7 @@
//void MatmulInt8Neon64(const int8_t *a, const int8_t *b, int8_t *dst, int row4, int col4, int deep16,
// const int *a_sums, const int *bias, int act_min, int act_max, int out_zp,
// int multiplier, int left_shift, int right_shift);
// int multiplier, int left_shift, int right_shift, int row, int col, int stride);
// x0: a(left matrix ptr)
// x1: b(right matrix ptr)
@ -40,13 +40,18 @@
// w11: multiplier
// w12: left_shift
// w13: right_shift
// w14: row
// w15: col
// w24: stride
MatmulInt8Neon64:
sub sp, sp, #160
sub sp, sp, #192
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
stp x19, x20, [sp], #16
stp x21, x22, [sp], #16
stp x23, x24, [sp], #16
stp x25, x26, [sp], #16
ldr w8, [sp]
ldr w9, [sp, #8]
@ -54,25 +59,28 @@ MatmulInt8Neon64:
ldr w11, [sp, #24]
ldr w12, [sp, #32]
ldr w13, [sp, #40]
ldr w14, [sp, #48]
ldr w15, [sp, #56]
ldr w24, [sp, #64]
mov w15, #0 // b col index
mov w16, #0 // a row index
mov w17, #4 // sizeof(int8)*4
mul w21, w5, w17 // the stride of a/b: sizeof(int8)*4*deep16
mov w17, #1
mov x25, x2
L1:
cmp w15, w4
cmp w4, #0 // if at the end of col4
beq End1
mov w16, #0 // reset a row index
mov w16, w3 // reset a row4 counter
mov w23, w14 // reset a row counter
mov x17, x0 // reload a ptr
mov x22, x6 // reload a_sums ptr
L2:
cmp w16, w3
cmp w16, #0
beq End2
mov x18, x1 // reload b ptr
mov x19, x7 // reload bias ptr
mov x19, x7 // reload bias ptr
mov w20, w5 // reload depth
dup v16.4s, wzr
dup v17.4s, wzr
@ -256,21 +264,128 @@ End3:
sqxtn v15.8b, v13.8h
sqxtn2 v15.16b, v14.8h
st1 {v15.16b}, [x2], #16
add w16, w16, #4 // a row index + 4
cmp w23, #4
blt Write // if rows < 4
cmp w15, #4
blt Write // if cols < 4
st1 {v15.s}[0], [x2], x24
st1 {v15.s}[1], [x2], x24
st1 {v15.s}[2], [x2], x24
st1 {v15.s}[3], [x2], x24
b Endwrite
Write:
cmp w15, #4
beq WriteCol4
cmp w15, #3
beq WriteCol3
cmp w15, #2
beq WriteCol2
cmp w15, #1
beq WriteCol1
WriteCol4:
st1 {v15.s}[0], [x2], x24
cmp w23, #1
beq Endwrite
st1 {v15.s}[1], [x2], x24
cmp w23, #2
beq Endwrite
st1 {v15.s}[2], [x2], x24
cmp w23, #3
beq Endwrite
st1 {v15.s}[3], [x2], x24
b Endwrite
WriteCol3:
mov x26, x2
st1 {v15.b}[0], [x26], #1
st1 {v15.b}[1], [x26], #1
st1 {v15.b}[2], [x26], #1
add x2, x2, x24
cmp w23, #1
beq Endwrite
mov x26, x2
st1 {v15.b}[4], [x26], #1
st1 {v15.b}[5], [x26], #1
st1 {v15.b}[6], [x26], #1
add x2, x2, x24
cmp w23, #2
beq Endwrite
mov x26, x2
st1 {v15.b}[8], [x26], #1
st1 {v15.b}[9], [x26], #1
st1 {v15.b}[10], [x26], #1
add x2, x2, x24
cmp w23, #3
beq Endwrite
mov x26, x2
st1 {v15.b}[12], [x26], #1
st1 {v15.b}[13], [x26], #1
st1 {v15.b}[14], [x26], #1
add x2, x2, x24
b Endwrite
WriteCol2:
mov x26, x2
st1 {v15.b}[0], [x26], #1
st1 {v15.b}[1], [x26], #1
add x2, x2, x24
cmp w23, #1
beq Endwrite
mov x26, x2
st1 {v15.b}[4], [x26], #1
st1 {v15.b}[5], [x26], #1
add x2, x2, x24
cmp w23, #2
beq Endwrite
mov x26, x2
st1 {v15.b}[8], [x26], #1
st1 {v15.b}[9], [x26], #1
add x2, x2, x24
cmp w23, #3
beq Endwrite
mov x26, x2
st1 {v15.b}[12], [x26], #1
st1 {v15.b}[13], [x26], #1
add x2, x2, x24
b Endwrite
WriteCol1:
st1 {v15.b}[0], [x2], x24
cmp w23, #1
beq Endwrite
st1 {v15.b}[4], [x2], x24
cmp w23, #2
beq Endwrite
st1 {v15.b}[8], [x2], x24
cmp w23, #3
beq Endwrite
st1 {v15.b}[12], [x2], x24
b Endwrite
Endwrite:
sub w16, w16, #4 // a row4 counter - 4
sub w23, w23, #4 // a row counter - 4
b L2
End2:
add w15, w15, #4 // b col index + 4
sub w4, w4, #4 // b col4 counter - 4
sub w15, w15, #4 // b col counter - 4
add x1, x1, x21 // b ptr + stride
add x7, x7, #16 // bias ptr + stride
add x25, x25, #4 // output + stride(4 * sizeof(int8))
mov x2, x25
b L1
End1:
sub sp, sp, #160
sub sp, sp, #192
ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
ldp x19, x20, [sp], #16
ldp x21, x22, [sp], #16
ldp x23, x24, [sp], #16
ldp x25, x26, [sp], #16
ret
#endif

View File

@ -0,0 +1,117 @@
#ifdef __aarch64__
.text
.align 5
.global ConvDwFp16Row
#ifndef __APPLE__
.type ConvDwFp16Row, %function
#endif
// void ConvDwFp16Row(float16_t* output_ptr, const float16_t* input_ptr,const float16_t* filter_ptr,
// size_t num_pixels, size_t input_channel, size_t input_step)
// x0: output_ptr, x1: input_ptr, x2: filter_ptr, x3: num_pixels,
// x4: input_channel, x5: input_step
//
ConvDwFp16Row:
// registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to
// https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers
// x19 ~ x29 should be also preserved
// whereas our coding style do not permit such amount of parameters
cmp x3, #0
beq End
mov x9, x0
mov x12, #2 // sizeof(float16_t)
mul x5, x5, x12
LoopOutPixel:
mov x6, x1
mov x7, x2
mov x8, x4
LoopInputDepth32In:
cmp x8, #32
blt Loop8
sub x8, x8, #32
ld1 {v0.8h, v1.8h}, [x6], #32
ld1 {v2.8h, v3.8h}, [x7], #32
ld1 {v16.8h, v17.8h}, [x0], #32
cmp x8, #32
blt LoopInputDepth32Out
LoopInputDepth32:
fmla v16.8h, v0.8h, v2.8h
fmla v17.8h, v1.8h, v3.8h
st1 {v16.8h, v17.8h}, [x9], #32
ld1 {v4.8h, v5.8h}, [x6], #32
ld1 {v6.8h, v7.8h}, [x7], #32
ld1 {v18.8h, v19.8h}, [x0], #32
fmla v18.8h, v4.8h, v6.8h
fmla v19.8h, v5.8h, v7.8h
st1 {v18.8h, v19.8h}, [x9], #32
ld1 {v0.8h, v1.8h}, [x6], #32
ld1 {v2.8h, v3.8h}, [x7], #32
ld1 {v16.8h, v17.8h}, [x0], #32
sub x8, x8, #32
cmp x8, #32
bge LoopInputDepth32
LoopInputDepth32Out:
fmla v16.8h, v0.8h, v2.8h
fmla v17.8h, v1.8h, v3.8h
st1 {v16.8h, v17.8h}, [x9], #32
ld1 {v4.8h, v5.8h}, [x6], #32
ld1 {v6.8h, v7.8h}, [x7], #32
ld1 {v18.8h, v19.8h}, [x0], #32
fmla v18.8h, v4.8h, v6.8h
fmla v19.8h, v5.8h, v7.8h
st1 {v18.8h, v19.8h}, [x9], #32
Loop8:
cmp x8, #8
blt L0
LoopInputDepth8:
ld1 {v0.8h}, [x6], #16
ld1 {v2.8h}, [x7], #16
ld1 {v16.8h}, [x0], #16
fmla v16.8h, v0.8h, v2.8h
st1 {v16.8h}, [x9], #16
sub x8, x8, #8
cmp x8, #8
bge LoopInputDepth8
L0:
cmp x8, #0
beq Loop8LineEnd
LoopInputDepth0:
ldr h0, [x6], #2
ldr h1, [x7], #2
ldr h2, [x0], #2
fmul h0, h0, h1
fadd h2, h2, h0
str h2, [x9], #2
subs x8, x8, #1
bne LoopInputDepth0
Loop8LineEnd:
subs x3, x3, #1
add x1, x1, x5
bne LoopOutPixel
End:
ret
#endif

View File

@ -36,7 +36,7 @@ IndirectGemmInt8_24x4_dp:
ld1 {v17.4s}, [x22], x23
ld1 {v18.4s}, [x22], x23
ld1 {v19.4s}, [x22], x23
ld1{v20.4s}, [x22], x23
ld1 {v20.4s}, [x22], x23
ld1 {v21.4s}, [x22], x23
ld1 {v22.4s}, [x22], x23
ld1 {v23.4s}, [x22], x23
@ -404,7 +404,7 @@ IndirectGemmInt8_24x4_dp:
sshr v0.4s, v0.4s, #31
sqadd v8.4s, v8.4s, v0.4s
srshl v8.4s, v8.4s, v4.4s
and v0.16b, v4.16b, v9.16b
and v1.16b, v4.16b, v9.16b
sshr v1.4s, v1.4s, #31
sqadd v9.4s, v9.4s, v1.4s
srshl v9.4s, v9.4s, v4.4s
@ -420,7 +420,7 @@ IndirectGemmInt8_24x4_dp:
sshr v0.4s, v0.4s, #31
sqadd v12.4s, v12.4s, v0.4s
srshl v12.4s, v12.4s, v4.4s
and v0.16b, v4.16b, v13.16b
and v1.16b, v4.16b, v13.16b
sshr v1.4s, v1.4s, #31
sqadd v13.4s, v13.4s, v1.4s
srshl v13.4s, v13.4s, v4.4s
@ -436,7 +436,7 @@ IndirectGemmInt8_24x4_dp:
sshr v0.4s, v0.4s, #31
sqadd v16.4s, v16.4s, v0.4s
srshl v16.4s, v16.4s, v4.4s
and v0.16b, v4.16b, v17.16b
and v1.16b, v4.16b, v17.16b
sshr v1.4s, v1.4s, #31
sqadd v17.4s, v17.4s, v1.4s
srshl v17.4s, v17.4s, v4.4s
@ -452,7 +452,7 @@ IndirectGemmInt8_24x4_dp:
sshr v0.4s, v0.4s, #31
sqadd v20.4s, v20.4s, v0.4s
srshl v20.4s, v20.4s, v4.4s
and v0.16b, v4.16b, v21.16b
and v1.16b, v4.16b, v21.16b
sshr v1.4s, v1.4s, #31
sqadd v21.4s, v21.4s, v1.4s
srshl v21.4s, v21.4s, v4.4s
@ -468,7 +468,7 @@ IndirectGemmInt8_24x4_dp:
sshr v0.4s, v0.4s, #31
sqadd v24.4s, v24.4s, v0.4s
srshl v24.4s, v24.4s, v4.4s
and v0.16b, v4.16b, v25.16b
and v1.16b, v4.16b, v25.16b
sshr v1.4s, v1.4s, #31
sqadd v25.4s, v25.4s, v1.4s
srshl v25.4s, v25.4s, v4.4s
@ -484,7 +484,7 @@ IndirectGemmInt8_24x4_dp:
sshr v0.4s, v0.4s, #31
sqadd v28.4s, v28.4s, v0.4s
srshl v28.4s, v28.4s, v4.4s
and v0.16b, v4.16b, v29.16b
and v1.16b, v4.16b, v29.16b
sshr v1.4s, v1.4s, #31
sqadd v29.4s, v29.4s, v1.4s
srshl v29.4s, v29.4s, v4.4s

View File

@ -0,0 +1,820 @@
#ifdef __aarch64__
.text
.align 5
.global MatmulInt8DpNeon64
#ifndef __APPLE__
.type MatmulInt8DpNeon64, %function
#endif
//
// int8 RHS 4x8 block
// /-----------------------------------------|
// |v2.b[0] ... v2.b[12] v3.b[0] ... v3.b[12]|
// | ... ... |
// |v2.b[3] ... v2.b[15] v3.b[3] ... v3.b[15]|
// \-----------------------------------------/
// int8 LHS 8x4 block
// /---------------------\ /-------------------------------------------|
// |v0.b[0] ... v0.b[3] | |v16.s[0] ... v16.s[3] v17.s[0] ... v17.s[3]|
// |v0.b[4] ... v0.b[7] |v18.s[0] ... v18.s[3] v19.s[0] ... v19.s[3]|
// |v0.b[8] ... v0.b[11] |v20.s[0] ... v20.s[3] v21.s[0] ... v21.s[3]|
// |v0.b[12] ... v0.b[15]| |v22.s[0] ... v22.s[3] v23.s[0] ... v23.s[3]|
// |v1.b[0] ... v1.b[3] | |v24.s[0] ... v24.s[3] v25.s[0] ... v25.s[3]|
// |v1.b[4] ... v1.b[7] | |v26.s[0] ... v26.s[3] v27.s[0] ... v27.s[3]|
// |v1.b[8] ... v1.b[11]| |v28.s[0] ... v28.s[3] v29.s[0] ... v29.s[3]|
// |v1.b[12] ... v1.b[15]| |v30.s[0] ... v30.s[3] v31.s[0] ... v31.s[3]|
// \---------------------/ \-------------------------------------------/
// int32 accumulators 8x8 block
// int8 RHS 16x8 block
// /-------------|
// |v2 v3 |
// |v6 v7 |
// |v10 v11 |
// |v14 v15 |
// \-------------/
// int8 LHS 8x16 block
// /--------------------\ /-------------|
// |v0 v4 v8 v12| | |
// |v1 v5 v9 v13| | |
// \--------------------/ \-------------/
//void MatmulInt8DpNeon64(const int8_t *a, const int8_t *b, int8_t *dst, int row8, int col8, int deep4,
// const int *a_sums, const int *bias, int act_min, int act_max, int out_zp,
// int multiplier, int left_shift, int right_shift, int row, int col, int stride);
// x0: a(left matrix ptr)
// x1: b(right matrix ptr)
// x2: out ptr
// w3: row8
// w4: col8
// w5: deep4
// x6: a_sums
// x7: bias
// w8: act_min
// w9: act_max
// w10: out_zp
// w11: multiplier
// w12: left_shift
// w13: right_shift
// w14: row
// w15: col
// w24: stride
MatmulInt8DpNeon64:
sub sp, sp, #192
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
stp x19, x20, [sp], #16
stp x21, x22, [sp], #16
stp x23, x24, [sp], #16
stp x25, x26, [sp], #16
ldr w8, [sp]
ldr w9, [sp, #8]
ldr w10, [sp, #16]
ldr w11, [sp, #24]
ldr w12, [sp, #32]
ldr w13, [sp, #40]
ldr w14, [sp, #48]
ldr w15, [sp, #56]
ldr w24, [sp, #64]
mov w17, #8 // sizeof(int8)*8
mul w21, w5, w17 // the stride of a/b: sizeof(int8)*8*deep4
mov x25, x2
L1:
cmp w4, #0 // if at the end of col8
beq End1
mov w16, w3 // reset a row8 counter
mov w23, w14 // reset a row counter
mov x17, x0 // reload a ptr
mov x22, x6 // reload a_sums ptr
L2:
cmp w16, #0
beq End2
mov x18, x1 // reload b ptr
mov x19, x7 // reload bias ptr
mov w20, w5 // reload depth
dup v16.4s, wzr
dup v17.4s, wzr
dup v18.4s, wzr
dup v19.4s, wzr
dup v20.4s, wzr
dup v21.4s, wzr
dup v22.4s, wzr
dup v23.4s, wzr
dup v24.4s, wzr
dup v25.4s, wzr
dup v26.4s, wzr
dup v27.4s, wzr
dup v28.4s, wzr
dup v29.4s, wzr
dup v30.4s, wzr
dup v31.4s, wzr
L3:
cmp w20, #16
blt LoopD4
LoopD16:
ld1 {v0.16b, v1.16b}, [x17], #32
ld1 {v2.16b, v3.16b}, [x18], #32
sdot v16.4s, v2.16b, v0.4b[0]
sdot v18.4s, v2.16b, v0.4b[1]
sdot v20.4s, v2.16b, v0.4b[2]
sdot v22.4s, v2.16b, v0.4b[3]
ld1 {v4.16b, v5.16b}, [x17], #32
sdot v24.4s, v2.16b, v1.4b[0]
sdot v26.4s, v2.16b, v1.4b[1]
sdot v28.4s, v2.16b, v1.4b[2]
sdot v30.4s, v2.16b, v1.4b[3]
ld1 {v6.16b, v7.16b}, [x18], #32
sdot v17.4s, v3.16b, v0.4b[0]
sdot v19.4s, v3.16b, v0.4b[1]
sdot v21.4s, v3.16b, v0.4b[2]
sdot v23.4s, v3.16b, v0.4b[3]
sdot v25.4s, v3.16b, v1.4b[0]
sdot v27.4s, v3.16b, v1.4b[1]
sdot v29.4s, v3.16b, v1.4b[2]
sdot v31.4s, v3.16b, v1.4b[3]
ld1 {v8.16b, v9.16b}, [x17], #32
sdot v16.4s, v6.16b, v4.4b[0]
sdot v18.4s, v6.16b, v4.4b[1]
sdot v20.4s, v6.16b, v4.4b[2]
sdot v22.4s, v6.16b, v4.4b[3]
sdot v24.4s, v6.16b, v5.4b[0]
sdot v26.4s, v6.16b, v5.4b[1]
sdot v28.4s, v6.16b, v5.4b[2]
sdot v30.4s, v6.16b, v5.4b[3]
ld1 {v10.16b, v11.16b}, [x18], #32
sdot v17.4s, v7.16b, v4.4b[0]
sdot v19.4s, v7.16b, v4.4b[1]
sdot v21.4s, v7.16b, v4.4b[2]
sdot v23.4s, v7.16b, v4.4b[3]
sdot v25.4s, v7.16b, v5.4b[0]
sdot v27.4s, v7.16b, v5.4b[1]
sdot v29.4s, v7.16b, v5.4b[2]
sdot v31.4s, v7.16b, v5.4b[3]
ld1 {v12.16b, v13.16b}, [x17], #32
sdot v16.4s, v10.16b, v8.4b[0]
sdot v18.4s, v10.16b, v8.4b[1]
sdot v20.4s, v10.16b, v8.4b[2]
sdot v22.4s, v10.16b, v8.4b[3]
sdot v24.4s, v10.16b, v9.4b[0]
sdot v26.4s, v10.16b, v9.4b[1]
sdot v28.4s, v10.16b, v9.4b[2]
sdot v30.4s, v10.16b, v9.4b[3]
ld1 {v14.16b, v15.16b}, [x18], #32
sdot v17.4s, v11.16b, v8.4b[0]
sdot v19.4s, v11.16b, v8.4b[1]
sdot v21.4s, v11.16b, v8.4b[2]
sdot v23.4s, v11.16b, v8.4b[3]
sdot v25.4s, v11.16b, v9.4b[0]
sdot v27.4s, v11.16b, v9.4b[1]
sdot v29.4s, v11.16b, v9.4b[2]
sdot v31.4s, v11.16b, v9.4b[3]
sdot v16.4s, v14.16b, v12.4b[0]
sdot v18.4s, v14.16b, v12.4b[1]
sdot v20.4s, v14.16b, v12.4b[2]
sdot v22.4s, v14.16b, v12.4b[3]
sdot v24.4s, v14.16b, v13.4b[0]
sdot v26.4s, v14.16b, v13.4b[1]
sdot v28.4s, v14.16b, v13.4b[2]
sdot v30.4s, v14.16b, v13.4b[3]
sdot v17.4s, v15.16b, v12.4b[0]
sdot v19.4s, v15.16b, v12.4b[1]
sdot v21.4s, v15.16b, v12.4b[2]
sdot v23.4s, v15.16b, v12.4b[3]
sdot v25.4s, v15.16b, v13.4b[0]
sdot v27.4s, v15.16b, v13.4b[1]
sdot v29.4s, v15.16b, v13.4b[2]
sdot v31.4s, v15.16b, v13.4b[3]
subs w20, w20, #16 // depth - 16
b L3
LoopD4:
cmp w20, #0
beq End3
ld1 {v0.16b, v1.16b}, [x17], #32
ld1 {v2.16b, v3.16b}, [x18], #32
sdot v16.4s, v2.16b, v0.4b[0]
sdot v18.4s, v2.16b, v0.4b[1]
sdot v20.4s, v2.16b, v0.4b[2]
sdot v22.4s, v2.16b, v0.4b[3]
sdot v24.4s, v2.16b, v1.4b[0]
sdot v26.4s, v2.16b, v1.4b[1]
sdot v28.4s, v2.16b, v1.4b[2]
sdot v30.4s, v2.16b, v1.4b[3]
sdot v17.4s, v3.16b, v0.4b[0]
sdot v19.4s, v3.16b, v0.4b[1]
sdot v21.4s, v3.16b, v0.4b[2]
sdot v23.4s, v3.16b, v0.4b[3]
sdot v25.4s, v3.16b, v1.4b[0]
sdot v27.4s, v3.16b, v1.4b[1]
sdot v29.4s, v3.16b, v1.4b[2]
sdot v31.4s, v3.16b, v1.4b[3]
subs w20, w20, #4 // depth - 4
b LoopD4
End3:
// Add (Bias+Depth*Za*Zb-Za*Bsums)
ld1 {v15.4s}, [x19], #16
ld1 {v14.4s}, [x19], #16
add v16.4s, v16.4s, v15.4s
add v18.4s, v18.4s, v15.4s
add v20.4s, v20.4s, v15.4s
add v22.4s, v22.4s, v15.4s
add v24.4s, v24.4s, v15.4s
add v26.4s, v26.4s, v15.4s
add v28.4s, v28.4s, v15.4s
add v30.4s, v30.4s, v15.4s
add v17.4s, v17.4s, v14.4s
add v19.4s, v19.4s, v14.4s
add v21.4s, v21.4s, v14.4s
add v23.4s, v23.4s, v14.4s
add v25.4s, v25.4s, v14.4s
add v27.4s, v27.4s, v14.4s
add v29.4s, v29.4s, v14.4s
add v31.4s, v31.4s, v14.4s
// Subtract (Asums*Zb)
ld1 {v13.4s}, [x22], #16
ld1 {v12.4s}, [x22], #16
dup v0.4s, v13.s[0]
dup v1.4s, v13.s[1]
dup v2.4s, v13.s[2]
dup v3.4s, v13.s[3]
dup v4.4s, v12.s[0]
dup v5.4s, v12.s[1]
dup v6.4s, v12.s[2]
dup v7.4s, v12.s[3]
sub v16.4s, v16.4s, v0.4s
sub v17.4s, v17.4s, v0.4s
sub v18.4s, v18.4s, v1.4s
sub v19.4s, v19.4s, v1.4s
sub v20.4s, v20.4s, v2.4s
sub v21.4s, v21.4s, v2.4s
sub v22.4s, v22.4s, v3.4s
sub v23.4s, v23.4s, v3.4s
sub v24.4s, v24.4s, v4.4s
sub v25.4s, v25.4s, v4.4s
sub v26.4s, v26.4s, v5.4s
sub v27.4s, v27.4s, v5.4s
sub v28.4s, v28.4s, v6.4s
sub v29.4s, v29.4s, v6.4s
sub v30.4s, v30.4s, v7.4s
sub v31.4s, v31.4s, v7.4s
// Apply left shift
dup v11.4s, w12
sqshl v16.4s, v16.4s, v11.4s
sqshl v17.4s, v17.4s, v11.4s
sqshl v18.4s, v18.4s, v11.4s
sqshl v19.4s, v19.4s, v11.4s
sqshl v20.4s, v20.4s, v11.4s
sqshl v21.4s, v21.4s, v11.4s
sqshl v22.4s, v22.4s, v11.4s
sqshl v23.4s, v23.4s, v11.4s
sqshl v24.4s, v24.4s, v11.4s
sqshl v25.4s, v25.4s, v11.4s
sqshl v26.4s, v26.4s, v11.4s
sqshl v27.4s, v27.4s, v11.4s
sqshl v28.4s, v28.4s, v11.4s
sqshl v29.4s, v29.4s, v11.4s
sqshl v30.4s, v30.4s, v11.4s
sqshl v31.4s, v31.4s, v11.4s
// Apply the fixed-point part of the multiplier.
dup v10.4s, w11
sqrdmulh v16.4s, v16.4s, v10.4s
sqrdmulh v17.4s, v17.4s, v10.4s
sqrdmulh v18.4s, v18.4s, v10.4s
sqrdmulh v19.4s, v19.4s, v10.4s
sqrdmulh v20.4s, v20.4s, v10.4s
sqrdmulh v21.4s, v21.4s, v10.4s
sqrdmulh v22.4s, v22.4s, v10.4s
sqrdmulh v23.4s, v23.4s, v10.4s
sqrdmulh v24.4s, v24.4s, v10.4s
sqrdmulh v25.4s, v25.4s, v10.4s
sqrdmulh v26.4s, v26.4s, v10.4s
sqrdmulh v27.4s, v27.4s, v10.4s
sqrdmulh v28.4s, v28.4s, v10.4s
sqrdmulh v29.4s, v29.4s, v10.4s
sqrdmulh v30.4s, v30.4s, v10.4s
sqrdmulh v31.4s, v31.4s, v10.4s
// Apply right shift
dup v9.4s, w13
and v0.16b, v9.16b, v16.16b
sshr v0.4s, v0.4s, #31
sqadd v16.4s, v16.4s, v0.4s
srshl v16.4s, v16.4s, v9.4s
and v1.16b, v9.16b, v17.16b
sshr v1.4s, v1.4s, #31
sqadd v17.4s, v17.4s, v1.4s
srshl v17.4s, v17.4s, v9.4s
and v2.16b, v9.16b, v18.16b
sshr v2.4s, v2.4s, #31
sqadd v18.4s, v18.4s, v2.4s
srshl v18.4s, v18.4s, v9.4s
and v3.16b, v9.16b, v19.16b
sshr v3.4s, v3.4s, #31
sqadd v19.4s, v19.4s, v3.4s
srshl v19.4s, v19.4s, v9.4s
and v0.16b, v9.16b, v20.16b
sshr v0.4s, v0.4s, #31
sqadd v20.4s, v20.4s, v0.4s
srshl v20.4s, v20.4s, v9.4s
and v1.16b, v9.16b, v21.16b
sshr v1.4s, v1.4s, #31
sqadd v21.4s, v21.4s, v1.4s
srshl v21.4s, v21.4s, v9.4s
and v2.16b, v9.16b, v22.16b
sshr v2.4s, v2.4s, #31
sqadd v22.4s, v22.4s, v2.4s
srshl v22.4s, v22.4s, v9.4s
and v3.16b, v9.16b, v23.16b
sshr v3.4s, v3.4s, #31
sqadd v23.4s, v23.4s, v3.4s
srshl v23.4s, v23.4s, v9.4s
and v0.16b, v9.16b, v24.16b
sshr v0.4s, v0.4s, #31
sqadd v24.4s, v24.4s, v0.4s
srshl v24.4s, v24.4s, v9.4s
and v1.16b, v9.16b, v25.16b
sshr v1.4s, v1.4s, #31
sqadd v25.4s, v25.4s, v1.4s
srshl v25.4s, v25.4s, v9.4s
and v2.16b, v9.16b, v26.16b
sshr v2.4s, v2.4s, #31
sqadd v26.4s, v26.4s, v2.4s
srshl v26.4s, v26.4s, v9.4s
and v3.16b, v9.16b, v27.16b
sshr v3.4s, v3.4s, #31
sqadd v27.4s, v27.4s, v3.4s
srshl v27.4s, v27.4s, v9.4s
and v0.16b, v9.16b, v28.16b
sshr v0.4s, v0.4s, #31
sqadd v28.4s, v28.4s, v0.4s
srshl v28.4s, v28.4s, v9.4s
and v1.16b, v9.16b, v29.16b
sshr v1.4s, v1.4s, #31
sqadd v29.4s, v29.4s, v1.4s
srshl v29.4s, v29.4s, v9.4s
and v2.16b, v9.16b, v30.16b
sshr v2.4s, v2.4s, #31
sqadd v30.4s, v30.4s, v2.4s
srshl v30.4s, v30.4s, v9.4s
and v3.16b, v9.16b, v31.16b
sshr v3.4s, v3.4s, #31
sqadd v31.4s, v31.4s, v3.4s
srshl v31.4s, v31.4s, v9.4s
// Add the destination zero point
dup v8.4s, w10
add v16.4s, v16.4s, v8.4s
add v17.4s, v17.4s, v8.4s
add v18.4s, v18.4s, v8.4s
add v19.4s, v19.4s, v8.4s
add v20.4s, v20.4s, v8.4s
add v21.4s, v21.4s, v8.4s
add v22.4s, v22.4s, v8.4s
add v23.4s, v23.4s, v8.4s
add v24.4s, v24.4s, v8.4s
add v25.4s, v25.4s, v8.4s
add v26.4s, v26.4s, v8.4s
add v27.4s, v27.4s, v8.4s
add v28.4s, v28.4s, v8.4s
add v29.4s, v29.4s, v8.4s
add v30.4s, v30.4s, v8.4s
add v31.4s, v31.4s, v8.4s
// Apply the act_min bound
dup v7.4s, w8
smax v16.4s, v16.4s, v7.4s
smax v17.4s, v17.4s, v7.4s
smax v18.4s, v18.4s, v7.4s
smax v19.4s, v19.4s, v7.4s
// Apply the act_min bound
dup v6.4s, w9
smin v16.4s, v16.4s, v6.4s
smin v17.4s, v17.4s, v6.4s
smin v18.4s, v18.4s, v6.4s
smin v19.4s, v19.4s, v6.4s
// int32 -> int16
sqxtn v0.4h, v16.4s
sqxtn2 v0.8h, v17.4s
sqxtn v1.4h, v18.4s
sqxtn2 v1.8h, v19.4s
sqxtn v2.4h, v20.4s
sqxtn2 v2.8h, v21.4s
sqxtn v3.4h, v22.4s
sqxtn2 v3.8h, v23.4s
sqxtn v4.4h, v24.4s
sqxtn2 v4.8h, v25.4s
sqxtn v5.4h, v26.4s
sqxtn2 v5.8h, v27.4s
sqxtn v6.4h, v28.4s
sqxtn2 v6.8h, v29.4s
sqxtn v7.4h, v30.4s
sqxtn2 v7.8h, v31.4s
// int16 -> int8
sqxtn v8.8b, v0.8h
sqxtn2 v8.16b, v1.8h
sqxtn v9.8b, v2.8h
sqxtn2 v9.16b, v3.8h
sqxtn v10.8b, v4.8h
sqxtn2 v10.16b, v5.8h
sqxtn v11.8b, v6.8h
sqxtn2 v11.16b, v7.8h
cmp w23, #8
blt Write // if rows < 8
cmp w15, #8
blt Write // if cols < 8
st1 {v8.d}[0], [x2], x24
st1 {v8.d}[1], [x2], x24
st1 {v9.d}[0], [x2], x24
st1 {v9.d}[1], [x2], x24
st1 {v10.d}[0], [x2], x24
st1 {v10.d}[1], [x2], x24
st1 {v11.d}[0], [x2], x24
st1 {v11.d}[1], [x2], x24
b Endwrite
Write:
cmp w15, #8
bge WriteCol8
cmp w15, #7
beq WriteCol7
cmp w15, #6
beq WriteCol6
cmp w15, #5
beq WriteCol5
cmp w15, #4
beq WriteCol4
cmp w15, #3
beq WriteCol3
cmp w15, #2
beq WriteCol2
cmp w15, #1
beq WriteCol1
WriteCol8:
st1 {v8.d}[0], [x2], x24
cmp w23, #1
beq Endwrite
st1 {v8.d}[1], [x2], x24
cmp w23, #2
beq Endwrite
st1 {v9.d}[0], [x2], x24
cmp w23, #3
beq Endwrite
st1 {v9.d}[1], [x2], x24
cmp w23, #4
beq Endwrite
st1 {v10.d}[0], [x2], x24
cmp w23, #5
beq Endwrite
st1 {v10.d}[1], [x2], x24
cmp w23, #6
beq Endwrite
st1 {v11.d}[0], [x2], x24
cmp w23, #7
beq Endwrite
st1 {v11.d}[1], [x2], x24
b Endwrite
WriteCol7:
mov x26, x2
st1 {v8.s}[0], [x26], #4
st1 {v8.h}[2], [x26], #2
st1 {v8.b}[6], [x26], #1
add x2, x2, x24
cmp w23, #1
beq Endwrite
mov x26, x2
st1 {v8.s}[2], [x26], #4
st1 {v8.h}[6], [x26], #2
st1 {v8.b}[14], [x26], #1
add x2, x2, x24
cmp w23, #2
beq Endwrite
mov x26, x2
st1 {v9.s}[0], [x26], #4
st1 {v9.h}[2], [x26], #2
st1 {v9.b}[6], [x26], #1
add x2, x2, x24
cmp w23, #3
beq Endwrite
mov x26, x2
st1 {v9.s}[2], [x26], #4
st1 {v9.h}[6], [x26], #2
st1 {v9.b}[14], [x26], #1
add x2, x2, x24
cmp w23, #4
beq Endwrite
mov x26, x2
st1 {v10.s}[0], [x26], #4
st1 {v10.h}[2], [x26], #2
st1 {v10.b}[6], [x26], #1
add x2, x2, x24
cmp w23, #5
beq Endwrite
mov x26, x2
st1 {v10.s}[2], [x26], #4
st1 {v10.h}[6], [x26], #2
st1 {v10.b}[14], [x26], #1
add x2, x2, x24
cmp w23, #6
beq Endwrite
mov x26, x2
st1 {v11.s}[0], [x26], #4
st1 {v11.h}[2], [x26], #2
st1 {v11.b}[6], [x26], #1
add x2, x2, x24
cmp w23, #7
beq Endwrite
mov x26, x2
st1 {v11.s}[2], [x26], #4
st1 {v11.h}[6], [x26], #2
st1 {v11.b}[14], [x26], #1
add x2, x2, x24
b Endwrite
WriteCol6:
mov x26, x2
st1 {v8.s}[0], [x26], #4
st1 {v8.h}[2], [x26], #2
add x2, x2, x24
cmp w23, #1
beq Endwrite
mov x26, x2
st1 {v8.s}[2], [x26], #4
st1 {v8.h}[6], [x26], #2
add x2, x2, x24
cmp w23, #2
beq Endwrite
mov x26, x2
st1 {v9.s}[0], [x26], #4
st1 {v9.h}[2], [x26], #2
add x2, x2, x24
cmp w23, #3
beq Endwrite
mov x26, x2
st1 {v9.s}[2], [x26], #4
st1 {v9.h}[6], [x26], #2
add x2, x2, x24
cmp w23, #4
beq Endwrite
mov x26, x2
st1 {v10.s}[0], [x26], #4
st1 {v10.h}[2], [x26], #2
add x2, x2, x24
cmp w23, #5
beq Endwrite
mov x26, x2
st1 {v10.s}[2], [x26], #4
st1 {v10.h}[6], [x26], #2
add x2, x2, x24
cmp w23, #6
beq Endwrite
mov x26, x2
st1 {v11.s}[0], [x26], #4
st1 {v11.h}[2], [x26], #2
add x2, x2, x24
cmp w23, #7
beq Endwrite
mov x26, x2
st1 {v11.s}[2], [x26], #4
st1 {v11.h}[6], [x26], #2
add x2, x2, x24
b Endwrite
WriteCol5:
mov x26, x2
st1 {v8.s}[0], [x26], #4
st1 {v8.b}[4], [x26], #1
add x2, x2, x24
cmp w23, #1
beq Endwrite
mov x26, x2
st1 {v8.s}[2], [x26], #4
st1 {v8.b}[12], [x26], #1
add x2, x2, x24
cmp w23, #2
beq Endwrite
mov x26, x2
st1 {v9.s}[0], [x26], #4
st1 {v9.b}[4], [x26], #1
add x2, x2, x24
cmp w23, #3
beq Endwrite
mov x26, x2
st1 {v9.s}[2], [x26], #4
st1 {v9.b}[12], [x26], #1
add x2, x2, x24
cmp w23, #4
beq Endwrite
mov x26, x2
st1 {v10.s}[0], [x26], #4
st1 {v10.b}[4], [x26], #1
add x2, x2, x24
cmp w23, #5
beq Endwrite
mov x26, x2
st1 {v10.s}[2], [x26], #4
st1 {v10.b}[12], [x26], #1
add x2, x2, x24
cmp w23, #6
beq Endwrite
mov x26, x2
st1 {v11.s}[0], [x26], #4
st1 {v11.b}[4], [x26], #1
add x2, x2, x24
cmp w23, #7
beq Endwrite
mov x26, x2
st1 {v11.s}[2], [x26], #4
st1 {v11.b}[12], [x26], #1
add x2, x2, x24
b Endwrite
WriteCol4:
st1 {v8.s}[0], [x2], x24
cmp w23, #1
beq Endwrite
st1 {v8.s}[2], [x2], x24
cmp w23, #2
beq Endwrite
st1 {v9.s}[0], [x2], x24
cmp w23, #3
beq Endwrite
st1 {v9.s}[2], [x2], x24
cmp w23, #4
beq Endwrite
st1 {v10.s}[0], [x2], x24
cmp w23, #5
beq Endwrite
st1 {v10.s}[2], [x2], x24
cmp w23, #6
beq Endwrite
st1 {v11.s}[0], [x2], x24
cmp w23, #7
beq Endwrite
st1 {v11.s}[2], [x2], x24
b Endwrite
WriteCol3:
mov x26, x2
st1 {v8.h}[0], [x26], #2
st1 {v8.b}[2], [x26], #1
add x2, x2, x24
cmp w23, #1
beq Endwrite
mov x26, x2
st1 {v8.h}[4], [x26], #2
st1 {v8.b}[10], [x26], #1
add x2, x2, x24
cmp w23, #2
beq Endwrite
mov x26, x2
st1 {v9.h}[0], [x26], #2
st1 {v9.b}[2], [x26], #1
add x2, x2, x24
cmp w23, #3
beq Endwrite
mov x26, x2
st1 {v9.h}[4], [x26], #2
st1 {v9.b}[10], [x26], #1
add x2, x2, x24
cmp w23, #4
beq Endwrite
mov x26, x2
st1 {v10.h}[0], [x26], #2
st1 {v10.b}[2], [x26], #1
add x2, x2, x24
cmp w23, #5
beq Endwrite
mov x26, x2
st1 {v10.h}[4], [x26], #2
st1 {v10.b}[10], [x26], #1
add x2, x2, x24
cmp w23, #6
beq Endwrite
mov x26, x2
st1 {v11.h}[0], [x26], #2
st1 {v11.b}[2], [x26], #1
add x2, x2, x24
cmp w23, #7
beq Endwrite
mov x26, x2
st1 {v11.h}[4], [x26], #2
st1 {v11.b}[10], [x26], #1
add x2, x2, x24
b Endwrite
WriteCol2:
st1 {v8.h}[0], [x2], x24
cmp w23, #1
beq Endwrite
st1 {v8.h}[4], [x2], x24
cmp w23, #2
beq Endwrite
st1 {v9.h}[0], [x2], x24
cmp w23, #3
beq Endwrite
st1 {v9.h}[4], [x2], x24
cmp w23, #4
beq Endwrite
st1 {v10.h}[0], [x2], x24
cmp w23, #5
beq Endwrite
st1 {v10.h}[4], [x2], x24
cmp w23, #6
beq Endwrite
st1 {v11.h}[0], [x2], x24
cmp w23, #7
beq Endwrite
st1 {v11.h}[4], [x2], x24
b Endwrite
WriteCol1:
st1 {v8.b}[0], [x2], x24
cmp w23, #1
beq Endwrite
st1 {v8.b}[8], [x2], x24
cmp w23, #2
beq Endwrite
st1 {v9.b}[0], [x2], x24
cmp w23, #3
beq Endwrite
st1 {v9.b}[8], [x2], x24
cmp w23, #4
beq Endwrite
st1 {v10.b}[0], [x2], x24
cmp w23, #5
beq Endwrite
st1 {v10.b}[8], [x2], x24
cmp w23, #6
beq Endwrite
st1 {v11.b}[0], [x2], x24
cmp w23, #7
beq Endwrite
st1 {v11.b}[8], [x2], x24
b Endwrite
Endwrite:
sub w16, w16, #8 // a row8 counter - 8
sub w23, w23, #8 // a row counter - 8
b L2
End2:
sub w4, w4, #8 // b col8 counter - 8
sub w15, w15, #8 // b col counter - 8
add x1, x1, x21 // b ptr + stride
add x7, x7, #32 // bias ptr + stride
add x25, x25, #8 // output + stride(8 * sizeof(int8))
mov x2, x25
b L1
End1:
sub sp, sp, #192
ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
ldp x19, x20, [sp], #16
ldp x21, x22, [sp], #16
ldp x23, x24, [sp], #16
ldp x25, x26, [sp], #16
ret
#endif

View File

@ -228,19 +228,3 @@ void IndirectGemmFp32_Comm(float *output, const float *input, const float *weigh
return;
}
void SimplePostFuncInt8(const int *in, int8_t *out, int oc, int plane, int plane8, int32_t multiplier,
int32_t left_shift, int32_t right_shift, int32_t zp) {
/* (int32_t)row8x8-major * multiplier => (int8_t)row-major */
for (int r = 0; r < plane; r++) {
for (int c = 0; c < oc; c++) {
int c8div = c / 8, c8mod = c % 8;
int src_index = c8div * plane8 * 8 + r * 8 + c8mod;
int dst_index = r * oc + c;
int32_t value = in[src_index];
value = MultiplyByQuantizedMultiplier(value, multiplier, left_shift, right_shift) + zp;
value = MSMIN(CHAR_MAX, value);
value = MSMAX(CHAR_MIN, value);
out[dst_index] = (int8_t)value;
}
}
}

View File

@ -32,8 +32,6 @@ typedef struct ConvParameter {
int stride_w_;
int dilation_h_;
int dilation_w_;
int pad_h_;
int pad_w_;
int pad_u_;
int pad_d_;
int pad_l_;
@ -51,8 +49,7 @@ typedef struct ConvParameter {
int thread_num_;
int input_unit_;
int output_unit_;
bool is_relu_;
bool is_relu6_;
ActType act_type_;
} ConvParameter;
typedef struct SlidingWindowParam {

View File

@ -0,0 +1,98 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "nnacl/fp16/activation_fp16.h"
#include "nnacl/errorcode.h"
int ReluFp16(const float16_t *src, float16_t *dst, int ele_num) {
int eight_block = UP_DIV(ele_num, C8NUM);
int i;
for (i = 0; i < eight_block - 1; i++) {
int index = i * C8NUM;
#ifdef ENABLE_NEON
float16x8_t relu_src = vld1q_f16(src + index);
float16x8_t zero_src = vdupq_n_f16(0);
relu_src = vmaxq_f16(relu_src, zero_src);
vst1q_f16(dst + index, relu_src);
#else
int j;
for (j = 0; j < C8NUM; j++) {
dst[index + j] = src[index + j] < 0 ? 0 : src[index + j];
}
#endif
}
for (int j = (eight_block - 1) * C8NUM; j < ele_num; ++j) {
dst[j] = src[j] < 0 ? 0 : src[j];
}
return NNACL_OK;
}
int Relu6Fp16(const float16_t *data, float16_t *dst, int ele_num) {
int eight_block = UP_DIV(ele_num, C8NUM);
int i;
for (i = 0; i < eight_block - 1; i++) {
int index = i * C8NUM;
#ifdef ENABLE_NEON
float16x8_t relu6_data = vld1q_f16(data + index);
float16x8_t zero_data = vdupq_n_f16(0);
float16x8_t six_data = vdupq_n_f16(6);
relu6_data = vmaxq_f16(relu6_data, zero_data);
relu6_data = vminq_f16(relu6_data, six_data);
vst1q_f16(dst + index, relu6_data);
#else
int j;
for (j = 0; j < C8NUM; ++j) {
dst[index + j] = data[index + j] < 0 ? 0 : data[index + j];
dst[index + j] = dst[index + j] > 6 ? 6 : dst[index + j];
}
#endif
}
for (int j = (eight_block - 1) * C8NUM; j < ele_num; ++j) {
dst[j] = data[j] < 0 ? 0 : data[j];
dst[j] = dst[j] > 6 ? 6 : dst[j];
}
return NNACL_OK;
}
int LReluFp16(const float16_t *src, float16_t *dst, int ele_num, float16_t alpha) {
for (int i = 0; i < ele_num; ++i) {
dst[i] = src[i] > (float16_t)0.0f ? src[i] : (src[i] * alpha);
}
return NNACL_OK;
}
int SigmoidFp16(const float16_t *src, float16_t *dst, int ele_num) {
for (int i = 0; i < ele_num; ++i) {
dst[i] = (float16_t)1.0f / (float16_t)(1.0f + exp(-src[i]));
}
return NNACL_OK;
}
int TanhFp16(const float16_t *src, float16_t *dst, int ele_num) {
for (int i = 0; i < ele_num; ++i) {
dst[i] = (float16_t)1.0f - (float16_t)2.0f / (float16_t)(exp(2 * src[i]) + 1);
}
return NNACL_OK;
}
int HSwishFp16(const float16_t *src, float16_t *dst, int ele_num) {
for (int i = 0; i < ele_num; ++i) {
float16_t in = src[i];
float16_t relu6 = MSMIN(MSMAX(in + 3, 0), 6);
dst[i] = in * relu6 / (float16_t)6.0f;
}
return NNACL_OK;
}

View File

@ -0,0 +1,44 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_NNACL_FP16_ACTIVATION_FP16_H_
#define MINDSPORE_LITE_NNACL_FP16_ACTIVATION_FP16_H_
#ifdef ENABLE_NEON
#include <arm_neon.h>
#endif
#include <math.h>
#include "nnacl/op_base.h"
#include "nnacl/quantization/fixed_point.h"
typedef struct ActivationParameter {
OpParameter op_parameter_;
int type_;
float alpha_;
} ActivationParameter;
#ifdef __cplusplus
extern "C" {
#endif
int ReluFp16(const float16_t *src, float16_t *dst, int ele_num);
int Relu6Fp16(const float16_t *data, float16_t *dst, int ele_num);
int LReluFp16(const float16_t *src, float16_t *dst, int ele_num, float16_t alpha);
int SigmoidFp16(const float16_t *src, float16_t *dst, int ele_num);
int TanhFp16(const float16_t *src, float16_t *dst, int ele_num);
int HSwishFp16(const float16_t *src, float16_t *dst, int ele_num);
#ifdef __cplusplus
}
#endif
#endif // MINDSPORE_LITE_NNACL_FP16_ACTIVATION_FP16_H_

View File

@ -74,33 +74,48 @@ int ElementOptMulFp16(float16_t *input0, float16_t *input1, float16_t *output, i
ArithmeticParameter *param) {
int block_mod = element_size % C8NUM;
int block_c8 = element_size - block_mod;
float16_t in0_opt = input0[0];
float16_t in1_opt = input1[0];
#ifdef ENABLE_NEON
float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]};
float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]};
float16_t in0_opt = input0[0];
float16_t in1_opt = input1[0];
#endif
for (int index = 0; index < block_c8; index += C8NUM) {
if (param->in_elements_num0_ == 1) {
for (int index = 0; index < block_c8; index += C8NUM) {
#ifdef ENABLE_NEON
float16x8_t vin0 = param->in_elements_num0_ == 1 ? vin0_opt : vld1q_f16(input0);
float16x8_t vin1 = param->in_elements_num1_ == 1 ? vin1_opt : vld1q_f16(input1);
float16x8_t vout = vmulq_f16(vin0, vin1);
vst1q_f16(output, vout);
float16x8_t vin0 = vin0_opt;
float16x8_t vin1 = vld1q_f16(input1);
float16x8_t vout = vmulq_f16(vin0, vin1);
vst1q_f16(output, vout);
#else
for (int i = 0; i < C8NUM; ++i) {
float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[i];
float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[i];
output[i] = in0 * in1;
}
for (int i = 0; i < C8NUM; ++i) {
output[i] = in0_opt * input1[i];
}
#endif
input0 += C8NUM;
input1 += C8NUM;
output += C8NUM;
}
for (int index = 0; index < block_mod; ++index) {
float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[index];
float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[index];
output[index] = in0 * in1;
input1 += C8NUM;
output += C8NUM;
}
for (int index = 0; index < block_mod; ++index) {
output[index] = in0_opt * input1[index];
}
} else {
for (int index = 0; index < block_c8; index += C8NUM) {
#ifdef ENABLE_NEON
float16x8_t vin0 = vld1q_f16(input0);
float16x8_t vin1 = vin1_opt;
float16x8_t vout = vmulq_f16(vin0, vin1);
vst1q_f16(output, vout);
#else
for (int i = 0; i < C8NUM; ++i) {
output[i] = input0[i] * in1_opt;
}
#endif
input0 += C8NUM;
output += C8NUM;
}
for (int index = 0; index < block_mod; ++index) {
output[index] = input0[index] * in1_opt;
}
}
return NNACL_OK;
@ -113,7 +128,6 @@ int ElementMulReluFp16(float16_t *input0, float16_t *input1, float16_t *output,
#ifdef ENABLE_NEON
float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0};
#endif
for (int index = 0; index < block_c8; index += C8NUM) {
#ifdef ENABLE_NEON
float16x8_t vin0 = vld1q_f16(input0);
@ -143,39 +157,58 @@ int ElementOptMulReluFp16(float16_t *input0, float16_t *input1, float16_t *outpu
ArithmeticParameter *param) {
int block_mod = element_size % C8NUM;
int block_c8 = element_size - block_mod;
float16_t in0_opt = input0[0];
float16_t in1_opt = input1[0];
#ifdef ENABLE_NEON
float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]};
float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]};
float16_t in0_opt = input0[0];
float16_t in1_opt = input1[0];
float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0};
#endif
for (int index = 0; index < block_c8; index += C8NUM) {
if (param->in_elements_num0_ == 1) {
for (int index = 0; index < block_c8; index += C8NUM) {
#ifdef ENABLE_NEON
float16x8_t vin0 = param->in_elements_num0_ == 1 ? vin0_opt : vld1q_f16(input0);
float16x8_t vin1 = param->in_elements_num1_ == 1 ? vin1_opt : vld1q_f16(input1);
float16x8_t vout = vmulq_f16(vin0, vin1);
vout = vmaxq_f16(vout, zeros);
vst1q_f16(output, vout);
float16x8_t vin0 = vin0_opt;
float16x8_t vin1 = vld1q_f16(input1);
float16x8_t vout = vmulq_f16(vin0, vin1);
vout = vmaxq_f16(vout, zeros);
vst1q_f16(output, vout);
#else
float16_t res;
for (int i = 0; i < C8NUM; ++i) {
float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[i];
float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[i];
res = in0 * in1;
output[i] = res > 0 ? res : 0;
}
float16_t res;
for (int i = 0; i < C8NUM; ++i) {
res = in0_opt * input1[i];
output[i] = res > 0 ? res : 0;
}
#endif
input0 += C8NUM;
input1 += C8NUM;
output += C8NUM;
}
for (int index = 0; index < block_mod; ++index) {
float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[index];
float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[index];
float16_t res = in0 * in1;
output[index] = res > 0 ? res : 0;
input1 += C8NUM;
output += C8NUM;
}
for (int index = 0; index < block_mod; ++index) {
float16_t res = in0_opt * input1[index];
output[index] = res > 0 ? res : 0;
}
} else {
for (int index = 0; index < block_c8; index += C8NUM) {
#ifdef ENABLE_NEON
float16x8_t vin0 = vld1q_f16(input0);
float16x8_t vin1 = vin1_opt;
float16x8_t vout = vmulq_f16(vin0, vin1);
vout = vmaxq_f16(vout, zeros);
vst1q_f16(output, vout);
#else
float16_t res;
for (int i = 0; i < C8NUM; ++i) {
res = input0[i] * in1_opt;
output[i] = res > 0 ? res : 0;
}
#endif
input0 += C8NUM;
output += C8NUM;
}
for (int index = 0; index < block_mod; ++index) {
float16_t res = input0[index] * in1_opt;
output[index] = res > 0 ? res : 0;
}
}
return NNACL_OK;
@ -216,37 +249,52 @@ int ElementOptMulRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *outp
ArithmeticParameter *param) {
int block_mod = element_size % C8NUM;
int block_c8 = element_size - block_mod;
float16_t in0_opt = input0[0];
float16_t in1_opt = input1[0];
#ifdef ENABLE_NEON
float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]};
float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]};
float16_t in0_opt = input0[0];
float16_t in1_opt = input1[0];
float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0};
float16x8_t bounds = {6, 6, 6, 6, 6, 6, 6, 6};
#endif
for (int index = 0; index < block_c8; index += C8NUM) {
if (param->in_elements_num0_ == 1) {
for (int index = 0; index < block_c8; index += C8NUM) {
#ifdef ENABLE_NEON
float16x8_t vin0 = param->in_elements_num0_ == 1 ? vin0_opt : vld1q_f16(input0);
float16x8_t vin1 = param->in_elements_num1_ == 1 ? vin1_opt : vld1q_f16(input1);
float16x8_t vout = vmulq_f16(vin0, vin1);
vout = vminq_f16(vmaxq_f16(vout, zeros), bounds);
vst1q_f16(output, vout);
float16x8_t vin0 = vin0_opt;
float16x8_t vin1 = vld1q_f16(input1);
float16x8_t vout = vmulq_f16(vin0, vin1);
vout = vminq_f16(vmaxq_f16(vout, zeros), bounds);
vst1q_f16(output, vout);
#else
for (int i = 0; i < C8NUM; ++i) {
float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[i];
float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[i];
output[i] = MSMIN(MSMAX(in0 * in1, 0), 6);
}
for (int i = 0; i < C8NUM; ++i) {
output[i] = MSMIN(MSMAX(in0_opt * input1[i], 0), 6);
}
#endif
input0 += C8NUM;
input1 += C8NUM;
output += C8NUM;
}
for (int index = 0; index < block_mod; ++index) {
float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[index];
float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[index];
output[index] = MSMIN(MSMAX(in0 * in1, 0), 6);
input1 += C8NUM;
output += C8NUM;
}
for (int index = 0; index < block_mod; ++index) {
output[index] = MSMIN(MSMAX(in0_opt * input1[index], 0), 6);
}
} else {
for (int index = 0; index < block_c8; index += C8NUM) {
#ifdef ENABLE_NEON
float16x8_t vin0 = vld1q_f16(input0);
float16x8_t vin1 = vin1_opt;
float16x8_t vout = vmulq_f16(vin0, vin1);
vout = vminq_f16(vmaxq_f16(vout, zeros), bounds);
vst1q_f16(output, vout);
#else
for (int i = 0; i < C8NUM; ++i) {
output[i] = MSMIN(MSMAX(input0[i] * in1_opt, 0), 6);
}
#endif
input0 += C8NUM;
output += C8NUM;
}
for (int index = 0; index < block_mod; ++index) {
output[index] = MSMIN(MSMAX(input0[index] * in1_opt, 0), 6);
}
}
return NNACL_OK;
@ -255,7 +303,6 @@ int ElementOptMulRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *outp
int ElementAddFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) {
int block_mod = element_size % C8NUM;
int block_c8 = element_size - block_mod;
for (int index = 0; index < block_c8; index += C8NUM) {
#ifdef ENABLE_NEON
float16x8_t vin0 = vld1q_f16(input0);
@ -280,34 +327,50 @@ int ElementOptAddFp16(float16_t *input0, float16_t *input1, float16_t *output, i
ArithmeticParameter *param) {
int block_mod = element_size % C8NUM;
int block_c8 = element_size - block_mod;
float16_t in0_opt = input0[0];
float16_t in1_opt = input1[0];
#ifdef ENABLE_NEON
float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]};
float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]};
float16_t in0_opt = input0[0];
float16_t in1_opt = input1[0];
#endif
for (int index = 0; index < block_c8; index += C8NUM) {
if (param->in_elements_num0_ == 1) {
for (int index = 0; index < block_c8; index += C8NUM) {
#ifdef ENABLE_NEON
float16x8_t vin0 = param->in_elements_num0_ == 1 ? vin0_opt : vld1q_f16(input0);
float16x8_t vin1 = param->in_elements_num1_ == 1 ? vin1_opt : vld1q_f16(input1);
float16x8_t vout = vaddq_f16(vin0, vin1);
vst1q_f16(output, vout);
float16x8_t vin0 = vin0_opt;
float16x8_t vin1 = vld1q_f16(input1);
float16x8_t vout = vaddq_f16(vin0, vin1);
vst1q_f16(output, vout);
#else
for (int i = 0; i < C8NUM; ++i) {
float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[i];
float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[i];
output[i] = in0 + in1;
}
for (int i = 0; i < C8NUM; ++i) {
output[i] = in0_opt + input1[i];
}
#endif
input0 += C8NUM;
input1 += C8NUM;
output += C8NUM;
}
for (int index = 0; index < block_mod; ++index) {
float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[index];
float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[index];
output[index] = in0 + in1;
input1 += C8NUM;
output += C8NUM;
}
for (int index = 0; index < block_mod; ++index) {
output[index] = in0_opt + input1[index];
}
} else {
for (int index = 0; index < block_c8; index += C8NUM) {
#ifdef ENABLE_NEON
float16x8_t vin0 = vld1q_f16(input0);
float16x8_t vin1 = vin1_opt;
float16x8_t vout = vaddq_f16(vin0, vin1);
vst1q_f16(output, vout);
#else
for (int i = 0; i < C8NUM; ++i) {
output[i] = input0[i] + in1_opt;
}
#endif
input0 += C8NUM;
output += C8NUM;
}
for (int index = 0; index < block_mod; ++index) {
output[index] = input0[index] + in1_opt;
}
}
return NNACL_OK;
}
@ -345,37 +408,54 @@ int ElementOptAddReluFp16(float16_t *input0, float16_t *input1, float16_t *outpu
ArithmeticParameter *param) {
int block_mod = element_size % C8NUM;
int block_c8 = element_size - block_mod;
float16_t in0_opt = input0[0];
float16_t in1_opt = input1[0];
#ifdef ENABLE_NEON
float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]};
float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]};
float16_t in0_opt = input0[0];
float16_t in1_opt = input1[0];
float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0};
#endif
for (int index = 0; index < block_c8; index += C8NUM) {
if (param->in_elements_num0_ == 1) {
for (int index = 0; index < block_c8; index += C8NUM) {
#ifdef ENABLE_NEON
float16x8_t vin0 = param->in_elements_num0_ == 1 ? vin0_opt : vld1q_f16(input0);
float16x8_t vin1 = param->in_elements_num1_ == 1 ? vin1_opt : vld1q_f16(input1);
float16x8_t vout = vaddq_f16(vin0, vin1);
vout = vmaxq_f16(vout, zeros);
vst1q_f16(output, vout);
float16x8_t vin0 = vin0_opt;
float16x8_t vin1 = vld1q_f16(input1);
float16x8_t vout = vaddq_f16(vin0, vin1);
vout = vmaxq_f16(vout, zeros);
vst1q_f16(output, vout);
#else
for (int i = 0; i < C8NUM; ++i) {
float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[i];
float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[i];
output[i] = MSMAX(in0 + in1, 0);
}
for (int i = 0; i < C8NUM; ++i) {
output[i] = MSMAX(in0_opt + input1[i], 0);
}
#endif
input0 += C8NUM;
input1 += C8NUM;
output += C8NUM;
}
for (int index = 0; index < block_mod; ++index) {
float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[index];
float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[index];
float16_t res = in0 + in1;
output[index] = res > 0 ? res : 0;
input1 += C8NUM;
output += C8NUM;
}
for (int index = 0; index < block_mod; ++index) {
float16_t res = in0_opt + input1[index];
output[index] = res > 0 ? res : 0;
}
} else {
for (int index = 0; index < block_c8; index += C8NUM) {
#ifdef ENABLE_NEON
float16x8_t vin0 = vld1q_f16(input0);
float16x8_t vin1 = vin1_opt;
float16x8_t vout = vaddq_f16(vin0, vin1);
vout = vmaxq_f16(vout, zeros);
vst1q_f16(output, vout);
#else
for (int i = 0; i < C8NUM; ++i) {
output[i] = MSMAX(input0[i] + in1_opt, 0);
}
#endif
input0 += C8NUM;
output += C8NUM;
}
for (int index = 0; index < block_mod; ++index) {
float16_t res = input0[index] + in1_opt;
output[index] = res > 0 ? res : 0;
}
}
return NNACL_OK;
}
@ -415,39 +495,54 @@ int ElementOptAddRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *outp
ArithmeticParameter *param) {
int block_mod = element_size % C8NUM;
int block_c8 = element_size - block_mod;
float16_t in0_opt = input0[0];
float16_t in1_opt = input1[0];
#ifdef ENABLE_NEON
float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]};
float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]};
float16_t in0_opt = input0[0];
float16_t in1_opt = input1[0];
float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0};
float16x8_t bounds = {6, 6, 6, 6, 6, 6, 6, 6};
#endif
for (int index = 0; index < block_c8; index += C8NUM) {
if (param->in_elements_num0_ == 1) {
for (int index = 0; index < block_c8; index += C8NUM) {
#ifdef ENABLE_NEON
float16x8_t vin0 = param->in_elements_num0_ == 1 ? vin0_opt : vld1q_f16(input0);
float16x8_t vin1 = param->in_elements_num1_ == 1 ? vin1_opt : vld1q_f16(input1);
float16x8_t vout = vaddq_f16(vin0, vin1);
vout = vminq_f16(vmaxq_f16(vout, zeros), bounds);
vst1q_f16(output, vout);
float16x8_t vin0 = vin0_opt;
float16x8_t vin1 = vld1q_f16(input1);
float16x8_t vout = vaddq_f16(vin0, vin1);
vout = vminq_f16(vmaxq_f16(vout, zeros), bounds);
vst1q_f16(output, vout);
#else
for (int i = 0; i < C8NUM; ++i) {
float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[i];
float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[i];
output[i] = MSMIN(MSMAX(in0 + in1, 0), 6);
}
for (int i = 0; i < C8NUM; ++i) {
output[i] = MSMIN(MSMAX(in0_opt + input1[i], 0), 6);
}
#endif
input0 += C8NUM;
input1 += C8NUM;
output += C8NUM;
input1 += C8NUM;
output += C8NUM;
}
for (int index = 0; index < block_mod; ++index) {
output[index] = MSMIN(MSMAX(in0_opt + input1[index], 0), 6);
}
} else {
for (int index = 0; index < block_c8; index += C8NUM) {
#ifdef ENABLE_NEON
float16x8_t vin0 = vld1q_f16(input0);
float16x8_t vin1 = vin1_opt;
float16x8_t vout = vaddq_f16(vin0, vin1);
vout = vminq_f16(vmaxq_f16(vout, zeros), bounds);
vst1q_f16(output, vout);
#else
for (int i = 0; i < C8NUM; ++i) {
output[i] = MSMIN(MSMAX(input0[i] + in1_opt, 0), 6);
}
#endif
input0 += C8NUM;
output += C8NUM;
}
for (int index = 0; index < block_mod; ++index) {
output[index] = MSMIN(MSMAX(input0[index] + in1_opt, 0), 6);
}
}
for (int index = 0; index < block_mod; ++index) {
float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[index];
float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[index];
output[index] = MSMIN(MSMAX(in0 + in1, 0), 6);
}
return NNACL_OK;
}
@ -479,11 +574,11 @@ int ElementOptSubFp16(float16_t *input0, float16_t *input1, float16_t *output, i
ArithmeticParameter *param) {
int block_mod = element_size % C8NUM;
int block_c8 = element_size - block_mod;
float16_t in0_opt = input0[0];
float16_t in1_opt = input1[0];
#ifdef ENABLE_NEON
float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]};
float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]};
float16_t in0_opt = input0[0];
float16_t in1_opt = input1[0];
#endif
for (int index = 0; index < block_c8; index += C8NUM) {
#ifdef ENABLE_NEON
@ -542,11 +637,11 @@ int ElementOptSubReluFp16(float16_t *input0, float16_t *input1, float16_t *outpu
ArithmeticParameter *param) {
int block_mod = element_size % C8NUM;
int block_c8 = element_size - block_mod;
float16_t in0_opt = input0[0];
float16_t in1_opt = input1[0];
#ifdef ENABLE_NEON
float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]};
float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]};
float16_t in0_opt = input0[0];
float16_t in1_opt = input1[0];
float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0};
#endif
for (int index = 0; index < block_c8; index += C8NUM) {
@ -609,11 +704,11 @@ int ElementOptSubRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *outp
ArithmeticParameter *param) {
int block_mod = element_size % C8NUM;
int block_c8 = element_size - block_mod;
float16_t in0_opt = input0[0];
float16_t in1_opt = input1[0];
#ifdef ENABLE_NEON
float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]};
float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]};
float16_t in0_opt = input0[0];
float16_t in1_opt = input1[0];
float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0};
float16x8_t bounds = {6, 6, 6, 6, 6, 6, 6, 6};
#endif
@ -680,11 +775,11 @@ int ElementOptDivFp16(float16_t *input0, float16_t *input1, float16_t *output, i
ArithmeticParameter *param) {
int block_mod = element_size % C8NUM;
int block_c8 = element_size - block_mod;
float16_t in0_opt = input0[0];
float16_t in1_opt = input1[0];
#ifdef ENABLE_NEON
float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]};
float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]};
float16_t in0_opt = input0[0];
float16_t in1_opt = input1[0];
#endif
for (int index = 0; index < block_c8; index += C8NUM) {
if (param->in_elements_num1_ == 1) {
@ -765,12 +860,11 @@ int ElementOptDivReluFp16(float16_t *input0, float16_t *input1, float16_t *outpu
ArithmeticParameter *param) {
int block_mod = element_size % C8NUM;
int block_c8 = element_size - block_mod;
float16_t in0_opt = input0[0];
float16_t in1_opt = input1[0];
#ifdef ENABLE_NEON
float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]};
float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]};
float16_t in0_opt = input0[0];
float16_t in1_opt = input1[0];
float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0};
#endif
for (int index = 0; index < block_c8; index += C8NUM) {
@ -855,11 +949,11 @@ int ElementOptDivRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *outp
int block_mod = element_size % C8NUM;
int block_c8 = element_size - block_mod;
float16_t in0_opt = input0[0];
float16_t in1_opt = input1[0];
#ifdef ENABLE_NEON
float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]};
float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]};
float16_t in0_opt = input0[0];
float16_t in1_opt = input1[0];
float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0};
float16x8_t bounds = {6, 6, 6, 6, 6, 6, 6, 6};
#endif

View File

@ -17,8 +17,8 @@
#include "nnacl/fp16/batchnorm_fp16.h"
#include <math.h>
void BatchNormFp16(const void *input, const void *mean, const void *variance,
BatchNormParameter *param, int task_id, void *output) {
void BatchNormFp16(const float16_t *input, const void *mean, const void *variance,
BatchNormParameter *param, int task_id, float16_t *output) {
int units_per_thread = UP_DIV(param->unit_, param->op_parameter_.thread_num_);
int completed_units = task_id * units_per_thread;
int cur_unit = MSMIN(units_per_thread, param->unit_ - completed_units);
@ -27,8 +27,9 @@ void BatchNormFp16(const void *input, const void *mean, const void *variance,
for (int i = 0; i < cur_unit; i++) {
for (int c = 0; c < param->channel_; c++) {
float16_t variance_sqrt = sqrt(((const float16_t *)variance)[c] + param->epsilon_);
((float16_t *)output)[cur_offset + c] =
(((const float16_t *)input)[cur_offset + c] - ((const float16_t *)mean)[c]) / variance_sqrt;
if (variance_sqrt != 0) {
output[cur_offset + c] = (input[cur_offset + c] - ((const float16_t *)mean)[c]) / variance_sqrt;
}
}
cur_offset += param->channel_;
}
@ -44,8 +45,12 @@ void FusedBatchNormFp16(const void *input, const void *scale, const void *offset
for (int i = 0; i < cur_unit; i++) {
for (int c = 0; c < param->channel_; c++) {
float16_t variance_sqrt = sqrt(((const float16_t *)variance)[c] + param->epsilon_);
float16_t norm_val = (((const float16_t *)input)[cur_offset + c] - ((const float16_t *)mean)[c]) / variance_sqrt;
((float16_t *)output)[cur_offset + c] = norm_val * ((const float16_t *)scale)[c] + ((const float16_t *)offset)[c];
if (variance_sqrt != 0) {
float16_t norm_val =
(((const float16_t *)input)[cur_offset + c] - ((const float16_t *)mean)[c]) / variance_sqrt;
((float16_t *)output)[cur_offset + c] =
norm_val * ((const float16_t *)scale)[c] + ((const float16_t *)offset)[c];
}
}
cur_offset += param->channel_;
}

View File

@ -25,8 +25,8 @@
extern "C" {
#endif
void BatchNormFp16(const void *input, const void *mean, const void *variance, BatchNormParameter *param, int task_id,
void *output);
void BatchNormFp16(const float16_t *input, const void *mean, const void *variance, BatchNormParameter *param,
int task_id, float16_t *output);
void FusedBatchNormFp16(const void *input, const void *scale, const void *offset, const void *mean,
const void *variance, BatchNormParameter *param, int task_id, void *output);

View File

@ -1,61 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "nnacl/fp16/common_func.h"
void ReluFp16(float16_t *data, float16_t *dst, int ele_num) {
int eight_block = UP_DIV(ele_num, C8NUM);
for (int i = 0; i < eight_block - 1; i++) {
int index = i * C8NUM;
#ifdef ENABLE_NEON
float16x8_t relu_data = vld1q_f16(data + index);
float16x8_t zero_data = vdupq_n_f16(0);
relu_data = vmaxq_f16(relu_data, zero_data);
vst1q_f16(dst + index, relu_data);
#else
data[index] = data[index] < 0 ? 0 : data[index];
data[index + 1] = data[index + 1] < 0 ? 0 : data[index + 1];
data[index + 2] = data[index + 2] < 0 ? 0 : data[index + 2];
data[index + 3] = data[index + 3] < 0 ? 0 : data[index + 3];
#endif
}
for (int j = (eight_block - 1) * C8NUM; j < ele_num; ++j) {
data[j] = data[j] < 0 ? 0 : data[j];
}
}
void Relu6Fp16(float16_t *data, float16_t *dst, int ele_num) {
int eight_block = UP_DIV(ele_num, C8NUM);
for (int i = 0; i < eight_block - 1; i++) {
int index = i * C8NUM;
#ifdef ENABLE_NEON
float16x8_t relu6_data = vld1q_f16(data + index);
float16x8_t zero_data = vdupq_n_f16(0);
float16x8_t six_data = vdupq_n_f16(6);
relu6_data = vmaxq_f16(relu6_data, zero_data);
relu6_data = vminq_f16(relu6_data, six_data);
vst1q_f16(dst + index, relu6_data);
#else
for (int j = 0; j < C8NUM; ++j) {
data[index + j] = data[index + j] < 0 ? 0 : data[index + j];
data[index + j] = data[index + j] > 6 ? 6 : data[index + j];
}
#endif
}
for (int j = (eight_block - 1) * C8NUM; j < ele_num; ++j) {
data[j] = data[j] < 0 ? 0 : data[j];
data[j] = data[j] > 6 ? 6 : data[j];
}
}

View File

@ -1,51 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_NNACL_FP16_COMMON_FUNC_H_
#define MINDSPORE_LITE_NNACL_FP16_COMMON_FUNC_H_
#include <stdint.h>
#include <stdio.h>
#include <string.h>
#include "nnacl/op_base.h"
#include "nnacl/conv_parameter.h"
#ifdef __cplusplus
extern "C" {
#endif
#ifdef ENABLE_ARM64
void ConvDwFp16Border(float16_t *dst, const float16_t *src, const float16_t *weight, const float16_t *bias,
size_t height, size_t width, size_t in_kh_step, size_t in_kw_step, size_t kernel_w, size_t relu,
size_t relu6);
void ConvDwFp16Center(float16_t *dst, const float16_t *src, const float16_t *weight, const float16_t *bias,
size_t height, size_t width, size_t kernel_h, size_t kernel_w, size_t out_h_step,
size_t block_channel, size_t in_sh_step, size_t in_sw_step, size_t in_kh_step, size_t in_kw_step,
size_t relu, size_t relu6);
void DeconvDwFp16Border(float16_t *dst, const float16_t *src, const float16_t *weight, size_t height, size_t width,
size_t in_kh_step, size_t in_kw_step, size_t kernel_w);
void DeconvDwFp16Center(float16_t *dst, const float16_t *src, const float16_t *weight, size_t height, size_t width,
size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step,
size_t in_sw_step, size_t in_kh_step, size_t in_kw_step);
#endif
void ReluFp16(float16_t *data, float16_t *dst, int ele_num);
void Relu6Fp16(float16_t *data, float16_t *dst, int ele_num);
#ifdef __cplusplus
}
#endif
#endif /* MINDSPORE_LITE_NNACL_FP32_COMMON_FUNC_H_ */

View File

@ -15,8 +15,62 @@
*/
#include "nnacl/fp16/conv_depthwise_fp16.h"
#include <arm_neon.h>
#include "nnacl/fp16/common_func.h"
#include <string.h>
#include "nnacl/fp16/activation_fp16.h"
void ConvDwFp16(float16_t *output_data, const float16_t *input_data, const float16_t *weight_data,
const float16_t *bias_data, const ConvParameter *conv_param, int task_id) {
int h_step = UP_DIV(conv_param->output_h_, conv_param->thread_num_);
int h_start = h_step * task_id;
int h_end = MSMIN(h_start + h_step, conv_param->output_h_);
bool relu = conv_param->act_type_ == ActType_Relu;
bool relu6 = conv_param->act_type_ == ActType_Relu6;
for (int b = 0; b < conv_param->output_batch_; b++) {
const float16_t *src = input_data + b * conv_param->input_h_ * conv_param->input_w_ * conv_param->input_channel_;
float16_t *dst = output_data + b * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_channel_;
for (int oh = h_start; oh < h_end; oh++) {
float16_t *dst_data = dst + oh * conv_param->output_w_ * conv_param->output_channel_;
int ih_origin = oh * conv_param->stride_h_ - conv_param->pad_u_;
int start_kh = MSMAX(0, UP_DIV(-ih_origin, conv_param->dilation_h_));
int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih_origin, conv_param->dilation_h_));
for (int ow = 0; ow < conv_param->output_w_; ow++) {
memcpy(dst_data + ow * conv_param->output_channel_, bias_data, conv_param->output_channel_ * sizeof(float16_t));
}
for (int kh = start_kh; kh < end_kh; kh++) {
int ih = ih_origin + conv_param->dilation_w_ * kh;
const float16_t *src_kh = src + ih * conv_param->input_w_ * conv_param->input_channel_;
const float16_t *weight_kh = weight_data + kh * conv_param->kernel_w_ * conv_param->output_channel_;
int in_sw_step = conv_param->stride_w_ * conv_param->input_channel_;
for (int kw = 0; kw < conv_param->kernel_w_; kw++) {
int out_w_start = MSMAX(
0, (conv_param->pad_l_ - conv_param->dilation_w_ * kw + conv_param->stride_w_ - 1) / conv_param->stride_w_);
int out_w_end = MSMIN(conv_param->output_w_, (conv_param->input_w_ + conv_param->pad_l_ -
conv_param->dilation_w_ * kw + conv_param->stride_w_ - 1) /
conv_param->stride_w_);
float16_t *dst_w = dst_data + out_w_start * conv_param->output_channel_;
int iw_origin = (out_w_start * conv_param->stride_w_) - conv_param->pad_l_ + conv_param->dilation_w_ * kw;
const float16_t *src_kw = src_kh + iw_origin * conv_param->input_channel_;
int num_pixels = out_w_end - out_w_start;
ConvDwFp16Row(dst_w, src_kw, weight_kh, num_pixels, conv_param->output_channel_, in_sw_step);
weight_kh += conv_param->output_channel_;
}
}
if (relu) {
ReluFp16(dst_data, dst_data, conv_param->output_w_ * conv_param->output_channel_);
}
if (relu6) {
Relu6Fp16(dst_data, dst_data, conv_param->output_w_ * conv_param->output_channel_);
}
}
}
}
/*conv depthwise fp16 begin*/
void DepthwiseBorderPixelFp16(float16_t *dst, const float16_t *src, const float16_t *weight, const float16_t *bias,
@ -53,16 +107,18 @@ void DepthwiseBorderPixelFp16(float16_t *dst, const float16_t *src, const float1
void DepthwiseBorderFp16(float16_t *dst, const float16_t *src, const float16_t *weight, const float16_t *bias, int top,
int bottom, int left, int right, const ConvParameter *conv_param,
const SlidingWindowParam *sliding) {
bool relu = conv_param->act_type_ == ActType_Relu;
bool relu6 = conv_param->act_type_ == ActType_Relu6;
float16_t *dst_h = dst + top * sliding->out_h_step_;
for (int oh = top; oh < bottom; oh++) {
int ih = oh * conv_param->stride_h_ - conv_param->pad_h_;
int ih = oh * conv_param->stride_h_ - conv_param->pad_u_;
int start_kh = MSMAX(0, UP_DIV(-ih, conv_param->dilation_h_));
int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih, conv_param->dilation_h_));
const float16_t *src_h = src + ih * sliding->in_h_step_;
float16_t *dst_kernel = dst_h + left * sliding->block_channel_;
for (int ow = left; ow < right; ow++) {
int iw = ow * conv_param->stride_w_ - conv_param->pad_w_;
int iw = ow * conv_param->stride_w_ - conv_param->pad_l_;
int start_kw = MSMAX(0, UP_DIV(-iw, conv_param->dilation_w_));
int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->input_w_ - iw, conv_param->dilation_w_));
const float16_t *src_w = src_h + iw * sliding->block_channel_;
@ -72,11 +128,10 @@ void DepthwiseBorderFp16(float16_t *dst, const float16_t *src, const float16_t *
#ifdef ENABLE_ARM64
ConvDwFp16Border(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw,
sliding->in_kh_step_ * sizeof(float16_t), sliding->in_kw_step_ * sizeof(float16_t),
conv_param->kernel_w_ * C8NUM * sizeof(float16_t), conv_param->is_relu_, conv_param->is_relu6_);
conv_param->kernel_w_ * C8NUM * sizeof(float16_t), relu, relu6);
#else
DepthwiseBorderPixelFp16(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw,
sliding->in_kh_step_, sliding->in_kw_step_, conv_param->kernel_w_ * C8NUM,
conv_param->is_relu_, conv_param->is_relu6_);
sliding->in_kh_step_, sliding->in_kw_step_, conv_param->kernel_w_ * C8NUM, relu, relu6);
#endif
dst_kernel += sliding->block_channel_;
} // width loop
@ -139,6 +194,8 @@ void DepthwiseCenterFp16(float16_t *dst, const float16_t *src, const float16_t *
void ConvDwC8Fp16(float16_t *output_data, const float16_t *input_data, const float16_t *weight_data,
const float16_t *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding,
int task_id) {
bool relu = conv_param->act_type_ == ActType_Relu;
bool relu6 = conv_param->act_type_ == ActType_Relu6;
const float16_t *src = input_data;
float16_t *dst = output_data;
for (int b = 0; b < conv_param->output_batch_; b++) {
@ -157,8 +214,8 @@ void ConvDwC8Fp16(float16_t *output_data, const float16_t *input_data, const flo
conv_param->output_w_, conv_param, sliding);
if (sliding->right_ > sliding->left_ && sliding->bottom_ > sliding->top_) {
int in_h_start = sliding->top_ * conv_param->stride_h_ - conv_param->pad_h_;
int in_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_w_;
int in_h_start = sliding->top_ * conv_param->stride_h_ - conv_param->pad_u_;
int in_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_l_;
const float16_t *in_t = src_data + in_h_start * sliding->in_h_step_ + in_w_start * sliding->block_channel_;
float16_t *out_t = dst_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_;
#ifdef ENABLE_ARM64
@ -166,12 +223,12 @@ void ConvDwC8Fp16(float16_t *output_data, const float16_t *input_data, const flo
conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_ * sizeof(float16_t),
sliding->block_channel_ * sizeof(float16_t), sliding->in_sh_step_ * sizeof(float16_t),
sliding->in_sw_step_ * sizeof(float16_t), sliding->in_kh_step_ * sizeof(float16_t),
sliding->in_kw_step_ * sizeof(float16_t), conv_param->is_relu_, conv_param->is_relu6_);
sliding->in_kw_step_ * sizeof(float16_t), relu, relu6);
#else
DepthwiseCenterFp16(out_t, in_t, weight, bias, sliding->bottom_ - sliding->top_,
sliding->right_ - sliding->left_, conv_param->kernel_h_, conv_param->kernel_w_,
sliding->out_h_step_, sliding->block_channel_, sliding->in_sh_step_, sliding->in_sw_step_,
sliding->in_kh_step_, sliding->in_kw_step_, conv_param->is_relu_, conv_param->is_relu6_);
sliding->in_kh_step_, sliding->in_kw_step_, relu, relu6);
#endif
}
} // output C8 loop
@ -210,14 +267,14 @@ void DeconvDepthwiseBorderFp16(float16_t *dst, const float16_t *src, const float
const SlidingWindowParam *sliding) {
const float16_t *src_h = src + top * sliding->out_h_step_;
for (int ih = top; ih < bottom; ih++) {
int oh = ih * conv_param->stride_h_ - conv_param->pad_h_;
int oh = ih * conv_param->stride_h_ - conv_param->pad_u_;
int start_kh = MSMAX(0, UP_DIV(-oh, conv_param->dilation_h_));
int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->output_h_ - oh, conv_param->dilation_h_));
float16_t *dst_h = dst + oh * sliding->in_h_step_;
const float16_t *src_kernel = src_h + left * sliding->block_channel_;
for (int iw = left; iw < right; iw++) {
int ow = iw * conv_param->stride_w_ - conv_param->pad_w_;
int ow = iw * conv_param->stride_w_ - conv_param->pad_l_;
int start_kw = MSMAX(0, UP_DIV(-ow, conv_param->dilation_w_));
int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->output_w_ - ow, conv_param->dilation_w_));
float16_t *dst_w = dst_h + ow * sliding->block_channel_;
@ -282,12 +339,14 @@ void DeconvDepthwiseCenterFp16(float16_t *dst, const float16_t *src, const float
void DeconvDepthwisePostFuncFp16(float16_t *dst, const float16_t *bias, int block_channel,
const ConvParameter *conv_param) {
bool relu = conv_param->act_type_ == ActType_Relu;
bool relu6 = conv_param->act_type_ == ActType_Relu6;
float16_t *dst_k = dst;
for (int k = 0; k < conv_param->output_h_ * conv_param->output_w_; k++) {
for (int c = 0; c < C8NUM; c++) {
dst_k[c] += bias[c];
dst_k[c] = (conv_param->is_relu_) ? (MSMAX(0, dst_k[c])) : (dst_k[c]);
dst_k[c] = (conv_param->is_relu6_) ? (MSMIN(6, MSMAX(0, dst_k[c]))) : (dst_k[c]);
dst_k[c] = (relu) ? (MSMAX(0, dst_k[c])) : (dst_k[c]);
dst_k[c] = (relu6) ? (MSMIN(6, MSMAX(0, dst_k[c]))) : (dst_k[c]);
}
dst_k += block_channel;
}
@ -315,8 +374,8 @@ void DeconvDwC8Fp16(float16_t *output_data, const float16_t *input_data, const f
conv_param->input_w_, conv_param, sliding);
if (sliding->right_ > sliding->left_ && sliding->bottom_ > sliding->top_) {
int oh_h_start = sliding->top_ * conv_param->stride_h_ - conv_param->pad_h_;
int oh_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_w_;
int oh_h_start = sliding->top_ * conv_param->stride_h_ - conv_param->pad_u_;
int oh_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_l_;
float16_t *out_t = dst_data + oh_h_start * sliding->in_h_step_ + oh_w_start * sliding->block_channel_;
const float16_t *in_t =
src_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_;

View File

@ -23,6 +23,26 @@
#ifdef __cplusplus
extern "C" {
#endif
#ifdef ENABLE_ARM64
void ConvDwFp16Row(float16_t *output_ptr, const float16_t *input_ptr, const float16_t *filter_ptr, size_t num_pixels,
size_t input_channel, size_t input_step);
void ConvDwFp16Border(float16_t *dst, const float16_t *src, const float16_t *weight, const float16_t *bias,
size_t height, size_t width, size_t in_kh_step, size_t in_kw_step, size_t kernel_w, size_t relu,
size_t relu6);
void ConvDwFp16Center(float16_t *dst, const float16_t *src, const float16_t *weight, const float16_t *bias,
size_t height, size_t width, size_t kernel_h, size_t kernel_w, size_t out_h_step,
size_t block_channel, size_t in_sh_step, size_t in_sw_step, size_t in_kh_step, size_t in_kw_step,
size_t relu, size_t relu6);
void DeconvDwFp16Border(float16_t *dst, const float16_t *src, const float16_t *weight, size_t height, size_t width,
size_t in_kh_step, size_t in_kw_step, size_t kernel_w);
void DeconvDwFp16Center(float16_t *dst, const float16_t *src, const float16_t *weight, size_t height, size_t width,
size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step,
size_t in_sw_step, size_t in_kh_step, size_t in_kw_step);
#endif
void ConvDwFp16(float16_t *output_data, const float16_t *input_data, const float16_t *weight_data,
const float16_t *bias_data, const ConvParameter *conv_param, int task_id);
void ConvDwC8Fp16(float16_t *output_data, const float16_t *input_data, const float16_t *weight_data,
const float16_t *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding,
int task_id);

View File

@ -173,16 +173,18 @@ void SWBorderPixel(float16_t *dst, const float16_t *src, const float16_t *weight
void SWBorderFp16(float16_t *dst, const float16_t *src, const float16_t *weight, const float16_t *bias, int top,
int bottom, int left, int right, const ConvParameter *conv_param, const SlidingWindowParam *sliding) {
bool relu = conv_param->act_type_ == ActType_Relu;
bool relu6 = conv_param->act_type_ == ActType_Relu6;
float16_t *dst_h = dst + top * sliding->out_h_step_;
for (int oh = top; oh < bottom; oh++) {
int ih = oh * conv_param->stride_h_ - conv_param->pad_h_;
int ih = oh * conv_param->stride_h_ - conv_param->pad_u_;
int start_kh = MSMAX(0, UP_DIV(-ih, conv_param->dilation_h_));
int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih, conv_param->dilation_h_));
const float16_t *src_h = src + ih * sliding->in_h_step_;
float16_t *dst_kernel = dst_h + left * sliding->block_channel_;
for (int ow = left; ow < right; ow++) {
int iw = ow * conv_param->stride_w_ - conv_param->pad_w_;
int iw = ow * conv_param->stride_w_ - conv_param->pad_l_;
int start_kw = MSMAX(0, UP_DIV(-iw, conv_param->dilation_w_));
int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->input_w_ - iw, conv_param->dilation_w_));
const float16_t *src_w = src_h + iw * sliding->ic4_channel_;
@ -192,7 +194,7 @@ void SWBorderFp16(float16_t *dst, const float16_t *src, const float16_t *weight,
SWBorderPixel(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw,
sliding->in_kh_step_, sliding->in_kw_step_, conv_param->kernel_h_, conv_param->kernel_w_,
sliding->ic4_channel_, conv_param->is_relu_, conv_param->is_relu6_);
sliding->ic4_channel_, relu, relu6);
dst_kernel += sliding->block_channel_;
} // width loop
@ -273,6 +275,8 @@ void SWCenterFp16(float16_t *dst, const float16_t *src, const float16_t *weight,
void ConvSWFp16(const float16_t *input_data, const float16_t *packed_weight, const float16_t *bias_data,
float16_t *tmp_out_block, float16_t *output_data, int task_id, ConvParameter *conv_param,
SlidingWindowParam *slidingWindow_param) {
bool relu = conv_param->act_type_ == ActType_Relu;
bool relu6 = conv_param->act_type_ == ActType_Relu6;
int oc4_res = conv_param->output_channel_ % C4NUM;
const float16_t *src = input_data;
float16_t *dst;
@ -299,8 +303,8 @@ void ConvSWFp16(const float16_t *input_data, const float16_t *packed_weight, con
if (slidingWindow_param->right_ > slidingWindow_param->left_ &&
slidingWindow_param->bottom_ > slidingWindow_param->top_) {
int in_h_start = slidingWindow_param->top_ * conv_param->stride_h_ - conv_param->pad_h_;
int in_w_start = slidingWindow_param->left_ * conv_param->stride_w_ - conv_param->pad_w_;
int in_h_start = slidingWindow_param->top_ * conv_param->stride_h_ - conv_param->pad_u_;
int in_w_start = slidingWindow_param->left_ * conv_param->stride_w_ - conv_param->pad_l_;
const float16_t *in_t =
src_data + in_h_start * slidingWindow_param->in_h_step_ + in_w_start * slidingWindow_param->ic4_channel_;
float16_t *out_t = dst_data + slidingWindow_param->top_ * slidingWindow_param->out_h_step_ +
@ -310,7 +314,7 @@ void ConvSWFp16(const float16_t *input_data, const float16_t *packed_weight, con
conv_param->kernel_w_, slidingWindow_param->out_h_step_, slidingWindow_param->block_channel_,
slidingWindow_param->ic4_channel_, slidingWindow_param->in_sh_step_,
slidingWindow_param->in_sw_step_, slidingWindow_param->in_kh_step_,
slidingWindow_param->in_kw_step_, conv_param->is_relu_, conv_param->is_relu6_);
slidingWindow_param->in_kw_step_, relu, relu6);
}
} // output C4 loop
src += slidingWindow_param->in_step_;
@ -330,8 +334,8 @@ void ConvFp16(float16_t *input_data, float16_t *packed_input, float16_t *packed_
int out_h = conv_param->output_h_;
int out_w = conv_param->output_w_;
int out_channel = conv_param->output_channel_;
bool relu = conv_param->is_relu_;
bool relu6 = conv_param->is_relu6_;
bool relu = conv_param->act_type_ == ActType_Relu;
bool relu6 = conv_param->act_type_ == ActType_Relu6;
int thread_count = conv_param->thread_num_;
const int tile_n = 16;
int output_count = out_h * out_w;
@ -365,9 +369,10 @@ void ConvFp16(float16_t *input_data, float16_t *packed_input, float16_t *packed_
out_channel * sizeof(float16_t), 0, 0, relu, relu6);
} else {
// res part
IndirectGemmFp16_16x8(tmp_out_block, gemm_input, packed_weight, bias_data, conv_depth, ic4, out_channel,
float16_t *tmp_out_ptr = tmp_out_block + task_id * tile_n * out_channel;
IndirectGemmFp16_16x8(tmp_out_ptr, gemm_input, packed_weight, bias_data, conv_depth, ic4, out_channel,
out_channel * sizeof(float16_t), 0, 0, relu, relu6);
memcpy(output_data + out_offset, tmp_out_block, real_cal_num * out_channel * sizeof(float16_t));
memcpy(output_data + out_offset, tmp_out_ptr, real_cal_num * out_channel * sizeof(float16_t));
}
}
}
@ -395,7 +400,6 @@ void Conv3x3Fp16(float16_t *input_data, float16_t *transed_weight, const float16
int input_batch = conv_param->input_batch_;
for (int batch = 0; batch < input_batch; batch++) {
int in_batch_offset = batch * ic4 * C4NUM * conv_param->input_h_ * conv_param->input_w_;
int tmp_out_batch_offset = batch * oc8 * C8NUM * out_w_block * out_h_block * output_unit * output_unit;
for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_count) {
int start_index = thread_id * tile_num;

View File

@ -73,8 +73,8 @@ int DeConvPostFp16(const float16_t *src, float16_t *tmp, const float16_t *bias,
for (int ih = 0; ih < conv_param->input_h_; ih++) {
for (int iw = 0; iw < conv_param->input_w_; iw++) {
int oh = ih * conv_param->stride_h_ - conv_param->pad_h_;
int ow = iw * conv_param->stride_w_ - conv_param->pad_w_;
int oh = ih * conv_param->stride_h_ - conv_param->pad_u_;
int ow = iw * conv_param->stride_w_ - conv_param->pad_l_;
int kh_start = MSMAX(0, UP_DIV(-oh, conv_param->dilation_h_));
int kh_end = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->output_h_ - oh, conv_param->dilation_h_));
@ -112,7 +112,7 @@ int DeConvPostFp16(const float16_t *src, float16_t *tmp, const float16_t *bias,
} /*ih*/
} /*oc8*/
PostConvFuncFp16C8(tmp, dst, bias, output_channel, output_plane, conv_param->output_channel_, conv_param->is_relu_,
conv_param->is_relu6_);
PostConvFuncFp16C8(tmp, dst, bias, output_channel, output_plane, conv_param->output_channel_,
conv_param->act_type_ == ActType_Relu, conv_param->act_type_ == ActType_Relu6);
return NNACL_OK;
}

View File

@ -21,14 +21,14 @@
void Conv1x1InputPackFp16(const float16_t *src, float16_t *dst, ConvParameter *conv_param) {
/* support nhwc */
for (int dst_h = 0; dst_h < conv_param->output_h_; dst_h++) {
int src_h = dst_h * conv_param->stride_h_ - conv_param->pad_h_;
int src_h = dst_h * conv_param->stride_h_ - conv_param->pad_u_;
if (src_h < 0 || src_h >= conv_param->input_h_) {
continue;
}
const float16_t *src_h_ptr = src + src_h * conv_param->input_w_ * conv_param->input_channel_;
float16_t *dst_h_ptr = dst + dst_h * conv_param->output_w_ * conv_param->input_channel_;
for (int dst_w = 0; dst_w < conv_param->output_w_; dst_w++) {
int src_w = dst_w * conv_param->stride_w_ - conv_param->pad_w_;
int src_w = dst_w * conv_param->stride_w_ - conv_param->pad_l_;
if (src_w < 0 || src_w >= conv_param->input_w_) {
continue;
}
@ -46,44 +46,40 @@ void Im2ColPackUnitFp16(float16_t *input_data, ConvParameter *conv_param, float1
int kernel_w = conv_param->kernel_w_;
int stride_h = conv_param->stride_h_;
int stride_w = conv_param->stride_w_;
int pad_h = conv_param->pad_h_;
int pad_w = conv_param->pad_w_;
int pad_h = conv_param->pad_u_;
int pad_w = conv_param->pad_l_;
int dilation_h = conv_param->dilation_h_;
int dilation_w = conv_param->dilation_w_;
int in_channel = conv_param->input_channel_;
int in_h = conv_param->input_h_;
int in_w = conv_param->input_w_;
int out_w = conv_param->output_w_;
int channel_block = UP_DIV(in_channel, 4);
int kernel_plane = kernel_h * kernel_w;
int ic4 = UP_DIV(in_channel, 4);
memset(packed_input, 0, kernel_w * kernel_h * ic4 * C4NUM * 16 * sizeof(float16_t));
for (int i = 0; i < real_cal_num; i++) {
int block_start = block_index + i;
int input_h = block_start / out_w * stride_h - pad_h;
int input_w = block_start % out_w * stride_w - pad_w;
for (int j = 0; j < kernel_h; j++) {
int input_y = input_h + j * dilation_h;
if (input_y < 0 || input_y >= in_h) {
continue;
}
int input_y_stride = input_y * in_w * channel_block * C4NUM;
for (int n = 0; n < kernel_w; n++) {
int input_x = input_w + n * dilation_w;
if (input_x < 0 || input_x >= in_w) {
continue;
}
int input_x_stride = input_y_stride + input_x * channel_block * C4NUM;
int input_plane_offset = (j * kernel_w + n) * 16 * C4NUM * channel_block + i * C4NUM;
for (int m = 0; m < channel_block; m++) {
int input_stride = input_h * in_w * ic4 * C4NUM + input_w * ic4 * C4NUM;
int kh_s = MSMAX(0, UP_DIV(-input_h, dilation_h));
int kh_e = MSMIN(kernel_h, UP_DIV(in_h - input_h, dilation_h));
int kw_s = MSMAX(0, UP_DIV(-input_w, dilation_w));
int kw_e = MSMIN(kernel_w, UP_DIV(in_w - input_w, dilation_w));
for (int j = kh_s; j < kh_e; j++) {
int input_y_stride = j * dilation_h * in_w * ic4 * C4NUM + input_stride;
for (int n = kw_s; n < kw_e; n++) {
int input_x_stride = input_y_stride + n * dilation_w * ic4 * C4NUM;
int input_plane_offset = (j * kernel_w + n) * 16 * C4NUM * ic4 + i * C4NUM;
for (int m = 0; m < ic4; m++) {
int channel_block_stride = input_x_stride + m * C4NUM;
int channel_block_offset = input_plane_offset + m * 16 * C4NUM;
#ifdef ENABLE_ARM64
vst1_f16(packed_input + channel_block_offset, vld1_f16(input_data + channel_block_stride));
#else
(packed_input + channel_block_offset)[0] = (input_data + channel_block_stride)[0];
(packed_input + channel_block_offset)[1] = (input_data + channel_block_stride)[1];
(packed_input + channel_block_offset)[2] = (input_data + channel_block_stride)[2];
(packed_input + channel_block_offset)[3] = (input_data + channel_block_stride)[3];
for (int l = 0; l < C4NUM; ++l) {
(packed_input + channel_block_offset)[l] = (input_data + channel_block_stride)[l];
}
#endif
} // channel_block loop
} // kernel_w loop
@ -221,6 +217,19 @@ void PackNCHWToNC4HW4Fp16(const void *src, void *dst, int batch, int plane, int
}
}
void PackNCHWToNHWCFp16(const void *src, void *dst, int batch, int plane, int channel) {
for (int n = 0; n < batch; n++) {
for (int c = 0; c < channel; c++) {
for (int hw = 0; hw < plane; hw++) {
int nhwc_index = n * channel * plane + hw * channel + c;
int nchw_index = n * channel * plane + c * plane + hw;
((float16_t *)(dst))[nhwc_index] = ((const float16_t *)(src))[nchw_index];
}
}
}
return;
}
void PackNHWCToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int channel) {
int ic4 = UP_DIV(channel, C4NUM);
int nhwc4_batch_unit_offset = ic4 * C4NUM * plane;

View File

@ -41,6 +41,8 @@ void PackNHWCToNC4HW4Fp16(const void *src, void *dst, int batch, int plane, int
void PackNCHWToNC4HW4Fp16(const void *src, void *dst, int batch, int plane, int channel);
void PackNCHWToNHWCFp16(const void *src, void *dst, int batch, int plane, int channel);
void PackNHWCToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int channel);
void PackNHWCToNHWC8Fp16(const void *src, void *dst, int batch, int plane, int channel);

View File

@ -0,0 +1,67 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "nnacl/fp16/softmax_fp16.h"
#include <math.h>
#include <float.h>
// output = exp(input) / reduce_sum(exp(input), axis)
void SoftmaxFp16(const float16_t *input_ptr, float16_t *output_ptr, float16_t *sum_data, SoftmaxParameter *parameter) {
int32_t axis = parameter->axis_;
int n_dim = parameter->n_dim_;
int ele_size = parameter->element_size_;
int *input_shape = parameter->input_shape_;
float16_t max_data = input_ptr[0];
for (int i = 0; i < ele_size; i++) {
max_data = max_data > input_ptr[i] ? max_data : input_ptr[i];
}
for (int i = 0; i < ele_size; i++) {
output_ptr[i] = exp(input_ptr[i] - max_data);
}
int inner_size = 1, outter_size = 1;
for (int i = 0; i < axis; i++) {
outter_size *= input_shape[i];
}
for (int i = axis + 1; i < n_dim; i++) {
inner_size *= input_shape[i];
}
for (int i = 0; i < outter_size; i++) {
int outter_offset = i * input_shape[axis] * inner_size;
int sum_outter_offset = i * inner_size;
for (int k = 0; k < inner_size; k++) {
int inner_offset = outter_offset + k;
for (int j = 0; j < input_shape[axis]; j++) {
int axis_offset = inner_offset + j * inner_size;
sum_data[k + sum_outter_offset] += output_ptr[axis_offset];
}
}
}
for (int i = 0; i < outter_size; i++) {
int outter_offset = i * input_shape[axis] * inner_size;
int sum_outter_offset = i * inner_size;
for (int j = 0; j < input_shape[axis]; j++) {
int axis_offset = outter_offset + j * inner_size;
for (int k = 0; k < inner_size; k++) {
int inner_offset = axis_offset + k;
output_ptr[inner_offset] = output_ptr[inner_offset] / sum_data[k + sum_outter_offset];
}
}
}
}

View File

@ -0,0 +1,33 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_NNACL_FP16_SOFTMAX_FP16_H_
#define MINDSPORE_LITE_NNACL_FP16_SOFTMAX_FP16_H_
#include "nnacl/op_base.h"
#include "nnacl/softmax_parameter.h"
#ifdef ENABLE_NEON
#include <arm_neon.h>
#endif
#ifdef __cplusplus
extern "C" {
#endif
void SoftmaxFp16(const float16_t *input_ptr, float16_t *output_ptr, float16_t *sum_data, SoftmaxParameter *parameter);
#ifdef __cplusplus
}
#endif
#endif // MINDSPORE_LITE_NNACL_FP16_SOFTMAX_FP16_H_

View File

@ -230,8 +230,8 @@ void Conv3x3Fp16InputTransform(const float16_t *input_data, float16_t *trans_inp
int input_channel = conv_param->input_channel_;
int input_width = conv_param->input_w_;
int input_height = conv_param->input_h_;
int pad_w = conv_param->pad_w_;
int pad_h = conv_param->pad_h_;
int pad_w = conv_param->pad_l_;
int pad_h = conv_param->pad_u_;
int ic8 = UP_DIV(input_channel, C8NUM);
if (out_w_block == 0) {
return;
@ -576,8 +576,8 @@ void WinogradInputTransformFp16(const float16_t *input_data, float16_t *trans_in
int output_unit = conv_param->output_unit_;
int in_channel = conv_param->input_channel_;
int ic8 = UP_DIV(in_channel, C8NUM);
int pad_h = conv_param->pad_h_;
int pad_w = conv_param->pad_w_;
int pad_h = conv_param->pad_u_;
int pad_w = conv_param->pad_l_;
int input_h = conv_param->input_h_;
int input_w = conv_param->input_w_;
if (out_w_block_num == 0) {
@ -607,7 +607,7 @@ void WinogradInputTransformFp16(const float16_t *input_data, float16_t *trans_in
for (int j = 0; j < (interval_x_e - interval_x_s); j++) {
int src_x_offset = src_y_offset + j * ic8 * C8NUM;
int dst_x_offset = dst_y_offset + j * C8NUM;
float16_t *src_addr = input_data + src_x_offset;
const float16_t *src_addr = input_data + src_x_offset;
float16_t *dst_addr = tmp_data + dst_x_offset;
#ifdef ENABLE_NEON
vst1q_f16(dst_addr, vld1q_f16(src_addr));

View File

@ -43,15 +43,33 @@ int LRelu(const float *src, int length, float *dst, float alpha) {
}
int Sigmoid(const float *src, int length, float *dst) {
const float upper_bound = 16.619047164916992188f;
const float lower_bound = -9.0f;
for (int i = 0; i < length; ++i) {
dst[i] = 1.0f / (1.0f + exp(-src[i]));
float input_val = src[i];
float result;
if (input_val > upper_bound) {
result = 1.0f;
} else if (input_val < lower_bound) {
result = exp(input_val);
} else {
result = 1.0f / (1.0f + exp(-input_val));
}
dst[i] = result;
}
return NNACL_OK;
}
int Tanh(const float *src, int length, float *dst) {
for (int i = 0; i < length; ++i) {
dst[i] = 1.0f - 2.0f / (exp(2 * src[i]) + 1);
float tmp_in = src[i];
if (tmp_in > 5.0) {
dst[i] = 1.0f;
} else if (tmp_in < -5.0) {
dst[i] = -1.0f;
} else {
dst[i] = 1.0f - 2.0f / (exp(2 * tmp_in) + 1);
}
}
return NNACL_OK;
}

View File

@ -20,53 +20,453 @@
#define ACCURACY_DATA 0.00000001
int ElementOptMul(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param) {
int block_mod = element_size % C4NUM;
int block_c4 = element_size - block_mod;
float in0_opt = input0[0];
float in1_opt = input1[0];
#ifdef ENABLE_NEON
float32x4_t vin0_opt = {input0[0], input0[0], input0[0], input0[0]};
float32x4_t vin1_opt = {input1[0], input1[0], input1[0], input1[0]};
#endif
if (param->in_elements_num0_ == 1) {
for (int i = 0; i < element_size; ++i) {
output[i] = input0[0] * input1[i];
for (int index = 0; index < block_c4; index += C4NUM) {
#ifdef ENABLE_NEON
float32x4_t vin0 = vin0_opt;
float32x4_t vin1 = vld1q_f32(input1);
float32x4_t vout = vmulq_f32(vin0, vin1);
vst1q_f32(output, vout);
#else
for (int i = 0; i < C4NUM; ++i) {
output[i] = in0_opt * input1[i];
}
#endif
input1 += C4NUM;
output += C4NUM;
}
} else if (param->in_elements_num1_ == 1) {
for (int i = 0; i < element_size; ++i) {
output[i] = input0[i] * input1[0];
for (int index = 0; index < block_mod; ++index) {
output[index] = in0_opt * input1[index];
}
} else {
for (int i = 0; i < element_size; ++i) {
output[i] = input0[i] * input1[i];
for (int index = 0; index < block_c4; index += C4NUM) {
#ifdef ENABLE_NEON
float32x4_t vin0 = vld1q_f32(input0);
float32x4_t vin1 = vin1_opt;
float32x4_t vout = vmulq_f32(vin0, vin1);
vst1q_f32(output, vout);
#else
for (int i = 0; i < C4NUM; ++i) {
output[i] = input0[i] * in1_opt;
}
#endif
input0 += C4NUM;
output += C4NUM;
}
for (int index = 0; index < block_mod; ++index) {
output[index] = input0[index] * in1_opt;
}
}
return NNACL_OK;
}
int ElementOptMulRelu(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param) {
int block_mod = element_size % C4NUM;
int block_c4 = element_size - block_mod;
float in0_opt = input0[0];
float in1_opt = input1[0];
#ifdef ENABLE_NEON
float32x4_t vin0_opt = {input0[0], input0[0], input0[0], input0[0]};
float32x4_t vin1_opt = {input1[0], input1[0], input1[0], input1[0]};
float32x4_t zeros = {0, 0, 0, 0};
#endif
if (param->in_elements_num0_ == 1) {
for (int index = 0; index < block_c4; index += C4NUM) {
#ifdef ENABLE_NEON
float32x4_t vin0 = vin0_opt;
float32x4_t vin1 = vld1q_f32(input1);
float32x4_t vout = vmaxq_f32(vmulq_f32(vin0, vin1), zeros);
vst1q_f32(output, vout);
#else
for (int i = 0; i < C4NUM; ++i) {
output[i] = MSMAX(in0_opt * input1[i], 0);
}
#endif
input1 += C4NUM;
output += C4NUM;
}
for (int index = 0; index < block_mod; ++index) {
output[index] = MSMAX(in0_opt * input1[index], 0);
}
} else {
for (int index = 0; index < block_c4; index += C4NUM) {
#ifdef ENABLE_NEON
float32x4_t vin0 = vld1q_f32(input0);
float32x4_t vin1 = vin1_opt;
float32x4_t vout = vmaxq_f32(vmulq_f32(vin0, vin1), zeros);
vst1q_f32(output, vout);
#else
for (int i = 0; i < C4NUM; ++i) {
output[i] = MSMAX(input0[i] * in1_opt, 0);
}
#endif
input0 += C4NUM;
output += C4NUM;
}
for (int index = 0; index < block_mod; ++index) {
output[index] = MSMAX(input0[index] * in1_opt, 0);
}
}
return NNACL_OK;
}
int ElementOptMulRelu6(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param) {
int block_mod = element_size % C4NUM;
int block_c4 = element_size - block_mod;
float in0_opt = input0[0];
float in1_opt = input1[0];
#ifdef ENABLE_NEON
float32x4_t vin0_opt = {input0[0], input0[0], input0[0], input0[0]};
float32x4_t vin1_opt = {input1[0], input1[0], input1[0], input1[0]};
float32x4_t zeros = {0, 0, 0, 0};
float32x4_t bounds = {6, 6, 6, 6};
#endif
if (param->in_elements_num0_ == 1) {
for (int index = 0; index < block_c4; index += C4NUM) {
#ifdef ENABLE_NEON
float32x4_t vin0 = vin0_opt;
float32x4_t vin1 = vld1q_f32(input1);
float32x4_t vout = vminq_f32(vmaxq_f32(vmulq_f32(vin0, vin1), zeros), bounds);
vst1q_f32(output, vout);
#else
for (int i = 0; i < C4NUM; ++i) {
output[i] = MSMIN(MSMAX(in0_opt * input1[i], 0), 6);
}
#endif
input1 += C4NUM;
output += C4NUM;
}
for (int index = 0; index < block_mod; ++index) {
output[index] = MSMIN(MSMAX(in0_opt * input1[index], 0), 6);
}
} else {
for (int index = 0; index < block_c4; index += C4NUM) {
#ifdef ENABLE_NEON
float32x4_t vin0 = vld1q_f32(input0);
float32x4_t vin1 = vin1_opt;
float32x4_t vout = vminq_f32(vmaxq_f32(vmulq_f32(vin0, vin1), zeros), bounds);
vst1q_f32(output, vout);
#else
for (int i = 0; i < C4NUM; ++i) {
output[i] = MSMIN(MSMAX(input0[i] * in1_opt, 0), 6);
}
#endif
input0 += C4NUM;
output += C4NUM;
}
for (int index = 0; index < block_mod; ++index) {
output[index] = MSMIN(MSMAX(input0[index] * in1_opt, 0), 6);
}
}
return NNACL_OK;
}
int ElementOptSub(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param) {
int block_mod = element_size % C4NUM;
int block_c4 = element_size - block_mod;
float in0_opt = input0[0];
float in1_opt = input1[0];
#ifdef ENABLE_NEON
float32x4_t vin0_opt = {input0[0], input0[0], input0[0], input0[0]};
float32x4_t vin1_opt = {input1[0], input1[0], input1[0], input1[0]};
#endif
if (param->in_elements_num0_ == 1) {
for (int i = 0; i < element_size; ++i) {
output[i] = input0[0] - input1[i];
for (int index = 0; index < block_c4; index += C4NUM) {
#ifdef ENABLE_NEON
float32x4_t vin0 = vin0_opt;
float32x4_t vin1 = vld1q_f32(input1);
float32x4_t vout = vsubq_f32(vin0, vin1);
vst1q_f32(output, vout);
#else
for (int i = 0; i < C4NUM; ++i) {
output[i] = in0_opt - input1[i];
}
#endif
input1 += C4NUM;
output += C4NUM;
}
} else if (param->in_elements_num1_ == 1) {
for (int i = 0; i < element_size; ++i) {
output[i] = input0[i] - input1[0];
for (int index = 0; index < block_mod; ++index) {
output[index] = in0_opt - input1[index];
}
} else {
for (int i = 0; i < element_size; ++i) {
output[i] = input0[i] - input1[i];
for (int index = 0; index < block_c4; index += C4NUM) {
#ifdef ENABLE_NEON
float32x4_t vin0 = vld1q_f32(input0);
float32x4_t vin1 = vin1_opt;
float32x4_t vout = vsubq_f32(vin0, vin1);
vst1q_f32(output, vout);
#else
for (int i = 0; i < C4NUM; ++i) {
output[i] = input0[i] - in1_opt;
}
#endif
input0 += C4NUM;
output += C4NUM;
}
for (int index = 0; index < block_mod; ++index) {
output[index] = input0[index] - in1_opt;
}
}
return NNACL_OK;
}
int ElementOptAdd(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param) {
int ElementOptSubRelu(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param) {
int block_mod = element_size % C4NUM;
int block_c4 = element_size - block_mod;
float in0_opt = input0[0];
float in1_opt = input1[0];
#ifdef ENABLE_NEON
float32x4_t vin0_opt = {input0[0], input0[0], input0[0], input0[0]};
float32x4_t vin1_opt = {input1[0], input1[0], input1[0], input1[0]};
float32x4_t zeros = {0, 0, 0, 0};
#endif
if (param->in_elements_num0_ == 1) {
for (int i = 0; i < element_size; ++i) {
output[i] = input0[0] + input1[i];
for (int index = 0; index < block_c4; index += C4NUM) {
#ifdef ENABLE_NEON
float32x4_t vin0 = vin0_opt;
float32x4_t vin1 = vld1q_f32(input1);
float32x4_t vout = vmaxq_f32(vsubq_f32(vin0, vin1), zeros);
vst1q_f32(output, vout);
#else
for (int i = 0; i < C4NUM; ++i) {
output[i] = MSMAX(in0_opt - input1[i], 0);
}
#endif
input1 += C4NUM;
output += C4NUM;
}
} else if (param->in_elements_num1_ == 1) {
for (int i = 0; i < element_size; ++i) {
output[i] = input0[i] + input1[0];
for (int index = 0; index < block_mod; ++index) {
output[index] = MSMAX(in0_opt - input1[index], 0);
}
} else {
for (int i = 0; i < element_size; ++i) {
output[i] = input0[i] + input1[i];
for (int index = 0; index < block_c4; index += C4NUM) {
#ifdef ENABLE_NEON
float32x4_t vin0 = vld1q_f32(input0);
float32x4_t vin1 = vin1_opt;
float32x4_t vout = vmaxq_f32(vsubq_f32(vin0, vin1), zeros);
vst1q_f32(output, vout);
#else
for (int i = 0; i < C4NUM; ++i) {
output[i] = MSMAX(input0[i] - in1_opt, 0);
}
#endif
input0 += C4NUM;
output += C4NUM;
}
for (int index = 0; index < block_mod; ++index) {
output[index] = MSMAX(input0[index] - in1_opt, 0);
}
}
return NNACL_OK;
}
int ElementOptSubRelu6(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param) {
int block_mod = element_size % C4NUM;
int block_c4 = element_size - block_mod;
float in0_opt = input0[0];
float in1_opt = input1[0];
#ifdef ENABLE_NEON
float32x4_t vin0_opt = {input0[0], input0[0], input0[0], input0[0]};
float32x4_t vin1_opt = {input1[0], input1[0], input1[0], input1[0]};
float32x4_t zeros = {0, 0, 0, 0};
float32x4_t bounds = {6, 6, 6, 6};
#endif
if (param->in_elements_num0_ == 1) {
for (int index = 0; index < block_c4; index += C4NUM) {
#ifdef ENABLE_NEON
float32x4_t vin0 = vin0_opt;
float32x4_t vin1 = vld1q_f32(input1);
float32x4_t vout = vminq_f32(vmaxq_f32(vsubq_f32(vin0, vin1), zeros), bounds);
vst1q_f32(output, vout);
#else
for (int i = 0; i < C4NUM; ++i) {
output[i] = MSMIN(MSMAX(in0_opt - input1[i], 0), 6);
}
#endif
input1 += C4NUM;
output += C4NUM;
}
for (int index = 0; index < block_mod; ++index) {
output[index] = MSMIN(MSMAX(in0_opt - input1[index], 0), 6);
}
} else {
for (int index = 0; index < block_c4; index += C4NUM) {
#ifdef ENABLE_NEON
float32x4_t vin0 = vld1q_f32(input0);
float32x4_t vin1 = vin1_opt;
float32x4_t vout = vminq_f32(vmaxq_f32(vsubq_f32(vin0, vin1), zeros), bounds);
vst1q_f32(output, vout);
#else
for (int i = 0; i < C4NUM; ++i) {
output[i] = MSMIN(MSMAX(input0[i] - in1_opt, 0), 6);
}
#endif
input0 += C4NUM;
output += C4NUM;
}
for (int index = 0; index < block_mod; ++index) {
output[index] = MSMIN(MSMAX(input0[index] - in1_opt, 0), 6);
}
}
return NNACL_OK;
}
int ElementOptAdd(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param) {
int block_mod = element_size % C4NUM;
int block_c4 = element_size - block_mod;
float in0_opt = input0[0];
float in1_opt = input1[0];
#ifdef ENABLE_NEON
float32x4_t vin0_opt = {input0[0], input0[0], input0[0], input0[0]};
float32x4_t vin1_opt = {input1[0], input1[0], input1[0], input1[0]};
#endif
if (param->in_elements_num0_ == 1) {
for (int index = 0; index < block_c4; index += C4NUM) {
#ifdef ENABLE_NEON
float32x4_t vin0 = vin0_opt;
float32x4_t vin1 = vld1q_f32(input1);
float32x4_t vout = vaddq_f32(vin0, vin1);
vst1q_f32(output, vout);
#else
for (int i = 0; i < C4NUM; ++i) {
output[i] = in0_opt + input1[i];
}
#endif
input1 += C4NUM;
output += C4NUM;
}
for (int index = 0; index < block_mod; ++index) {
output[index] = in0_opt + input1[index];
}
} else {
for (int index = 0; index < block_c4; index += C4NUM) {
#ifdef ENABLE_NEON
float32x4_t vin0 = vld1q_f32(input0);
float32x4_t vin1 = vin1_opt;
float32x4_t vout = vaddq_f32(vin0, vin1);
vst1q_f32(output, vout);
#else
for (int i = 0; i < C4NUM; ++i) {
output[i] = input0[i] + in1_opt;
}
#endif
input0 += C4NUM;
output += C4NUM;
}
for (int index = 0; index < block_mod; ++index) {
output[index] = input0[index] + in1_opt;
}
}
return NNACL_OK;
}
int ElementOptAddRelu(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param) {
int block_mod = element_size % C4NUM;
int block_c4 = element_size - block_mod;
float in0_opt = input0[0];
float in1_opt = input1[0];
#ifdef ENABLE_NEON
float32x4_t vin0_opt = {input0[0], input0[0], input0[0], input0[0]};
float32x4_t vin1_opt = {input1[0], input1[0], input1[0], input1[0]};
float32x4_t zeros = {0, 0, 0, 0};
#endif
if (param->in_elements_num0_ == 1) {
for (int index = 0; index < block_c4; index += C4NUM) {
#ifdef ENABLE_NEON
float32x4_t vin0 = vin0_opt;
float32x4_t vin1 = vld1q_f32(input1);
float32x4_t vout = vmaxq_f32(vaddq_f32(vin0, vin1), zeros);
vst1q_f32(output, vout);
#else
for (int i = 0; i < C4NUM; ++i) {
output[i] = MSMAX(in0_opt + input1[i], 0);
}
#endif
input1 += C4NUM;
output += C4NUM;
}
for (int index = 0; index < block_mod; ++index) {
output[index] = MSMAX(in0_opt + input1[index], 0);
}
} else {
for (int index = 0; index < block_c4; index += C4NUM) {
#ifdef ENABLE_NEON
float32x4_t vin0 = vld1q_f32(input0);
float32x4_t vin1 = vin1_opt;
float32x4_t vout = vmaxq_f32(vaddq_f32(vin0, vin1), zeros);
vst1q_f32(output, vout);
#else
for (int i = 0; i < C4NUM; ++i) {
output[i] = MSMAX(input0[i] + in1_opt, 0);
}
#endif
input0 += C4NUM;
output += C4NUM;
}
for (int index = 0; index < block_mod; ++index) {
output[index] = MSMAX(input0[index] + in1_opt, 0);
}
}
return NNACL_OK;
}
int ElementOptAddRelu6(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param) {
int block_mod = element_size % C4NUM;
int block_c4 = element_size - block_mod;
float in0_opt = input0[0];
float in1_opt = input1[0];
#ifdef ENABLE_NEON
float32x4_t vin0_opt = {input0[0], input0[0], input0[0], input0[0]};
float32x4_t vin1_opt = {input1[0], input1[0], input1[0], input1[0]};
float32x4_t zeros = {0, 0, 0, 0};
float32x4_t bounds = {6, 6, 6, 6};
#endif
if (param->in_elements_num0_ == 1) {
for (int index = 0; index < block_c4; index += C4NUM) {
#ifdef ENABLE_NEON
float32x4_t vin0 = vin0_opt;
float32x4_t vin1 = vld1q_f32(input1);
float32x4_t vout = vminq_f32(vmaxq_f32(vaddq_f32(vin0, vin1), zeros), bounds);
vst1q_f32(output, vout);
#else
for (int i = 0; i < C4NUM; ++i) {
output[i] = MSMIN(MSMAX(in0_opt + input1[i], 0), 6);
}
#endif
input1 += C4NUM;
output += C4NUM;
}
for (int index = 0; index < block_mod; ++index) {
output[index] = MSMIN(MSMAX(in0_opt + input1[index], 0), 6);
}
} else {
for (int index = 0; index < block_c4; index += C4NUM) {
#ifdef ENABLE_NEON
float32x4_t vin0 = vld1q_f32(input0);
float32x4_t vin1 = vin1_opt;
float32x4_t vout = vminq_f32(vmaxq_f32(vaddq_f32(vin0, vin1), zeros), bounds);
vst1q_f32(output, vout);
#else
for (int i = 0; i < C4NUM; ++i) {
output[i] = MSMIN(MSMAX(input0[i] + in1_opt, 0), 6);
}
#endif
input0 += C4NUM;
output += C4NUM;
}
for (int index = 0; index < block_mod; ++index) {
output[index] = MSMIN(MSMAX(input0[index] + in1_opt, 0), 6);
}
}
return NNACL_OK;
}

View File

@ -27,8 +27,14 @@
extern "C" {
#endif
int ElementOptAdd(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param);
int ElementOptAddRelu(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param);
int ElementOptAddRelu6(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param);
int ElementOptSub(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param);
int ElementOptSubRelu(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param);
int ElementOptSubRelu6(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param);
int ElementOptMul(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param);
int ElementOptMulRelu(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param);
int ElementOptMulRelu6(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param);
int ElementMul(float *input0, float *input1, float *output, int element_size);
int ElementMulRelu(float *input0, float *input1, float *output, int element_size);
int ElementMulRelu6(float *input0, float *input1, float *output, int element_size);

View File

@ -15,7 +15,6 @@
*/
#include "nnacl/fp32/batchnorm.h"
#include "nnacl/fp16/batchnorm_fp16.h"
#include <math.h>
#include "nnacl/batchnorm_parameter.h"
#include "nnacl/op_base.h"

View File

@ -18,6 +18,7 @@
#include <string.h>
#include "nnacl/fp32/common_func.h"
#include "nnacl/winograd_transform.h"
#include "nnacl/fp32/matmul.h"
void SWBorderPixel(float *dst, const float *src, const float *weight, const float *bias, int height, int width,
int in_kh_step, int in_kw_step, int kernel_h, int kernel_w, int ic4, bool is_relu, bool is_relu6) {
@ -57,16 +58,18 @@ void SWBorderPixel(float *dst, const float *src, const float *weight, const floa
void SWBorder(float *dst, const float *src, const float *weight, const float *bias, int top, int bottom, int left,
int right, const ConvParameter *conv_param, const SlidingWindowParam *sliding) {
int ic4 = sliding->ic4_channel_ / C4NUM;
bool relu = conv_param->act_type_ == ActType_Relu;
bool relu6 = conv_param->act_type_ == ActType_Relu6;
float *dst_h = dst + top * sliding->out_h_step_;
for (int oh = top; oh < bottom; oh++) {
int ih = oh * conv_param->stride_h_ - conv_param->pad_h_;
int ih = oh * conv_param->stride_h_ - conv_param->pad_u_;
int start_kh = MSMAX(0, UP_DIV(-ih, conv_param->dilation_h_));
int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih, conv_param->dilation_h_));
const float *src_h = src + ih * sliding->in_h_step_;
float *dst_kernel = dst_h + left * sliding->block_channel_;
for (int ow = left; ow < right; ow++) {
int iw = ow * conv_param->stride_w_ - conv_param->pad_w_;
int iw = ow * conv_param->stride_w_ - conv_param->pad_l_;
int start_kw = MSMAX(0, UP_DIV(-iw, conv_param->dilation_w_));
int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->input_w_ - iw, conv_param->dilation_w_));
const float *src_w = src_h + iw * sliding->ic4_channel_;
@ -75,8 +78,8 @@ void SWBorder(float *dst, const float *src, const float *weight, const float *bi
const float *weight_kernel = weight + (start_kh * conv_param->kernel_w_ + start_kw) * sliding->ic4_channel_;
SWBorderPixel(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw,
sliding->in_kh_step_, sliding->in_kw_step_, conv_param->kernel_h_, conv_param->kernel_w_, ic4,
conv_param->is_relu_, conv_param->is_relu6_);
sliding->in_kh_step_, sliding->in_kw_step_, conv_param->kernel_h_, conv_param->kernel_w_, ic4, relu,
relu6);
dst_kernel += sliding->block_channel_;
} // width loop
@ -144,6 +147,8 @@ void ConvSWFp32(const float *input_data, const float *packed_weight, const float
float *output_data, int task_id, ConvParameter *conv_param, SlidingWindowParam *slidingWindow_param) {
int ic4 = slidingWindow_param->ic4_channel_ / C4NUM;
int oc4_res = conv_param->output_channel_ % C4NUM;
bool relu = conv_param->act_type_ == ActType_Relu;
bool relu6 = conv_param->act_type_ == ActType_Relu6;
const float *src = input_data;
float *dst = NULL;
if (oc4_res == 0) {
@ -169,28 +174,26 @@ void ConvSWFp32(const float *input_data, const float *packed_weight, const float
if (slidingWindow_param->right_ > slidingWindow_param->left_ &&
slidingWindow_param->bottom_ > slidingWindow_param->top_) {
int in_h_start = slidingWindow_param->top_ * conv_param->stride_h_ - conv_param->pad_h_;
int in_w_start = slidingWindow_param->left_ * conv_param->stride_w_ - conv_param->pad_w_;
int in_h_start = slidingWindow_param->top_ * conv_param->stride_h_ - conv_param->pad_u_;
int in_w_start = slidingWindow_param->left_ * conv_param->stride_w_ - conv_param->pad_l_;
const float *in_t =
src_data + in_h_start * slidingWindow_param->in_h_step_ + in_w_start * slidingWindow_param->ic4_channel_;
float *out_t = dst_data + slidingWindow_param->top_ * slidingWindow_param->out_h_step_ +
slidingWindow_param->left_ * slidingWindow_param->block_channel_;
#ifdef ENABLE_ARM64
ConvSwFp32Center(out_t, in_t, weight, bias, slidingWindow_param->bottom_ - slidingWindow_param->top_,
slidingWindow_param->right_ - slidingWindow_param->left_, conv_param->kernel_h_,
conv_param->kernel_w_, slidingWindow_param->out_h_step_ * sizeof(float),
slidingWindow_param->block_channel_ * sizeof(float), ic4,
slidingWindow_param->in_sh_step_ * sizeof(float),
slidingWindow_param->in_sw_step_ * sizeof(float),
slidingWindow_param->in_kh_step_ * sizeof(float),
slidingWindow_param->in_kw_step_ * sizeof(float),
conv_param->is_relu_, conv_param->is_relu6_);
ConvSwFp32Center(
out_t, in_t, weight, bias, slidingWindow_param->bottom_ - slidingWindow_param->top_,
slidingWindow_param->right_ - slidingWindow_param->left_, conv_param->kernel_h_, conv_param->kernel_w_,
slidingWindow_param->out_h_step_ * sizeof(float), slidingWindow_param->block_channel_ * sizeof(float), ic4,
slidingWindow_param->in_sh_step_ * sizeof(float), slidingWindow_param->in_sw_step_ * sizeof(float),
slidingWindow_param->in_kh_step_ * sizeof(float), slidingWindow_param->in_kw_step_ * sizeof(float), relu,
relu6);
#else
SWCenter(out_t, in_t, weight, bias, slidingWindow_param->bottom_ - slidingWindow_param->top_,
slidingWindow_param->right_ - slidingWindow_param->left_, conv_param->kernel_h_,
conv_param->kernel_w_, slidingWindow_param->out_h_step_, slidingWindow_param->block_channel_, ic4,
slidingWindow_param->right_ - slidingWindow_param->left_, conv_param->kernel_h_, conv_param->kernel_w_,
slidingWindow_param->out_h_step_, slidingWindow_param->block_channel_, ic4,
slidingWindow_param->in_sh_step_, slidingWindow_param->in_sw_step_, slidingWindow_param->in_kh_step_,
slidingWindow_param->in_kw_step_, conv_param->is_relu_, conv_param->is_relu6_);
slidingWindow_param->in_kw_step_, relu, relu6);
#endif
}
} // output C4 loop
@ -219,6 +222,8 @@ void ConvFp32(float *input_data, float *packed_input, float *packed_weight, cons
int kernel_plane = kernel_h * kernel_w;
int unit_size = kernel_plane * ic4 * C4NUM;
int packed_input_size = output_tile_count * TILE_NUM * unit_size;
bool relu = conv_param->act_type_ == ActType_Relu;
bool relu6 = conv_param->act_type_ == ActType_Relu6;
// we accumulate 4 channels per time for input blocks
int conv_depth = kernel_h * kernel_w;
@ -240,23 +245,18 @@ void ConvFp32(float *input_data, float *packed_input, float *packed_weight, cons
if (real_cal_num == TILE_NUM) {
float *gemm_output = output_data + out_offset;
gemm_func(gemm_output, gemm_input, packed_weight, bias_data, conv_depth, ic4, out_channel, output_offset, 0, 0,
conv_param->is_relu_, conv_param->is_relu6_);
relu, relu6);
} else {
// res part
gemm_func(tmp_out_block, gemm_input, packed_weight, bias_data, conv_depth, ic4, out_channel, output_offset, 0,
0, conv_param->is_relu_, conv_param->is_relu6_);
memcpy(output_data + out_offset, tmp_out_block, real_cal_num * out_channel * sizeof(float));
float *tmp_out_ptr = tmp_out_block + task_id * TILE_NUM * out_channel;
gemm_func(tmp_out_ptr, gemm_input, packed_weight, bias_data, conv_depth, ic4, out_channel, output_offset, 0, 0,
relu, relu6);
memcpy(output_data + out_offset, tmp_out_ptr, real_cal_num * out_channel * sizeof(float));
}
}
}
}
// fp32 conv1x1 strassen matmul
int Conv1x1Fp32(const float *input_data, const float *weight_data, float *output_data, float *tmp_ptr,
StrassenMatMulParameter matmul_param) {
return StrassenMatmul(input_data, weight_data, output_data, &matmul_param, FP32_STRASSEN_MAX_RECURSION, 0, tmp_ptr);
}
// fp32 conv winograd
void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_data, TmpBufferAddress *buffer_list,
int task_id, ConvParameter *conv_param, InputTransformUnitFunc input_trans_func,
@ -270,38 +270,46 @@ void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_
int out_w_block = UP_DIV(conv_param->output_w_, out_unit);
int out_h_block = UP_DIV(conv_param->output_h_, out_unit);
int output_count = out_w_block * out_h_block;
int output_tile_count = UP_DIV(output_count, TILE_NUM);
int output_tile_count = UP_DIV(output_count, C12NUM);
int out_channel = conv_param->output_channel_;
int oc4 = UP_DIV(out_channel, C4NUM);
int oc8 = UP_DIV(out_channel, C8NUM);
int input_unit_square = input_unit * input_unit;
size_t output_offset = oc4 * C4NUM * input_unit_square * sizeof(float);
float *trans_input = buffer_list[0];
float *gemm_out = buffer_list[1];
float *tmp_out_data = buffer_list[2];
float *tmp_data = buffer_list[3];
int trans_input_offset = TILE_NUM * input_unit_square * ic4 * C4NUM;
int gemm_out_offset = TILE_NUM * input_unit_square * oc4 * C4NUM;
float *col_buffer = buffer_list[4];
int trans_input_offset = C12NUM * input_unit_square * ic4 * C4NUM;
int gemm_out_offset = C12NUM * input_unit_square * oc8 * C8NUM;
int tmp_data_offset = input_unit_square * C4NUM;
int col_buffer_offset = C12NUM * ic4 * C4NUM;
// step 1 : filter transform (pre-processed offline)
// step 2 : input transform (online)
for (int b = 0; b < in_batch; b++) {
int in_batch_offset = b * ic4 * C4NUM * conv_param->input_h_ * conv_param->input_w_;
int tmp_out_batch_offset = b * out_w_block * out_h_block * out_unit * out_unit * oc4 * C4NUM;
for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_num) {
int out_tile_index = thread_id * TILE_NUM;
int cal_num = output_count - thread_id * TILE_NUM;
cal_num = cal_num > TILE_NUM ? TILE_NUM : cal_num;
int out_tile_index = thread_id * C12NUM;
int cal_num = output_count - thread_id * C12NUM;
cal_num = cal_num > C12NUM ? C12NUM : cal_num;
WinogradInputTransform(input_data + in_batch_offset, trans_input + task_id * trans_input_offset,
tmp_data + task_id * tmp_data_offset, cal_num, out_tile_index, out_w_block, conv_param,
input_trans_func);
// step 3 : gemm
gemm_func(gemm_out + task_id * gemm_out_offset, trans_input + task_id * trans_input_offset, trans_weight, NULL,
input_unit_square, ic4, oc4 * C4NUM, output_offset, 1, 1, 0, 0);
float *src_ptr = trans_input + task_id * trans_input_offset;
float *dst_ptr = gemm_out + task_id * gemm_out_offset;
float *tmp_col_ptr = col_buffer + task_id * col_buffer_offset;
for (int i = 0; i < input_unit_square; ++i) {
RowMajor2Col12Major(src_ptr + i * C12NUM * ic4 * C4NUM, tmp_col_ptr, C12NUM, ic4 * C4NUM);
MatMulOpt(tmp_col_ptr, trans_weight + i * ic4 * C4NUM * oc8 * C8NUM, dst_ptr + i * C8NUM, NULL, 0, ic4 * C4NUM,
cal_num, oc8 * C8NUM, input_unit_square, 2);
}
// step 4 : output transform
WinogradOutputTransform(gemm_out + task_id * gemm_out_offset, tmp_out_data + tmp_out_batch_offset, bias_data,
cal_num, out_tile_index, out_w_block, conv_param, output_trans_func);
WinogradOutputTransform(dst_ptr, tmp_out_data + tmp_out_batch_offset, bias_data, cal_num, out_tile_index,
out_w_block, conv_param, output_trans_func);
}
}
}
@ -442,24 +450,28 @@ void UnPackWinogradRelu6Output(const float *src, float *dst, int batch, int heig
}
// fp32 conv3x3
void Conv3x3Fp32(float *input_data, float *transed_weight, const float *bias_data, float *output_data,
TmpBufferAddress *buffer_list, int task_id, ConvParameter *conv_param, GEMM_FUNC_FP32 gemm_func) {
void Conv3x3Fp32(float *input_data, float *transed_weight, const float *bias_data, TmpBufferAddress *buffer_list,
int task_id, ConvParameter *conv_param, GEMM_FUNC_FP32 gemm_func) {
int thread_count = conv_param->thread_num_;
int ic4 = UP_DIV(conv_param->input_channel_, C4NUM);
int output_channel = conv_param->output_channel_;
int oc4 = UP_DIV(output_channel, C4NUM);
int oc8 = UP_DIV(output_channel, C8NUM);
int out_w_block = UP_DIV(conv_param->output_w_, OUPUT_UNIT);
int out_h_block = UP_DIV(conv_param->output_h_, OUPUT_UNIT);
int output_count = out_w_block * out_h_block;
int output_tile_count = UP_DIV(output_count, TILE_NUM);
int output_tile_count = UP_DIV(output_count, C12NUM);
const int input_unit_square = 4 * 4;
float *tile_buffer = buffer_list[0];
float *block_unit_buffer = buffer_list[1];
float *tmp_dst_buffer = buffer_list[2];
float *nc4hw4_out = buffer_list[3];
int tile_buffer_offset = TILE_NUM * input_unit_square * ic4 * C4NUM;
float *col_buffer = buffer_list[4];
int tile_buffer_offset = C12NUM * input_unit_square * ic4 * C4NUM;
int block_unit_buffer_offset = input_unit_square * C4NUM;
int tmp_dst_buffer_offset = TILE_NUM * input_unit_square * oc4 * C4NUM;
int tmp_dst_buffer_offset = C12NUM * input_unit_square * oc8 * C8NUM;
int col_buffer_offset = C12NUM * ic4 * C4NUM;
int input_batch = conv_param->input_batch_;
for (int batch = 0; batch < input_batch; batch++) {
@ -467,18 +479,22 @@ void Conv3x3Fp32(float *input_data, float *transed_weight, const float *bias_dat
int nc4hw4_buffer_offset = batch * oc4 * C4NUM * conv_param->output_h_ * conv_param->output_w_;
for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_count) {
int start_index = thread_id * TILE_NUM;
int real_cal_num = (output_count - start_index) < TILE_NUM ? (output_count - start_index) : TILE_NUM;
int start_index = thread_id * C12NUM;
int real_cal_num = (output_count - start_index) < C12NUM ? (output_count - start_index) : C12NUM;
Conv3x3Fp32InputTransform(input_data + in_batch_offset, tile_buffer + task_id * tile_buffer_offset,
block_unit_buffer + task_id * block_unit_buffer_offset, start_index, real_cal_num,
out_w_block, conv_param);
gemm_func(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, tile_buffer + task_id * tile_buffer_offset,
transed_weight, NULL, input_unit_square, ic4, oc4 * C4NUM,
oc4 * C4NUM * input_unit_square * sizeof(float), 1, 1, 0, 0);
Conv3x3Fp32OutputTransform(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, nc4hw4_out + nc4hw4_buffer_offset,
bias_data, start_index, real_cal_num, out_w_block, conv_param);
float *src_ptr = tile_buffer + task_id * tile_buffer_offset;
float *tmp_col_ptr = col_buffer + task_id * col_buffer_offset;
float *dst_ptr = tmp_dst_buffer + task_id * tmp_dst_buffer_offset;
for (int i = 0; i < input_unit_square; ++i) {
RowMajor2Col12Major(src_ptr + i * C12NUM * ic4 * C4NUM, tmp_col_ptr, C12NUM, ic4 * C4NUM);
MatMulOpt(tmp_col_ptr, transed_weight + i * ic4 * C4NUM * oc8 * C8NUM, dst_ptr + i * C8NUM, NULL, 0,
ic4 * C4NUM, real_cal_num, oc8 * C8NUM, input_unit_square, 2);
}
Conv3x3Fp32OutputTransform(dst_ptr, nc4hw4_out + nc4hw4_buffer_offset, bias_data, start_index, real_cal_num,
out_w_block, conv_param);
}
}
}

View File

@ -24,7 +24,6 @@
#include "nnacl/op_base.h"
#include "nnacl/common_func.h"
#include "nnacl/conv_parameter.h"
#include "nnacl/fp32/strassen_matmul.h"
#include "nnacl/winograd_utils.h"
#include "nnacl/fp32/conv_depthwise.h"
@ -52,10 +51,6 @@ void ConvFp32(float *input_data, float *packed_input, float *packed_weight, cons
float *tmp_out_block, float *output_data, int task_id, ConvParameter *conv_param,
GEMM_FUNC_FP32 gemm_func);
// fp32 conv1x1 strassen matmul
int Conv1x1Fp32(const float *input_data, const float *weight_data, float *output_data, float *tmp_ptr,
StrassenMatMulParameter matmul_param);
// fp32 convolution winograd
void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_data, TmpBufferAddress *buffer_list,
int task_id, ConvParameter *conv_param, InputTransformUnitFunc input_trans_func,
@ -70,8 +65,8 @@ void UnPackWinogradRelu6Output(const float *src, float *dst, int batch, int heig
int output_unit);
// fp32 conv3x3
void Conv3x3Fp32(float *input_data, float *transed_weight, const float *bias_data, float *output_data,
TmpBufferAddress *buffer_list, int task_id, ConvParameter *conv_param, GEMM_FUNC_FP32 gemm_func);
void Conv3x3Fp32(float *input_data, float *transed_weight, const float *bias_data, TmpBufferAddress *buffer_list,
int task_id, ConvParameter *conv_param, GEMM_FUNC_FP32 gemm_func);
#ifdef __cplusplus
}
#endif

View File

@ -38,13 +38,15 @@ void ConvDw(float *output_data, const float *input_data, const float *weight_dat
int h_step = UP_DIV(conv_param->output_h_, conv_param->thread_num_);
int h_start = h_step * task_id;
int h_end = MSMIN(h_start + h_step, conv_param->output_h_);
bool relu = conv_param->act_type_ == ActType_Relu;
bool relu6 = conv_param->act_type_ == ActType_Relu6;
for (int b = 0; b < conv_param->output_batch_; b++) {
const float *src = input_data + b * conv_param->input_h_ * conv_param->input_w_ * conv_param->input_channel_;
float *dst = output_data + b * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_channel_;
for (int oh = h_start; oh < h_end; oh++) {
float *dst_data = dst + oh * conv_param->output_w_ * conv_param->output_channel_;
int ih_origin = oh * conv_param->stride_h_ - conv_param->pad_h_;
int ih_origin = oh * conv_param->stride_h_ - conv_param->pad_u_;
int start_kh = MSMAX(0, UP_DIV(-ih_origin, conv_param->dilation_h_));
int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih_origin, conv_param->dilation_h_));
@ -60,13 +62,13 @@ void ConvDw(float *output_data, const float *input_data, const float *weight_dat
int in_sw_step = conv_param->stride_w_ * conv_param->input_channel_;
for (int kw = 0; kw < conv_param->kernel_w_; kw++) {
int out_w_start = MSMAX(
0, (conv_param->pad_w_ - conv_param->dilation_w_ * kw + conv_param->stride_w_ - 1) / conv_param->stride_w_);
int out_w_end = MSMIN(conv_param->output_w_, (conv_param->input_w_ + conv_param->pad_w_ -
0, (conv_param->pad_l_ - conv_param->dilation_w_ * kw + conv_param->stride_w_ - 1) / conv_param->stride_w_);
int out_w_end = MSMIN(conv_param->output_w_, (conv_param->input_w_ + conv_param->pad_l_ -
conv_param->dilation_w_ * kw + conv_param->stride_w_ - 1) /
conv_param->stride_w_);
float *dst_w = dst_data + out_w_start * conv_param->output_channel_;
int iw_origin = (out_w_start * conv_param->stride_w_) - conv_param->pad_w_ + conv_param->dilation_w_ * kw;
int iw_origin = (out_w_start * conv_param->stride_w_) - conv_param->pad_l_ + conv_param->dilation_w_ * kw;
const float *src_kw = src_kh + iw_origin * conv_param->input_channel_;
int num_pixels = out_w_end - out_w_start;
@ -75,10 +77,10 @@ void ConvDw(float *output_data, const float *input_data, const float *weight_dat
weight_kh += conv_param->output_channel_;
}
}
if (conv_param->is_relu_) {
if (relu) {
ReluFp32(dst_data, dst_data, conv_param->output_w_ * conv_param->output_channel_);
}
if (conv_param->is_relu6_) {
if (relu6) {
Relu6Fp32(dst_data, dst_data, conv_param->output_w_ * conv_param->output_channel_);
}
}
@ -91,16 +93,16 @@ void InitSlidingParam(SlidingWindowParam *sliding, const ConvParameter *conv_par
int top = 0;
int bottom = conv_param->output_h_;
for (; left * conv_param->stride_w_ < conv_param->pad_w_; left++) {
for (; left * conv_param->stride_w_ < conv_param->pad_l_; left++) {
}
for (; (right - 1) * conv_param->stride_w_ - conv_param->pad_w_ + conv_param->kernel_w_ * conv_param->dilation_w_ >
for (; (right - 1) * conv_param->stride_w_ - conv_param->pad_l_ + conv_param->kernel_w_ * conv_param->dilation_w_ >
conv_param->input_w_ &&
right > left;
right--) {
}
for (; top * conv_param->stride_h_ < conv_param->pad_h_; top++) {
for (; top * conv_param->stride_h_ < conv_param->pad_u_; top++) {
}
for (; (bottom - 1) * conv_param->stride_h_ - conv_param->pad_h_ + conv_param->kernel_h_ * conv_param->dilation_h_ >
for (; (bottom - 1) * conv_param->stride_h_ - conv_param->pad_u_ + conv_param->kernel_h_ * conv_param->dilation_h_ >
conv_param->input_h_ &&
bottom > top;
bottom--) {
@ -181,16 +183,18 @@ void DepthwiseBorderPixel(float *dst, const float *src, const float *weight, con
void DepthwiseBorder(float *dst, const float *src, const float *weight, const float *bias, int top, int bottom,
int left, int right, const ConvParameter *conv_param, const SlidingWindowParam *sliding) {
bool relu = conv_param->act_type_ == ActType_Relu;
bool relu6 = conv_param->act_type_ == ActType_Relu6;
float *dst_h = dst + top * sliding->out_h_step_;
for (int oh = top; oh < bottom; oh++) {
int ih = oh * conv_param->stride_h_ - conv_param->pad_h_;
int ih = oh * conv_param->stride_h_ - conv_param->pad_u_;
int start_kh = MSMAX(0, UP_DIV(-ih, conv_param->dilation_h_));
int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih, conv_param->dilation_h_));
const float *src_h = src + ih * sliding->in_h_step_;
float *dst_kernel = dst_h + left * sliding->block_channel_;
for (int ow = left; ow < right; ow++) {
int iw = ow * conv_param->stride_w_ - conv_param->pad_w_;
int iw = ow * conv_param->stride_w_ - conv_param->pad_l_;
int start_kw = MSMAX(0, UP_DIV(-iw, conv_param->dilation_w_));
int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->input_w_ - iw, conv_param->dilation_w_));
const float *src_w = src_h + iw * sliding->block_channel_;
@ -201,11 +205,10 @@ void DepthwiseBorder(float *dst, const float *src, const float *weight, const fl
#ifdef ENABLE_ARM64
ConvDwFp32Border(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw,
sliding->in_kh_step_ * sizeof(float), sliding->in_kw_step_ * sizeof(float),
conv_param->kernel_w_ * C4NUM * sizeof(float), conv_param->is_relu_, conv_param->is_relu6_);
conv_param->kernel_w_ * C4NUM * sizeof(float), relu, relu6);
#else
DepthwiseBorderPixel(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw,
sliding->in_kh_step_, sliding->in_kw_step_, conv_param->kernel_w_ * C4NUM,
conv_param->is_relu_, conv_param->is_relu6_);
sliding->in_kh_step_, sliding->in_kw_step_, conv_param->kernel_w_ * C4NUM, relu, relu6);
#endif
dst_kernel += sliding->block_channel_;
} // width loop
@ -259,6 +262,8 @@ void DepthwiseCenter(float *dst, const float *src, const float *weight, const fl
// conv depthwise fp32: sliding window
void ConvDwC4Fp32(float *output_data, const float *input_data, const float *weight_data, const float *bias_data,
const ConvParameter *conv_param, const SlidingWindowParam *sliding, int task_id) {
bool relu = conv_param->act_type_ == ActType_Relu;
bool relu6 = conv_param->act_type_ == ActType_Relu6;
const float *src = input_data;
float *dst = output_data;
for (int b = 0; b < conv_param->output_batch_; b++) {
@ -277,8 +282,8 @@ void ConvDwC4Fp32(float *output_data, const float *input_data, const float *weig
conv_param->output_w_, conv_param, sliding);
if (sliding->right_ > sliding->left_ && sliding->bottom_ > sliding->top_) {
int in_h_start = sliding->top_ * conv_param->stride_h_ - conv_param->pad_h_;
int in_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_w_;
int in_h_start = sliding->top_ * conv_param->stride_h_ - conv_param->pad_u_;
int in_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_l_;
const float *in_t = src_data + in_h_start * sliding->in_h_step_ + in_w_start * sliding->block_channel_;
float *out_t = dst_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_;
#ifdef ENABLE_ARM64
@ -286,12 +291,12 @@ void ConvDwC4Fp32(float *output_data, const float *input_data, const float *weig
conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_ * sizeof(float),
sliding->block_channel_ * sizeof(float), sliding->in_sh_step_ * sizeof(float),
sliding->in_sw_step_ * sizeof(float), sliding->in_kh_step_ * sizeof(float),
sliding->in_kw_step_ * sizeof(float), conv_param->is_relu_, conv_param->is_relu6_);
sliding->in_kw_step_ * sizeof(float), relu, relu6);
#else
DepthwiseCenter(out_t, in_t, weight, bias, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_,
conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_, sliding->block_channel_,
sliding->in_sh_step_, sliding->in_sw_step_, sliding->in_kh_step_, sliding->in_kw_step_,
conv_param->is_relu_, conv_param->is_relu6_);
sliding->in_sh_step_, sliding->in_sw_step_, sliding->in_kh_step_, sliding->in_kw_step_, relu,
relu6);
#endif
}
} // output C4 loop
@ -302,399 +307,6 @@ void ConvDwC4Fp32(float *output_data, const float *input_data, const float *weig
}
/*conv depthwise fp32 end*/
/*conv depthwise 3x3 fp32 begin*/
void ConvDw3x3Fp32FilterTrans(float *trans_weight, float *weight, int oc4) {
for (int c = 0; c < oc4; c++) {
float *src = weight + c * C4NUM * 9;
float *dst = trans_weight + c * C4NUM * 16;
#ifdef ENABLE_ARM
float32x4_t g00 = vld1q_f32(src);
float32x4_t g01 = vld1q_f32(src + 4);
float32x4_t g02 = vld1q_f32(src + 2 * 4);
float32x4_t g10 = vld1q_f32(src + 3 * 4);
float32x4_t g11 = vld1q_f32(src + 4 * 4);
float32x4_t g12 = vld1q_f32(src + 5 * 4);
float32x4_t g20 = vld1q_f32(src + 6 * 4);
float32x4_t g21 = vld1q_f32(src + 7 * 4);
float32x4_t g22 = vld1q_f32(src + 8 * 4);
float32x4_t dst00 = g00;
float32x4_t dst01 = g01;
float32x4_t dst02 = g02;
float32x4_t dst10 = vaddq_f32(vmulq_n_f32(g00, 0.5), vmulq_n_f32(g10, 0.5));
dst10 = vaddq_f32(dst10, vmulq_n_f32(g20, 0.5));
float32x4_t dst11 = vaddq_f32(vmulq_n_f32(g01, 0.5), vmulq_n_f32(g11, 0.5));
dst11 = vaddq_f32(dst11, vmulq_n_f32(g21, 0.5));
float32x4_t dst12 = vaddq_f32(vmulq_n_f32(g02, 0.5), vmulq_n_f32(g12, 0.5));
dst12 = vaddq_f32(dst12, vmulq_n_f32(g22, 0.5));
float32x4_t dst20 = vsubq_f32(vmulq_n_f32(g00, 0.5), vmulq_n_f32(g10, 0.5));
dst20 = vaddq_f32(dst20, vmulq_n_f32(g20, 0.5));
float32x4_t dst21 = vsubq_f32(vmulq_n_f32(g01, 0.5), vmulq_n_f32(g11, 0.5));
dst21 = vaddq_f32(dst21, vmulq_n_f32(g21, 0.5));
float32x4_t dst22 = vsubq_f32(vmulq_n_f32(g02, 0.5), vmulq_n_f32(g12, 0.5));
dst22 = vaddq_f32(dst22, vmulq_n_f32(g22, 0.5));
float32x4_t dst30 = g20;
float32x4_t dst31 = g21;
float32x4_t dst32 = g22;
float32x4_t m00 = dst00;
float32x4_t m01 = vaddq_f32(vmulq_n_f32(dst00, 0.5), vmulq_n_f32(dst01, 0.5));
m01 = vaddq_f32(m01, vmulq_n_f32(dst02, 0.5));
float32x4_t m02 = vsubq_f32(vmulq_n_f32(dst00, 0.5), vmulq_n_f32(dst01, 0.5));
m02 = vaddq_f32(m02, vmulq_n_f32(dst02, 0.5));
float32x4_t m03 = dst02;
float32x4_t m10 = dst10;
float32x4_t m11 = vaddq_f32(vmulq_n_f32(dst10, 0.5), vmulq_n_f32(dst11, 0.5));
m11 = vaddq_f32(m11, vmulq_n_f32(dst12, 0.5));
float32x4_t m12 = vsubq_f32(vmulq_n_f32(dst10, 0.5), vmulq_n_f32(dst11, 0.5));
m12 = vaddq_f32(m12, vmulq_n_f32(dst12, 0.5));
float32x4_t m13 = dst12;
float32x4_t m20 = dst20;
float32x4_t m21 = vaddq_f32(vmulq_n_f32(dst20, 0.5), vmulq_n_f32(dst21, 0.5));
m21 = vaddq_f32(m21, vmulq_n_f32(dst22, 0.5));
float32x4_t m22 = vsubq_f32(vmulq_n_f32(dst20, 0.5), vmulq_n_f32(dst21, 0.5));
m22 = vaddq_f32(m22, vmulq_n_f32(dst22, 0.5));
float32x4_t m23 = dst22;
float32x4_t m30 = dst30;
float32x4_t m31 = vaddq_f32(vmulq_n_f32(dst30, 0.5), vmulq_n_f32(dst31, 0.5));
m31 = vaddq_f32(m31, vmulq_n_f32(dst32, 0.5));
float32x4_t m32 = vsubq_f32(vmulq_n_f32(dst30, 0.5), vmulq_n_f32(dst31, 0.5));
m32 = vaddq_f32(m32, vmulq_n_f32(dst32, 0.5));
float32x4_t m33 = dst32;
vst1q_f32(dst, m00);
vst1q_f32(dst + 4, m01);
vst1q_f32(dst + 8, m02);
vst1q_f32(dst + 12, m03);
vst1q_f32(dst + 16, m10);
vst1q_f32(dst + 20, m11);
vst1q_f32(dst + 24, m12);
vst1q_f32(dst + 28, m13);
vst1q_f32(dst + 32, m20);
vst1q_f32(dst + 36, m21);
vst1q_f32(dst + 40, m22);
vst1q_f32(dst + 44, m23);
vst1q_f32(dst + 48, m30);
vst1q_f32(dst + 52, m31);
vst1q_f32(dst + 56, m32);
vst1q_f32(dst + 60, m33);
#else
for (int j = 0; j < C4NUM; j++) {
float *local_ptr = src + j;
float dst00 = local_ptr[0];
float dst01 = (local_ptr + 4)[0];
float dst02 = (local_ptr + 8)[0];
const float dst10 = 0.5f * local_ptr[0] + 0.5f * (local_ptr + 12)[0] + 0.5f * (local_ptr + 24)[0];
const float dst11 = 0.5f * (local_ptr + 4)[0] + 0.5f * (local_ptr + 16)[0] + 0.5f * (local_ptr + 28)[0];
const float dst12 = 0.5f * (local_ptr + 8)[0] + 0.5f * (local_ptr + 20)[0] + 0.5f * (local_ptr + 32)[0];
const float dst20 = 0.5f * local_ptr[0] - 0.5f * (local_ptr + 12)[0] + 0.5f * (local_ptr + 24)[0];
const float dst21 = 0.5f * (local_ptr + 4)[0] - 0.5f * (local_ptr + 16)[0] + 0.5f * (local_ptr + 28)[0];
const float dst22 = 0.5f * (local_ptr + 8)[0] - 0.5f * (local_ptr + 20)[0] + 0.5f * (local_ptr + 32)[0];
float dst30 = (local_ptr + 24)[0];
float dst31 = (local_ptr + 28)[0];
float dst32 = (local_ptr + 32)[0];
float m00 = dst00;
const float m01 = 0.5f * dst00 + 0.5f * dst01 + 0.5f * dst02;
const float m02 = 0.5f * dst00 - 0.5f * dst01 + 0.5f * dst02;
float m03 = dst02;
float m10 = dst10;
const float m11 = 0.5f * dst10 + 0.5f * dst11 + 0.5f * dst12;
const float m12 = 0.5f * dst10 - 0.5f * dst11 + 0.5f * dst12;
float m13 = dst12;
float m20 = dst20;
const float m21 = 0.5f * dst20 + 0.5f * dst21 + 0.5f * dst22;
const float m22 = 0.5f * dst20 - 0.5f * dst21 + 0.5f * dst22;
float m23 = dst22;
float m30 = dst30;
const float m31 = 0.5f * dst30 + 0.5f * dst31 + 0.5f * dst32;
const float m32 = 0.5f * dst30 - 0.5f * dst31 + 0.5f * dst32;
float m33 = dst32;
*(dst + j) = m00;
*(dst + j + 4) = m01;
*(dst + j + 8) = m02;
*(dst + j + 12) = m03;
*(dst + j + 16) = m10;
*(dst + j + 20) = m11;
*(dst + j + 24) = m12;
*(dst + j + 28) = m13;
*(dst + j + 32) = m20;
*(dst + j + 36) = m21;
*(dst + j + 40) = m22;
*(dst + j + 44) = m23;
*(dst + j + 48) = m30;
*(dst + j + 52) = m31;
*(dst + j + 56) = m32;
*(dst + j + 60) = m33;
}
#endif
}
}
void ConvDw3x3Fp32InputTrans(const float *input_data, float *trans_input, float *block_buffer, int out_h_block,
int out_w_block, const ConvParameter *conv_param) {
int ic4 = UP_DIV(conv_param->input_channel_, C4NUM);
const int input_unit = 4;
memset(trans_input, 0, out_h_block * out_h_block * 16 * C4NUM * sizeof(float));
for (int oh = 0; oh < out_h_block; oh++) {
int ih = oh * 2 - conv_param->pad_h_;
int real_h_start = ih > 0 ? 0 : -ih;
int real_h_end = (ih + input_unit) < conv_param->input_h_ ? input_unit : (conv_param->input_h_ - ih);
for (int ow = 0; ow < out_w_block; ow++) {
int iw = ow * 2 - conv_param->pad_w_;
int real_w_start = iw > 0 ? 0 : -iw;
int real_w_end = (iw + input_unit) < conv_param->input_w_ ? input_unit : (conv_param->input_w_ - iw);
memset(block_buffer, 0, 16 * C4NUM * sizeof(float));
int src_plane_offset = ic4 * C4NUM * (ih * conv_param->input_w_ + iw);
for (int h = real_h_start; h < real_h_end; h++) {
int src_h_offset = src_plane_offset + (h * conv_param->input_w_) * ic4 * C4NUM;
int dst_h_offset = (h * input_unit) * C4NUM;
for (int w = real_w_start; w < real_w_end; w++) {
int src_w_offset = src_h_offset + w * ic4 * C4NUM;
int dst_w_offset = dst_h_offset + w * C4NUM;
float *src_addr = (float *)(input_data) + src_w_offset;
float *dst_addr = block_buffer + dst_w_offset;
#ifdef ENABLE_NEON
vst1q_f32(dst_addr, vld1q_f32(src_addr));
#else
for (int k = 0; k < C4NUM; k++) {
(dst_addr + k)[0] = (src_addr + k)[0];
}
#endif
}
}
int trans_offset = (oh * out_w_block + ow) * 16 * C4NUM;
Conv3x3Fp32InputUnit(block_buffer, trans_input + trans_offset, C4NUM);
}
}
}
void ConvDw3x3Fp32Winograd(float *trans_buffer, const float *weight, int out_h_block, int out_w_block) {
const int unit = 4;
for (int oh = 0; oh < out_h_block; oh++) {
float *buf_oh = trans_buffer + oh * out_w_block * 16 * C4NUM;
for (int ow = 0; ow < out_w_block; ow++) {
float *buf_ow = buf_oh + ow * 16 * C4NUM;
for (int kh = 0; kh < unit; kh++) {
float *buf_kh = buf_ow + kh * unit * C4NUM;
const float *weight_kh = weight + kh * unit * C4NUM;
for (int kw = 0; kw < unit; kw++) {
float *buf_kw = buf_kh + kw * C4NUM;
const float *weight_kw = weight_kh + kw * C4NUM;
for (int c = 0; c < C4NUM; c++) {
buf_kw[c] = buf_kw[c] * weight_kw[c];
}
}
}
}
}
}
void ConvDw3x3Fp32OutputUnit(float *src_buf, float *dst_output, const float *bias, int channel, int output_w,
bool h_in_range, bool w_in_range, bool is_relu, bool is_relu6) {
#ifdef ENABLE_ARM
float32x4_t bias_ptr = vld1q_f32(bias);
float32x4_t s00 = vld1q_f32(src_buf);
float32x4_t s01 = vld1q_f32(src_buf + 4);
float32x4_t s02 = vld1q_f32(src_buf + 8);
float32x4_t s03 = vld1q_f32(src_buf + 12);
float32x4_t s10 = vld1q_f32(src_buf + 16);
float32x4_t s11 = vld1q_f32(src_buf + 20);
float32x4_t s12 = vld1q_f32(src_buf + 24);
float32x4_t s13 = vld1q_f32(src_buf + 28);
float32x4_t s20 = vld1q_f32(src_buf + 32);
float32x4_t s21 = vld1q_f32(src_buf + 36);
float32x4_t s22 = vld1q_f32(src_buf + 40);
float32x4_t s23 = vld1q_f32(src_buf + 44);
float32x4_t s30 = vld1q_f32(src_buf + 48);
float32x4_t s31 = vld1q_f32(src_buf + 52);
float32x4_t s32 = vld1q_f32(src_buf + 56);
float32x4_t s33 = vld1q_f32(src_buf + 60);
float32x4_t t00 = vaddq_f32(vaddq_f32(s00, s10), s20);
float32x4_t t01 = vaddq_f32(vaddq_f32(s01, s11), s21);
float32x4_t t02 = vaddq_f32(vaddq_f32(s02, s12), s22);
float32x4_t t03 = vaddq_f32(vaddq_f32(s03, s13), s23);
float32x4_t t10 = vsubq_f32(vsubq_f32(s10, s20), s30);
float32x4_t t11 = vsubq_f32(vsubq_f32(s11, s21), s31);
float32x4_t t12 = vsubq_f32(vsubq_f32(s12, s22), s32);
float32x4_t t13 = vsubq_f32(vsubq_f32(s13, s23), s33);
float32x4_t d00 = vaddq_f32(vaddq_f32(vaddq_f32(t00, t01), t02), bias_ptr);
float32x4_t d01 = vaddq_f32(vsubq_f32(vsubq_f32(t01, t02), t03), bias_ptr);
float32x4_t d10 = vaddq_f32(vaddq_f32(vaddq_f32(t10, t11), t12), bias_ptr);
float32x4_t d11 = vaddq_f32(vsubq_f32(vsubq_f32(t11, t12), t13), bias_ptr);
float32x4_t zeros = {0, 0, 0, 0};
float32x4_t bounds = {6, 6, 6, 6};
if (is_relu) {
d00 = vmaxq_f32(d00, zeros);
d01 = vmaxq_f32(d01, zeros);
d10 = vmaxq_f32(d10, zeros);
d11 = vmaxq_f32(d11, zeros);
}
if (is_relu6) {
d00 = vminq_f32(vmaxq_f32(d00, zeros), bounds);
d01 = vminq_f32(vmaxq_f32(d01, zeros), bounds);
d10 = vminq_f32(vmaxq_f32(d10, zeros), bounds);
d11 = vminq_f32(vmaxq_f32(d11, zeros), bounds);
}
vst1q_f32(dst_output, d00);
if (w_in_range) {
vst1q_f32(dst_output + channel, d01);
}
if (h_in_range) {
vst1q_f32(dst_output + output_w * channel, d10);
if (w_in_range) {
vst1q_f32(dst_output + output_w * channel + channel, d11);
}
}
#else
for (int i = 0; i < C4NUM; i++) {
const float *local_ptr = src_buf + i;
const float *bias_ptr = bias + i;
float s00 = local_ptr[0];
float s01 = (local_ptr + 4)[0];
float s02 = (local_ptr + 8)[0];
float s03 = (local_ptr + 12)[0];
float s10 = (local_ptr + 16)[0];
float s11 = (local_ptr + 20)[0];
float s12 = (local_ptr + 24)[0];
float s13 = (local_ptr + 28)[0];
float s20 = (local_ptr + 32)[0];
float s21 = (local_ptr + 36)[0];
float s22 = (local_ptr + 40)[0];
float s23 = (local_ptr + 44)[0];
float s30 = (local_ptr + 48)[0];
float s31 = (local_ptr + 52)[0];
float s32 = (local_ptr + 56)[0];
float s33 = (local_ptr + 60)[0];
float t00 = s00 + s10 + s20;
float t01 = s01 + s11 + s21;
float t02 = s02 + s12 + s22;
float t03 = s03 + s13 + s23;
float t10 = s10 - s20 - s30;
float t11 = s11 - s21 - s31;
float t12 = s12 - s22 - s32;
float t13 = s13 - s23 - s33;
float d00 = t00 + t01 + t02 + bias_ptr[0];
float d01 = t01 - t02 - t03 + bias_ptr[0];
float d10 = t10 + t11 + t12 + bias_ptr[0];
float d11 = t11 - t12 - t13 + bias_ptr[0];
if (is_relu) {
d00 = MSMAX(d00, 0);
d01 = MSMAX(d01, 0);
d10 = MSMAX(d10, 0);
d11 = MSMAX(d11, 0);
}
if (is_relu6) {
d00 = MSMIN(MSMAX(d00, 0), 6);
d01 = MSMIN(MSMAX(d01, 0), 6);
d10 = MSMIN(MSMAX(d10, 0), 6);
d11 = MSMIN(MSMAX(d11, 0), 6);
}
(dst_output + i)[0] = d00;
if (w_in_range) {
(dst_output + i + channel)[0] = d01;
}
if (h_in_range) {
(dst_output + i + output_w * channel)[0] = d10;
if (w_in_range) {
(dst_output + i + output_w * channel + channel)[0] = d11;
}
}
}
#endif
}
void ConvDw3x3Fp32OutputTrans(float *trans_buffer, float *output_data, const float *bias, int out_h_block,
int out_w_block, const ConvParameter *conv_param) {
int oc4 = UP_DIV(conv_param->output_channel_, C4NUM);
bool h_in_range = true;
for (int oh = 0; oh < out_h_block; oh++) {
const int real_oh = 2 * oh;
if ((oh + 1) * 2 > conv_param->output_h_) {
h_in_range = false;
}
bool w_in_range = true;
float *buf_oh = trans_buffer + oh * out_w_block * 16 * C4NUM;
float *output_oh = output_data + real_oh * conv_param->output_w_ * oc4 * C4NUM;
for (int ow = 0; ow < out_w_block; ow++) {
const int real_ow = 2 * ow;
if ((ow + 1) * 2 > conv_param->output_w_) {
w_in_range = false;
}
float *buf_ow = buf_oh + ow * 16 * C4NUM;
float *output_ow = output_oh + real_ow * oc4 * C4NUM;
ConvDw3x3Fp32OutputUnit(buf_ow, output_ow, bias, oc4 * C4NUM, conv_param->output_w_, h_in_range, w_in_range,
conv_param->is_relu_, conv_param->is_relu6_);
}
}
}
void ConvDw3x3Fp32(float *output_data, const float *input_data, const float *weight_data, const float *bias_data,
float *trans_buffer, float *block_buffer, const ConvParameter *conv_param, int task_id) {
int thread_count = conv_param->thread_num_;
int output_channel = conv_param->output_channel_;
int oc4 = UP_DIV(output_channel, C4NUM);
int out_h_block = UP_DIV(conv_param->output_h_, 2);
int out_w_block = UP_DIV(conv_param->output_w_, 2);
int input_batch = conv_param->input_batch_;
for (int batch = 0; batch < input_batch; batch++) {
const float *input = input_data + batch * conv_param->input_h_ * conv_param->input_w_ *
UP_DIV(conv_param->input_channel_, C4NUM) * C4NUM;
float *output = output_data + batch * conv_param->output_h_ * conv_param->output_w_ *
UP_DIV(conv_param->output_channel_, C4NUM) * C4NUM;
for (int oc = task_id; oc < oc4; oc += thread_count) {
const float *weight = weight_data + oc * 16 * C4NUM;
const float *bias = bias_data + oc * C4NUM;
ConvDw3x3Fp32InputTrans(input + oc * C4NUM, trans_buffer, block_buffer, out_h_block, out_w_block, conv_param);
ConvDw3x3Fp32Winograd(trans_buffer, weight, out_h_block, out_w_block);
ConvDw3x3Fp32OutputTrans(trans_buffer, output + oc * C4NUM, bias, out_h_block, out_w_block, conv_param);
}
}
}
/*conv depthwise 3x3 fp32 end*/
/*deconv depthwise fp32 begin*/
void DeconvDepthwiseBorderPixel(float *dst, const float *src, const float *weight, int height, int width,
int in_kh_step, int in_kw_step, int kernel_w_step) {
@ -727,14 +339,14 @@ void DeconvDepthwiseBorder(float *dst, const float *src, const float *weight, in
const ConvParameter *conv_param, const SlidingWindowParam *sliding) {
const float *src_h = src + top * sliding->out_h_step_;
for (int ih = top; ih < bottom; ih++) {
int oh = ih * conv_param->stride_h_ - conv_param->pad_h_;
int oh = ih * conv_param->stride_h_ - conv_param->pad_u_;
int start_kh = MSMAX(0, UP_DIV(-oh, conv_param->dilation_h_));
int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->output_h_ - oh, conv_param->dilation_h_));
float *dst_h = dst + oh * sliding->in_h_step_;
const float *src_kernel = src_h + left * sliding->block_channel_;
for (int iw = left; iw < right; iw++) {
int ow = iw * conv_param->stride_w_ - conv_param->pad_w_;
int ow = iw * conv_param->stride_w_ - conv_param->pad_l_;
int start_kw = MSMAX(0, UP_DIV(-ow, conv_param->dilation_w_));
int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->output_w_ - ow, conv_param->dilation_w_));
float *dst_w = dst_h + ow * sliding->block_channel_;
@ -790,12 +402,14 @@ void DeconvDepthwiseCenter(float *dst, const float *src, const float *weight, in
#endif
void DeconvDepthwisePostFunc(float *dst, const float *bias, int block_channel, const ConvParameter *conv_param) {
bool relu = conv_param->act_type_ == ActType_Relu;
bool relu6 = conv_param->act_type_ == ActType_Relu6;
float *dst_k = dst;
for (int k = 0; k < conv_param->output_h_ * conv_param->output_w_; k++) {
for (int c = 0; c < C4NUM; c++) {
dst_k[c] += bias[c];
dst_k[c] = (conv_param->is_relu_) ? (MSMAX(0, dst_k[c])) : (dst_k[c]);
dst_k[c] = (conv_param->is_relu6_) ? (MSMIN(6, MSMAX(0, dst_k[c]))) : (dst_k[c]);
dst_k[c] = (relu) ? (MSMAX(0, dst_k[c])) : (dst_k[c]);
dst_k[c] = (relu6) ? (MSMIN(6, MSMAX(0, dst_k[c]))) : (dst_k[c]);
}
dst_k += block_channel;
}
@ -821,8 +435,8 @@ void DeconvDwC4Fp32(float *output_data, const float *input_data, const float *we
conv_param->input_w_, conv_param, sliding);
if (sliding->right_ > sliding->left_ && sliding->bottom_ > sliding->top_) {
int oh_h_start = sliding->top_ * conv_param->stride_h_ - conv_param->pad_h_;
int oh_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_w_;
int oh_h_start = sliding->top_ * conv_param->stride_h_ - conv_param->pad_u_;
int oh_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_l_;
float *out_t = dst_data + oh_h_start * sliding->in_h_step_ + oh_w_start * sliding->block_channel_;
const float *in_t = src_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_;

View File

@ -48,11 +48,6 @@ void DepthwiseBorder(float *dst, const float *src, const float *weight, const fl
void ConvDwC4Fp32(float *output_data, const float *input_data, const float *weight_data, const float *bias_data,
const ConvParameter *conv_param, const SlidingWindowParam *sliding, int task_id);
void ConvDw3x3Fp32FilterTrans(float *trans_weight, float *weight, int oc4);
void ConvDw3x3Fp32(float *output_data, const float *input_data, const float *weight_data, const float *bias_data,
float *trans_buffer, float *block_buffer, const ConvParameter *conv_param, int task_id);
void DeconvDwC4Fp32(float *output_data, const float *input_data, const float *weight_data, const float *bias_data,
const ConvParameter *conv_param, const SlidingWindowParam *sliding, int task_id);

View File

@ -33,18 +33,18 @@ void PackDeConvWeightFp32(const float *weight, float *dst, int input_channel, in
return;
}
int DeConvPostFp32C8x8(const float *src, float *tmp, const float *bias, float *dst, int output_channel,
ConvParameter *conv_param) {
/* row8x8-major(ih*iw x oc*kh*kw) -> row8-major(oh*ow x oc) */
int DeConvPostFp32C12x8(const float *src, float *tmp, const float *bias, float *dst, int output_channel,
ConvParameter *conv_param) {
/* row12x8-major(ih*iw x oc*kh*kw) -> row8-major(oh*ow x oc) */
size_t input_plane = conv_param->input_w_ * conv_param->input_h_;
size_t kernel_plane = conv_param->kernel_w_ * conv_param->kernel_h_;
size_t output_plane = conv_param->output_w_ * conv_param->output_h_;
int oc8 = UP_ROUND(output_channel, C8NUM);
int in_plane8 = UP_ROUND(input_plane, C8NUM);
int in_plane12 = UP_ROUND(input_plane, C12NUM);
int src_iw_stride = C8NUM;
int src_ih_stride = conv_param->input_w_ * C8NUM;
int src_kw_stride = in_plane8 * C8NUM;
int src_kh_stride = in_plane8 * conv_param->kernel_w_ * C8NUM;
int src_kw_stride = in_plane12 * C8NUM;
int src_kh_stride = in_plane12 * conv_param->kernel_w_ * C8NUM;
int dst_oh_stride = conv_param->output_w_ * C8NUM;
int dst_ow_stride = C8NUM;
int dst_kh_stride = conv_param->dilation_h_ * conv_param->output_w_ * C8NUM;
@ -52,13 +52,13 @@ int DeConvPostFp32C8x8(const float *src, float *tmp, const float *bias, float *d
for (int c = 0; c < oc8; c += 8) {
float *dst_ptr = tmp + c * output_plane;
const float *src_ptr = src + c * in_plane8 * kernel_plane;
const float *src_ptr = src + c * in_plane12 * kernel_plane;
memset(dst_ptr, 0, output_plane * C8NUM * sizeof(float));
for (int ih = 0; ih < conv_param->input_h_; ih++) {
for (int iw = 0; iw < conv_param->input_w_; iw++) {
int oh = ih * conv_param->stride_h_ - conv_param->pad_h_;
int ow = iw * conv_param->stride_w_ - conv_param->pad_w_;
int oh = ih * conv_param->stride_h_ - conv_param->pad_u_;
int ow = iw * conv_param->stride_w_ - conv_param->pad_l_;
int kh_start = MSMAX(0, UP_DIV(-oh, conv_param->dilation_h_));
int kh_end = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->output_h_ - oh, conv_param->dilation_h_));
@ -97,45 +97,7 @@ int DeConvPostFp32C8x8(const float *src, float *tmp, const float *bias, float *d
} /*ih*/
} /*oc8*/
PostConvFuncFp32C8(tmp, dst, bias, output_channel, output_plane, conv_param->output_channel_, conv_param->is_relu_,
conv_param->is_relu6_);
return NNACL_OK;
}
int DeConvPostFp32C4(const float *src, float *tmp_c4, float *dst, const float *bias, int output_channel,
int input_plane, int kernel_plane, int output_plane, ConvParameter *conv_param) {
int oc4 = UP_DIV(output_channel, C4NUM);
for (int c = 0; c < oc4; c++) {
float *dst_ptr = tmp_c4 + c * output_plane * C4NUM;
const float *src_ptr = src + c * input_plane * kernel_plane * C4NUM;
memset(dst_ptr, 0, output_plane * C4NUM * sizeof(float));
for (int ih = 0; ih < conv_param->input_h_; ih++) {
for (int iw = 0; iw < conv_param->input_w_; iw++) {
int oh = ih * conv_param->stride_h_ - conv_param->pad_h_;
int ow = iw * conv_param->stride_w_ - conv_param->pad_w_;
int kh_start = MSMAX(0, UP_DIV(-oh, conv_param->dilation_h_));
int kh_end = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->output_h_ - oh, conv_param->dilation_h_));
int kw_start = MSMAX(0, UP_DIV(-ow, conv_param->dilation_w_));
int kw_end = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->output_w_ - ow, conv_param->dilation_w_));
for (int kh = kh_start; kh < kh_end; kh++) {
for (int kw = kw_start; kw < kw_end; kw++) {
int src_index = ih * conv_param->input_w_ * C4NUM + iw * C4NUM +
kh * input_plane * conv_param->kernel_w_ * C4NUM + kw * input_plane * C4NUM;
int dst_index = oh * conv_param->output_w_ * C4NUM + ow * C4NUM +
kh * conv_param->dilation_h_ * conv_param->output_w_ * C4NUM +
kw * conv_param->dilation_w_ * C4NUM;
for (int i = 0; i < C4NUM; i++) {
dst_ptr[dst_index + i] += src_ptr[src_index + i];
}
} /*kw*/
} /*kh*/
} /*iw*/
} /*ih*/
} /*oc4*/
PostConvFuncFp32C4(tmp_c4, dst, bias, output_channel, output_plane, conv_param->output_channel_, conv_param->is_relu_,
conv_param->is_relu6_);
PostConvFuncFp32C8(tmp, dst, bias, output_channel, output_plane, conv_param->output_channel_,
conv_param->act_type_ == ActType_Relu, conv_param->act_type_ == ActType_Relu6);
return NNACL_OK;
}

View File

@ -16,20 +16,19 @@
#ifndef MINDSPORE_LITE_NNACL_FP32_DECONV_H_
#define MINDSPORE_LITE_NNACL_FP32_DECONV_H_
#include <string.h>
#include "nnacl/pack.h"
#include "nnacl/op_base.h"
#include "nnacl/conv_parameter.h"
#include "nnacl/fp32/strassen_matmul.h"
#include "nnacl/errorcode.h"
#include "nnacl/fp32/common_func.h"
#ifdef __cplusplus
extern "C" {
#endif
void PackDeConvWeightFp32(const float *weight, float *dst, int input_channel, int output_channel, int plane);
int DeConvPostFp32C4(const float *src, float *tmp_c4, float *dst, const float *bias, int output_channel,
int input_plane, int kernel_plane, int output_plane, ConvParameter *conv_param);
int DeConvPostFp32C8x8(const float *src, float *tmp_out, const float *bias, float *dst, int output_channel,
ConvParameter *conv_param);
int DeConvPostFp32C12x8(const float *src, float *tmp_out, const float *bias, float *dst, int output_channel,
ConvParameter *conv_param);
#ifdef __cplusplus
}
#endif

View File

@ -19,19 +19,13 @@
#include "nnacl/op_base.h"
typedef struct GatherParameter {
OpParameter op_parameter_;
int axis_;
int batchDims_;
} GatherParameter;
#ifdef __cplusplus
extern "C" {
#endif
int Gather(float *input, int outer_size, int inner_size, int limit, int *indices, int indices_element_size,
float *output);
int GatherInt32(const int32_t *input, int outer_size, int inner_size, int limit, int *indices,
int indices_element_size, int32_t *output);
int GatherInt32(const int32_t *input, int outer_size, int inner_size, int limit, int *indices, int indices_element_size,
int32_t *output);
#ifdef __cplusplus
}
#endif

View File

@ -28,6 +28,129 @@ void RowMajor2Row8Major(float *src_ptr, float *dst_ptr, int row, int col) {
return;
}
void RowMajor2Row12Major(float *src_ptr, float *dst_ptr, int row, int col) {
for (int r = 0; r < row; r++) {
float *src = src_ptr + r * col;
for (int c = 0; c < col; c++) {
int cd8 = c / C12NUM;
int cm8 = c % C12NUM;
dst_ptr[cd8 * C12NUM * row + r * C12NUM + cm8] = src[c];
}
}
return;
}
void RowMajor2Col12Major(float *src_ptr, float *dst_ptr, size_t row, size_t col) {
size_t row_up_12 = UP_ROUND(row, C12NUM);
size_t row12 = row / C12NUM * C12NUM;
size_t col4 = col / C4NUM * C4NUM;
float *src_r = src_ptr;
float *dst_r = dst_ptr;
size_t ri = 0;
for (; ri < row12; ri += C12NUM) {
size_t ci = 0;
for (; ci < col4; ci += C4NUM) {
float *src_c = src_r + ci;
float *dst_c = dst_r + ci * C12NUM;
/* 12x4 row-major to col-major */
#ifdef ENABLE_ARM64
size_t stride = col * sizeof(float);
asm volatile(
"mov x10, %[src_c]\n"
"mov x11, %[dst_c]\n"
"ld1 {v0.4s}, [x10], %[stride]\n"
"ld1 {v1.4s}, [x10], %[stride]\n"
"ld1 {v2.4s}, [x10], %[stride]\n"
"ld1 {v3.4s}, [x10], %[stride]\n"
"ld1 {v4.4s}, [x10], %[stride]\n"
"ld1 {v5.4s}, [x10], %[stride]\n"
"ld1 {v6.4s}, [x10], %[stride]\n"
"ld1 {v7.4s}, [x10], %[stride]\n"
"zip1 v12.4s, v0.4s, v1.4s\n"
"zip2 v13.4s, v0.4s, v1.4s\n"
"zip1 v14.4s, v2.4s, v3.4s\n"
"zip2 v15.4s, v2.4s, v3.4s\n"
"ld1 {v8.4s}, [x10], %[stride]\n"
"ld1 {v9.4s}, [x10], %[stride]\n"
"ld1 {v10.4s}, [x10], %[stride]\n"
"ld1 {v11.4s}, [x10], %[stride]\n"
"zip1 v16.4s, v4.4s, v5.4s\n"
"zip2 v17.4s, v4.4s, v5.4s\n"
"zip1 v18.4s, v6.4s, v7.4s\n"
"zip2 v19.4s, v6.4s, v7.4s\n"
"trn1 v20.2d, v12.2d, v14.2d\n"
"trn2 v23.2d, v12.2d, v14.2d\n"
"trn1 v26.2d, v13.2d, v15.2d\n"
"trn2 v29.2d, v13.2d, v15.2d\n"
"trn1 v21.2d, v16.2d, v18.2d\n"
"trn2 v24.2d, v16.2d, v18.2d\n"
"trn1 v27.2d, v17.2d, v19.2d\n"
"trn2 v30.2d, v17.2d, v19.2d\n"
"zip1 v12.4s, v8.4s, v9.4s\n"
"zip2 v13.4s, v8.4s, v9.4s\n"
"zip1 v14.4s, v10.4s, v11.4s\n"
"zip2 v15.4s, v10.4s, v11.4s\n"
"trn1 v22.2d, v12.2d, v14.2d\n"
"trn2 v25.2d, v12.2d, v14.2d\n"
"trn1 v28.2d, v13.2d, v15.2d\n"
"trn2 v31.2d, v13.2d, v15.2d\n"
"st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x11], #64\n"
"st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x11], #64\n"
"st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x11], #64\n"
:
: [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride)
: "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14",
"v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29",
"v30", "v31");
#else
for (int tr = 0; tr < C12NUM; tr++) {
for (int tc = 0; tc < C4NUM; tc++) {
dst_c[tc * C12NUM + tr] = src_c[tr * col + tc];
}
}
#endif
}
for (; ci < col; ci++) {
float *src_c = src_r + ci;
float *dst_c = dst_r + ci * C12NUM;
for (size_t i = 0; i < C12NUM; i++) {
dst_c[i] = src_c[i * col];
}
}
src_r += C12NUM * col;
dst_r += C12NUM * col;
}
for (; ri < row; ri++) {
for (size_t i = 0; i < col; i++) {
dst_r[i * C12NUM] = src_r[i];
}
src_r += col;
dst_r += 1;
}
for (; ri < row_up_12; ri++) {
for (size_t i = 0; i < col; i++) {
dst_r[i * C12NUM] = 0;
}
dst_r += 1;
}
return;
}
void RowMajor2Col8Major(float *src_ptr, float *dst_ptr, size_t row, size_t col) {
size_t row8 = row / C8NUM * C8NUM;
size_t col4 = col / C4NUM * C4NUM;
@ -221,18 +344,18 @@ void Row8x8Major2RowMajor(float *src_ptr, float *dst_ptr, size_t row, size_t col
return;
}
void MatMul8x8(const float *a, const float *b, float *dst, const float *bias, ActType act_type, int deep, int row,
int col, int stride, bool write_nhwc) {
if (write_nhwc) {
void MatMul12x8(const float *a, const float *b, float *dst, const float *bias, ActType act_type, int deep, int row,
int col, int stride, int out_type) {
if (out_type == OutType_Nhwc) {
/* col8-major * row8-major => col-major */
for (int r = 0; r < row; r++) {
for (int c = 0; c < col; c++) {
int r8div = r / 8, r8mod = r % 8;
int r12div = r / 12, r12mod = r % 12;
int c8div = c / 8, c8mod = c % 8;
size_t ci = r * stride + c;
float value = 0;
for (int d = 0; d < deep; d++) {
size_t ai = r8div * deep * 8 + d * 8 + r8mod;
size_t ai = r12div * deep * 12 + d * 12 + r12mod;
size_t bi = c8div * deep * 8 + d * 8 + c8mod;
value = value + a[ai] * b[bi];
}
@ -242,22 +365,41 @@ void MatMul8x8(const float *a, const float *b, float *dst, const float *bias, Ac
dst[ci] = value;
}
}
} else if (out_type == OutType_C8) {
/* col8-major * row8-major => col12x8-major */
int col_8 = UP_ROUND(col, C8NUM);
int row_12 = UP_ROUND(row, C12NUM);
for (int r = 0; r < row_12; r++) {
for (int c = 0; c < col_8; c++) {
int r12div = r / C12NUM, r12mod = r % C12NUM;
int c8div = c / C8NUM, c8mod = c % C8NUM;
size_t ci = (c8div * C8NUM * row_12 + r * C8NUM + c8mod);
float value = 0;
for (int d = 0; d < deep; d++) {
size_t ai = r12div * deep * C12NUM + d * C12NUM + r12mod;
size_t bi = c8div * deep * C8NUM + d * C8NUM + c8mod;
value = value + a[ai] * b[bi];
}
if (bias != NULL) value += bias[c];
if (act_type == ActType_Relu6) value = MSMIN(6.0f, value);
if (act_type != ActType_No) value = MSMAX(0.0f, value);
dst[ci] = value;
}
}
} else {
/* col8-major * row8-major => col8x8-major */
int col_8 = UP_ROUND(col, C8NUM);
int row_8 = UP_ROUND(row, C8NUM);
for (int r = 0; r < row_8; r++) {
for (int c = 0; c < col_8; c++) {
int r8div = r / 8, r8mod = r % 8;
int c8div = c / 8, c8mod = c % 8;
size_t ci = c8div * row_8 * 8 + r * 8 + c8mod;
for (int i = 0; i < row; ++i) {
int src_r_offset = i;
int dst_r_offset = i * col * stride;
for (int j = 0; j < col; ++j) {
int c8div = j / 8, c8mod = j % 8;
size_t ci = dst_r_offset + c8div * 8 * stride + c8mod;
float value = 0;
for (int d = 0; d < deep; d++) {
size_t ai = r8div * deep * 8 + d * 8 + r8mod;
for (int d = 0; d < deep; ++d) {
size_t ai = src_r_offset + d * C12NUM;
size_t bi = c8div * deep * 8 + d * 8 + c8mod;
value = value + a[ai] * b[bi];
}
if (bias != NULL) value += bias[c];
if (bias != NULL) value += bias[j];
if (act_type == ActType_Relu6) value = MSMIN(6.0f, value);
if (act_type != ActType_No) value = MSMAX(0.0f, value);
dst[ci] = value;
@ -267,11 +409,16 @@ void MatMul8x8(const float *a, const float *b, float *dst, const float *bias, Ac
return;
}
void MatMul(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row, int col,
int stride, bool write_nhwc) {
void MatMulOpt(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row,
int col, size_t stride, int out_type) {
#ifdef ENABLE_ARM64
MatmulFloatNeon64(a, b, c, bias, (int)act_type, deep, row, col, stride, write_nhwc);
if (out_type == 2 && row <= 8) {
MatmulFloatNeon64OptRemain(a, b, c, deep, row, col, stride);
} else {
MatmulFloatNeon64Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type == OutType_Nhwc),
(int)(out_type == OutType_TileC8));
}
#else
MatMul8x8(a, b, c, bias, act_type, deep, row, col, stride, write_nhwc);
MatMul12x8(a, b, c, bias, act_type, deep, row, col, stride, out_type);
#endif
}

View File

@ -26,14 +26,20 @@
#ifdef __cplusplus
extern "C" {
#endif
void MatMul(const float *a, const float *b, float *c, const float *bias, ActType act_type, int depth, int row, int col,
int stride, bool write_nhwc);
void MatMulOpt(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row,
int col, size_t stride, int out_type);
void RowMajor2Row8Major(float *src_ptr, float *dst_ptr, int row, int col);
void RowMajor2Row12Major(float *src_ptr, float *dst_ptr, int row, int col);
void RowMajor2Col8Major(float *src_ptr, float *dst_ptr, size_t row, size_t col);
void RowMajor2Col12Major(float *src_ptr, float *dst_ptr, size_t row, size_t col);
void Row8x8Major2RowMajor(float *src_ptr, float *dst_ptr, size_t row, size_t col, size_t stride);
#ifdef ENABLE_ARM64
void MatmulFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row,
int col, size_t stride, bool write_nhwc);
void MatmulFloatNeon64Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row,
int col, size_t stride, size_t write_nhwc, size_t write_c4);
void MatmulFloatNeon64OptRemain(const float *a, const float *b, float *c, int depth, int row, int col, size_t stride);
#endif
#ifdef __cplusplus
}

View File

@ -130,8 +130,103 @@ void MaxPooling(const float *input_ptr, float *output_ptr, PoolingParameter *poo
int out_plane = output_w * output_h;
int out_tile_count = UP_DIV(out_plane, TILE_NUM);
int thread_num = pooling_param->thread_num_;
int c4 = UP_DIV(channel, C4NUM); /* oc && ic */
for (int batch = 0; batch < output_batch; batch++) {
const float *src_b_ptr = input_ptr + batch * in_h * in_w * channel;
float *dst_b_ptr = output_ptr + batch * output_h * output_w * channel;
for (int thread_id = task_id; thread_id < out_tile_count; thread_id += thread_num) {
int cal_start_index = thread_id * TILE_NUM;
int real_cal_num = (out_plane - cal_start_index) > TILE_NUM ? TILE_NUM : (out_plane - cal_start_index);
for (int i = 0; i < real_cal_num; i++) {
int index = cal_start_index + i;
int out_w_index = index % output_w;
int out_h_index = index / output_w;
int in_w_index = out_w_index * stride_w - pad_w;
int in_h_index = out_h_index * stride_h - pad_h;
const float *src_plane_ptr = src_b_ptr;
float *dst_plane_ptr = dst_b_ptr + index * channel;
int real_win_h_start = MSMAX(0, -in_h_index);
int real_win_h_end = MSMIN(win_h, in_h - in_h_index);
int resl_win_w_start = MSMAX(0, -in_w_index);
int real_win_w_end = MSMIN(win_w, in_w - in_w_index);
for (int ci = 0; ci < c4 - 1; ci++) {
const float *src_c_ptr = src_plane_ptr + ci * C4NUM;
float *dst_c_ptr = dst_plane_ptr + ci * C4NUM;
#ifdef ENABLE_NEON
float32x4_t tmp_max = vdupq_n_f32(-FLT_MAX);
#else
float tmp_max1 = -FLT_MAX;
float tmp_max2 = -FLT_MAX;
float tmp_max3 = -FLT_MAX;
float tmp_max4 = -FLT_MAX;
#endif
for (int kh = real_win_h_start; kh < real_win_h_end; kh++) {
for (int kw = resl_win_w_start; kw < real_win_w_end; kw++) {
const float *src_win_ptr = src_c_ptr + ((in_h_index + kh) * in_w + in_w_index + kw) * channel;
#ifdef ENABLE_NEON
tmp_max = vmaxq_f32(tmp_max, vld1q_f32(src_win_ptr));
#else
tmp_max1 = fmax(tmp_max1, src_win_ptr[0]);
tmp_max2 = fmax(tmp_max2, src_win_ptr[1]);
tmp_max3 = fmax(tmp_max3, src_win_ptr[2]);
tmp_max4 = fmax(tmp_max4, src_win_ptr[3]);
#endif
} // win_w loop
} // win_h loop
#ifdef ENABLE_NEON
vst1q_f32(dst_c_ptr, tmp_max);
#else
dst_c_ptr[0] = tmp_max1;
dst_c_ptr[1] = tmp_max2;
dst_c_ptr[2] = tmp_max3;
dst_c_ptr[3] = tmp_max4;
#endif
} // ic4-1 loop
int channel_s = (c4 - 1) * C4NUM;
for (int ci = channel_s; ci < channel; ci++) {
float *dst_c_ptr = dst_plane_ptr + ci;
const float *src_c_ptr = src_plane_ptr + ci;
float tmp_max = -FLT_MAX;
for (int kh = real_win_h_start; kh < real_win_h_end; kh++) {
for (int kw = resl_win_w_start; kw < real_win_w_end; kw++) {
const float *src_win_ptr = src_c_ptr + ((in_h_index + kh) * in_w + in_w_index + kw) * channel;
tmp_max = fmax(tmp_max, src_win_ptr[0]);
} // win_w loop
} // win_h loop
dst_c_ptr[0] = tmp_max;
} // channel_res loop
} // real_cal_num loop
} // out_plane loop
} // out_batch loop
}
void AvgPoolingRelu(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param, int task_id) {
int stride_w = pooling_param->stride_w_;
int stride_h = pooling_param->stride_h_;
int pad_w = pooling_param->pad_l_;
int pad_h = pooling_param->pad_u_;
int win_w = pooling_param->window_w_;
int win_h = pooling_param->window_h_;
int channel = pooling_param->input_channel_;
int c4 = UP_DIV(channel, C4NUM);
// input channel is equal to output channel
int in_w = pooling_param->input_w_;
int in_h = pooling_param->input_h_;
int output_w = pooling_param->output_w_;
int output_h = pooling_param->output_h_;
int output_batch = pooling_param->output_batch_;
int out_plane = output_w * output_h;
int out_tile_count = UP_DIV(out_plane, TILE_NUM);
int thread_num = pooling_param->thread_num_;
#ifdef ENABLE_NEON
float32x4_t zeros = vdupq_n_f32(0);
#endif
for (int batch = 0; batch < output_batch; batch++) {
int in_batch_offset = batch * in_h * in_w * channel;
@ -149,6 +244,121 @@ void MaxPooling(const float *input_ptr, float *output_ptr, PoolingParameter *poo
for (int j = 0; j < c4 - 1; j++) {
int in_channel_offset = in_batch_offset + j * C4NUM;
int out_channel_offset = out_plane_offset + j * C4NUM;
#ifdef ENABLE_NEON
float32x4_t tmp_avg = vdupq_n_f32(0);
#else
float tmp_avg1 = 0;
float tmp_avg2 = 0;
float tmp_avg3 = 0;
float tmp_avg4 = 0;
#endif
int real_count = 0;
for (int h = 0; h < win_h; h++) {
for (int w = 0; w < win_w; w++) {
if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 ||
(in_w_index + w) >= in_w) {
continue;
} else {
int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel;
#ifdef ENABLE_NEON
tmp_avg = vaddq_f32(tmp_avg, vld1q_f32(input_ptr + in_offset));
#else
tmp_avg1 += *(input_ptr + in_offset);
tmp_avg2 += *(input_ptr + in_offset + 1);
tmp_avg3 += *(input_ptr + in_offset + 2);
tmp_avg4 += *(input_ptr + in_offset + 3);
#endif
++real_count;
}
} // win_w loop
} // win_h loop
#ifdef ENABLE_NEON
tmp_avg = vmaxq_f32(tmp_avg, zeros);
vst1q_f32(output_ptr + out_channel_offset, tmp_avg / vdupq_n_f32(real_count));
#else
tmp_avg1 = fmax(tmp_avg1, 0);
tmp_avg2 = fmax(tmp_avg2, 0);
tmp_avg3 = fmax(tmp_avg3, 0);
tmp_avg4 = fmax(tmp_avg4, 0);
*(output_ptr + out_channel_offset) = tmp_avg1 / (float)real_count;
*(output_ptr + out_channel_offset + 1) = tmp_avg2 / (float)real_count;
*(output_ptr + out_channel_offset + 2) = tmp_avg3 / (float)real_count;
*(output_ptr + out_channel_offset + 3) = tmp_avg4 / (float)real_count;
#endif
} // ic4-1 loop
int channel_s = (c4 - 1) * C4NUM;
for (int k = channel_s; k < channel; k++) {
int in_channel_offset = in_batch_offset + k;
int out_channel_offset = out_plane_offset + k;
float tmp_avg = 0;
int real_count = 0;
for (int h = 0; h < win_h; h++) {
for (int w = 0; w < win_w; w++) {
if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 ||
(in_w_index + w) >= in_w) {
continue;
} else {
int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel;
tmp_avg += *(input_ptr + in_offset);
++real_count;
}
} // win_w loop
} // win_h loop
tmp_avg = fmax(tmp_avg, 0);
*(output_ptr + out_channel_offset) = tmp_avg / (float)real_count;
} // channel_res loop
} // real_cal_num loop
} // out_plane loop
} // out_batch loop
}
void MaxPoolingRelu(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param, int task_id) {
int stride_w = pooling_param->stride_w_;
int stride_h = pooling_param->stride_h_;
int pad_w = pooling_param->pad_l_;
int pad_h = pooling_param->pad_u_;
int win_w = pooling_param->window_w_;
int win_h = pooling_param->window_h_;
int channel = pooling_param->input_channel_;
int in_w = pooling_param->input_w_;
int in_h = pooling_param->input_h_;
int output_w = pooling_param->output_w_;
int output_h = pooling_param->output_h_;
int output_batch = pooling_param->output_batch_;
int out_plane = output_w * output_h;
int out_tile_count = UP_DIV(out_plane, TILE_NUM);
int thread_num = pooling_param->thread_num_;
int c4 = UP_DIV(channel, C4NUM);
#ifdef ENABLE_NEON
float32x4_t zeros = vdupq_n_f32(0);
#endif
for (int batch = 0; batch < output_batch; batch++) {
const float *src_b_ptr = input_ptr + batch * in_h * in_w * channel;
float *dst_b_ptr = output_ptr + batch * output_h * output_w * channel;
for (int thread_id = task_id; thread_id < out_tile_count; thread_id += thread_num) {
int cal_start_index = thread_id * TILE_NUM;
int real_cal_num = (out_plane - cal_start_index) > TILE_NUM ? TILE_NUM : (out_plane - cal_start_index);
for (int i = 0; i < real_cal_num; i++) {
int index = cal_start_index + i;
int out_w_index = index % output_w;
int out_h_index = index / output_w;
int in_w_index = out_w_index * stride_w - pad_w;
int in_h_index = out_h_index * stride_h - pad_h;
const float *src_plane_ptr = src_b_ptr;
float *dst_plane_ptr = dst_b_ptr + index * channel;
int real_win_h_start = MSMAX(0, -in_h_index);
int real_win_h_end = MSMIN(win_h, in_h - in_h_index);
int resl_win_w_start = MSMAX(0, -in_w_index);
int real_win_w_end = MSMIN(win_w, in_w - in_w_index);
for (int ci = 0; ci < c4 - 1; ci++) {
const float *src_c_ptr = src_plane_ptr + ci * C4NUM;
float *dst_c_ptr = dst_plane_ptr + ci * C4NUM;
#ifdef ENABLE_NEON
float32x4_t tmp_max = vdupq_n_f32(-FLT_MAX);
#else
@ -157,6 +367,105 @@ void MaxPooling(const float *input_ptr, float *output_ptr, PoolingParameter *poo
float tmp_max3 = -FLT_MAX;
float tmp_max4 = -FLT_MAX;
#endif
for (int kh = real_win_h_start; kh < real_win_h_end; kh++) {
for (int kw = resl_win_w_start; kw < real_win_w_end; kw++) {
const float *src_win_ptr = src_c_ptr + ((in_h_index + kh) * in_w + in_w_index + kw) * channel;
#ifdef ENABLE_NEON
tmp_max = vmaxq_f32(tmp_max, vld1q_f32(src_win_ptr));
#else
tmp_max1 = fmax(tmp_max1, src_win_ptr[0]);
tmp_max2 = fmax(tmp_max2, src_win_ptr[1]);
tmp_max3 = fmax(tmp_max3, src_win_ptr[2]);
tmp_max4 = fmax(tmp_max4, src_win_ptr[3]);
#endif
} // win_w loop
} // win_h loop
#ifdef ENABLE_NEON
tmp_max = vmaxq_f32(tmp_max, zeros);
vst1q_f32(dst_c_ptr, tmp_max);
#else
// relu:
tmp_max1 = fmax(tmp_max1, 0);
tmp_max2 = fmax(tmp_max2, 0);
tmp_max3 = fmax(tmp_max3, 0);
tmp_max4 = fmax(tmp_max4, 0);
dst_c_ptr[0] = tmp_max1;
dst_c_ptr[1] = tmp_max2;
dst_c_ptr[2] = tmp_max3;
dst_c_ptr[3] = tmp_max4;
#endif
} // ic4-1 loop
int channel_s = (c4 - 1) * C4NUM;
for (int ci = channel_s; ci < channel; ci++) {
float *dst_c_ptr = dst_plane_ptr + ci;
const float *src_c_ptr = src_plane_ptr + ci;
float tmp_max = -FLT_MAX;
for (int kh = real_win_h_start; kh < real_win_h_end; kh++) {
for (int kw = resl_win_w_start; kw < real_win_w_end; kw++) {
const float *src_win_ptr = src_c_ptr + ((in_h_index + kh) * in_w + in_w_index + kw) * channel;
tmp_max = fmax(tmp_max, src_win_ptr[0]);
} // win_w loop
} // win_h loop
dst_c_ptr[0] = tmp_max;
} // channel_res loop
} // real_cal_num loop
} // out_plane loop
} // out_batch loop
}
void AvgPoolingRelu6(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param, int task_id) {
int stride_w = pooling_param->stride_w_;
int stride_h = pooling_param->stride_h_;
int pad_w = pooling_param->pad_l_;
int pad_h = pooling_param->pad_u_;
int win_w = pooling_param->window_w_;
int win_h = pooling_param->window_h_;
int channel = pooling_param->input_channel_;
int c4 = UP_DIV(channel, C4NUM);
int in_w = pooling_param->input_w_;
int in_h = pooling_param->input_h_;
int output_w = pooling_param->output_w_;
int output_h = pooling_param->output_h_;
int output_batch = pooling_param->output_batch_;
int out_plane = output_w * output_h;
int out_tile_count = UP_DIV(out_plane, TILE_NUM);
int thread_num = pooling_param->thread_num_;
// input channel is equal to output channel
#ifdef ENABLE_NEON
float32x4_t zeros = vdupq_n_f32(0);
float32x4_t bounds = vdupq_n_f32(6);
#endif
for (int batch = 0; batch < output_batch; batch++) {
int in_batch_offset = batch * in_h * in_w * channel;
int out_batch_offset = batch * output_h * output_w * channel;
for (int thread_id = task_id; thread_id < out_tile_count; thread_id += thread_num) {
int cal_start_index = thread_id * TILE_NUM;
int real_cal_num = (out_plane - cal_start_index) > TILE_NUM ? TILE_NUM : (out_plane - cal_start_index);
for (int i = 0; i < real_cal_num; i++) {
int index = cal_start_index + i;
int out_w_index = index % output_w;
int out_h_index = index / output_w;
int in_w_index = out_w_index * stride_w - pad_w;
int in_h_index = out_h_index * stride_h - pad_h;
int out_plane_offset = out_batch_offset + index * channel;
for (int j = 0; j < c4 - 1; j++) {
int in_channel_offset = in_batch_offset + j * C4NUM;
int out_channel_offset = out_plane_offset + j * C4NUM;
#ifdef ENABLE_NEON
float32x4_t tmp_avg = vdupq_n_f32(0);
#else
float tmp_avg1 = 0;
float tmp_avg2 = 0;
float tmp_avg3 = 0;
float tmp_avg4 = 0;
#endif
int real_count = 0;
for (int h = 0; h < win_h; h++) {
for (int w = 0; w < win_w; w++) {
if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 ||
@ -165,30 +474,48 @@ void MaxPooling(const float *input_ptr, float *output_ptr, PoolingParameter *poo
} else {
int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel;
#ifdef ENABLE_NEON
tmp_max = vmaxq_f32(tmp_max, vld1q_f32(input_ptr + in_offset));
tmp_avg = vaddq_f32(tmp_avg, vld1q_f32(input_ptr + in_offset));
#else
tmp_max1 = fmax(tmp_max1, *(input_ptr + in_offset));
tmp_max2 = fmax(tmp_max2, *(input_ptr + in_offset + 1));
tmp_max3 = fmax(tmp_max3, *(input_ptr + in_offset + 2));
tmp_max4 = fmax(tmp_max4, *(input_ptr + in_offset + 3));
tmp_avg1 += *(input_ptr + in_offset);
tmp_avg2 += *(input_ptr + in_offset + 1);
tmp_avg3 += *(input_ptr + in_offset + 2);
tmp_avg4 += *(input_ptr + in_offset + 3);
#endif
++real_count;
}
} // win_w loop
} // win_h loop
#ifdef ENABLE_NEON
vst1q_f32(output_ptr + out_channel_offset, tmp_max);
tmp_avg = tmp_avg / vdupq_n_f32(real_count);
tmp_avg = vmaxq_f32(tmp_avg, zeros);
tmp_avg = vminq_f32(tmp_avg, bounds);
vst1q_f32(output_ptr + out_channel_offset, tmp_avg);
#else
*(output_ptr + out_channel_offset) = tmp_max1;
*(output_ptr + out_channel_offset + 1) = tmp_max2;
*(output_ptr + out_channel_offset + 2) = tmp_max3;
*(output_ptr + out_channel_offset + 3) = tmp_max4;
tmp_avg1 /= (float)real_count;
tmp_avg2 /= (float)real_count;
tmp_avg3 /= (float)real_count;
tmp_avg4 /= (float)real_count;
tmp_avg1 = fmax(tmp_avg1, 0);
tmp_avg2 = fmax(tmp_avg2, 0);
tmp_avg3 = fmax(tmp_avg3, 0);
tmp_avg4 = fmax(tmp_avg4, 0);
tmp_avg1 = fmin(tmp_avg1, 6);
tmp_avg2 = fmin(tmp_avg2, 6);
tmp_avg3 = fmin(tmp_avg3, 6);
tmp_avg4 = fmin(tmp_avg4, 6);
*(output_ptr + out_channel_offset) = tmp_avg1;
*(output_ptr + out_channel_offset + 1) = tmp_avg2;
*(output_ptr + out_channel_offset + 2) = tmp_avg3;
*(output_ptr + out_channel_offset + 3) = tmp_avg4;
#endif
} // ic4-1 loop
int channel_s = (c4 - 1) * C4NUM;
for (int k = channel_s; k < channel; k++) {
int in_channel_offset = in_batch_offset + k;
int out_channel_offset = out_plane_offset + k;
float tmp_max = -FLT_MAX;
float tmp_avg = 0;
int real_count = 0;
for (int h = 0; h < win_h; h++) {
for (int w = 0; w < win_w; w++) {
if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 ||
@ -196,11 +523,125 @@ void MaxPooling(const float *input_ptr, float *output_ptr, PoolingParameter *poo
continue;
} else {
int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel;
tmp_max = fmax(tmp_max, *(input_ptr + in_offset));
tmp_avg += *(input_ptr + in_offset);
++real_count;
}
} // win_w loop
} // win_h loop
*(output_ptr + out_channel_offset) = tmp_max;
tmp_avg /= (float)real_count;
tmp_avg = fmax(tmp_avg, 0);
tmp_avg = fmin(tmp_avg, 6);
*(output_ptr + out_channel_offset) = tmp_avg;
} // channel_res loop
} // real_cal_num loop
} // out_plane loop
} // out_batch loop
}
void MaxPoolingRelu6(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param, int task_id) {
int stride_w = pooling_param->stride_w_;
int stride_h = pooling_param->stride_h_;
int pad_w = pooling_param->pad_l_;
int pad_h = pooling_param->pad_u_;
int win_w = pooling_param->window_w_;
int win_h = pooling_param->window_h_;
int channel = pooling_param->input_channel_;
int in_w = pooling_param->input_w_;
int in_h = pooling_param->input_h_;
int output_w = pooling_param->output_w_;
int output_h = pooling_param->output_h_;
int output_batch = pooling_param->output_batch_;
int out_plane = output_w * output_h;
int out_tile_count = UP_DIV(out_plane, TILE_NUM);
int thread_num = pooling_param->thread_num_;
int c4 = UP_DIV(channel, C4NUM);
#ifdef ENABLE_NEON
float32x4_t zeros = vdupq_n_f32(0);
float32x4_t bounds = vdupq_n_f32(6);
#endif
for (int batch = 0; batch < output_batch; batch++) {
const float *src_b_ptr = input_ptr + batch * in_h * in_w * channel;
float *dst_b_ptr = output_ptr + batch * output_h * output_w * channel;
for (int thread_id = task_id; thread_id < out_tile_count; thread_id += thread_num) {
int cal_start_index = thread_id * TILE_NUM;
int real_cal_num = (out_plane - cal_start_index) > TILE_NUM ? TILE_NUM : (out_plane - cal_start_index);
for (int i = 0; i < real_cal_num; i++) {
int index = cal_start_index + i;
int out_w_index = index % output_w;
int out_h_index = index / output_w;
int in_w_index = out_w_index * stride_w - pad_w;
int in_h_index = out_h_index * stride_h - pad_h;
const float *src_plane_ptr = src_b_ptr;
float *dst_plane_ptr = dst_b_ptr + index * channel;
int real_win_h_start = MSMAX(0, -in_h_index);
int real_win_h_end = MSMIN(win_h, in_h - in_h_index);
int resl_win_w_start = MSMAX(0, -in_w_index);
int real_win_w_end = MSMIN(win_w, in_w - in_w_index);
for (int ci = 0; ci < c4 - 1; ci++) {
const float *src_c_ptr = src_plane_ptr + ci * C4NUM;
float *dst_c_ptr = dst_plane_ptr + ci * C4NUM;
#ifdef ENABLE_NEON
float32x4_t tmp_max = vdupq_n_f32(-FLT_MAX);
#else
float tmp_max1 = -FLT_MAX;
float tmp_max2 = -FLT_MAX;
float tmp_max3 = -FLT_MAX;
float tmp_max4 = -FLT_MAX;
#endif
for (int kh = real_win_h_start; kh < real_win_h_end; kh++) {
for (int kw = resl_win_w_start; kw < real_win_w_end; kw++) {
const float *src_win_ptr = src_c_ptr + ((in_h_index + kh) * in_w + in_w_index + kw) * channel;
#ifdef ENABLE_NEON
tmp_max = vmaxq_f32(tmp_max, vld1q_f32(src_win_ptr));
#else
tmp_max1 = fmax(tmp_max1, src_win_ptr[0]);
tmp_max2 = fmax(tmp_max2, src_win_ptr[1]);
tmp_max3 = fmax(tmp_max3, src_win_ptr[2]);
tmp_max4 = fmax(tmp_max4, src_win_ptr[3]);
#endif
} // win_w loop
} // win_h loop
#ifdef ENABLE_NEON
tmp_max = vmaxq_f32(tmp_max, zeros);
tmp_max = vminq_f32(tmp_max, bounds);
vst1q_f32(dst_c_ptr, tmp_max);
#else
// relu:
tmp_max1 = fmax(tmp_max1, 0);
tmp_max2 = fmax(tmp_max2, 0);
tmp_max3 = fmax(tmp_max3, 0);
tmp_max4 = fmax(tmp_max4, 0);
tmp_max1 = fmin(tmp_max1, 6);
tmp_max2 = fmin(tmp_max2, 6);
tmp_max3 = fmin(tmp_max3, 6);
tmp_max4 = fmin(tmp_max4, 6);
dst_c_ptr[0] = tmp_max1;
dst_c_ptr[1] = tmp_max2;
dst_c_ptr[2] = tmp_max3;
dst_c_ptr[3] = tmp_max4;
#endif
} // ic4-1 loop
int channel_s = (c4 - 1) * C4NUM;
for (int ci = channel_s; ci < channel; ci++) {
float *dst_c_ptr = dst_plane_ptr + ci;
const float *src_c_ptr = src_plane_ptr + ci;
float tmp_max = -FLT_MAX;
for (int kh = real_win_h_start; kh < real_win_h_end; kh++) {
for (int kw = resl_win_w_start; kw < real_win_w_end; kw++) {
const float *src_win_ptr = src_c_ptr + ((in_h_index + kh) * in_w + in_w_index + kw) * channel;
tmp_max = fmax(tmp_max, src_win_ptr[0]);
} // win_w loop
} // win_h loop
dst_c_ptr[0] = tmp_max;
} // channel_res loop
} // real_cal_num loop
} // out_plane loop

View File

@ -30,6 +30,14 @@ extern "C" {
void AvgPooling(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param, int task_id);
void MaxPooling(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param, int task_id);
void AvgPoolingRelu(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param, int task_id);
void MaxPoolingRelu(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param, int task_id);
void AvgPoolingRelu6(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param, int task_id);
void MaxPoolingRelu6(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param, int task_id);
#ifdef __cplusplus
}
#endif

View File

@ -17,17 +17,15 @@
#include "nnacl/fp32/resize.h"
#include "nnacl/common_func.h"
#include "nnacl/errorcode.h"
int ResizeBilinear(const float *input_data, float *output_data, const int *input_shape, const int *output_shape,
bool align_corners, int tid, int thread_num) {
if (input_data == NULL || output_data == NULL || input_shape == NULL || output_shape == NULL) {
int PrepareResizeBilinear(const int *input_shape, const int *output_shape, bool align_corners, int *y_bottoms,
int *y_tops, int *x_lefts, int *x_rights, float *y_bottom_weights, float *x_left_weights) {
if (input_shape == NULL || output_shape == NULL || y_bottoms == NULL || y_tops == NULL || x_lefts == NULL ||
x_rights == NULL || y_bottom_weights == NULL || x_left_weights == NULL) {
return NNACL_NULL_PTR;
}
int in_n = input_shape[0];
int in_h = input_shape[1];
int in_w = input_shape[2];
int in_c = input_shape[3];
int new_height = output_shape[1];
int new_width = output_shape[2];
@ -40,65 +38,119 @@ int ResizeBilinear(const float *input_data, float *output_data, const int *input
width_scale = (float)(in_w - 1) / (new_width - 1);
}
int n, h, w, c;
for (n = 0; n < in_n; n++) {
for (h = tid; h < new_height; h += thread_num) {
float actual_y = (float)h * height_scale;
int y_bottom = (int)(floor(actual_y));
int y_top = y_bottom + 1 < in_h ? (y_bottom + 1) : (in_h - 1);
float y_top_weight = actual_y - (float)(y_bottom);
const float y_bottom_weight = 1.0f - y_top_weight;
for (w = 0; w < new_width; w++) {
float actual_x = (float)(w)*width_scale;
int x_left = (int)(floor(actual_x));
int x_right = x_left + 1 < in_w ? (x_left + 1) : (in_w - 1);
float x_right_weight = actual_x - (float)(x_left);
const float x_left_weight = 1.0f - x_right_weight;
c = 0;
int h, w;
for (h = 0; h < new_height; h++) {
float actual_y = (float)h * height_scale;
int y_bottom = (int)(floor(actual_y));
int y_top = y_bottom + 1 < in_h ? (y_bottom + 1) : (in_h - 1);
float y_top_weight = actual_y - (float)(y_bottom);
const float y_bottom_weight = 1.0f - y_top_weight;
y_bottoms[h] = y_bottom;
y_tops[h] = y_top;
y_bottom_weights[h] = y_bottom_weight;
}
for (w = 0; w < new_width; w++) {
float actual_x = (float)(w)*width_scale;
int x_left = (int)(floor(actual_x));
int x_right = x_left + 1 < in_w ? (x_left + 1) : (in_w - 1);
float x_right_weight = actual_x - (float)(x_left);
const float x_left_weight = 1.0f - x_right_weight;
x_lefts[w] = x_left;
x_rights[w] = x_right;
x_left_weights[w] = x_left_weight;
}
return NNACL_OK;
}
int ResizeBilinear(const float *input_data, float *output_data, const int *input_shape, const int *output_shape,
int *y_bottoms, int *y_tops, int *x_lefts, int *x_rights, float *y_bottom_weights,
float *x_left_weights, int n_h_begin, int n_h_end) {
if (input_data == NULL || output_data == NULL || input_shape == NULL || output_shape == NULL || y_bottoms == NULL ||
y_tops == NULL || x_lefts == NULL || x_rights == NULL || y_bottom_weights == NULL || x_left_weights == NULL) {
return NNACL_NULL_PTR;
}
int in_w = input_shape[2];
int in_c = input_shape[3];
int new_height = output_shape[1];
int new_width = output_shape[2];
int n_h, n, h, w, c;
n = n_h_begin / new_height;
h = n_h_begin % new_height;
int n_h_stride = new_width * in_c;
int out_offset = n_h_begin * n_h_stride;
for (n_h = n_h_begin; n_h < n_h_end; n_h++, h++) {
if (h == new_height) {
h = 0;
n++;
}
int y_bottom = y_bottoms[h];
int y_top = y_tops[h];
float y_bottom_weight = y_bottom_weights[h];
float y_top_weight = 1.0f - y_bottom_weight;
for (w = 0; w < new_width; w++) {
int x_left = x_lefts[w];
int x_right = x_rights[w];
float x_left_weight = x_left_weights[w];
float x_right_weight = 1.0f - x_left_weight;
float top_left_weight = y_top_weight * x_left_weight;
float top_right_weight = y_top_weight * x_right_weight;
float bottom_left_weight = y_bottom_weight * x_left_weight;
float bottom_right_weight = y_bottom_weight * x_right_weight;
c = 0;
int in_bottom_left_offset = offset(input_shape, n, y_bottom, x_left, c);
int in_bottom_right_offset = in_bottom_left_offset + (x_right - x_left) * in_c;
int in_top_left_offset = in_bottom_left_offset + (y_top - y_bottom) * in_w * in_c;
int in_top_right_offset = in_bottom_right_offset + (y_top - y_bottom) * in_w * in_c;
#ifdef ENABLE_NEON
for (; c <= in_c - 4; c += 4) {
float32x4_t bottom_left = vld1q_f32(input_data + offset(input_shape, n, y_bottom, x_left, c));
float32x4_t bottom_right = vld1q_f32(input_data + offset(input_shape, n, y_bottom, x_right, c));
float32x4_t top_left = vld1q_f32(input_data + offset(input_shape, n, y_top, x_left, c));
float32x4_t top_right = vld1q_f32(input_data + offset(input_shape, n, y_top, x_right, c));
float32x4_t top_left_w = vdupq_n_f32(top_left_weight);
float32x4_t top_right_w = vdupq_n_f32(top_right_weight);
float32x4_t bottom_left_w = vdupq_n_f32(bottom_left_weight);
float32x4_t bottom_right_w = vdupq_n_f32(bottom_right_weight);
float32x4_t y_top_w = vdupq_n_f32(y_top_weight);
float32x4_t y_bottom_w = vdupq_n_f32(y_bottom_weight);
float32x4_t x_left_w = vdupq_n_f32(x_left_weight);
float32x4_t x_right_w = vdupq_n_f32(x_right_weight);
for (; c <= in_c - 4; c += 4) {
float32x4_t bottom_left = vld1q_f32(input_data + in_bottom_left_offset + c);
float32x4_t bottom_right = vld1q_f32(input_data + in_bottom_right_offset + c);
float32x4_t top_left = vld1q_f32(input_data + in_top_left_offset + c);
float32x4_t top_right = vld1q_f32(input_data + in_top_right_offset + c);
float32x4_t interp_value = vdupq_n_f32(0.0);
float32x4_t tmp = vmulq_f32(bottom_left, y_bottom_w);
tmp = vmulq_f32(tmp, x_left_w);
interp_value = vaddq_f32(interp_value, tmp);
float32x4_t interp_value = vdupq_n_f32(0.0);
tmp = vmulq_f32(bottom_right, y_bottom_w);
tmp = vmulq_f32(tmp, x_right_w);
interp_value = vaddq_f32(interp_value, tmp);
float32x4_t tmp = vmulq_f32(bottom_left, bottom_left_w);
interp_value = vaddq_f32(interp_value, tmp);
tmp = vmulq_f32(top_left, y_top_w);
tmp = vmulq_f32(tmp, x_left_w);
interp_value = vaddq_f32(interp_value, tmp);
tmp = vmulq_f32(bottom_right, bottom_right_w);
interp_value = vaddq_f32(interp_value, tmp);
tmp = vmulq_f32(top_right, y_top_w);
tmp = vmulq_f32(tmp, x_right_w);
interp_value = vaddq_f32(interp_value, tmp);
vst1q_f32(output_data + offset(output_shape, n, h, w, c), interp_value);
}
tmp = vmulq_f32(top_left, top_left_w);
interp_value = vaddq_f32(interp_value, tmp);
tmp = vmulq_f32(top_right, top_right_w);
interp_value = vaddq_f32(interp_value, tmp);
vst1q_f32(output_data + out_offset, interp_value);
out_offset += 4;
}
#endif
for (; c < in_c; c++) {
float bottom_left = input_data[offset(input_shape, n, y_bottom, x_left, c)];
float bottom_right = input_data[offset(input_shape, n, y_bottom, x_right, c)];
float top_left = input_data[offset(input_shape, n, y_top, x_left, c)];
float top_right = input_data[offset(input_shape, n, y_top, x_right, c)];
float interp_value = bottom_left * y_bottom_weight * x_left_weight +
bottom_right * y_bottom_weight * x_right_weight +
top_left * y_top_weight * x_left_weight + top_right * y_top_weight * x_right_weight;
output_data[offset(output_shape, n, h, w, c)] = interp_value;
}
for (; c < in_c; c++) {
float bottom_left = input_data[in_bottom_left_offset + c];
float bottom_right = input_data[in_bottom_right_offset + c];
float top_left = input_data[in_top_left_offset + c];
float top_right = input_data[in_top_right_offset + c];
float interp_value = bottom_left * bottom_left_weight + bottom_right * bottom_right_weight +
top_left * top_left_weight + top_right * top_right_weight;
output_data[out_offset] = interp_value;
out_offset++;
}
}
}
return NNACL_OK;
}

View File

@ -25,9 +25,12 @@
#ifdef __cplusplus
extern "C" {
#endif
int ResizeBilinear(const float *input_data, float *output_data, const int *input_shape, const int *output_shape,
bool align_corners, int tid, int thread_num);
int PrepareResizeBilinear(const int *input_shape, const int *output_shape, bool align_corners, int *y_bottoms,
int *y_tops, int *x_lefts, int *x_rights, float *y_bottom_weights, float *x_left_weights);
int ResizeBilinear(const float *input_data, float *output_data, const int *input_shape, const int *output_shape,
int *y_bottoms, int *y_tops, int *x_lefts, int *x_rights, float *y_bottom_weights,
float *x_left_weights, int n_h_begin, int n_h_end);
int ResizeNearestNeighbor(const float *input_data, float *output_data, const int *input_shape, const int *output_shape,
int tid, int thread_num);
#ifdef __cplusplus

View File

@ -0,0 +1,79 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "nnacl/fp32/scale.h"
#ifdef ENABLE_ARM
#include <arm_neon.h>
#endif
void ScaleInner(float *in_data, float *out_data, float *scale, float *offset, int outer_start, int outer_end,
int axis_size, int inner_size) {
for (int out = outer_start; out < outer_end; out++) {
int out_offset = out * axis_size * inner_size;
for (int i = 0; i < axis_size; i++) {
int axis_offset = out_offset + i * inner_size;
int in_index = 0;
#ifdef ENABLE_ARM64
for (; in_index < inner_size - 4; in_index += 4) {
int in_offset = axis_offset + in_index;
float32x4_t data = vld1q_f32(in_data + in_offset);
float32x4_t scale_4 = vdupq_n_f32(scale[i]);
float32x4_t offset_4 = vdupq_n_f32(offset[i]);
float32x4_t reslut = vfmaq_f32(offset_4, data, scale_4);
vst1q_f32(out_data + in_offset, reslut);
}
#endif
for (; in_index < inner_size; in_index++) {
int in_offset = axis_offset + in_index;
out_data[in_offset] = in_data[in_offset] * scale[i] + offset[i];
}
}
}
}
void ScaleAxis(float *in_data, float *out_data, float *scale, float *offset, int outer_start, int outer_end,
int axis_size) {
for (int out = outer_start; out < outer_end; out++) {
int out_offset = out * axis_size;
int index = 0;
#ifdef ENABLE_ARM64
for (; index < axis_size - 4; index += 4) {
int in_offset = out_offset + index;
float32x4_t data = vld1q_f32(in_data + in_offset);
float32x4_t scale_4 = vld1q_f32(scale + index);
float32x4_t offset_4 = vld1q_f32(offset + index);
float32x4_t reslut = vfmaq_f32(offset_4, data, scale_4);
vst1q_f32(out_data + in_offset, reslut);
}
#endif
for (; index < axis_size; index++) {
int in_offset = out_offset + index;
out_data[in_offset] = in_data[in_offset] * scale[index] + offset[index];
}
}
}
void DoScale(float *in_data, float *out_data, float *scale, float *offset, int task_id, ScaleParameter *scale_param) {
int outer_step = UP_DIV(scale_param->outer_size_, scale_param->op_parameter_.thread_num_);
int outer_start = task_id * outer_step;
int outer_end = MSMIN(outer_start + outer_step, scale_param->outer_size_);
if (scale_param->inner_size_ == 1) {
ScaleAxis(in_data, out_data, scale, offset, outer_start, outer_end, scale_param->axis_size_);
} else {
ScaleInner(in_data, out_data, scale, offset, outer_start, outer_end, scale_param->axis_size_,
scale_param->inner_size_);
}
}

View File

@ -0,0 +1,30 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_NNACL_SCALE_FP32_H_
#define MINDSPORE_LITE_NNACL_SCALE_FP32_H_
#include "nnacl/op_base.h"
#include "nnacl/scale.h"
#ifdef __cplusplus
extern "C" {
#endif
void DoScale(float *in_data, float *out_data, float *scale, float *offset, int task_id, ScaleParameter *scale_param);
#ifdef __cplusplus
}
#endif
#endif // MINDSPORE_LITE_NNACL_SCALE_FP32_H_

View File

@ -16,132 +16,79 @@
#include "nnacl/fp32/space_to_batch.h"
#include "nnacl/arithmetic_common.h"
#include "nnacl/errorcode.h"
#include "nnacl/fp32/concat.h"
#include "nnacl/op_base.h"
int EnumElement(int *shape, int n_dims) {
int total = 1;
for (int i = 0; i < n_dims; i++) {
total *= shape[i];
}
return total;
}
void TransposeForNHWC(const float *in_data, float *out_data, int *strides, int *out_strides, int *perm,
int *output_shape, int h_start, int h_end) {
const int stride0 = strides[perm[0]];
const int stride1 = strides[perm[1]];
const int stride2 = strides[perm[2]];
const int stride3 = strides[perm[3]];
const int stride4 = strides[perm[4]];
const int out_stride0 = out_strides[0];
const int out_stride1 = out_strides[1];
const int out_stride2 = out_strides[2];
const int out_stride3 = out_strides[3];
const int out_stride4 = out_strides[4];
const int output0 = output_shape[0];
const int output2 = output_shape[2];
const int output3 = output_shape[3];
const int output4 = output_shape[4];
for (int i = 0; i < output0; ++i) {
int out_stride0_i = i * out_stride0;
int stride0_i = i * stride0;
for (int j = h_start; j < h_end; ++j) {
int out_stride1_j = j * out_stride1;
int stride1_j = j * stride1;
for (int k = 0; k < output2; ++k) {
int out_stride2_k = k * out_stride2;
int stride2_k = k * stride2;
for (int m = 0; m < output3; ++m) {
int out_stride3_m = m * out_stride3;
int stride3_m = m * stride3;
for (int n = 0; n < output4; ++n) {
int out_stride4_n = n * out_stride4;
int stride4_n = n * stride4;
memcpy(out_data + out_stride0_i + out_stride1_j + out_stride2_k + out_stride3_m + out_stride4_n,
in_data + stride0_i + stride1_j + stride2_k + stride3_m + stride4_n, stride4 * sizeof(float));
}
}
void DoSpaceToBatchNHWC(const float *input, float *output, SpaceToBatchParameter *param, int *in_shape,
int *out_shape) {
int out_dim0 = out_shape[0];
int out_dim1 = out_shape[1];
int out_dim2 = out_shape[2];
int copy_num = out_shape[3];
int block_w = param->block_sizes_[1];
int block_h = param->block_sizes_[0];
int in_strides[4];
ComputeStrides(in_shape, in_strides, 4);
int out_strides[4];
ComputeStrides(out_shape, out_strides, 4);
size_t copy_size = copy_num * sizeof(float);
size_t out_offset = 0;
for (int n = 0; n < out_dim0; ++n) {
int in_n = n % in_shape[0];
int32_t stride_w = (n / in_shape[0]) % block_w;
int32_t stride_h = (n / in_shape[0]) / block_w;
size_t in_offset0 = in_n * in_strides[0];
for (int h = 0; h < out_dim1; ++h) {
size_t in_offset1 = in_offset0 + (h * block_h + stride_h) * in_strides[1];
for (int w = 0; w < out_dim2; ++w) {
size_t in_offset2 = in_offset1 + (w * block_w + stride_w) * in_strides[2];
memcpy(output + out_offset, input + in_offset2, copy_size);
out_offset += copy_num;
}
}
}
}
int SpaceToBatchForNHWC(const float *input, float *output, int *in_shape, int shape_size, int *block_sizes, int h_start,
int h_end) {
int trans_in_shape[6] = {in_shape[0], in_shape[1] / block_sizes[0],
block_sizes[0], in_shape[2] / block_sizes[1],
block_sizes[1], in_shape[3]};
int trans_out_shape[6] = {
in_shape[0], block_sizes[0], block_sizes[1], in_shape[1] / block_sizes[0], in_shape[2] / block_sizes[1],
in_shape[3]};
int in_strides[C4NUM + 2];
ComputeStrides(trans_in_shape, in_strides, shape_size + 2);
int out_strides[C4NUM + 2];
ComputeStrides(trans_out_shape, out_strides, shape_size + 2);
int perm[6] = {0, 2, 4, 1, 3, 5};
TransposeForNHWC(input, output, in_strides, out_strides, perm, trans_out_shape, h_start, h_end);
return NNACL_OK;
}
void DoPadding(const float *input, float *padded_input, SpaceToBatchParameter param, float *tmp_space[]) {
float *tmp = padded_input;
(void)memcpy(tmp, input, param.num_elements_ * sizeof(float));
float *target = tmp_space[0];
float *tmp_zeros = tmp_space[1];
float *tmp2 = NULL;
int cur_shape[param.n_dims_], cur_start_shape[param.n_dims_], cur_end_shape[param.n_dims_],
cur_target_shape[param.n_dims_];
float *concat_inputs[3];
int *concat_shapes[4];
for (int i = 0; i < param.n_dims_; i++) {
cur_shape[i] = param.in_shape_[i];
cur_start_shape[i] = param.in_shape_[i];
cur_end_shape[i] = param.in_shape_[i];
cur_target_shape[i] = param.in_shape_[i];
}
for (int i = 0; i < param.n_space_dims_; ++i) {
if (param.padded_in_shape_[i + 1] > param.in_shape_[i + 1]) {
int concat_idx = 0;
cur_target_shape[i + 1] = 0;
if (param.paddings_[2 * i] != 0) {
cur_start_shape[i + 1] = param.paddings_[2 * i];
concat_inputs[concat_idx] = tmp_zeros;
concat_shapes[concat_idx++] = cur_start_shape;
cur_target_shape[i + 1] += cur_start_shape[i + 1];
void DoSpaceToBatchPaddingNHWC(const float *input, float *output, int *in_shape, int *padding, int *out_shape,
const float *pedding_h_data, const float *pedding_w_data) {
int in_h = in_shape[1];
int in_w = in_shape[2];
int in_c = in_shape[3];
int out_w = out_shape[2];
int out_c = out_shape[3];
size_t ped_h_num = out_w * out_c;
size_t ped_h_size = ped_h_num * sizeof(float);
size_t ped_w_size = out_c * sizeof(float);
size_t out_offset = 0;
int in_strides[4];
ComputeStrides(in_shape, in_strides, 4);
int out_strides[4];
ComputeStrides(out_shape, out_strides, 4);
size_t copy_size = in_c * sizeof(float);
for (int i = 0; i < in_shape[0]; ++i) {
size_t in_offset0 = i * in_strides[0];
for (int pad_h_top = 0; pad_h_top < padding[0]; ++pad_h_top) {
memcpy(output + out_offset, pedding_h_data, ped_h_size);
out_offset += ped_h_num;
}
for (int j = 0; j < in_h; ++j) {
size_t in_offset1 = in_offset0 + j * in_strides[1];
for (int pad_w_left = 0; pad_w_left < padding[2]; ++pad_w_left) {
memcpy(output + out_offset, pedding_w_data, ped_w_size);
out_offset += out_c;
}
concat_inputs[concat_idx] = tmp;
concat_shapes[concat_idx++] = cur_shape;
cur_target_shape[i + 1] += cur_shape[i + 1];
if (param.paddings_[2 * i + 1] != 0) {
cur_end_shape[i + 1] = param.paddings_[2 * i + 1];
concat_inputs[concat_idx] = tmp_zeros;
concat_shapes[concat_idx++] = cur_end_shape;
cur_target_shape[i + 1] += cur_end_shape[i + 1];
for (int k = 0; k < in_w; ++k) {
size_t in_offset2 = in_offset1 + k * in_strides[2];
memcpy(output + out_offset, input + in_offset2, copy_size);
out_offset += in_c;
}
concat_shapes[concat_idx] = cur_target_shape;
Concat((void **)concat_inputs, concat_idx, i + 1, concat_shapes, param.n_dims_, target);
tmp2 = tmp;
tmp = target;
target = tmp2;
cur_start_shape[i + 1] = cur_end_shape[i + 1] = cur_shape[i + 1] = concat_shapes[concat_idx][i + 1];
for (int pad_w_right = 0; pad_w_right < padding[3]; ++pad_w_right) {
memcpy(output + out_offset, pedding_w_data, ped_w_size);
out_offset += out_c;
}
}
for (int pad_h_bottom = 0; pad_h_bottom < padding[1]; ++pad_h_bottom) {
memcpy(output + out_offset, pedding_h_data, ped_h_size);
out_offset += ped_h_num;
}
}
if (padded_input != tmp) {
memcpy(padded_input, tmp, param.num_elements_padded_ * sizeof(float));
}
}
int SpaceToBatch(const float *input, float *output, SpaceToBatchParameter param, int h_start, int h_end) {
if (input == NULL || output == NULL) {
return NNACL_NULL_PTR;
}
int ret =
SpaceToBatchForNHWC(input, output, param.padded_in_shape_, param.n_dims_, param.block_sizes_, h_start, h_end);
return ret;
}

View File

@ -22,26 +22,17 @@
typedef struct SpaceToBatchParameter {
OpParameter op_parameter_;
int block_sizes_[8];
int paddings_[8];
int n_dims_;
int num_elements_;
int num_elements_padded_;
int n_space_dims_;
int in_shape_[8];
int padded_in_shape_[8];
bool need_paddings_;
int block_sizes_[4];
int paddings_[4];
} SpaceToBatchParameter;
#ifdef __cplusplus
extern "C" {
#endif
int SpaceToBatch(const float *input, float *output, SpaceToBatchParameter param, int h_start, int h_end);
int SpaceToBatchForNHWC(const float *input, float *output, int *in_shape, int shape_size, int *block_size, int h_start,
int h_end);
void TransposeForNHWC(const float *in_data, float *out_data, int *strides, int *out_strides, int *perm,
int *output_shape, int h_start, int h_end);
void DoPadding(const float *input, float *padded_input, SpaceToBatchParameter param, float *tmp_space[]);
int EnumElement(int *shape, int n_dims);
void DoSpaceToBatchNHWC(const float *input, float *output, SpaceToBatchParameter *param, int *in_shape,
int *out_shape);
void DoSpaceToBatchPaddingNHWC(const float *input, float *output, int *in_shape, int *padding, int *out_shape,
const float *pedding_h_data, const float *pedding_w_data);
#ifdef __cplusplus
}
#endif

View File

@ -1,204 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "nnacl/fp32/strassen_matmul.h"
bool CheckRecursion(int row, int col, int deep, int max_recursion, int cur_recursion) {
if (cur_recursion >= max_recursion) {
return false;
}
if (row % 2 != 0 || col % 2 != 0 || deep % 2 != 0) {
return false;
}
int row2 = row / 2;
int col2 = col / 2;
int deep2 = deep / 2;
float save_cost = row * col * 4 * deep * 4 * 2 + row * col * 4 -
7 * (row2 * col2 * 4 * deep2 * 4 * 2 - row2 * col2 * 4) - 4 * (row2 * deep2 * 4 * 3) -
4 * (deep2 * 4 * col2 * 4 * 3) - 7 * (row2 * col2 * 4 * 3);
return (save_cost > 0.f);
}
void GemmMatMulComm(const float *a_ptr, const float *b_ptr, float *dst_ptr, int row, int col, int deep, int b_stride,
int c_stride) {
int row4mod = row % 4;
int row4div = row / 4;
for (int r = 0; r < row; r++) {
int r4mod = r % 4;
int r4div = r / 4;
for (int c = 0; c < col * 4; c++) {
float value = 0;
int ic = c / 4 * c_stride + r * 4 + c % 4;
for (int d = 0; d < deep * 4; d++) {
int d4mod = d % 4;
int d4div = d / 4;
int a_stride = (r < (row4div * 4)) ? 4 : row4mod;
int ai = r4div * 4 * deep * 4 + d4div * a_stride * 4 + r4mod * 4 + d4mod;
int bi = c / 4 * b_stride + d * 4 + c % 4;
value = value + a_ptr[ai] * b_ptr[bi];
}
dst_ptr[ic] = value;
}
}
return;
}
void GemmMatMul(const float *a_ptr, const float *b_ptr, float *dst_ptr, int row, int col, int deep, int b_stride,
int c_stride) {
int row4mod = row % 4;
int row4div = row / 4;
if (row4div > 0) {
GemmMatMulComm(a_ptr, b_ptr, dst_ptr, row4div * 4, col, deep, b_stride, c_stride);
}
if (row4mod != 0) {
GemmMatMulComm(a_ptr + row4div * deep * 4 * 4, b_ptr, dst_ptr + row4div * 4 * 4, row4mod, col, deep, b_stride,
c_stride);
}
return;
}
int RecursionMatmul(const float *a_ptr, const float *b_ptr, float *c_ptr, StrassenMatMulParameter *matmul_param,
int max_recursion, int cur_recursion, float *tmp_a_ptr) {
size_t row2 = matmul_param->row_ / 2;
size_t deep2 = matmul_param->deep_ / 2;
size_t col2 = matmul_param->col_ / 2;
size_t a_stride = matmul_param->a_stride_;
size_t b_stride = matmul_param->b_stride_;
size_t c_stride = matmul_param->c_stride_;
StrassenMatMulParameter rec_matmul;
rec_matmul.row_ = row2;
rec_matmul.deep_ = deep2;
rec_matmul.col_ = col2;
float *x_ptr = (float *)(malloc(row2 * MSMAX(deep2, col2) * FP32_STRASSEN_UINT * sizeof(float)));
if (x_ptr == NULL) {
return NNACL_ERRCODE_STRASSEN_RECURSION_MALLOC;
}
float *y_ptr = (float *)(malloc(col2 * deep2 * FP32_STRASSEN_WEIGHT_UINT * sizeof(float)));
if (y_ptr == NULL) {
free(x_ptr);
return NNACL_ERRCODE_STRASSEN_RECURSION_MALLOC;
}
size_t x_stride = row2 * FP32_STRASSEN_UINT;
size_t y_stride = deep2 * FP32_STRASSEN_WEIGHT_UINT;
const float *a11 = a_ptr;
const float *a12 = a_ptr + deep2 * a_stride;
const float *a21 = a_ptr + row2 * FP32_STRASSEN_UINT;
const float *a22 = a_ptr + deep2 * a_stride + row2 * FP32_STRASSEN_UINT;
const float *b11 = b_ptr;
const float *b12 = b_ptr + col2 * b_stride;
const float *b21 = b_ptr + deep2 * FP32_STRASSEN_WEIGHT_UINT;
const float *b22 = b_ptr + col2 * b_stride + deep2 * FP32_STRASSEN_WEIGHT_UINT;
float *c11 = c_ptr;
float *c12 = c_ptr + col2 * c_stride;
float *c21 = c_ptr + row2 * FP32_STRASSEN_UINT;
float *c22 = c_ptr + col2 * c_stride + row2 * FP32_STRASSEN_UINT;
/* S3 = A11 - A21 */
MatrixSub(a11, a21, x_ptr, a_stride, a_stride, x_stride, row2, deep2);
/* T3 = B22 - B12 */
MatrixSub(b22, b12, y_ptr, b_stride, b_stride, y_stride, deep2 * 4, col2);
/* P7 = S3T3 */
rec_matmul.a_stride_ = x_stride;
rec_matmul.b_stride_ = y_stride;
rec_matmul.c_stride_ = c_stride;
StrassenMatmul(x_ptr, y_ptr, c21, &rec_matmul, max_recursion, cur_recursion + 1, tmp_a_ptr);
/* S1 = A21 + A22 */
MatrixAdd(a21, a22, x_ptr, a_stride, a_stride, x_stride, row2, deep2);
/* T1 = B12 - B11 */
MatrixSub(b12, b11, y_ptr, b_stride, b_stride, y_stride, deep2 * 4, col2);
/* P5 = S1T1 */
StrassenMatmul(x_ptr, y_ptr, c22, &rec_matmul, max_recursion, cur_recursion + 1, tmp_a_ptr);
/* S2 = S1 - A11 */
MatrixSub(x_ptr, a11, x_ptr, x_stride, a_stride, x_stride, row2, deep2);
/* T2 = B22 - T1 */
MatrixSub(b22, y_ptr, y_ptr, b_stride, y_stride, y_stride, deep2 * 4, col2);
/* P6 = S2T2 */
StrassenMatmul(x_ptr, y_ptr, c12, &rec_matmul, max_recursion, cur_recursion + 1, tmp_a_ptr);
/* S4 = A12 - S2 */
MatrixSub(a12, x_ptr, x_ptr, a_stride, x_stride, x_stride, row2, deep2);
/* P3 = S4B22 */
rec_matmul.b_stride_ = b_stride;
StrassenMatmul(x_ptr, b22, c11, &rec_matmul, max_recursion, cur_recursion + 1, tmp_a_ptr);
/* P1 = A11B11 */
rec_matmul.a_stride_ = a_stride;
rec_matmul.c_stride_ = row2 * FP32_STRASSEN_UINT;
StrassenMatmul(a11, b11, x_ptr, &rec_matmul, max_recursion, cur_recursion + 1, tmp_a_ptr);
/* U2 = P1 + P6
U3 = U2 + P7
U4 = U2 + P5
U7 = U3 + P5
U5 = U4 + P3 */
MatrixMultiAdd(c11, c12, c21, c22, x_ptr, row2, col2, c_stride, x_stride);
/* T4 = T2 - B21 */
MatrixSub(y_ptr, b21, y_ptr, y_stride, b_stride, y_stride, deep2 * 4, col2);
/* P4 = A22T4 */
rec_matmul.b_stride_ = y_stride;
rec_matmul.c_stride_ = c_stride;
StrassenMatmul(a22, y_ptr, c11, &rec_matmul, max_recursion, cur_recursion + 1, tmp_a_ptr);
/* U6 = U3 - P4 */
MatrixSub(c21, c11, c21, c_stride, c_stride, c_stride, row2, col2);
/* P2 = A12B21 */
rec_matmul.b_stride_ = b_stride;
StrassenMatmul(a12, b21, c11, &rec_matmul, max_recursion, cur_recursion + 1, tmp_a_ptr);
/* U1 = P1 + P2 */
MatrixAdd(x_ptr, c11, c11, x_stride, c_stride, c_stride, row2, col2);
free(x_ptr);
free(y_ptr);
return NNACL_OK;
}
int CommonMatMul(const float *a_ptr, const float *b_ptr, float *c_ptr, StrassenMatMulParameter *matmul_param,
float *tmp_a_ptr) {
MatrixPack(a_ptr, tmp_a_ptr, matmul_param->row_, matmul_param->deep_, matmul_param->a_stride_);
GemmMatMul(tmp_a_ptr, b_ptr, c_ptr, matmul_param->row_, matmul_param->col_, matmul_param->deep_,
matmul_param->b_stride_, matmul_param->c_stride_);
return NNACL_OK;
}
int StrassenMatmul(const float *a_ptr, const float *b_ptr, float *c_ptr, StrassenMatMulParameter *matmul_param,
int max_recursion, int cur_recursion, float *tmp_a_ptr) {
if (CheckRecursion(matmul_param->row_, matmul_param->col_, matmul_param->deep_, cur_recursion, max_recursion)) {
return RecursionMatmul(a_ptr, b_ptr, c_ptr, matmul_param, max_recursion, cur_recursion, tmp_a_ptr);
}
return CommonMatMul(a_ptr, b_ptr, c_ptr, matmul_param, tmp_a_ptr);
}

View File

@ -1,45 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_NNACL_FP32_STRASSEN_MATMUL_H_
#define MINDSPORE_LITE_NNACL_FP32_STRASSEN_MATMUL_H_
#include <memory.h>
#include "nnacl/pack.h"
#include "nnacl/op_base.h"
#include "nnacl/errorcode.h"
#include "nnacl/strassen_matmul.h"
#include "nnacl/fp32/common_func.h"
#define FP32_STRASSEN_UINT C4NUM
#define FP32_STRASSEN_WEIGHT_UINT (C4NUM * C4NUM)
#define FP32_STRASSEN_MAX_RECURSION 5
#ifdef __cplusplus
extern "C" {
#endif
int RecursionMatmul(const float *a_ptr, const float *b_ptr, float *c_ptr, StrassenMatMulParameter *matmul_param,
int max_recursion, int, float *tmp_a_ptr);
int CommonMatMul(const float *a_ptr, const float *b_ptr, float *c_ptr, StrassenMatMulParameter *Matmul_param,
float *tmp_a_ptr);
int StrassenMatmul(const float *a_ptr, const float *b_ptr, float *c_ptr, StrassenMatMulParameter *matmul_param,
int max_recursion, int cur_recursion, float *tmp_a_ptr);
#ifdef __cplusplus
}
#endif
#endif // MINDSPORE_LITE_NNACL_FP32_STRASSEN_MATMUL_H_

View File

@ -16,9 +16,27 @@
#include "nnacl/fp32/topk.h"
int DescendCmp(const void *a, const void *b) { return ((const TopkNode *)b)->element - ((const TopkNode *)a)->element; }
int DescendCmp(const void *a, const void *b) {
float sub = ((const TopkNode *)b)->element - ((const TopkNode *)a)->element;
if (sub > 0) {
return 1;
} else if (sub < 0) {
return -1;
} else {
return 0;
}
}
int AscendCmp(const void *a, const void *b) { return ((const TopkNode *)a)->element - ((const TopkNode *)b)->element; }
int AscendCmp(const void *a, const void *b) {
float sub = ((const TopkNode *)a)->element - ((const TopkNode *)b)->element;
if (sub > 0) {
return 1;
} else if (sub < 0) {
return -1;
} else {
return 0;
}
}
void Topk(float *input_data, float *output_data, int32_t *output_index, TopkParameter *parameter) {
int last_dim_size = parameter->last_dim_size_;

View File

@ -20,9 +20,9 @@
static int is_a_ge_zero_and_a_lt_b(int a, int b) { return (unsigned)(a) < (unsigned)(b); }
void im2col_hwc(const float *in_data, float *data_col, ConvParameter *conv_param) {
const int pad_left = /*conv_param->pad_l_*/ conv_param->pad_w_;
const int pad_left = /*conv_param->pad_l_*/ conv_param->pad_l_;
// const int pad_right = /*conv_param->pad_r_*/conv_param->pad_w_;
const int pad_up = /*conv_param->pad_u_*/ conv_param->pad_h_;
const int pad_up = /*conv_param->pad_u_*/ conv_param->pad_u_;
// const int pad_down = /*conv_param->pad_d/*/conv_param->pad_h_;
const int stride_h = conv_param->stride_h_;
@ -72,9 +72,9 @@ void im2col_hwc(const float *in_data, float *data_col, ConvParameter *conv_param
// output matrix is (kernel_h*kernel_w*channels)X(output_h*output_w)
void im2row_hwc(const float *in_data, float *data_row, ConvParameter *conv_param) {
const int pad_left = /*conv_param->pad_l_*/ conv_param->pad_w_;
const int pad_left = /*conv_param->pad_l_*/ conv_param->pad_l_;
// const int pad_right = /*conv_param->pad_r_*/conv_param->pad_w_;
const int pad_up = /*conv_param->pad_u_*/ conv_param->pad_h_;
const int pad_up = /*conv_param->pad_u_*/ conv_param->pad_u_;
// const int pad_down = /*conv_param->pad_d/*/conv_param->pad_h_;
const int stride_h = conv_param->stride_h_;

View File

@ -14,20 +14,15 @@
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_NNACL_STRASSEN_MATMUL_H_
#define MINDSPORE_LITE_NNACL_STRASSEN_MATMUL_H_
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_GATHER_PARAMETER_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_GATHER_PARAMETER_H_
#include "nnacl/op_base.h"
/* hw*inc4 X inc4*oc4 */
typedef struct StrassenMatMulParameter {
OpParameter op_parameter;
int row_; /* h * w */
int col_; /* oc4 / 4 */
int deep_; /* inc4 / 4 */
int a_stride_; /* h * w * 4 */
int b_stride_; /* inc4 * 4 */
int c_stride_; /* h * w * 4 */
} StrassenMatMulParameter;
typedef struct GatherParameter {
OpParameter op_parameter_;
int axis_;
int batchDims_;
} GatherParameter;
#endif // MINDSPORE_LITE_NNACL_STRASSEN_MATMUL_H_
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_GATHER_PARAMETER_H_

View File

@ -49,6 +49,10 @@ void ConvDwInt8Center(int8_t *dst, const int16_t *src, const int16_t *weight, co
size_t width, size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel,
size_t in_sh_step, size_t in_sw_step, size_t in_kh_step, size_t in_kw_step, int out_multiplier,
int left_shift, int right_shift, int32_t out_zp, int32_t acc_min, int32_t acc_max);
void ConvDwInt8Row(int32_t *output_ptr, const int8_t *input_ptr, const int16_t *weight_ptr, int num_pixels,
int output_channel, int input_step, int8_t input_zp);
void ConvDwInt8PostAlign4(int8_t *dst, int32_t *buffer, int num_pixels, int32_t output_zp, int32_t out_multiplier,
int32_t left_shift, int32_t right_shift, int32_t acc_min, int32_t acc_max);
#endif
#ifdef __cplusplus

View File

@ -20,6 +20,99 @@
#include "nnacl/int8/common_func.h"
/*conv depthwise int8 begin*/
// only support perlayer
#ifndef ENABLE_ARM64
void ConvDwInt8Row(int32_t *output_ptr, const int8_t *input_ptr, const int16_t *weight_ptr, int num_pixels,
int output_channel, int input_step, int8_t input_zp) {
for (int i = 0; i < num_pixels; i++) {
for (int c = 0; c < output_channel; c++) {
const int16_t input = input_ptr[c] - input_zp;
*output_ptr++ += input * weight_ptr[c];
}
input_ptr += input_step;
}
}
#endif
void ConvDwInt8Post(int8_t *dst, int32_t *buffer, int num_pixels, int32_t output_zp, int32_t out_multiplier,
int32_t left_shift, int32_t right_shift, int32_t acc_min, int32_t acc_max) {
int align_num = 0;
#ifdef ENABLE_ARM64
align_num = num_pixels / 4 * 4;
ConvDwInt8PostAlign4(dst, buffer, align_num, output_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max);
#endif
for (int i = align_num; i < num_pixels; i++) {
buffer[i] = RoundingDivideByPOT(
SaturatingRoundingDoublingHighMul(buffer[i] * (1 << (unsigned int)left_shift), out_multiplier), -right_shift);
buffer[i] += output_zp;
buffer[i] = MSMAX(buffer[i], acc_min);
buffer[i] = MSMIN(buffer[i], acc_max);
dst[i] = (buffer[i]);
}
}
void ConvDwInt8(int8_t *output_data, int32_t *row_buffer, const int8_t *input_data, const int16_t *weight_data,
const int32_t *bias_data, const ConvParameter *conv_param, int task_id) {
int h_step = UP_DIV(conv_param->output_h_, conv_param->thread_num_);
int h_start = h_step * task_id;
int h_end = MSMIN(h_start + h_step, conv_param->output_h_);
int out_multiplier = conv_param->conv_quant_arg_.quant_multiplier_[0];
int left_shift = conv_param->conv_quant_arg_.left_shift_[0];
int right_shift = conv_param->conv_quant_arg_.right_shift_[0];
int intput_zp = conv_param->conv_quant_arg_.input_quant_args_[0].zp_;
int output_zp = conv_param->conv_quant_arg_.output_quant_args_[0].zp_;
int acc_min = conv_param->conv_quant_arg_.out_act_min_[0];
int acc_max = conv_param->conv_quant_arg_.out_act_max_[0];
for (int b = 0; b < conv_param->output_batch_; b++) {
const int8_t *src = input_data + b * conv_param->input_h_ * conv_param->input_w_ * conv_param->input_channel_;
int8_t *dst = output_data + b * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_channel_;
for (int oh = h_start; oh < h_end; oh++) {
int8_t *dst_data = dst + oh * conv_param->output_w_ * conv_param->output_channel_;
int ih_origin = oh * conv_param->stride_h_ - conv_param->pad_u_;
int start_kh = MSMAX(0, UP_DIV(-ih_origin, conv_param->dilation_h_));
int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih_origin, conv_param->dilation_h_));
// init acc
for (int ow = 0; ow < conv_param->output_w_; ow++) {
memcpy(row_buffer + ow * conv_param->output_channel_, bias_data, conv_param->output_channel_ * sizeof(int32_t));
}
for (int kh = start_kh; kh < end_kh; kh++) {
int ih = ih_origin + conv_param->dilation_w_ * kh;
const int8_t *src_kh = src + ih * conv_param->input_w_ * conv_param->input_channel_;
const int16_t *weight_kh = weight_data + kh * conv_param->kernel_w_ * conv_param->output_channel_;
int in_sw_step = conv_param->stride_w_ * conv_param->input_channel_;
for (int kw = 0; kw < conv_param->kernel_w_; kw++) {
int out_w_start = MSMAX(
0, (conv_param->pad_l_ - conv_param->dilation_w_ * kw + conv_param->stride_w_ - 1) / conv_param->stride_w_);
int out_w_end = MSMIN(conv_param->output_w_, (conv_param->input_w_ + conv_param->pad_l_ -
conv_param->dilation_w_ * kw + conv_param->stride_w_ - 1) /
conv_param->stride_w_);
int32_t *acc_w = row_buffer + out_w_start * conv_param->output_channel_;
int iw_origin = (out_w_start * conv_param->stride_w_) - conv_param->pad_l_ + conv_param->dilation_w_ * kw;
const int8_t *src_kw = src_kh + iw_origin * conv_param->input_channel_;
int num_pixels = out_w_end - out_w_start;
ConvDwInt8Row(acc_w, src_kw, weight_kh, num_pixels, conv_param->output_channel_, in_sw_step, intput_zp);
weight_kh += conv_param->output_channel_;
}
}
// post func, acc int32 -> dst int8
ConvDwInt8Post(dst_data, row_buffer, conv_param->output_w_ * conv_param->output_channel_, output_zp,
out_multiplier, left_shift, right_shift, acc_min, acc_max);
}
}
}
/*conv depthwise int8 end*/
/*conv depthwise sliding window int8 begin*/
void DepthwiseBorderPixelInt8(int8_t *dst, const int16_t *src, const int16_t *weight, const int32_t *bias, int height,
int width, int in_kh_step, int in_kw_step, int kernel_w, int *out_multiplier,
int *left_shift, int *right_shift, int32_t out_zp, int32_t acc_min, int32_t acc_max,
@ -68,14 +161,14 @@ void DepthwiseBorderInt8(int8_t *dst, const int16_t *src, const int16_t *weight,
bool per_channel) {
int8_t *dst_h = dst + top * sliding->out_h_step_;
for (int oh = top; oh < bottom; oh++) {
int ih = oh * conv_param->stride_h_ - conv_param->pad_h_;
int ih = oh * conv_param->stride_h_ - conv_param->pad_u_;
int start_kh = MSMAX(0, UP_DIV(-ih, conv_param->dilation_h_));
int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih, conv_param->dilation_h_));
const int16_t *src_h = src + ih * sliding->in_h_step_;
int8_t *dst_kernel = dst_h + left * sliding->block_channel_;
for (int ow = left; ow < right; ow++) {
int iw = ow * conv_param->stride_w_ - conv_param->pad_w_;
int iw = ow * conv_param->stride_w_ - conv_param->pad_l_;
int start_kw = MSMAX(0, UP_DIV(-iw, conv_param->dilation_w_));
int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->input_w_ - iw, conv_param->dilation_w_));
const int16_t *src_w = src_h + iw * sliding->block_channel_;
@ -153,8 +246,8 @@ void DepthwiseCenterInt8(int8_t *dst, const int16_t *src, const int16_t *weight,
}
#endif
void ConvDwInt8(int8_t *output_data, const int16_t *input_data, const int16_t *weight_data, const int32_t *bias_data,
const ConvParameter *conv_param, const SlidingWindowParam *sliding, int task_id) {
void ConvDwSWInt8(int8_t *output_data, const int16_t *input_data, const int16_t *weight_data, const int32_t *bias_data,
const ConvParameter *conv_param, const SlidingWindowParam *sliding, int task_id) {
const int16_t *src = input_data;
int8_t *dst = output_data;
bool per_channel = conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL;
@ -186,8 +279,8 @@ void ConvDwInt8(int8_t *output_data, const int16_t *input_data, const int16_t *w
per_channel);
if (sliding->right_ > sliding->left_ && sliding->bottom_ > sliding->top_) {
int in_h_start = sliding->top_ * conv_param->stride_h_ - conv_param->pad_h_;
int in_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_w_;
int in_h_start = sliding->top_ * conv_param->stride_h_ - conv_param->pad_u_;
int in_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_l_;
const int16_t *in_t = src_data + in_h_start * sliding->in_h_step_ + in_w_start * sliding->block_channel_;
int8_t *out_t = dst_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_;
#ifdef ENABLE_ARM64
@ -215,7 +308,7 @@ void ConvDwInt8(int8_t *output_data, const int16_t *input_data, const int16_t *w
} // batch loop
// output nhwc4
}
/*conv depthwise int8 end*/
/*conv depthwise sliding window int8 end*/
/*deconv depthwise int8 begin*/
void DeconvDepthwiseBorderPixelInt8(int32_t *dst, const int16_t *src, const int16_t *weight, int height, int width,
@ -241,14 +334,14 @@ void DeconvDepthwiseBorderInt8(int32_t *dst, const int16_t *src, const int16_t *
int right, const ConvParameter *conv_param, const SlidingWindowParam *sliding) {
const int16_t *src_h = src + top * sliding->out_h_step_;
for (int ih = top; ih < bottom; ih++) {
int oh = ih * conv_param->stride_h_ - conv_param->pad_h_;
int oh = ih * conv_param->stride_h_ - conv_param->pad_u_;
int start_kh = MSMAX(0, UP_DIV(-oh, conv_param->dilation_h_));
int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->output_h_ - oh, conv_param->dilation_h_));
int32_t *dst_h = dst + oh * sliding->in_h_step_;
const int16_t *src_kernel = src_h + left * sliding->block_channel_;
for (int iw = left; iw < right; iw++) {
int ow = iw * conv_param->stride_w_ - conv_param->pad_w_;
int ow = iw * conv_param->stride_w_ - conv_param->pad_l_;
int start_kw = MSMAX(0, UP_DIV(-ow, conv_param->dilation_w_));
int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->output_w_ - ow, conv_param->dilation_w_));
int32_t *dst_w = dst_h + ow * C4NUM;
@ -341,8 +434,8 @@ void DeconvDwInt8(int8_t *output_data, int32_t *output_buffer, const int16_t *in
conv_param->input_w_, conv_param, sliding);
if (sliding->right_ > sliding->left_ && sliding->bottom_ > sliding->top_) {
int oh_h_start = sliding->top_ * conv_param->stride_h_ - conv_param->pad_h_;
int oh_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_w_;
int oh_h_start = sliding->top_ * conv_param->stride_h_ - conv_param->pad_u_;
int oh_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_l_;
int32_t *out_t = output_buffer + oh_h_start * sliding->in_h_step_ + oh_w_start * sliding->block_channel_;
const int16_t *in_t =
src_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_;

View File

@ -23,8 +23,12 @@
#ifdef __cplusplus
extern "C" {
#endif
void ConvDwInt8(int8_t *output_data, const int16_t *input_data, const int16_t *weight_data, const int32_t *bias_data,
const ConvParameter *conv_param, const SlidingWindowParam *sliding, int task_id);
void ConvDwInt8(int8_t *output_data, int32_t *output_row, const int8_t *input_data, const int16_t *weight_data,
const int32_t *bias_data, const ConvParameter *conv_param, int task_id);
void ConvDwSWInt8(int8_t *output_data, const int16_t *input_data, const int16_t *weight_data, const int32_t *bias_data,
const ConvParameter *conv_param, const SlidingWindowParam *sliding, int task_id);
void DeconvDwInt8(int8_t *output_data, int32_t *output_buffer, const int16_t *input_data, const int16_t *weight_data,
const int32_t *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding,

View File

@ -28,7 +28,7 @@ void IndirectGemmInt8(int8_t *dst, int32_t *tmp_dst, const int8_t *src, const in
int32_t out_zp = conv_param->conv_quant_arg_.output_quant_args_[0].zp_;
int32_t act_min = conv_param->conv_quant_arg_.out_act_min_[0];
int32_t act_max = conv_param->conv_quant_arg_.out_act_max_[0];
int oc4 = UP_DIV(output_channel, C4NUM);
#ifdef ENABLE_ARM64
size_t asymmetric = conv_param->conv_quant_arg_.asymmetric_ & FILTER_ASYMMETRIC;
size_t per_channel = conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL;
@ -36,6 +36,7 @@ void IndirectGemmInt8(int8_t *dst, int32_t *tmp_dst, const int8_t *src, const in
output_channel * sizeof(int8_t), input_sum, act_min, act_max, out_zp, out_multiplier,
shift_before, shift_after, asymmetric, per_channel);
#else
int oc4 = UP_DIV(output_channel, C4NUM);
int tile_num = conv_param->tile_num_;
int plane_c4 = UP_DIV(kernel_plane, C4NUM);
for (int oc = 0; oc < output_channel; oc++) {
@ -198,7 +199,7 @@ void IndirectGemmInt8Opt(int8_t *dst, int32_t *tmp_dst, const int8_t *src, const
}
}
void Conv3x3Uint8Gemm(int32_t *dst, const int16_t *src, const int16_t *weight, int oc, int ic8, size_t real_cal_num) {
void Conv3x3Int8Gemm(int32_t *dst, const int16_t *src, const int16_t *weight, int oc, int ic8, size_t real_cal_num) {
int oc4 = UP_DIV(oc, C4NUM);
#ifdef ENABLE_ARM64
IndirectGemmInt16to32_8x4(dst, src, weight, 16, ic8, oc4, oc4 * 4 * 16 * sizeof(int32_t));
@ -263,7 +264,8 @@ void ConvInt8(int8_t *input_data, int8_t *packed_input, int8_t *packed_weight, c
int output_tile_count = UP_DIV(output_count, tile_n);
int ic4 = UP_DIV(in_channel, C4NUM);
int kernel_plane = kernel_h * kernel_w;
int unit_size = kernel_plane * ic4 * C4NUM;
int plane_block = UP_DIV(kernel_plane, C4NUM);
int unit_size = plane_block * C4NUM * ic4 * C4NUM;
int packed_input_size = output_tile_count * tile_n * unit_size;
int input_sum_offset;
if (conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL) {
@ -297,9 +299,10 @@ void ConvInt8(int8_t *input_data, int8_t *packed_input, int8_t *packed_weight, c
out_channel, tmp_input_sum, conv_param);
} else {
// res part
IndirectGemmInt8(tmp_out, tmp_dst + tmp_dst_offset, gemm_input, packed_weight, bias_data, ic4, kernel_plane,
int8_t *tmp_out_ptr = tmp_out + task_id * tile_n * out_channel;
IndirectGemmInt8(tmp_out_ptr, tmp_dst + tmp_dst_offset, gemm_input, packed_weight, bias_data, ic4, kernel_plane,
out_channel, tmp_input_sum, conv_param);
memcpy(output_data + out_offset, tmp_out, real_cal_num * out_channel);
memcpy(output_data + out_offset, tmp_out_ptr, real_cal_num * out_channel);
}
}
}
@ -359,14 +362,274 @@ void ConvInt8Opt(int8_t *input_data, int8_t *packed_input, int8_t *packed_weight
kernel_plane, out_channel, tmp_input_sum, conv_param, gemm_func);
} else {
// res part
IndirectGemmInt8Opt(tmp_out, tmp_dst + tmp_dst_offset, gemm_input, packed_weight, bias_data, ic4, kernel_plane,
out_channel, tmp_input_sum, conv_param, gemm_func);
memcpy(output_data + out_offset, tmp_out, real_cal_num * out_channel);
int8_t *tmp_out_ptr = tmp_out + task_id * tile_n * out_channel;
IndirectGemmInt8Opt(tmp_out_ptr, tmp_dst + tmp_dst_offset, gemm_input, packed_weight, bias_data, ic4,
kernel_plane, out_channel, tmp_input_sum, conv_param, gemm_func);
memcpy(output_data + out_offset, tmp_out_ptr, real_cal_num * out_channel);
}
}
}
}
void Conv1x1PreOpt(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum, size_t input_channel,
size_t output_channel, size_t plane_size, ConvParameter *conv_param) {
int ic4 = UP_ROUND(input_channel, C4NUM);
size_t hw_8div = plane_size / C8NUM * C8NUM;
size_t hw_8res = plane_size - hw_8div;
size_t ic_4div = input_channel / C4NUM * C4NUM;
int32_t filter_zp = conv_param->conv_quant_arg_.filter_quant_args_[0].zp_;
if (conv_param->conv_quant_arg_.filter_arg_num_ == 1) {
const int8_t *src_r = src_input;
int8_t *pack_r = packed_input;
/* per layer */
for (int hwi = 0; hwi < hw_8div; hwi += C8NUM) {
const int8_t *src_ic = src_r;
int8_t *pack_ic = pack_r;
int32_t *input_sum_r = input_sum + hwi;
#ifdef ENABLE_ARM64
size_t src_stride = input_channel;
size_t ic_4res = input_channel - ic_4div;
asm volatile(
"dup v10.4s, wzr \n"
"dup v11.4s, wzr \n"
"mov x20, %[input_sum_r] \n"
"dup v20.4s, %w[filter_zp] \n"
"mov x10, %[src_ic] \n"
"mov x11, %[pack_ic] \n"
"mov x0, #0 \n"
"1: \n"
"cmp x0, %[ic_4div] \n"
"add x0, x0, #4\n"
"mov x12, x10 \n"
"add x10, x10, #4\n"
"blt 2f \n"
"cmp %[ic_4res], #0\n"
"beq 6f \n"
"cmp %[ic_4res], #1\n"
"beq 3f \n"
"cmp %[ic_4res], #2\n"
"beq 4f \n"
"cmp %[ic_4res], #3\n"
"beq 5f \n"
"2: \n"
"ld1 {v0.s}[0], [x12], %[src_stride]\n"
"ld1 {v0.s}[1], [x12], %[src_stride]\n"
"ld1 {v0.s}[2], [x12], %[src_stride]\n"
"ld1 {v0.s}[3], [x12], %[src_stride]\n"
"ld1 {v1.s}[0], [x12], %[src_stride]\n"
"ld1 {v1.s}[1], [x12], %[src_stride]\n"
"ld1 {v1.s}[2], [x12], %[src_stride]\n"
"ld1 {v1.s}[3], [x12], %[src_stride]\n"
"st1 {v0.16b}, [x11], #16\n"
"st1 {v1.16b}, [x11], #16\n"
"saddlp v4.8h, v0.16b \n"
"saddlp v5.8h, v1.16b \n"
"saddlp v0.4s, v4.8h \n"
"saddlp v1.4s, v5.8h \n"
"add v10.4s, v10.4s, v0.4s \n"
"add v11.4s, v11.4s, v1.4s \n"
"b 1b \n"
"3: \n" /* col res 1 */
"dup v0.4s, wzr \n"
"dup v1.4s, wzr \n"
"ld1 {v0.b}[0], [x12], %[src_stride]\n"
"ld1 {v0.b}[4], [x12], %[src_stride]\n"
"ld1 {v0.b}[8], [x12], %[src_stride]\n"
"ld1 {v0.b}[12], [x12], %[src_stride]\n"
"ld1 {v1.b}[0], [x12], %[src_stride]\n"
"ld1 {v1.b}[4], [x12], %[src_stride]\n"
"ld1 {v1.b}[8], [x12], %[src_stride]\n"
"ld1 {v1.b}[12], [x12], %[src_stride]\n"
"st1 {v0.16b}, [x11], #16\n"
"st1 {v1.16b}, [x11], #16\n"
"saddlp v4.8h, v0.16b \n"
"saddlp v5.8h, v1.16b \n"
"saddlp v0.4s, v4.8h \n"
"saddlp v1.4s, v5.8h \n"
"add v10.4s, v10.4s, v0.4s \n"
"add v11.4s, v11.4s, v1.4s \n"
"b 6f \n"
"4: \n" /* col res 2 */
"dup v0.4s, wzr \n"
"dup v1.4s, wzr \n"
"ld1 {v0.h}[0], [x12], %[src_stride]\n"
"ld1 {v0.h}[2], [x12], %[src_stride]\n"
"ld1 {v0.h}[4], [x12], %[src_stride]\n"
"ld1 {v0.h}[6], [x12], %[src_stride]\n"
"ld1 {v1.h}[0], [x12], %[src_stride]\n"
"ld1 {v1.h}[2], [x12], %[src_stride]\n"
"ld1 {v1.h}[4], [x12], %[src_stride]\n"
"ld1 {v1.h}[6], [x12], %[src_stride]\n"
"st1 {v0.16b}, [x11], #16\n"
"st1 {v1.16b}, [x11], #16\n"
"saddlp v4.8h, v0.16b \n"
"saddlp v5.8h, v1.16b \n"
"saddlp v0.4s, v4.8h \n"
"saddlp v1.4s, v5.8h \n"
"add v10.4s, v10.4s, v0.4s \n"
"add v11.4s, v11.4s, v1.4s \n"
"b 6f \n"
"5: \n" /* col res 3 */
"dup v0.4s, wzr \n"
"dup v1.4s, wzr \n"
"add x13, x12, #2 \n"
"ld1 {v0.h}[0], [x12], %[src_stride]\n"
"ld1 {v0.b}[2], [x13], %[src_stride]\n"
"ld1 {v0.h}[2], [x12], %[src_stride]\n"
"ld1 {v0.b}[6], [x13], %[src_stride]\n"
"ld1 {v0.h}[4], [x12], %[src_stride]\n"
"ld1 {v0.b}[10], [x13], %[src_stride]\n"
"ld1 {v0.h}[6], [x12], %[src_stride]\n"
"ld1 {v0.b}[14], [x13], %[src_stride]\n"
"ld1 {v1.h}[0], [x12], %[src_stride]\n"
"ld1 {v1.b}[2], [x13], %[src_stride]\n"
"ld1 {v1.h}[2], [x12], %[src_stride]\n"
"ld1 {v1.b}[6], [x13], %[src_stride]\n"
"ld1 {v1.h}[4], [x12], %[src_stride]\n"
"ld1 {v1.b}[10], [x13], %[src_stride]\n"
"ld1 {v1.h}[6], [x12], %[src_stride]\n"
"ld1 {v1.b}[14], [x13], %[src_stride]\n"
"st1 {v0.16b}, [x11], #16\n"
"st1 {v1.16b}, [x11], #16\n"
"saddlp v4.8h, v0.16b \n"
"saddlp v5.8h, v1.16b \n"
"saddlp v0.4s, v4.8h \n"
"saddlp v1.4s, v5.8h \n"
"add v10.4s, v10.4s, v0.4s \n"
"add v11.4s, v11.4s, v1.4s \n"
"b 6f \n"
"6: \n"
"mul v10.4s, v10.4s, v20.4s \n"
"mul v11.4s, v11.4s, v20.4s \n"
"st1 {v10.4s}, [x20], #16 \n"
"st1 {v11.4s}, [x20], #16 \n"
:
: [ src_ic ] "r"(src_ic), [ pack_ic ] "r"(pack_ic), [ input_sum_r ] "r"(input_sum_r),
[ src_stride ] "r"(src_stride), [ ic_4div ] "r"(ic_4div), [ ic_4res ] "r"(ic_4res),
[ filter_zp ] "r"(filter_zp)
: "x0", "x1", "x10", "x11", "x12", "x13", "x20", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v10", "v11",
"v20");
#else
int32_t tmp_sum_value[8] = {0};
for (int ici = 0; ici < ic_4div; ici += C4NUM) {
for (int i = 0; i < C8NUM; i++) {
tmp_sum_value[i] += src_ic[0 + i * input_channel];
tmp_sum_value[i] += src_ic[1 + i * input_channel];
tmp_sum_value[i] += src_ic[2 + i * input_channel];
tmp_sum_value[i] += src_ic[3 + i * input_channel];
pack_ic[0 + i * C4NUM] = src_ic[0 + i * input_channel];
pack_ic[1 + i * C4NUM] = src_ic[1 + i * input_channel];
pack_ic[2 + i * C4NUM] = src_ic[2 + i * input_channel];
pack_ic[3 + i * C4NUM] = src_ic[3 + i * input_channel];
}
src_ic += C4NUM;
pack_ic += C4NUM * C8NUM;
}
for (int ici = ic_4div; ici < input_channel; ici += 1) {
for (int i = 0; i < C8NUM; i++) {
tmp_sum_value[i] += src_ic[i * input_channel];
pack_ic[i * C4NUM] = src_ic[i * input_channel];
}
src_ic += 1;
pack_ic += 1;
}
for (int i = 0; i < C8NUM; i++) {
input_sum_r[i] = tmp_sum_value[i] * filter_zp;
}
#endif
src_r += input_channel * C8NUM;
pack_r += ic4 * C8NUM;
}
if (hw_8div != plane_size) {
memset(pack_r, 0, C8NUM * ic4);
for (int hwi = hw_8div; hwi < plane_size; hwi += 1) {
int32_t tmp_sum_value = 0;
const int8_t *src_ic = src_r;
int8_t *pack_ic = pack_r;
for (int ici = 0; ici < ic_4div; ici += C4NUM) {
tmp_sum_value += src_ic[0];
tmp_sum_value += src_ic[1];
tmp_sum_value += src_ic[2];
tmp_sum_value += src_ic[3];
pack_ic[0] = src_ic[0];
pack_ic[1] = src_ic[1];
pack_ic[2] = src_ic[2];
pack_ic[3] = src_ic[3];
src_ic += C4NUM;
pack_ic += C4NUM * C8NUM;
}
for (int ici = ic_4div; ici < input_channel; ici += 1) {
tmp_sum_value += src_ic[0];
pack_ic[0] = src_ic[0];
src_ic += 1;
pack_ic += 1;
}
input_sum[hwi] = tmp_sum_value * filter_zp;
src_r += input_channel;
pack_r += C4NUM;
}
for (int hwi = plane_size; hwi < plane_size + hw_8res; hwi++) {
input_sum[hwi] = 0;
}
}
} else {
/* per channel */
RowMajor2Row4x8MajorInt8(src_input, packed_input, plane_size, input_channel);
PackInputSum8x4Int8(packed_input, input_sum, input_channel, output_channel, plane_size, conv_param);
}
return;
}
void Conv1x1Int8Opt(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum,
const int32_t *bias, int row, int col, int deep4, ConvParameter *conv_param,
MATMUL_OPT_R_FUNC matmul_func) {
matmul_func(packed_input, packed_weight, dst, row, col, deep4, conv_param->output_channel_, input_sum, bias,
conv_param->conv_quant_arg_.left_shift_, conv_param->conv_quant_arg_.right_shift_,
conv_param->conv_quant_arg_.quant_multiplier_, conv_param->conv_quant_arg_.output_quant_args_[0].zp_,
conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], false);
return;
}
void Conv1x1Int8(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum,
const int32_t *bias, int row, int col, int deep16, ConvParameter *conv_param) {
#ifdef ENABLE_ARM64
MatmulInt8Neon64(packed_input, packed_weight, dst, UP_ROUND(row, C4NUM), UP_ROUND(col, C4NUM), deep16, input_sum,
bias, conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0],
conv_param->conv_quant_arg_.output_quant_args_[0].zp_,
conv_param->conv_quant_arg_.quant_multiplier_[0], conv_param->conv_quant_arg_.left_shift_[0],
conv_param->conv_quant_arg_.right_shift_[0], row, col, conv_param->output_channel_);
#else
MatMulInt8_16x4_r(packed_input, packed_weight, dst, row, col, deep16, conv_param->output_channel_, input_sum, bias,
conv_param->conv_quant_arg_.left_shift_, conv_param->conv_quant_arg_.right_shift_,
conv_param->conv_quant_arg_.quant_multiplier_,
conv_param->conv_quant_arg_.output_quant_args_[0].zp_, conv_param->conv_quant_arg_.out_act_min_[0],
conv_param->conv_quant_arg_.out_act_max_[0], false);
#endif
return;
}
// int8 convolution 3x3
void Conv3x3Int8(int16_t *input_data, int16_t *transed_weight, const int32_t *bias_data, int8_t *output_data,
int16_t *tile_buffer, int16_t *block_unit_buffer, int32_t *tmp_dst_buffer, int8_t *tmp_out,
@ -391,15 +654,15 @@ void Conv3x3Int8(int16_t *input_data, int16_t *transed_weight, const int32_t *bi
int start_index = thread_id * TILE_NUM;
int real_cal_num = (output_count - start_index) < TILE_NUM ? (output_count - start_index) : TILE_NUM;
Conv3x3Uint8InputTransform(input_data + in_batch_offset, tile_buffer + task_id * tile_buffer_offset,
block_unit_buffer + task_id * block_unit_buffer_offset, start_index, real_cal_num,
out_w_block, conv_param);
Conv3x3Int8InputTransform(input_data + in_batch_offset, tile_buffer + task_id * tile_buffer_offset,
block_unit_buffer + task_id * block_unit_buffer_offset, start_index, real_cal_num,
out_w_block, conv_param);
Conv3x3Uint8Gemm(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, tile_buffer + task_id * tile_buffer_offset,
transed_weight, output_channel, ic8, real_cal_num);
Conv3x3Int8Gemm(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, tile_buffer + task_id * tile_buffer_offset,
transed_weight, output_channel, ic8, real_cal_num);
Conv3x3Uint8OutputTransform(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, tmp_out + tmp_out_batch_offset,
bias_data, start_index, real_cal_num, out_w_block, conv_param);
Conv3x3Int8OutputTransform(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, tmp_out + tmp_out_batch_offset,
bias_data, start_index, real_cal_num, out_w_block, conv_param);
}
}
}

View File

@ -25,6 +25,8 @@
#include "nnacl/conv_parameter.h"
#include "nnacl/winograd_utils.h"
#include "nnacl/quantization/quantize.h"
#include "nnacl/matmul_parameter.h"
#include "nnacl/int8/matmul_int8.h"
typedef void (*GEMM_FUNC)(int8_t *dst, const int8_t *src, const int8_t *weight, const int32_t *bias, size_t ksize,
size_t ic4, size_t output_channel, size_t offset, const int32_t *input_sum, size_t act_min,
@ -51,6 +53,15 @@ void ConvInt8Opt(int8_t *input_data, int8_t *packed_input, int8_t *packed_weight
int32_t *tmp_dst, int8_t *tmp_out, int8_t *output_data, int32_t *input_sum, int task_id,
ConvParameter *conv_param, GEMM_FUNC gemm_func);
// int8 convolution 1x1
void Conv1x1PreOpt(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum, size_t input_channel,
size_t output_channel, size_t plane_size, ConvParameter *conv_param);
void Conv1x1Int8(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum,
const int32_t *bias, int row, int col, int deep16, ConvParameter *conv_param);
void Conv1x1Int8Opt(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum,
const int32_t *bias, int row, int col, int deep4, ConvParameter *conv_param,
MATMUL_OPT_R_FUNC matmul_func);
// int8 convolution 3x3
void Conv3x3Int8(int16_t *input_data, int16_t *transed_weight, const int32_t *bias_data, int8_t *output_data,
int16_t *tile_buffer, int16_t *block_unit_buffer, int32_t *tmp_dst_buffer, int8_t *tmp_out,

View File

@ -33,8 +33,8 @@ int DeConvPostInt8C8(const int32_t *src, const int32_t *bias, int32_t *tmp, int8
for (int ih = 0; ih < conv_param->input_h_; ih++) {
for (int iw = 0; iw < conv_param->input_w_; iw++) {
int oh = ih * conv_param->stride_h_ - conv_param->pad_h_;
int ow = iw * conv_param->stride_w_ - conv_param->pad_w_;
int oh = ih * conv_param->stride_h_ - conv_param->pad_u_;
int ow = iw * conv_param->stride_w_ - conv_param->pad_l_;
int kh_start = MSMAX(0, UP_DIV(-oh, conv_param->dilation_h_));
int kh_end = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->output_h_ - oh, conv_param->dilation_h_));
@ -88,8 +88,8 @@ int DeConvPostInt8C4(const int32_t *src, const int32_t *bias, int32_t *tmp, int8
for (int ih = 0; ih < conv_param->input_h_; ih++) {
for (int iw = 0; iw < conv_param->input_w_; iw++) {
int oh = ih * conv_param->stride_h_ - conv_param->pad_h_;
int ow = iw * conv_param->stride_w_ - conv_param->pad_w_;
int oh = ih * conv_param->stride_h_ - conv_param->pad_u_;
int ow = iw * conv_param->stride_w_ - conv_param->pad_l_;
int kh_start = MSMAX(0, UP_DIV(-oh, conv_param->dilation_h_));
int kh_end = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->output_h_ - oh, conv_param->dilation_h_));
@ -172,73 +172,7 @@ void DeConvPackWeightSum(int8_t *weight, int32_t *weight_sum, int32_t input_zp,
void DeConvPackInputSum(const int8_t *src, int32_t *dst, int32_t filter_zp, size_t row4, size_t col16,
bool suppport_opt) {
/* optimize normal -> same layout */
#ifdef ENABLE_ARM64
asm volatile(
"mov x10, %[src] \n"
"mov x11, %[dst] \n"
"dup v15.4s, %w[filter_zp] \n"
"mov x0, #0 \n"
"1: \n"
"cmp x0, %[row4] \n"
"beq 4f \n"
"add x0, x0, #4\n"
"dup v10.4s, wzr \n"
"mov x2, #0 \n"
"2: \n"
"cmp x2, %[col16] \n"
"beq 3f \n"
"add x2, x2, #16\n"
"ld1 {v0.16b}, [x10], #16\n"
"ld1 {v1.16b}, [x10], #16\n"
"ld1 {v2.16b}, [x10], #16\n"
"ld1 {v3.16b}, [x10], #16\n"
"saddlp v4.8h, v0.16b \n"
"saddlp v5.8h, v1.16b \n"
"saddlp v6.8h, v2.16b \n"
"saddlp v7.8h, v3.16b \n"
"saddlp v0.4S, v4.8h \n"
"saddlp v1.4S, v5.8h \n"
"saddlp v2.4S, v6.8h \n"
"saddlp v3.4S, v7.8h \n"
"addv s4, v0.4S \n"
"addv s5, v1.4S \n"
"addv s6, v2.4S \n"
"addv s7, v3.4S \n"
"mov v0.s[0], v4.s[0] \n"
"mov v0.s[1], v5.s[0] \n"
"mov v0.s[2], v6.s[0] \n"
"mov v0.s[3], v7.s[0] \n"
"add v10.4s, v10.4s, v0.4s \n"
"b 2b\n"
"3: \n"
"mul v10.4s, v10.4s, v15.4s \n"
"st1 {v10.4s}, [x11], #16 \n"
"beq 1b \n"
"4: \n"
:
: [ dst ] "r"(dst), [ src ] "r"(src), [ row4 ] "r"(row4), [ col16 ] "r"(col16), [ filter_zp ] "r"(filter_zp)
: "x0", "x1", "x2", "x3", "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v10", "v15");
#else
for (int r = 0; r < row4; r++) {
int32_t tmp_value = 0;
for (int c = 0; c < col16; c++) {
int r4div = r / C4NUM, r4mod = r % C4NUM, c16div = c / C16NUM, c16mod = c % C16NUM;
int src_index = r4div * C4NUM * col16 + c16div * C16NUM * C4NUM + r4mod * C16NUM + c16mod;
tmp_value += src[src_index];
}
}
#endif
PackInputSum16x4PerLayer(src, dst, filter_zp, row4, col16);
return;
}

View File

@ -29,8 +29,8 @@ int DivInt8(int8_t *input0_data, int8_t *input1_data, int8_t *output_data, int64
}
int recip_shift;
const int32_t input1_inv = (input1_val > 0) ? ComputerReciproal(input1_val, 31, &recip_shift)
: -ComputerReciproal(-input1_val, 31, &recip_shift);
const int32_t input1_inv = (input1_val > 0) ? ComputerReciprocal(input1_val, 31, &recip_shift)
: -ComputerReciprocal(-input1_val, 31, &recip_shift);
const int leading_bits = CountLeadingSignBits(input0_val);
const int32_t raw_data =
SaturatingRoundingDoublingHighMul(input0_val * (1 << (unsigned int)leading_bits), input1_inv);

View File

@ -0,0 +1,34 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "nnacl/int8/gatherNd_int8.h"
#include <string.h>
#include "nnacl/errorcode.h"
int GatherNdInt8(int8_t *input, int8_t *output, int *in_offset, int area, int count, GatherQuantArg param) {
double alpha = param.alpha_;
int z1 = param.zp_in_;
int z2 = param.zp_out_;
for (int i = 0; i < count; ++i) {
for (int j = 0; j < area; ++j) {
int32_t tmp = round(alpha * (input[in_offset[i] + j] - z1)) + z2;
tmp = tmp > 127 ? 127 : tmp;
tmp = tmp < -128 ? -128 : tmp;
output[area * i + j] = (int8_t)tmp;
}
}
return NNACL_OK;
}

View File

@ -0,0 +1,31 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_GATHERND_INT8_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_GATHERND_INT8_H_
#include "nnacl/op_base.h"
#include "nnacl/quantization/quantize.h"
#ifdef __cplusplus
extern "C" {
#endif
int GatherNdInt8(int8_t *in_data, int8_t *out_data, int *in_offset, int area, int count, GatherQuantArg param);
#ifdef __cplusplus
}
#endif
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_GATHERND_INT8_H_

View File

@ -0,0 +1,44 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
#include "nnacl/int8/gather_int8.h"
#include "nnacl/op_base.h"
#include "nnacl/quantization/quantize.h"
#include "nnacl/errorcode.h"
int GatherInt8(int8_t *in_data, int8_t *out_data, int outer_size, int inner_size, int limit, int *indices,
int indices_element_size, GatherQuantArg para) {
double alpha = para.alpha_;
int z1 = para.zp_in_;
int z2 = para.zp_out_;
int i, m, j;
for (m = 0; m < outer_size; ++m) {
const int8_t *inputm = in_data + inner_size * m * limit;
int8_t *outputm = out_data + inner_size * m * indices_element_size;
for (i = 0; i < indices_element_size; ++i) {
if (indices[i] < 0 || indices[i] > limit) {
return NNACL_ERR;
}
for (j = 0; j < inner_size; ++j) {
int32_t tmp = round(alpha * (inputm[indices[i] * inner_size + j] - z1)) + z2;
tmp = tmp > 127 ? 127 : tmp;
tmp = tmp < -128 ? -128 : tmp;
outputm[i * inner_size + j] = (int8_t)tmp;
}
}
}
return NNACL_OK;
}

View File

@ -0,0 +1,32 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_GATHER_INT8_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_GATHER_INT8_H_
#include "nnacl/op_base.h"
#include "nnacl/quantization/quantize.h"
#ifdef __cplusplus
extern "C" {
#endif
int GatherInt8(int8_t *in_data, int8_t *out_data, int outer_size, int inner_size, int limit, int *indices,
int indices_element_size, GatherQuantArg para);
#ifdef __cplusplus
}
#endif
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_GATHER_INT8_H_

View File

@ -28,6 +28,36 @@ void RowMajor2Row8MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col)
}
}
void RowMajor2Row4x16MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col) {
int col16 = UP_ROUND(col, C16NUM);
for (int r = 0; r < row; r++) {
int rd4 = r / C4NUM;
int rm4 = r % C4NUM;
for (int c = 0; c < col; c++) {
int cd16 = c / C16NUM;
int cm16 = c % C16NUM;
int dst_index = rd4 * col16 * C4NUM + cd16 * C4NUM * C16NUM + rm4 * C16NUM + cm16;
int src_index = r * col + c;
dst_ptr[dst_index] = src_ptr[src_index];
}
}
}
void RowMajor2Row8x4MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col) {
int col4 = UP_ROUND(col, C4NUM);
for (int r = 0; r < row; r++) {
int rd8 = r / C8NUM;
int rm8 = r % C8NUM;
for (int c = 0; c < col; c++) {
int cd4 = c / C4NUM;
int cm4 = c % C4NUM;
int dst_index = rd8 * col4 * C8NUM + cd4 * C8NUM * C4NUM + rm8 * C4NUM + cm4;
int src_index = r * col + c;
dst_ptr[dst_index] = src_ptr[src_index];
}
}
}
void MatrixPack4x16UnitInt8(int8_t *src, int8_t *dst, int row, int col, int stride) {
for (int r = 0; r < row; r++) {
int8_t *src_r = src + r * stride;
@ -37,6 +67,29 @@ void MatrixPack4x16UnitInt8(int8_t *src, int8_t *dst, int row, int col, int stri
return;
}
void MatrixEmptyInt8(int8_t *dst, int row, int col) {
for (int r = 0; r < row; r++) {
int8_t *dst_r = dst + r * C16NUM;
memset(dst_r, 0, col * sizeof(int8_t));
}
return;
}
void RowMajor2Row4x8MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col) {
/* Row-major to row16x4-major (block row-major) */
int col4 = UP_ROUND(col, C4NUM);
for (int r = 0; r < row; r++) {
int rd8 = r / C8NUM, rm8 = r % C8NUM;
for (int c = 0; c < col; c++) {
int cd4 = c / C4NUM, cm4 = c % C4NUM;
int src_index = r * col + c;
int dst_index = rd8 * col4 * C8NUM + cd4 * C4NUM * C8NUM + rm8 * C4NUM + cm4;
dst_ptr[dst_index] = src_ptr[src_index];
}
}
return;
}
void RowMajor2Row16x4MajorInt8(void *src_ptr, void *dst_ptr, int row, int col) {
/* Row-major to row16x4-major (block row-major) */
int col16 = UP_ROUND(col, C16NUM);
@ -50,16 +103,17 @@ void RowMajor2Row16x4MajorInt8(void *src_ptr, void *dst_ptr, int row, int col) {
for (int ri = 0; ri < row_4div; ri += C4NUM) {
for (int ci = 0; ci < col_16div; ci += C16NUM) {
#ifdef ENABLE_ARM64
size_t col_offset = col;
int8_t *src_c = src_r + ci;
int8_t *dst_c = dst_r + ci * C4NUM;
asm volatile(
"mov x10, %[src_c] \n"
"mov x11, %[dst_c] \n"
"ld1 {v0.16b}, [x10], %[col]\n"
"ld1 {v1.16b}, [x10], %[col]\n"
"ld1 {v2.16b}, [x10], %[col]\n"
"ld1 {v3.16b}, [x10], %[col]\n"
"ld1 {v0.16b}, [x10], %[col_offset]\n"
"ld1 {v1.16b}, [x10], %[col_offset]\n"
"ld1 {v2.16b}, [x10], %[col_offset]\n"
"ld1 {v3.16b}, [x10], %[col_offset]\n"
"st1 {v0.16b}, [x11], #16\n"
"st1 {v1.16b}, [x11], #16\n"
@ -67,7 +121,7 @@ void RowMajor2Row16x4MajorInt8(void *src_ptr, void *dst_ptr, int row, int col) {
"st1 {v3.16b}, [x11], #16\n"
:
: [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ col ] "r"(col)
: [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ col_offset ] "r"(col_offset)
: "x10", "x11", "v0", "v1", "v2", "v3");
#else
MatrixPack4x16UnitInt8(src_r + ci, dst_r + ci * C4NUM, C4NUM, C16NUM, col);
@ -76,12 +130,15 @@ void RowMajor2Row16x4MajorInt8(void *src_ptr, void *dst_ptr, int row, int col) {
if (col != col_16div) {
MatrixPack4x16UnitInt8(src_r + col_16div, dst_r + col_16div * C4NUM, C4NUM, col_16res, col);
MatrixEmptyInt8(dst_r + col_16div * C4NUM + col_16res, C4NUM, C16NUM - col_16res);
}
src_r += C4NUM * col;
dst_r += C4NUM * col16;
}
if (row != row_4div) {
memset(dst_r, 0, C4NUM * col16);
for (int ci = 0; ci < col_16div; ci += C16NUM) {
MatrixPack4x16UnitInt8(src_r + ci, dst_r + ci * C4NUM, row_4res, C16NUM, col);
}
@ -103,25 +160,6 @@ void RowMajor2Col8MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col)
}
}
void MatMulInt8(const int8_t *a, const int8_t *b, int32_t *c, const int row8, const int col8, const int deep,
const int32_t a_zp, const int32_t b_zp) {
/* col8-major * row8-major => row8x8-major */
for (int row = 0; row < row8; row++) {
for (int col = 0; col < col8; col++) {
int r8div = row / 8, r8mod = row % 8;
int c8div = col / 8, c8mod = col % 8;
size_t ci = c8div * row8 * 8 + row * 8 + c8mod;
int32_t value = 0;
for (int d = 0; d < deep; d++) {
size_t ai = r8div * deep * 8 + d * 8 + r8mod;
size_t bi = c8div * deep * 8 + d * 8 + c8mod;
value = value + ((int32_t)a[ai] - a_zp) * ((int32_t)b[bi] - b_zp);
}
c[ci] = value;
}
}
}
void MatMulInt8_16x4(const int8_t *a, const int8_t *b, int *dst, int row_4, int col_4, int deep_16,
const int *input_sum, const int *bias) {
/* row4x16-major * row16x4-major => row4x4-major */
@ -145,7 +183,100 @@ void MatMulInt8_16x4(const int8_t *a, const int8_t *b, int *dst, int row_4, int
return;
}
#ifdef ENABLE_ARM64
void MatMulInt8_16x4_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_16,
size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift,
int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini, int32_t maxi,
bool per_channel) {
/* row4x16-major * row16x4-major => (int8)row-major : per-channel */
for (int r = 0; r < row; r++) {
for (int c = 0; c < col; c++) {
int r4div = r / C4NUM, r4mod = r % C4NUM;
int c4div = c / C4NUM, c4mod = c % C4NUM;
size_t ci = r * stride + c;
int32_t value = 0;
for (int d = 0; d < deep_16; d++) {
int d16div = d / C16NUM, d16mod = d % C16NUM;
size_t ai = r4div * deep_16 * C4NUM + d16div * C4NUM * C16NUM + r4mod * C16NUM + d16mod;
size_t bi = c4div * deep_16 * C4NUM + d16div * C4NUM * C16NUM + c4mod * C16NUM + d16mod;
value = value + a[ai] * b[bi];
}
int32_t cur_input_sum = per_channel ? input_sum[c4div * UP_ROUND(row, C4NUM) + r * C4NUM + c4mod] : input_sum[r];
value -= cur_input_sum;
value += bias[c];
int32_t cur_left_shift = per_channel ? left_shift[c] : left_shift[0];
int32_t cur_right_shift = per_channel ? right_shift[c] : right_shift[0];
int32_t cur_multiplier = per_channel ? multiplier[c] : multiplier[0];
value = MultiplyByQuantizedMultiplier(value, cur_multiplier, cur_left_shift, cur_right_shift) + output_zp;
value = MSMIN(maxi, value);
value = MSMAX(mini, value);
dst[ci] = (int8_t)value;
}
}
return;
}
void MatMulInt8_8x8_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4,
size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift,
int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini, int32_t maxi,
bool per_channel) {
/* row8x4-major * row4x8-major => (int8)row-major */
for (int r = 0; r < row; r++) {
for (int c = 0; c < col; c++) {
int r8div = r / C8NUM, r8mod = r % C8NUM;
int c8div = c / C8NUM, c8mod = c % C8NUM;
size_t ci = r * stride + c;
int32_t value = 0;
for (int d = 0; d < deep_4; d++) {
int d4div = d / C4NUM, d4mod = d % C4NUM;
size_t ai = r8div * deep_4 * C8NUM + d4div * C8NUM * C4NUM + r8mod * C4NUM + d4mod;
size_t bi = c8div * deep_4 * C8NUM + d4div * C8NUM * C4NUM + c8mod * C4NUM + d4mod;
value = value + a[ai] * b[bi];
}
int32_t cur_input_sum = per_channel ? input_sum[c8div * UP_ROUND(row, C8NUM) + r * C8NUM + c8mod] : input_sum[r];
value -= cur_input_sum;
value += bias[c];
int32_t cur_left_shift = per_channel ? left_shift[c] : left_shift[0];
int32_t cur_right_shift = per_channel ? right_shift[c] : right_shift[0];
int32_t cur_multiplier = per_channel ? multiplier[c] : multiplier[0];
value = MultiplyByQuantizedMultiplier(value, cur_multiplier, cur_left_shift, cur_right_shift) + output_zp;
value = MSMIN(maxi, value);
value = MSMAX(mini, value);
dst[ci] = (int8_t)value;
}
}
return;
}
/* row4x16-major * col16x4-major => row4x4-major */
void MatmulInt8(const int8_t *a, const int8_t *b, int8_t *dst, const int *a_sums, const int *bias, int act_min,
int act_max, int out_zp, int multiplier, int left_shift, int right_shift, int row, int col, int deep16,
int stride) {
int8_t *output = dst;
for (int r = 0; r < row; r++) {
for (int c = 0; c < col; c++) {
int r4div = r / C4NUM;
int r4mod = r % C4NUM;
int c4div = c / C4NUM;
int c4mod = c % C4NUM;
int value = 0;
for (int d = 0; d < deep16; d++) {
int d16div = d / C16NUM;
int d16mod = d % C16NUM;
size_t ai = r4div * deep16 * C4NUM + d16div * C4NUM * C16NUM + r4mod * C16NUM + d16mod;
size_t bi = c4div * deep16 * C4NUM + d16div * C4NUM * C16NUM + c4mod * C16NUM + d16mod;
value += a[ai] * b[bi];
}
value -= a_sums[r];
value += bias[c];
value = MultiplyByQuantizedMultiplier(value, multiplier, left_shift, right_shift) + out_zp;
value = MSMIN(INT8_MAX, value);
value = MSMAX(INT8_MIN, value);
output[c] = (int8_t)value;
}
output += stride;
}
}
void RowMajor2Row4x16Major(int8_t *src, int row, int col, int8_t *dst, int col_16) {
int stride = sizeof(int8_t) * 16 * 4;
for (int r = 0; r < row; ++r) {
@ -168,23 +299,35 @@ void RowMajor2Col16x4Major(int8_t *src, int row, int col, int8_t *dst, int row_1
}
}
void RowMajor2Asums(int8_t *a, int row, int col, int b_zp, int *dst) {
// dst: weight_zp * input_row_sums
void CalcInputSums(int8_t *input, int row, int col, int weight_zp, int *dst, DataOrder order) {
for (int r = 0; r < row; ++r) {
int sum = 0;
for (int c = 0; c < col; ++c) {
int src_idx = r * col + c;
dst[r] += a[src_idx];
if (order == RowMajor) {
sum += input[r * col + c];
} else {
sum += input[c * row + r];
}
}
dst[r] *= b_zp;
sum *= weight_zp;
dst[r] = sum;
}
}
void RowMajor2Bbias(int8_t *b, int row, int col, int a_zp, int b_zp, int *bias, int *dst) {
// dst: bias + depth*input_zp*weight_zp - input_zp*weight_col_sums
void CalcWeightBiasSums(int8_t *weight, int row, int col, int input_zp, int weight_zp, int *bias, int *dst,
DataOrder order) {
for (int c = 0; c < col; ++c) {
int sum = 0;
for (int r = 0; r < row; ++r) {
int src_idx = r * col + c;
dst[c] += b[src_idx];
if (order == RowMajor) {
sum += weight[r * col + c];
} else {
sum += weight[c * row + r];
}
}
dst[c] = row * a_zp * b_zp - a_zp * dst[c];
dst[c] = row * input_zp * weight_zp - input_zp * sum;
if (bias) {
dst[c] += bias[c];
}
@ -201,4 +344,3 @@ void Row4x4Major2RowMajor(int8_t *src, int row4, int8_t *dst, int row, int cow)
}
}
}
#endif

View File

@ -24,25 +24,37 @@
#ifdef __cplusplus
extern "C" {
#endif
void MatMulInt8(const int8_t *a, const int8_t *b, int *c, const int row8, const int col8, const int deep,
const int a_zp, const int b_zp);
void MatMulInt8_16x4(const int8_t *a, const int8_t *b, int *dst, int row_4, int col_4, int deep_16,
const int *input_sum, const int *bias);
void MatMulInt8_16x4_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_16,
size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift,
int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini, int32_t maxi,
bool per_channel);
void RowMajor2Row8MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col);
void RowMajor2Row4x16MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col);
void RowMajor2Col8MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col);
void RowMajor2Row16x4MajorInt8(void *src_ptr, void *dst_ptr, int row, int col);
#ifdef ENABLE_ARM64
void MatMulInt8_8x8_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4,
size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift,
int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini, int32_t maxi,
bool per_channel);
void RowMajor2Row8x4MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col);
void RowMajor2Row4x8MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col);
void RowMajor2Row4x16Major(int8_t *src, int row, int col, int8_t *dst, int col_16);
void RowMajor2Col16x4Major(int8_t *src, int row, int col, int8_t *dst, int row_16);
void RowMajor2Asums(int8_t *a, int row, int col, int b_zp, int *dst);
void RowMajor2Bbias(int8_t *b, int row, int col, int a_zp, int b_zp, int *bias, int *dst);
void Row4x4Major2RowMajor(int8_t *src, int row4, int8_t *dst, int row, int cow);
void CalcInputSums(int8_t *input, int row, int col, int weight_zp, int *dst, DataOrder order);
void CalcWeightBiasSums(int8_t *weight, int row, int col, int input_zp, int weight_zp, int *bias, int *dst,
DataOrder order);
void MatmulInt8(const int8_t *a, const int8_t *b, int8_t *dst, const int *a_sums, const int *bias, int act_min,
int act_max, int out_zp, int multiplier, int left_shift, int right_shift, int row, int col, int deep16,
int stride);
// bias = bias + depth * a_zp * b_zp - a_zp * b_sums
#ifdef ENABLE_ARM64
void MatmulInt8Neon64(const int8_t *a, const int8_t *b, int8_t *dst, int row4, int col4, int deep16, const int *a_sums,
const int *bias, int act_min, int act_max, int out_zp, int multiplier, int left_shift,
int right_shift);
int right_shift, int row, int col, int stride);
void MatMulR4Int8Neon64(const int8_t *a, const int8_t *b, int32_t *dst, int row4, int col4, int deep16,
const int *input_sum, const int *bias);

View File

@ -89,8 +89,13 @@ void AvgPoolingOptInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParam
int output_batch = pooling_param->output_batch_;
int out_plane = output_w * output_h;
int out_tile_count = UP_DIV(out_plane, TILE_NUM);
int thread_num = pooling_param->thread_num_;
int c8 = UP_DIV(channel, C8NUM);
int thread_num = out_tile_count < pooling_param->thread_num_ ? out_tile_count : pooling_param->thread_num_;
float input_scale = pooling_param->quant_args_[0][0].scale_;
int input_zp = pooling_param->quant_args_[0][0].zp_;
float output_scale = pooling_param->quant_args_[1][0].scale_;
int output_zp = pooling_param->quant_args_[1][0].zp_;
double real_multiplier = input_scale / output_scale;
int c16 = channel / C16NUM;
const int8_t out_min = INT8_MIN;
const int8_t out_max = INT8_MAX;
@ -107,89 +112,159 @@ void AvgPoolingOptInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParam
int in_w_index = out_w_index * stride_w - pad_w;
int in_h_index = out_h_index * stride_h - pad_h;
int out_plane_offset = out_batch_offset + index * channel;
for (int j = 0; j < c8 - 1; j++) {
int in_channel_offset = in_batch_offset + j * C8NUM;
int out_channel_offset = out_plane_offset + j * C8NUM;
int16_t tmp_avg1 = 0;
int16_t tmp_avg2 = 0;
int16_t tmp_avg3 = 0;
int16_t tmp_avg4 = 0;
int16_t tmp_avg5 = 0;
int16_t tmp_avg6 = 0;
int16_t tmp_avg7 = 0;
int16_t tmp_avg8 = 0;
int real_count = 0;
for (int h = 0; h < win_h; h++) {
for (int w = 0; w < win_w; w++) {
if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 ||
(in_w_index + w) >= in_w) {
continue;
} else {
int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel;
tmp_avg1 += *(input_ptr + in_offset);
tmp_avg2 += *(input_ptr + in_offset + 1);
tmp_avg3 += *(input_ptr + in_offset + 2);
tmp_avg4 += *(input_ptr + in_offset + 3);
tmp_avg5 += *(input_ptr + in_offset + 4);
tmp_avg6 += *(input_ptr + in_offset + 5);
tmp_avg7 += *(input_ptr + in_offset + 6);
tmp_avg8 += *(input_ptr + in_offset + 7);
++real_count;
int input_stride = (in_h_index * in_w + in_w_index) * channel;
int kw_s = MSMAX(0, -in_w_index);
int kw_e = MSMIN(win_w, in_w - in_w_index);
int kh_s = MSMAX(0, -in_h_index);
int kh_e = MSMIN(win_h, in_h - in_h_index);
int real_count = (kw_e - kw_s) * (kh_e - kh_s);
// 16 channels
for (int j = 0; j < c16; j++) {
#ifdef ENABLE_NEON
int16x8_t tmp_avg[2];
tmp_avg[0] = vmovq_n_s16(0);
tmp_avg[1] = vmovq_n_s16(0);
#else
int16_t tmp_avg[16];
int16_t real_out[16];
for (int m = 0; m < C16NUM; ++m) {
tmp_avg[m] = 0;
}
#endif
int in_channel_offset = in_batch_offset + j * C16NUM;
int out_channel_offset = out_plane_offset + j * C16NUM;
for (int h = kh_s; h < kh_e; h++) {
for (int w = kw_s; w < kw_e; w++) {
int in_offset = in_channel_offset + input_stride + (h * in_w + w) * channel;
#ifdef ENABLE_NEON
int8x16_t in_ptr = vld1q_s8(input_ptr + in_offset);
int8x8_t in_data1 = vget_low_s8(in_ptr);
int8x8_t in_data2 = vget_high_s8(in_ptr);
int16x8_t data1 = vmovl_s8(in_data1);
int16x8_t data2 = vmovl_s8(in_data2);
tmp_avg[0] = vaddq_s16(tmp_avg[0], data1);
tmp_avg[1] = vaddq_s16(tmp_avg[1], data2);
#else
for (int k = 0; k < C16NUM; ++k) {
tmp_avg[k] += input_ptr[in_offset + k];
}
#endif
} // win_w loop
} // win_h loop
int16_t tmp_out1 = round((float)tmp_avg1 / (float)real_count);
int16_t tmp_out2 = round((float)tmp_avg2 / (float)real_count);
int16_t tmp_out3 = round((float)tmp_avg3 / (float)real_count);
int16_t tmp_out4 = round((float)tmp_avg4 / (float)real_count);
int16_t tmp_out5 = round((float)tmp_avg5 / (float)real_count);
int16_t tmp_out6 = round((float)tmp_avg6 / (float)real_count);
int16_t tmp_out7 = round((float)tmp_avg7 / (float)real_count);
int16_t tmp_out8 = round((float)tmp_avg8 / (float)real_count);
int16_t real_out1 = tmp_out1 < out_min ? out_min : tmp_out1;
int16_t real_out2 = tmp_out2 < out_min ? out_min : tmp_out2;
int16_t real_out3 = tmp_out3 < out_min ? out_min : tmp_out3;
int16_t real_out4 = tmp_out4 < out_min ? out_min : tmp_out4;
int16_t real_out5 = tmp_out5 < out_min ? out_min : tmp_out5;
int16_t real_out6 = tmp_out6 < out_min ? out_min : tmp_out6;
int16_t real_out7 = tmp_out7 < out_min ? out_min : tmp_out7;
int16_t real_out8 = tmp_out8 < out_min ? out_min : tmp_out8;
real_out1 = real_out1 > out_max ? out_max : real_out1;
real_out2 = real_out2 > out_max ? out_max : real_out2;
real_out3 = real_out3 > out_max ? out_max : real_out3;
real_out4 = real_out4 > out_max ? out_max : real_out4;
real_out5 = real_out5 > out_max ? out_max : real_out5;
real_out6 = real_out6 > out_max ? out_max : real_out6;
real_out7 = real_out7 > out_max ? out_max : real_out7;
real_out8 = real_out8 > out_max ? out_max : real_out8;
*(output_ptr + out_channel_offset) = (int8_t)real_out1;
*(output_ptr + out_channel_offset + 1) = (int8_t)real_out2;
*(output_ptr + out_channel_offset + 2) = (int8_t)real_out3;
*(output_ptr + out_channel_offset + 3) = (int8_t)real_out4;
*(output_ptr + out_channel_offset + 4) = (int8_t)real_out5;
*(output_ptr + out_channel_offset + 5) = (int8_t)real_out6;
*(output_ptr + out_channel_offset + 6) = (int8_t)real_out7;
*(output_ptr + out_channel_offset + 7) = (int8_t)real_out8;
} // in_channel loop
int channel_s = (c8 - 1) * C8NUM;
for (int k = channel_s; k < channel; k++) {
int in_channel_offset = in_batch_offset + k;
int out_channel_offset = out_plane_offset + k;
#ifdef ENABLE_NEON
int16_t tmp_data[8];
int16_t tmp_out[8];
int16_t tmp_data1[8];
int16_t tmp_out1[8];
for (int l = 0; l < C8NUM; l++) {
tmp_data[l] = tmp_avg[0][l] + 128 * real_count;
tmp_out[l] = (tmp_data[l] + real_count / 2) / real_count;
tmp_out[l] -= 128;
tmp_out[l] = round((tmp_out[l] - input_zp) * real_multiplier) + output_zp;
}
for (int l = 0; l < C8NUM; l++) {
tmp_data1[l] = tmp_avg[1][l] + 128 * real_count;
tmp_out1[l] = (tmp_data1[l] + real_count / 2) / real_count;
tmp_out1[l] -= 128;
tmp_out1[l] = round((tmp_out1[l] - input_zp) * real_multiplier) + output_zp;
}
int8x8_t real_out[2];
int8x8_t output_min = vdup_n_s8(out_min);
int8x8_t output_max = vdup_n_s8(out_max);
real_out[0] = vqmovn_s16(vld1q_s16(tmp_out));
real_out[0] = vmin_s8(real_out[0], output_max);
real_out[0] = vmax_s8(real_out[0], output_min);
vst1_s8(output_ptr + out_channel_offset, real_out[0]);
real_out[1] = vqmovn_s16(vld1q_s16(tmp_out1));
real_out[1] = vmin_s8(real_out[1], output_max);
real_out[1] = vmax_s8(real_out[1], output_min);
vst1_s8(output_ptr + out_channel_offset + 8, real_out[1]);
#else
for (int l = 0; l < C16NUM; ++l) {
int16_t tmp_data = tmp_avg[l] + 128 * real_count;
real_out[l] = (tmp_data + real_count / 2) / real_count - 128;
real_out[l] = (int8_t)(round((real_out[l] - input_zp) * real_multiplier) + output_zp);
real_out[l] = real_out[l] < out_min ? out_min : real_out[l];
real_out[l] = real_out[l] > out_max ? out_max : real_out[l];
*(output_ptr + out_channel_offset + l) = (int8_t)real_out[l];
}
#endif
}
// 8 channels
int channel_16_res = channel - c16 * C16NUM;
int c8 = channel_16_res / C8NUM;
int in_c16_offset = in_batch_offset + c16 * C16NUM;
int out_c16_offset = out_plane_offset + c16 * C16NUM;
for (int j = 0; j < c8; j++) {
#ifdef ENABLE_NEON
int16x8_t tmp_avg = vmovq_n_s16(0);
#else
int16_t tmp_avg[8] = {0, 0, 0, 0, 0, 0, 0, 0};
int16_t real_out[8];
#endif
int in_channel_offset = in_c16_offset + j * C8NUM;
int out_channel_offset = out_c16_offset + j * C8NUM;
for (int h = kh_s; h < kh_e; h++) {
for (int w = kw_s; w < kw_e; w++) {
int in_offset = in_channel_offset + input_stride + (h * in_w + w) * channel;
#ifdef ENABLE_NEON
int8x8_t in_ptr = vld1_s8(input_ptr + in_offset);
int16x8_t data = vmovl_s8(in_ptr);
tmp_avg = vaddq_s16(tmp_avg, data);
#else
for (int k = 0; k < C8NUM; ++k) {
tmp_avg[k] += input_ptr[in_offset + k];
}
#endif
} // win_w loop
} // win_h loop
#ifdef ENABLE_NEON
int16_t tmp_data[8];
int16_t tmp_out[8];
for (int l = 0; l < C8NUM; l++) {
tmp_data[l] = tmp_avg[l] + 128 * real_count;
tmp_out[l] = (tmp_data[l] + real_count / 2) / real_count;
tmp_out[l] -= 128;
tmp_out[l] = round((tmp_out[l] - input_zp) * real_multiplier) + output_zp;
}
int8x8_t real_out;
int8x8_t output_min = vdup_n_s8(out_min);
int8x8_t output_max = vdup_n_s8(out_max);
real_out = vqmovn_s16(vld1q_s16(tmp_out));
real_out = vmin_s8(real_out, output_max);
real_out = vmax_s8(real_out, output_min);
vst1_s8(output_ptr + out_channel_offset, real_out);
#else
for (int l = 0; l < C8NUM; ++l) {
int16_t tmp_data = tmp_avg[l] + 128 * real_count;
real_out[l] = (tmp_data + real_count / 2) / real_count - 128;
real_out[l] = (int8_t)(round((real_out[l] - input_zp) * real_multiplier) + output_zp);
real_out[l] = real_out[l] < out_min ? out_min : real_out[l];
real_out[l] = real_out[l] > out_max ? out_max : real_out[l];
*(output_ptr + out_channel_offset + l) = (int8_t)real_out[l];
}
#endif
}
// less than 8 channel
int channel_8_res = channel_16_res - c8 * C8NUM;
int in_c8_offset = in_c16_offset + c8 * C8NUM;
int out_c8_offset = out_c16_offset + c8 * C8NUM;
for (int k = 0; k < channel_8_res; k++) {
int in_channel_offset = in_c8_offset + k;
int out_channel_offset = out_c8_offset + k;
int16_t tmp_avg = 0;
int real_count = 0;
for (int h = 0; h < win_h; h++) {
for (int w = 0; w < win_w; w++) {
if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 ||
(in_w_index + w) >= in_w) {
continue;
} else {
int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel;
tmp_avg += *(input_ptr + in_offset);
++real_count;
}
for (int h = kh_s; h < kh_e; h++) {
for (int w = kw_s; w < kw_e; w++) {
int in_offset = in_channel_offset + input_stride + (h * in_w + w) * channel;
tmp_avg += input_ptr[in_offset];
} // win_w loop
} // win_h loop
int16_t tmp_out = round((float)tmp_avg / (float)real_count);
int16_t tmp_out = round((float)tmp_avg / (float)real_count + 128) - 128;
tmp_out = (int8_t)(round((tmp_out - input_zp) * real_multiplier) + output_zp);
int16_t real_out = tmp_out < out_min ? out_min : tmp_out;
real_out = real_out > out_max ? out_max : real_out;
*(output_ptr + out_channel_offset) = (int8_t)real_out;
@ -249,6 +324,109 @@ void MaxPoolingInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParamete
} // out_batch loop
}
void MaxPoolingWithQuantInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParameter *pooling_param,
int task_id) {
int stride_w = pooling_param->stride_w_;
int stride_h = pooling_param->stride_h_;
int pad_w = pooling_param->pad_l_;
int pad_h = pooling_param->pad_u_;
int win_w = pooling_param->window_w_;
int win_h = pooling_param->window_h_;
int channel = pooling_param->input_channel_;
int in_w = pooling_param->input_w_;
int in_h = pooling_param->input_h_;
int output_w = pooling_param->output_w_;
int output_h = pooling_param->output_h_;
int output_batch = pooling_param->output_batch_;
int out_plane = output_w * output_h;
int out_tile_count = UP_DIV(out_plane, TILE_NUM);
int thread_num = out_tile_count < pooling_param->thread_num_ ? out_tile_count : pooling_param->thread_num_;
int c16 = UP_DIV(channel, 16);
// input channel is equal to output channel
float input_scale = pooling_param->quant_args_[0][0].scale_;
int input_zp = pooling_param->quant_args_[0][0].zp_;
float output_scale = pooling_param->quant_args_[1][0].scale_;
int output_zp = pooling_param->quant_args_[1][0].zp_;
double real_multiplier = input_scale / output_scale;
for (int batch = 0; batch < output_batch; batch++) {
int in_batch_offset = batch * in_h * in_w * channel;
int out_batch_offset = batch * output_h * output_w * channel;
for (int thread_id = task_id; thread_id < out_tile_count; thread_id += thread_num) {
int cal_start_index = thread_id * TILE_NUM;
int real_cal_num = (out_plane - cal_start_index) > TILE_NUM ? TILE_NUM : (out_plane - cal_start_index);
for (int i = 0; i < real_cal_num; i++) {
int index = cal_start_index + i;
int out_w_index = index % output_w;
int out_h_index = index / output_w;
int in_w_index = out_w_index * stride_w - pad_w;
int in_h_index = out_h_index * stride_h - pad_h;
int out_plane_offset = out_batch_offset + index * channel;
for (int j = 0; j < c16 - 1; j++) {
int in_channel_offset = in_batch_offset + j * 16;
int out_channel_offset = out_plane_offset + j * 16;
#ifdef ENABLE_NEON
int8x16_t tmp_max = vdupq_n_s8(INT8_MIN);
#else
int8_t tmp_max[16];
for (int m = 0; m < C16NUM; ++m) {
tmp_max[m] = INT8_MIN;
}
#endif
for (int h = 0; h < win_h; h++) {
for (int w = 0; w < win_w; w++) {
if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 ||
(in_w_index + w) >= in_w) {
continue;
} else {
int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel;
#ifdef ENABLE_NEON
tmp_max = vmaxq_s8(tmp_max, vld1q_s8(input_ptr + in_offset));
#else
for (int k = 0; k < C16NUM; ++k) {
tmp_max[k] = MaxInt8(tmp_max[k], *(input_ptr + in_offset + k));
}
#endif
}
} // win_w loop
} // win_h loop
#ifdef ENABLE_NEON
for (int l = 0; l < C16NUM; ++l) {
tmp_max[l] = (int8_t)(round((tmp_max[l] - input_zp) * real_multiplier) + output_zp);
}
vst1q_s8(output_ptr + out_channel_offset, tmp_max);
#else
for (int l = 0; l < C16NUM; ++l) {
*(output_ptr + out_channel_offset + l) =
(int8_t)(round((tmp_max[l] - input_zp) * real_multiplier) + output_zp);
}
#endif
} // in_channel loop
// res channel
int channel_s = (c16 - 1) * 16;
for (int k = channel_s; k < channel; k++) {
int in_channel_offset = in_batch_offset + k;
int out_channel_offset = out_plane_offset + k;
int8_t tmp_max = INT8_MIN;
for (int h = 0; h < win_h; h++) {
for (int w = 0; w < win_w; w++) {
if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 ||
(in_w_index + w) >= in_w) {
continue;
} else {
int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel;
tmp_max = MaxInt8(tmp_max, *(input_ptr + in_offset));
}
} // win_w loop
} // win_h loop
*(output_ptr + out_channel_offset) = (int8_t)(round((tmp_max - input_zp) * real_multiplier) + output_zp);
} // channel_res loop
} // out_plane loop
} // out_batch loop
}
}
void MaxPoolingOptInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParameter *pooling_param, int task_id) {
int stride_w = pooling_param->stride_w_;
int stride_h = pooling_param->stride_h_;
@ -264,7 +442,7 @@ void MaxPoolingOptInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParam
int output_batch = pooling_param->output_batch_;
int out_plane = output_w * output_h;
int out_tile_count = UP_DIV(out_plane, TILE_NUM);
int thread_num = pooling_param->thread_num_;
int thread_num = out_tile_count < pooling_param->thread_num_ ? out_tile_count : pooling_param->thread_num_;
int c16 = UP_DIV(channel, 16);
for (int batch = 0; batch < output_batch; batch++) {
@ -286,22 +464,10 @@ void MaxPoolingOptInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParam
#ifdef ENABLE_NEON
int8x16_t tmp_max = vdupq_n_s8(INT8_MIN);
#else
int8_t tmp_max1 = INT8_MIN;
int8_t tmp_max2 = INT8_MIN;
int8_t tmp_max3 = INT8_MIN;
int8_t tmp_max4 = INT8_MIN;
int8_t tmp_max5 = INT8_MIN;
int8_t tmp_max6 = INT8_MIN;
int8_t tmp_max7 = INT8_MIN;
int8_t tmp_max8 = INT8_MIN;
int8_t tmp_max9 = INT8_MIN;
int8_t tmp_max10 = INT8_MIN;
int8_t tmp_max11 = INT8_MIN;
int8_t tmp_max12 = INT8_MIN;
int8_t tmp_max13 = INT8_MIN;
int8_t tmp_max14 = INT8_MIN;
int8_t tmp_max15 = INT8_MIN;
int8_t tmp_max16 = INT8_MIN;
int8_t tmp_max[16];
for (int m = 0; m < C16NUM; ++m) {
tmp_max[m] = INT8_MIN;
}
#endif
for (int h = 0; h < win_h; h++) {
for (int w = 0; w < win_w; w++) {
@ -313,22 +479,9 @@ void MaxPoolingOptInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParam
#ifdef ENABLE_NEON
tmp_max = vmaxq_s8(tmp_max, vld1q_s8(input_ptr + in_offset));
#else
tmp_max1 = MaxInt8(tmp_max1, *(input_ptr + in_offset));
tmp_max2 = MaxInt8(tmp_max2, *(input_ptr + in_offset + 1));
tmp_max3 = MaxInt8(tmp_max3, *(input_ptr + in_offset + 2));
tmp_max4 = MaxInt8(tmp_max4, *(input_ptr + in_offset + 3));
tmp_max5 = MaxInt8(tmp_max5, *(input_ptr + in_offset + 4));
tmp_max6 = MaxInt8(tmp_max6, *(input_ptr + in_offset + 5));
tmp_max7 = MaxInt8(tmp_max7, *(input_ptr + in_offset + 6));
tmp_max8 = MaxInt8(tmp_max8, *(input_ptr + in_offset + 7));
tmp_max9 = MaxInt8(tmp_max9, *(input_ptr + in_offset + 8));
tmp_max10 = MaxInt8(tmp_max10, *(input_ptr + in_offset + 9));
tmp_max11 = MaxInt8(tmp_max11, *(input_ptr + in_offset + 10));
tmp_max12 = MaxInt8(tmp_max12, *(input_ptr + in_offset + 11));
tmp_max13 = MaxInt8(tmp_max13, *(input_ptr + in_offset + 12));
tmp_max14 = MaxInt8(tmp_max14, *(input_ptr + in_offset + 13));
tmp_max15 = MaxInt8(tmp_max15, *(input_ptr + in_offset + 14));
tmp_max16 = MaxInt8(tmp_max16, *(input_ptr + in_offset + 15));
for (int k = 0; k < C16NUM; ++k) {
tmp_max[k] = MaxInt8(tmp_max[k], *(input_ptr + in_offset + k));
}
#endif
}
} // win_w loop
@ -336,24 +489,13 @@ void MaxPoolingOptInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParam
#ifdef ENABLE_NEON
vst1q_s8(output_ptr + out_channel_offset, tmp_max);
#else
*(output_ptr + out_channel_offset) = tmp_max1;
*(output_ptr + out_channel_offset + 1) = tmp_max2;
*(output_ptr + out_channel_offset + 2) = tmp_max3;
*(output_ptr + out_channel_offset + 3) = tmp_max4;
*(output_ptr + out_channel_offset + 4) = tmp_max5;
*(output_ptr + out_channel_offset + 5) = tmp_max6;
*(output_ptr + out_channel_offset + 6) = tmp_max7;
*(output_ptr + out_channel_offset + 7) = tmp_max8;
*(output_ptr + out_channel_offset + 8) = tmp_max9;
*(output_ptr + out_channel_offset + 9) = tmp_max10;
*(output_ptr + out_channel_offset + 10) = tmp_max11;
*(output_ptr + out_channel_offset + 11) = tmp_max12;
*(output_ptr + out_channel_offset + 12) = tmp_max13;
*(output_ptr + out_channel_offset + 13) = tmp_max14;
*(output_ptr + out_channel_offset + 14) = tmp_max15;
*(output_ptr + out_channel_offset + 15) = tmp_max16;
for (int l = 0; l < C16NUM; ++l) {
*(output_ptr + out_channel_offset + l) = tmp_max[l];
}
#endif
} // in_channel loop
// res channel
int channel_s = (c16 - 1) * 16;
for (int k = channel_s; k < channel; k++) {
int in_channel_offset = in_batch_offset + k;

View File

@ -32,6 +32,8 @@ void AvgPoolingOptInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParam
void MaxPoolingInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParameter *pooling_param, int task_id);
void MaxPoolingWithQuantInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParameter *pooling_param, int task_id);
void MaxPoolingOptInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParameter *pooling_param, int task_id);
#ifdef __cplusplus
}

View File

@ -86,6 +86,62 @@ int ResizeBilinearInt8(const int8_t *input_data, int8_t *output_data, const int
return NNACL_OK;
}
int ResizeBilinearInt8WithFloatWeight(const int8_t *input_data, int8_t *output_data, const int *input_shape,
const int *output_shape, const bool align_corners, QuantArg *quant_in,
QuantArg *quant_out, const QuantMulArg *mul_arg, int tid, int thread_num) {
if (input_data == NULL || output_data == NULL || input_shape == NULL || output_shape == NULL) {
return NNACL_NULL_PTR;
}
int32_t in_n = input_shape[0];
int32_t in_h = input_shape[1];
int32_t in_w = input_shape[2];
int32_t in_c = input_shape[3];
int32_t new_height = output_shape[1];
int32_t new_width = output_shape[2];
float height_scale, width_scale;
ComputeScaleFloat(in_h, new_height, align_corners, &height_scale);
ComputeScaleFloat(in_w, new_width, align_corners, &width_scale);
int n, h, w, c;
for (n = 0; n < in_n; n++) {
for (h = tid; h < new_height; h += thread_num) {
float actual_y;
int bottom, top;
float bottom_weight, top_weight;
ComputeInterpolationArgsFloatWeight(h, height_scale, in_h, &actual_y, &bottom, &bottom_weight, &top, &top_weight);
for (w = 0; w < new_width; w++) {
float actual_x;
int left, right;
float left_weight, right_weight;
ComputeInterpolationArgsFloatWeight(w, width_scale, in_w, &actual_x, &left, &left_weight, &right,
&right_weight);
for (c = 0; c < in_c; c++) {
float bottom_left_value = ((int32_t)input_data[offset(input_shape, n, bottom, left, c)] - quant_in->zp_) *
bottom_weight * left_weight;
float bottom_right_value = ((int32_t)input_data[offset(input_shape, n, bottom, right, c)] - quant_in->zp_) *
bottom_weight * right_weight;
float top_left_value =
((int32_t)input_data[offset(input_shape, n, top, left, c)] - quant_in->zp_) * top_weight * left_weight;
float top_right_value =
((int32_t)input_data[offset(input_shape, n, top, right, c)] - quant_in->zp_) * top_weight * right_weight;
float interp_value = bottom_left_value + bottom_right_value + top_left_value + top_right_value;
const int out_interp_value = MultiplyByQuantizedMultiplier((int32_t)interp_value, mul_arg->multiplier_,
mul_arg->left_shift_, mul_arg->right_shift_) +
quant_out->zp_;
int8_t out_value;
out_value = out_interp_value > INT8_MAX ? INT8_MAX : out_interp_value;
out_value = out_value < INT8_MIN ? INT8_MIN : out_value;
output_data[offset(output_shape, n, h, w, c)] = out_value;
}
}
}
}
return NNACL_OK;
}
int ResizeNearestNeighborInt8Simple(const int8_t *input_data, int8_t *output_data, const int *input_shape,
const int *output_shape, const bool align_corners, int tid, int thread_num) {
int batch, y, x, c;
@ -133,6 +189,22 @@ void ComputeInterpolationArgs(const int32_t pos, const int32_t scale, const int3
*scaled_high_weight = *scaled_pos - (1 << 10) * (*low);
}
void ComputeScaleFloat(const int32_t in_value, const int32_t out_value, const bool align_corners, float *scale) {
*scale = (float)in_value / out_value;
if (align_corners && out_value > 1) {
*scale = (float)(in_value - 1) / (out_value - 1);
}
}
void ComputeInterpolationArgsFloatWeight(const int32_t pos, const float scale, const int32_t size, float *actual_pos,
int32_t *low, float *low_weight, int32_t *high, float *high_weight) {
*actual_pos = pos * scale;
*low = *actual_pos > 0 ? floor(*actual_pos) : 0;
*low_weight = 1.0 - (*actual_pos - *low);
*high = *low + 1 < size ? *low + 1 : size - 1;
*high_weight = *actual_pos - (*low);
}
void ComputeNearestNeighborInt(const int32_t pos, const int in_size, const int32_t new_size, const bool align_corners,
int32_t *nearest) {
if (new_size == 0) {

View File

@ -31,6 +31,20 @@ int ResizeBilinearInt8(const int8_t *input_data, int8_t *output_data, const int
const bool align_corners, QuantArg *quant_in, QuantArg *quant_out, const QuantMulArg *mul_arg,
int tid, int thread_num);
int ResizeBilinearInt8WithFloatWeight(const int8_t *input_data, int8_t *output_data, const int *input_shape,
const int *output_shape, const bool align_corners, QuantArg *quant_in,
QuantArg *quant_out, const QuantMulArg *mul_arg, int tid, int thread_num);
void ComputeScale(const int32_t in_value, const int32_t out_value, const bool align_corners, int32_t *scale);
void ComputeInterpolationArgs(const int32_t pos, const int32_t scale, const int32_t size, int32_t *scaled_pos,
int32_t *low, int32_t *scaled_low_weight, int32_t *high, int32_t *scaled_high_weight);
void ComputeScaleFloat(const int32_t in_value, const int32_t out_value, const bool align_corners, float *scale);
void ComputeInterpolationArgsFloatWeight(const int32_t pos, const float scale, const int32_t size, float *actual_pos,
int32_t *low, float *low_weight, int32_t *high, float *high_weight);
int ResizeNearestNeighborInt8Simple(const int8_t *input_data, int8_t *output_data, const int *input_shape,
const int *output_shape, const bool align_corners, int tid, int thread_num);
@ -38,11 +52,6 @@ int ResizeNearestNeighborInt8(const int8_t *input_data, int8_t *output_data, con
const int *output_shape, const bool align_corners, const QuantMulArg *multiplier,
QuantArg *quant_in, QuantArg *quant_out, int tid, int thread_num);
void ComputeScale(const int32_t in_value, const int32_t out_value, const bool align_corners, int32_t *scale);
void ComputeInterpolationArgs(const int32_t pos, const int32_t scale, const int32_t size, int32_t *scaled_pos,
int32_t *low, int32_t *scaled_low_weight, int32_t *high, int32_t *scaled_high_weight);
void ComputeNearestNeighborInt(const int32_t pos, const int in_size, const int32_t new_size, const bool align_corners,
int32_t *nearest);
#ifdef __cplusplus

View File

@ -58,7 +58,7 @@ int SoftmaxInt8(const int8_t *input_ptr, int8_t *output_ptr, int count, int *exp
int axis_offset = outter_offset + i * inner_size;
for (int c = 0; c < inner_size; ++c) {
int num_bits_over_unit;
int shifted_scale = ComputerReciproal(sum_data[c], 12, &num_bits_over_unit);
int shifted_scale = ComputerReciprocal(sum_data[c], 12, &num_bits_over_unit);
int unsat_output = RoundingDivideByPOT(
SaturatingRoundingDoublingHighMul(shifted_scale, exp_data[axis_offset + c]), num_bits_over_unit + 31 - 8);

View File

@ -22,18 +22,28 @@
typedef void (*MATMUL_OPT_R4_FUNC)(const int8_t *a, const int8_t *b, int *dst, int row_4, int col_4, int deep_16,
const int *input_sum, const int *bias);
typedef void (*MATMUL_OPT_R_FUNC)(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4,
size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift,
int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini,
int32_t maxi, bool per_channel);
typedef void (*MAT_TRANS_FUNC)(void *dst, void *a, int row, int col);
typedef enum ActType { ActType_No, ActType_Relu, ActType_Relu6 } ActType;
typedef enum OutType { OutType_C8 = 0, OutType_Nhwc = 1, OutType_TileC8 = 2 } OutType;
typedef struct MatMulParameter {
OpParameter op_parameter_;
int row_;
int col_;
int row_4_;
int row_8_;
int row_12_;
int row_16_;
int col_4_;
int col_8_;
int deep_;
int deep_4_;
int deep_16_;
bool has_bias_;
int batch;
bool a_transpose_; /* false : row-major */

View File

@ -23,8 +23,8 @@
#define C4NUM 4
#define C8NUM 8
#define C12NUM 12
#define C16NUM 16
#define BLOCK 4
#define TILE_NUM 8
#define MSMIN(x, y) ((x) < (y) ? (x) : (y))
@ -55,10 +55,17 @@ typedef enum LiteDataType {
kDataTypeInt8,
} LiteDataType;
typedef enum DataOrder {
RowMajor,
ColMajor,
} DataOrder;
typedef struct OpParameter {
char name_[100];
int type_;
int thread_num_;
} OpParameter;
typedef enum ActType { ActType_No, ActType_Relu, ActType_Relu6, ActType_Prelu } ActType;
#endif // MINDSPORE_LITE_NNACL_OP_BASE_H_

View File

@ -15,6 +15,8 @@
*/
#include <stdlib.h>
#include <stdbool.h>
#include "nnacl/op_base.h"
#ifdef __cplusplus
extern "C" {
@ -27,6 +29,10 @@ extern void IndirectGemmInt8_24x4_dp(int8_t *dst, const int8_t *src, const int8_
extern void MatMulOptR4Int8Neon64(const int8_t *a, const int8_t *b, int *dst, int row4, int col4, int deep16,
const int *input_sum, const int *bias);
extern void MatmulInt8DpNeon64(const int8_t *a, const int8_t *b, int8_t *dst, int row8, int col8, int deep4,
const int *a_sums, const int *bias, int act_min, int act_max, int out_zp, int multiplier,
int left_shift, int right_shift, int row, int col, int stride);
#ifdef __cplusplus
}
#endif
@ -36,7 +42,7 @@ void IndirectGemmInt8_optimize_handler(int8_t *dst, const int8_t *src, const int
size_t ksize, size_t ic4, size_t output_channel, size_t offset,
const int32_t *input_sum, size_t act_min, size_t act_max, size_t out_zp,
int32_t *out_multiplier, int32_t *shift_before, int32_t *shift_after,
size_t asymmetric, size_t per_channel) {
size_t asymmetric, size_t per_channel) {
return IndirectGemmInt8_24x4_dp(dst, src, weight, bias, ksize, ic4, output_channel, offset, input_sum, act_min,
act_max, out_zp, out_multiplier, shift_before, shift_after, asymmetric, per_channel);
}
@ -45,4 +51,12 @@ void MatMulR4Int8_optimize_handler(const int8_t *a, const int8_t *b, int *dst, i
const int *input_sum, const int *bias) {
return MatMulOptR4Int8Neon64(a, b, dst, row4, col4, deep16, input_sum, bias);
}
void MatMulRInt8_optimize_handler(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4,
size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift,
int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini,
int32_t maxi, bool per_channel) {
return MatmulInt8DpNeon64(a, b, dst, UP_ROUND(row, 8), UP_ROUND(col, 8), deep_4, input_sum, bias, mini, maxi,
output_zp, multiplier[0], left_shift[0], right_shift[0], row, col, col);
}
#endif

View File

@ -62,6 +62,10 @@ void PackWeightFp32(float *weight_data, ConvParameter *conv_param, float *packed
} // kernel plane loop
}
void PackWeightKHWToHWKFp32(const void *src, void *dst, int plane, int channel) {
return PackNCHWToNHWCFp32(src, dst, 1, plane, channel);
}
void PackWeightInt8(int8_t *weight_data, ConvParameter *conv_param, int8_t *packed_weight, int32_t *weight_sum) {
// original weight format : ohwi
int kernel_h = conv_param->kernel_h_;
@ -153,22 +157,24 @@ void PackWeightInt8Opt(int8_t *weight_data, ConvParameter *conv_param, int8_t *p
} // kernel plane loop
}
void Conv1x1InputPackFp32(const float *src, float *dst, ConvParameter *conv_param) {
void Conv1x1InputPack(const void *src_ptr, void *dst_ptr, ConvParameter *conv_param, int data_size) {
/* support nhwc */
char *src = (char *)src_ptr;
char *dst = (char *)dst_ptr;
for (int dst_h = 0; dst_h < conv_param->output_h_; dst_h++) {
int src_h = dst_h * conv_param->stride_h_ - conv_param->pad_h_;
int src_h = dst_h * conv_param->stride_h_ - conv_param->pad_u_;
if (src_h < 0 || src_h >= conv_param->input_h_) {
continue;
}
const float *src_h_ptr = src + src_h * conv_param->input_w_ * conv_param->input_channel_;
float *dst_h_ptr = dst + dst_h * conv_param->output_w_ * conv_param->input_channel_;
const char *src_h_ptr = src + src_h * conv_param->input_w_ * conv_param->input_channel_ * data_size;
char *dst_h_ptr = dst + dst_h * conv_param->output_w_ * conv_param->input_channel_ * data_size;
for (int dst_w = 0; dst_w < conv_param->output_w_; dst_w++) {
int src_w = dst_w * conv_param->stride_w_ - conv_param->pad_w_;
int src_w = dst_w * conv_param->stride_w_ - conv_param->pad_l_;
if (src_w < 0 || src_w >= conv_param->input_w_) {
continue;
}
memcpy(dst_h_ptr + dst_w * conv_param->input_channel_, src_h_ptr + src_w * conv_param->input_channel_,
conv_param->input_channel_ * sizeof(float));
memcpy(dst_h_ptr + dst_w * conv_param->input_channel_ * data_size,
src_h_ptr + src_w * conv_param->input_channel_ * data_size, conv_param->input_channel_ * data_size);
}
}
return;
@ -188,6 +194,139 @@ void Pack1x1WeightFp32(const float *weight_data, float *packed_weight, ConvParam
return;
}
void PackInputSum16x4PerLayer(const int8_t *src, int32_t *dst, int32_t filter_zp, size_t row4, size_t col16) {
/* optimize normal -> same layout */
#ifdef ENABLE_ARM64
asm volatile(
"mov x10, %[src] \n"
"mov x11, %[dst] \n"
"dup v15.4s, %w[filter_zp] \n"
"mov x0, #0 \n"
"1: \n"
"cmp x0, %[row4] \n"
"beq 4f \n"
"add x0, x0, #4\n"
"dup v10.4s, wzr \n"
"mov x2, #0 \n"
"2: \n"
"cmp x2, %[col16] \n"
"beq 3f \n"
"add x2, x2, #16\n"
"ld1 {v0.16b}, [x10], #16\n"
"ld1 {v1.16b}, [x10], #16\n"
"ld1 {v2.16b}, [x10], #16\n"
"ld1 {v3.16b}, [x10], #16\n"
"saddlp v4.8h, v0.16b \n"
"saddlp v5.8h, v1.16b \n"
"saddlp v6.8h, v2.16b \n"
"saddlp v7.8h, v3.16b \n"
"saddlp v0.4S, v4.8h \n"
"saddlp v1.4S, v5.8h \n"
"saddlp v2.4S, v6.8h \n"
"saddlp v3.4S, v7.8h \n"
"addv s4, v0.4S \n"
"addv s5, v1.4S \n"
"addv s6, v2.4S \n"
"addv s7, v3.4S \n"
"mov v0.s[0], v4.s[0] \n"
"mov v0.s[1], v5.s[0] \n"
"mov v0.s[2], v6.s[0] \n"
"mov v0.s[3], v7.s[0] \n"
"add v10.4s, v10.4s, v0.4s \n"
"b 2b\n"
"3: \n"
"mul v10.4s, v10.4s, v15.4s \n"
"st1 {v10.4s}, [x11], #16 \n"
"beq 1b \n"
"4: \n"
:
: [ dst ] "r"(dst), [ src ] "r"(src), [ row4 ] "r"(row4), [ col16 ] "r"(col16), [ filter_zp ] "r"(filter_zp)
: "x0", "x1", "x2", "x3", "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v10", "v15");
#else
for (int r = 0; r < row4; r++) {
int32_t tmp_value = 0;
for (int c = 0; c < col16; c++) {
int r4div = r / C4NUM, r4mod = r % C4NUM, c16div = c / C16NUM, c16mod = c % C16NUM;
int src_index = r4div * C4NUM * col16 + c16div * C16NUM * C4NUM + r4mod * C16NUM + c16mod;
tmp_value += src[src_index];
}
dst[r] = tmp_value * filter_zp;
}
#endif
return;
}
void PackInputSum16x4Int8(const int8_t *input_value, int32_t *input_sum, size_t input_channel, size_t output_channel,
size_t plane_size, ConvParameter *conv_param) {
size_t hw4 = UP_ROUND(plane_size, C4NUM);
size_t ic16 = UP_ROUND(input_channel, C16NUM);
if (conv_param->conv_quant_arg_.filter_arg_num_ == 1) {
PackInputSum16x4PerLayer(input_value, input_sum, conv_param->conv_quant_arg_.filter_quant_args_[0].zp_, hw4, ic16);
} else {
for (int ri = 0; ri < plane_size; ri++) {
int ri4div = ri / C4NUM, ri4mod = ri % C4NUM;
for (int ci = 0; ci < output_channel; ci++) {
int32_t tmp_sum_value = 0;
int ci4div = ci / C4NUM, ci4mod = ci % C4NUM;
int32_t filter_zp = conv_param->conv_quant_arg_.filter_quant_args_[ci].zp_;
for (int di = 0; di < input_channel; di++) {
size_t di16div = di / C16NUM, di16mod = di % C16NUM;
int src_index = ri4div * C4NUM * ic16 + di16div * C16NUM * C4NUM + ri4mod * C16NUM + di16mod;
tmp_sum_value += input_value[src_index];
}
int dst_index = ci4div * C4NUM * hw4 + ri * C4NUM + ci4mod;
input_sum[dst_index] = tmp_sum_value * filter_zp;
}
}
}
return;
}
void PackInputSum8x4Int8(const int8_t *input_value, int32_t *input_sum, size_t input_channel, size_t output_channel,
size_t plane_size, ConvParameter *conv_param) {
size_t hw8 = UP_ROUND(plane_size, C8NUM);
size_t ic4 = UP_ROUND(input_channel, C4NUM);
if (conv_param->conv_quant_arg_.filter_arg_num_ == 1) {
for (int r = 0; r < hw8; r++) {
int32_t tmp_value = 0;
for (int c = 0; c < ic4; c++) {
int r8div = r / C8NUM, r8mod = r % C8NUM, c4div = c / C4NUM, c4mod = c % C4NUM;
int src_index = r8div * C8NUM * ic4 + c4div * C8NUM * C4NUM + r8mod * C4NUM + c4mod;
tmp_value += input_value[src_index];
}
input_sum[r] = tmp_value * conv_param->conv_quant_arg_.filter_quant_args_[0].zp_;
}
} else {
for (int ri = 0; ri < plane_size; ri++) {
int ri8div = ri / C8NUM, ri8mod = ri % C8NUM;
for (int ci = 0; ci < output_channel; ci++) {
int32_t tmp_sum_value = 0;
int ci8div = ci / C8NUM, ci8mod = ci % C8NUM;
int32_t filter_zp = conv_param->conv_quant_arg_.filter_quant_args_[ci].zp_;
for (int di = 0; di < input_channel; di++) {
size_t di4div = di / C4NUM, di4mod = di % C4NUM;
int src_index = ri8div * C8NUM * ic4 + di4div * C8NUM * C4NUM + ri8mod * C4NUM + di4mod;
tmp_sum_value += input_value[src_index];
}
int dst_index = ci8div * C8NUM * hw8 + ri * C8NUM + ci8mod;
input_sum[dst_index] = tmp_sum_value * filter_zp;
}
}
}
return;
}
void Im2ColPackUnitFp32(const float *input_data, ConvParameter *conv_param, float *packed_input, int real_cal_num,
int block_index) {
// input format : nhwc
@ -195,8 +334,8 @@ void Im2ColPackUnitFp32(const float *input_data, ConvParameter *conv_param, floa
int kernel_w = conv_param->kernel_w_;
int stride_h = conv_param->stride_h_;
int stride_w = conv_param->stride_w_;
int pad_h = conv_param->pad_h_;
int pad_w = conv_param->pad_w_;
int pad_h = conv_param->pad_u_;
int pad_w = conv_param->pad_l_;
int dilation_h = conv_param->dilation_h_;
int dilation_w = conv_param->dilation_w_;
int in_channel = conv_param->input_channel_;
@ -204,23 +343,21 @@ void Im2ColPackUnitFp32(const float *input_data, ConvParameter *conv_param, floa
int in_w = conv_param->input_w_;
int out_w = conv_param->output_w_;
int ic4 = UP_DIV(in_channel, C4NUM);
memset(packed_input, 0, kernel_h * kernel_w * ic4 * C4NUM * TILE_NUM * sizeof(float));
for (int i = 0; i < real_cal_num; i++) {
int block_start = block_index + i;
int input_h = block_start / out_w * stride_h - pad_h;
int input_w = block_start % out_w * stride_w - pad_w;
for (int j = 0; j < kernel_h; j++) {
int input_y = input_h + j * dilation_h;
if (input_y < 0 || input_y >= in_h) {
continue;
}
int input_y_stride = input_y * in_w * ic4 * C4NUM;
for (int n = 0; n < kernel_w; n++) {
int input_x = input_w + n * dilation_w;
if (input_x < 0 || input_x >= in_w) {
continue;
}
int input_x_stride = input_y_stride + input_x * ic4 * C4NUM;
int input_stride = input_h * in_w * ic4 * C4NUM + input_w * ic4 * C4NUM;
int kh_s = MSMAX(0, UP_DIV(-input_h, dilation_h));
int kh_e = MSMIN(kernel_h, UP_DIV(in_h - input_h, dilation_h));
int kw_s = MSMAX(0, UP_DIV(-input_w, dilation_w));
int kw_e = MSMIN(kernel_w, UP_DIV(in_w - input_w, dilation_w));
for (int j = kh_s; j < kh_e; j++) {
int input_y_stride = j * dilation_h * in_w * ic4 * C4NUM + input_stride;
for (int n = kw_s; n < kw_e; n++) {
int input_x_stride = input_y_stride + n * dilation_w * ic4 * C4NUM;
int input_plane_offset = (j * kernel_w + n) * C8NUM * C4NUM * ic4 + i * C4NUM;
for (int m = 0; m < ic4; m++) {
int channel_block_stride = input_x_stride + m * C4NUM;
@ -247,8 +384,8 @@ void Im2ColPackUnitInt8(const int8_t *input_data, int8_t *packed_input, int real
int kernel_w = conv_param->kernel_w_;
int stride_h = conv_param->stride_h_;
int stride_w = conv_param->stride_w_;
int pad_h = conv_param->pad_h_;
int pad_w = conv_param->pad_w_;
int pad_h = conv_param->pad_u_;
int pad_w = conv_param->pad_l_;
int dilation_h = conv_param->dilation_h_;
int dilation_w = conv_param->dilation_w_;
int in_channel = conv_param->input_channel_;
@ -318,8 +455,8 @@ void Im2ColPackUnitInt8Opt(const int8_t *input_data, int8_t *packed_input, int r
int kernel_w = conv_param->kernel_w_;
int stride_h = conv_param->stride_h_;
int stride_w = conv_param->stride_w_;
int pad_h = conv_param->pad_h_;
int pad_w = conv_param->pad_w_;
int pad_h = conv_param->pad_u_;
int pad_w = conv_param->pad_l_;
int dilation_h = conv_param->dilation_h_;
int dilation_w = conv_param->dilation_w_;
int in_channel = conv_param->input_channel_;
@ -350,9 +487,7 @@ void Im2ColPackUnitInt8Opt(const int8_t *input_data, int8_t *packed_input, int r
for (int m = 0; m < ic4; m++) {
int channel_block_stride = input_x_stride + m * C4NUM;
int channel_block_offset = input_plane_offset + m * tile_num * C4NUM;
for (int k = 0; k < C4NUM; k++) {
(packed_input + channel_block_offset)[k] = (input_data + channel_block_stride)[k];
}
memcpy(packed_input + channel_block_offset, input_data + channel_block_stride, 4);
} // channel_block loop
} // kernel_w loop
} // kernel_h loop
@ -660,6 +795,8 @@ void PackNC4HW4ToNHWCRelu6Fp32(const void *src, void *dst, int batch, int plane,
}
}
void PackNC4HW4ToNHWCPreluFp32(const void *src, void *dst, const void *slope, int batch, int plane, int channel) {}
void PackNC4HW4ToNCHWFp32(const void *src, void *dst, int batch, int plane, int channel) {
int c4 = UP_DIV(channel, C4NUM);
for (int b = 0; b < batch; b++) {

View File

@ -35,10 +35,18 @@ void Im2ColPackUnitInt8(const int8_t *input_data, int8_t *packed_input, int real
void Im2ColPackUnitInt8Opt(const int8_t *input_data, int8_t *packed_input, int real_cal_num, int block_index,
int32_t *input_sum, ConvParameter *conv_param);
void Conv1x1InputPackFp32(const float *src, float *dst, ConvParameter *conv_param);
void PackInputSum16x4PerLayer(const int8_t *src, int32_t *dst, int32_t filter_zp, size_t row4, size_t col16);
void Conv1x1InputPack(const void *src_ptr, void *dst_ptr, ConvParameter *conv_param, int data_size);
void Pack1x1WeightFp32(const float *weight_data, float *packed_weight, ConvParameter *conv_param);
void PackInputSum16x4Int8(const int8_t *input_value, int32_t *input_sum, size_t input_channel, size_t output_channel,
size_t plane_size, ConvParameter *conv_param);
void PackInputSum8x4Int8(const int8_t *input_value, int32_t *input_sum, size_t input_channel, size_t output_channel,
size_t plane_size, ConvParameter *conv_param);
void MatrixPack(const float *src, float *dst, int row, int ic4, int stride);
void PackInputToC8Int8(const int8_t *input_data, int16_t *packed_input, ConvParameter *conv_param);
@ -46,6 +54,8 @@ void PackInputToC8Int8(const int8_t *input_data, int16_t *packed_input, ConvPara
void PackWeightFp32(float *weight_data, ConvParameter *conv_param, float *packed_weight, int oc_block,
int oc_block_num);
void PackWeightKHWToHWKFp32(const void *src, void *dst, int plane, int channel);
void PackWeightInt8(int8_t *weight_data, ConvParameter *conv_param, int8_t *packed_weight, int32_t *weight_sum);
void PackWeightInt8Opt(int8_t *weight_data, ConvParameter *conv_param, int8_t *packed_weight, int32_t *weight_sum);
@ -76,6 +86,8 @@ void PackNC4HW4ToNHWCReluFp32(const void *src, void *dst, int batch, int plane,
void PackNC4HW4ToNHWCRelu6Fp32(const void *src, void *dst, int batch, int plane, int channel);
void PackNC4HW4ToNHWCPreluFp32(const void *src, void *dst, const void *slope, int batch, int plane, int channel);
void PackNC4HW4ToNCHWFp32(const void *src, void *dst, int batch, int plane, int channel);
void PackNHWCToC8HWN8Fp32(const void *src, void *dst, int batch, int plane, int channel);

View File

@ -19,14 +19,16 @@
#include "nnacl/op_base.h"
#include "nnacl/quantization/quantize.h"
typedef enum PoolMode { PoolMode_No, PoolMode_MaxPool, PoolMode_AvgPool } PoolMode;
typedef enum RoundMode { RoundMode_No, RoundMode_Ceil, RoundMode_Floor } RoundMode;
typedef struct PoolingParameter {
OpParameter op_parameter_;
PoolMode pool_mode_;
RoundMode round_mode_;
ActType act_type_;
QuantArg **quant_args_;
bool global_;
bool max_pooling_;
bool avg_pooling_;
bool round_ceil_;
bool round_floor_;
int window_w_;
int window_h_;
int input_w_;
@ -44,6 +46,8 @@ typedef struct PoolingParameter {
int stride_w_;
int stride_h_;
int thread_num_;
bool global_;
bool quantize_;
} PoolingParameter;
#endif // MINDSPORE_LITE_NNACL_POOLING_PARAMETER_H_

Some files were not shown because too many files have changed in this diff Show More