forked from mindspore-Ecosystem/mindspore
[feat] [assistant] [I3T96X] add new Dataset operator LibriSpeechDataset
This commit is contained in:
parent
9f08cdc4ab
commit
4e6f7dc97d
|
@ -24,9 +24,6 @@ if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
|
||||||
-Wno-return-std-move -Wno-unused-private-field -Wno-unused-lambda-capture -Wno-sign-compare \
|
-Wno-return-std-move -Wno-unused-private-field -Wno-unused-lambda-capture -Wno-sign-compare \
|
||||||
-Wno-overloaded-virtual -Wno-unneeded-internal-declaration -Wno-unused-variable -Wno-pessimizing-move \
|
-Wno-overloaded-virtual -Wno-unneeded-internal-declaration -Wno-unused-variable -Wno-pessimizing-move \
|
||||||
-Wno-inconsistent-missing-override -DHALF_ENABLE_CPP11_USER_LITERALS=0 -D_FORTIFY_SOURCE=2")
|
-Wno-inconsistent-missing-override -DHALF_ENABLE_CPP11_USER_LITERALS=0 -D_FORTIFY_SOURCE=2")
|
||||||
elseif(ENABLE_SYM_FILE)
|
|
||||||
set(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} -O2 -g -ggdb -Wl,--allow-shlib-undefined \
|
|
||||||
-DHALF_ENABLE_CPP11_USER_LITERALS=0 -D_FORTIFY_SOURCE=2")
|
|
||||||
else()
|
else()
|
||||||
set(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} -O2 -Wl,--allow-shlib-undefined \
|
set(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} -O2 -Wl,--allow-shlib-undefined \
|
||||||
-DHALF_ENABLE_CPP11_USER_LITERALS=0 -D_FORTIFY_SOURCE=2")
|
-DHALF_ENABLE_CPP11_USER_LITERALS=0 -D_FORTIFY_SOURCE=2")
|
||||||
|
|
|
@ -18,7 +18,7 @@
|
||||||
SET BASE_PATH=%CD%
|
SET BASE_PATH=%CD%
|
||||||
SET BUILD_PATH=%BASE_PATH%/build
|
SET BUILD_PATH=%BASE_PATH%/build
|
||||||
|
|
||||||
SET threads=8
|
SET threads=6
|
||||||
SET ENABLE_GITEE=OFF
|
SET ENABLE_GITEE=OFF
|
||||||
|
|
||||||
set VERSION_MAJOR=''
|
set VERSION_MAJOR=''
|
||||||
|
|
14
build.sh
14
build.sh
|
@ -27,7 +27,7 @@ usage()
|
||||||
echo " [-P on|off] [-z [on|off]] [-M on|off] [-V 10.1|11.1|310|910] [-I arm64|arm32|x86_64] [-K] \\"
|
echo " [-P on|off] [-z [on|off]] [-M on|off] [-V 10.1|11.1|310|910] [-I arm64|arm32|x86_64] [-K] \\"
|
||||||
echo " [-B on|off] [-E] [-l on|off] [-n full|lite|off] [-H on|off] \\"
|
echo " [-B on|off] [-E] [-l on|off] [-n full|lite|off] [-H on|off] \\"
|
||||||
echo " [-A on|off] [-S on|off] [-k on|off] [-W sse|neon|avx|avx512|off] \\"
|
echo " [-A on|off] [-S on|off] [-k on|off] [-W sse|neon|avx|avx512|off] \\"
|
||||||
echo " [-L Tensor-RT path] [-y on|off] \\"
|
echo " [-L Tensor-RT path] \\"
|
||||||
echo ""
|
echo ""
|
||||||
echo "Options:"
|
echo "Options:"
|
||||||
echo " -d Debug mode"
|
echo " -d Debug mode"
|
||||||
|
@ -61,10 +61,9 @@ usage()
|
||||||
echo " -l Compile with python dependency, default on"
|
echo " -l Compile with python dependency, default on"
|
||||||
echo " -S Enable enable download cmake compile dependency from gitee , default off"
|
echo " -S Enable enable download cmake compile dependency from gitee , default off"
|
||||||
echo " -k Enable make clean, clean up compilation generated cache "
|
echo " -k Enable make clean, clean up compilation generated cache "
|
||||||
echo " -W Enable SIMD instruction set, use [sse|neon|avx|avx512|off], default avx for cloud CPU backend"
|
echo " -W Enable x86_64 SSE or AVX instruction set, use [sse|neon|avx|avx512|off], default off for lite and avx for CPU"
|
||||||
echo " -H Enable hidden"
|
echo " -H Enable hidden"
|
||||||
echo " -L Link and specify Tensor-RT library path, default disable Tensor-RT lib linking"
|
echo " -L Link and specify Tensor-RT library path, default disable Tensor-RT lib linking"
|
||||||
echo " -y Compile the symbol table switch and save the symbol table to the directory output"
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# check value of input is 'on' or 'off'
|
# check value of input is 'on' or 'off'
|
||||||
|
@ -123,9 +122,8 @@ checkopts()
|
||||||
TENSORRT_HOME=""
|
TENSORRT_HOME=""
|
||||||
USER_ENABLE_DUMP_IR=false
|
USER_ENABLE_DUMP_IR=false
|
||||||
USER_ENABLE_DEBUGGER=false
|
USER_ENABLE_DEBUGGER=false
|
||||||
ENABLE_SYM_FILE="off"
|
|
||||||
# Process the options
|
# Process the options
|
||||||
while getopts 'drvj:c:t:hb:s:a:g:p:ie:m:l:I:RP:D:zM:V:K:B:En:A:S:k:W:H:L:y' opt
|
while getopts 'drvj:c:t:hb:s:a:g:p:ie:m:l:I:RP:D:zM:V:K:B:En:A:S:k:W:H:L:' opt
|
||||||
do
|
do
|
||||||
CASE_SENSIVE_ARG=${OPTARG}
|
CASE_SENSIVE_ARG=${OPTARG}
|
||||||
OPTARG=$(echo ${OPTARG} | tr '[A-Z]' '[a-z]')
|
OPTARG=$(echo ${OPTARG} | tr '[A-Z]' '[a-z]')
|
||||||
|
@ -142,9 +140,6 @@ checkopts()
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
;;
|
;;
|
||||||
y)
|
|
||||||
ENABLE_SYM_FILE="on"
|
|
||||||
;;
|
|
||||||
r)
|
r)
|
||||||
DEBUG_MODE="off"
|
DEBUG_MODE="off"
|
||||||
;;
|
;;
|
||||||
|
@ -447,9 +442,6 @@ build_mindspore()
|
||||||
if [[ -n "$TRAIN_MODE" ]]; then
|
if [[ -n "$TRAIN_MODE" ]]; then
|
||||||
CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_${TRAIN_MODE}=ON"
|
CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_${TRAIN_MODE}=ON"
|
||||||
fi
|
fi
|
||||||
if [[ "X$ENABLE_SYM_FILE" = "Xon" ]]; then
|
|
||||||
CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_SYM_FILE=ON"
|
|
||||||
fi
|
|
||||||
if [[ "X$ENABLE_ASAN" = "Xon" ]]; then
|
if [[ "X$ENABLE_ASAN" = "Xon" ]]; then
|
||||||
CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_ASAN=ON"
|
CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_ASAN=ON"
|
||||||
fi
|
fi
|
||||||
|
|
|
@ -0,0 +1,44 @@
|
||||||
|
set(FFMPEG_FLAGS
|
||||||
|
--disable-programs
|
||||||
|
--disable-doc
|
||||||
|
--disable-debug
|
||||||
|
--disable-avdevice
|
||||||
|
--disable-postproc
|
||||||
|
--disable-avfilter
|
||||||
|
--disable-network
|
||||||
|
--disable-encoders
|
||||||
|
--disable-hwaccels
|
||||||
|
--disable-muxers
|
||||||
|
--disable-bsfs
|
||||||
|
--disable-protocols
|
||||||
|
--enable-protocol=file
|
||||||
|
--enable-protocol=pipe
|
||||||
|
--disable-indevs
|
||||||
|
--disable-outdevs
|
||||||
|
--disable-devices
|
||||||
|
--disable-filters
|
||||||
|
--disable-bzlib
|
||||||
|
--disable-iconv
|
||||||
|
--disable-libxcb
|
||||||
|
--disable-lzma
|
||||||
|
--disable-sdl2
|
||||||
|
--disable-xlib
|
||||||
|
--disable-zlib)
|
||||||
|
|
||||||
|
set(REQ_URL "https://github.com/FFmpeg/FFmpeg/archive/n4.3.1.tar.gz")
|
||||||
|
set(MD5 "426ca412ca61634a248c787e29507206")
|
||||||
|
|
||||||
|
mindspore_add_pkg(ffmpeg
|
||||||
|
VER 4.3.1
|
||||||
|
LIBS avcodec avformat avutil swresample swscale
|
||||||
|
URL ${REQ_URL}
|
||||||
|
MD5 ${MD5}
|
||||||
|
CONFIGURE_COMMAND ./configure --disable-static --enable-shared --disable-x86asm ${FFMPEG_FLAGS}
|
||||||
|
)
|
||||||
|
|
||||||
|
include_directories(${ffmpeg_INC})
|
||||||
|
add_library(mindspore::avcodec ALIAS ffmpeg::avcodec)
|
||||||
|
add_library(mindspore::avformat ALIAS ffmpeg::avformat)
|
||||||
|
add_library(mindspore::avutil ALIAS ffmpeg::avutil)
|
||||||
|
add_library(mindspore::swresample ALIAS ffmpeg::swresample)
|
||||||
|
add_library(mindspore::swscale ALIAS ffmpeg::swscale)
|
|
@ -1,10 +1,10 @@
|
||||||
if(MSVC)
|
if(MSVC)
|
||||||
set(flatbuffers_CXXFLAGS "${CMAKE_CXX_FLAGS}")
|
set(flatbuffers_CXXFLAGS "${CMAKE_CXX_FLAGS}")
|
||||||
set(flatbuffers_CFLAGS "${CMAKE_C_FLAGS}")
|
set(flatbuffers_CFLAGS "${CMAKE_CXX_FLAGS}")
|
||||||
set(flatbuffers_LDFLAGS "${CMAKE_SHARED_LINKER_FLAGS}")
|
set(flatbuffers_LDFLAGS "${CMAKE_SHARED_LINKER_FLAGS}")
|
||||||
else()
|
else()
|
||||||
set(flatbuffers_CXXFLAGS "-fPIC -fPIE -D_FORTIFY_SOURCE=2 -O2 -fstack-protector-strong")
|
set(flatbuffers_CXXFLAGS "-D_FORTIFY_SOURCE=2 -O2")
|
||||||
set(flatbuffers_CFLAGS "-fPIC -fPIE -D_FORTIFY_SOURCE=2 -O2 -fstack-protector-strong")
|
set(flatbuffers_CFLAGS "-D_FORTIFY_SOURCE=2 -O2")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(WIN32)
|
if(WIN32)
|
||||||
|
|
|
@ -1,15 +1,13 @@
|
||||||
if(BUILD_LITE)
|
|
||||||
set(glog_CXXFLAGS "-D_FORTIFY_SOURCE=2 -O2 ${SECURE_CXX_FLAGS} -Dgoogle=mindspore_private")
|
|
||||||
set(glog_CFLAGS "-D_FORTIFY_SOURCE=2 -O2 ${SECURE_C_FLAGS}")
|
|
||||||
set(glog_LDFLAGS "${SECURE_SHARED_LINKER_FLAGS}")
|
|
||||||
set(glog_patch "")
|
|
||||||
set(glog_lib glog)
|
|
||||||
else()
|
|
||||||
set(glog_CXXFLAGS "-D_FORTIFY_SOURCE=2 -O2 ${SECURE_CXX_FLAGS} -Dgoogle=mindspore_private")
|
set(glog_CXXFLAGS "-D_FORTIFY_SOURCE=2 -O2 ${SECURE_CXX_FLAGS} -Dgoogle=mindspore_private")
|
||||||
set(glog_CFLAGS "-D_FORTIFY_SOURCE=2 -O2")
|
set(glog_CFLAGS "-D_FORTIFY_SOURCE=2 -O2")
|
||||||
if(NOT ENABLE_GLIBCXX)
|
if(NOT ENABLE_GLIBCXX)
|
||||||
set(glog_CXXFLAGS "${glog_CXXFLAGS} -D_GLIBCXX_USE_CXX11_ABI=0")
|
set(glog_CXXFLAGS "${glog_CXXFLAGS} -D_GLIBCXX_USE_CXX11_ABI=0")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
if(BUILD_LITE)
|
||||||
|
set(glog_patch "")
|
||||||
|
set(glog_lib glog)
|
||||||
|
else()
|
||||||
set(glog_patch ${CMAKE_SOURCE_DIR}/third_party/patch/glog/glog.patch001)
|
set(glog_patch ${CMAKE_SOURCE_DIR}/third_party/patch/glog/glog.patch001)
|
||||||
set(glog_lib mindspore_glog)
|
set(glog_lib mindspore_glog)
|
||||||
endif()
|
endif()
|
||||||
|
|
|
@ -9,7 +9,7 @@ endif()
|
||||||
|
|
||||||
if(ENABLE_GITEE)
|
if(ENABLE_GITEE)
|
||||||
set(REQ_URL "https://gitee.com/mirrors/JSON-for-Modern-CPP/repository/archive/v3.6.1.zip")
|
set(REQ_URL "https://gitee.com/mirrors/JSON-for-Modern-CPP/repository/archive/v3.6.1.zip")
|
||||||
set(MD5 "36ea0d9a709c6667b2798a62f6b197ae")
|
set(MD5 "5bda78ce308e6cfcf614dcf1d5ff27a7")
|
||||||
set(INCLUDE "./include")
|
set(INCLUDE "./include")
|
||||||
else()
|
else()
|
||||||
set(REQ_URL "https://github.com/nlohmann/json/releases/download/v3.6.1/include.zip")
|
set(REQ_URL "https://github.com/nlohmann/json/releases/download/v3.6.1/include.zip")
|
||||||
|
|
|
@ -89,6 +89,7 @@ if(ENABLE_MINDDATA)
|
||||||
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/tinyxml2.cmake)
|
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/tinyxml2.cmake)
|
||||||
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/cppjieba.cmake)
|
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/cppjieba.cmake)
|
||||||
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/sentencepiece.cmake)
|
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/sentencepiece.cmake)
|
||||||
|
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/ffmpeg.cmake)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(ENABLE_MINDDATA)
|
if(ENABLE_MINDDATA)
|
||||||
|
|
|
@ -25,7 +25,6 @@ option(ENABLE_ACL "enable acl" OFF)
|
||||||
option(ENABLE_GLIBCXX "enable_glibcxx" OFF)
|
option(ENABLE_GLIBCXX "enable_glibcxx" OFF)
|
||||||
option(MODE_ASCEND_ALL "supports all ascend platform" OFF)
|
option(MODE_ASCEND_ALL "supports all ascend platform" OFF)
|
||||||
option(MODE_ASCEND_ACL "supports ascend acl mode only" OFF)
|
option(MODE_ASCEND_ACL "supports ascend acl mode only" OFF)
|
||||||
option(ENABLE_SYM_FILE "enable sym file" OFF)
|
|
||||||
|
|
||||||
if(NOT ENABLE_D AND NOT ENABLE_TESTCASES AND NOT ENABLE_ACL AND NOT ENABLE_GE)
|
if(NOT ENABLE_D AND NOT ENABLE_TESTCASES AND NOT ENABLE_ACL AND NOT ENABLE_GE)
|
||||||
set(ENABLE_GLIBCXX ON)
|
set(ENABLE_GLIBCXX ON)
|
||||||
|
|
|
@ -12,8 +12,6 @@ set(CPACK_TEMPORARY_PACKAGE_FILE_NAME ${BUILD_PATH}/package/mindspore)
|
||||||
set(CPACK_TEMPORARY_INSTALL_DIRECTORY ${BUILD_PATH}/package/mindspore)
|
set(CPACK_TEMPORARY_INSTALL_DIRECTORY ${BUILD_PATH}/package/mindspore)
|
||||||
set(CPACK_PACK_ROOT_DIR ${BUILD_PATH}/package/)
|
set(CPACK_PACK_ROOT_DIR ${BUILD_PATH}/package/)
|
||||||
set(CPACK_CMAKE_SOURCE_DIR ${CMAKE_SOURCE_DIR})
|
set(CPACK_CMAKE_SOURCE_DIR ${CMAKE_SOURCE_DIR})
|
||||||
set(CPACK_ENABLE_SYM_FILE ${ENABLE_SYM_FILE})
|
|
||||||
set(CPACK_CMAKE_BUILD_TYPE ${CMAKE_BUILD_TYPE})
|
|
||||||
if(ENABLE_GE)
|
if(ENABLE_GE)
|
||||||
set(CPACK_MS_BACKEND "ge")
|
set(CPACK_MS_BACKEND "ge")
|
||||||
set(CPACK_MS_TARGET "ascend or cpu")
|
set(CPACK_MS_TARGET "ascend or cpu")
|
||||||
|
@ -127,6 +125,17 @@ if(ENABLE_MINDDATA)
|
||||||
DESTINATION ${INSTALL_LIB_DIR} RENAME libicudata.so.67 COMPONENT mindspore)
|
DESTINATION ${INSTALL_LIB_DIR} RENAME libicudata.so.67 COMPONENT mindspore)
|
||||||
install(FILES ${icu4c_LIBPATH}/libicui18n.so.67.1
|
install(FILES ${icu4c_LIBPATH}/libicui18n.so.67.1
|
||||||
DESTINATION ${INSTALL_LIB_DIR} RENAME libicui18n.so.67 COMPONENT mindspore)
|
DESTINATION ${INSTALL_LIB_DIR} RENAME libicui18n.so.67 COMPONENT mindspore)
|
||||||
|
|
||||||
|
install(FILES ${ffmpeg_LIBPATH}/libavcodec.so.58.91.100
|
||||||
|
DESTINATION ${INSTALL_LIB_DIR} RENAME libavcodec.so.58 COMPONENT mindspore)
|
||||||
|
install(FILES ${ffmpeg_LIBPATH}/libavformat.so.58.45.100
|
||||||
|
DESTINATION ${INSTALL_LIB_DIR} RENAME libavformat.so.58 COMPONENT mindspore)
|
||||||
|
install(FILES ${ffmpeg_LIBPATH}/libavutil.so.56.51.100
|
||||||
|
DESTINATION ${INSTALL_LIB_DIR} RENAME libavutil.so.56 COMPONENT mindspore)
|
||||||
|
install(FILES ${ffmpeg_LIBPATH}/libswresample.so.3.7.100
|
||||||
|
DESTINATION ${INSTALL_LIB_DIR} RENAME libswresample.so.3 COMPONENT mindspore)
|
||||||
|
install(FILES ${ffmpeg_LIBPATH}/libswscale.so.5.7.100
|
||||||
|
DESTINATION ${INSTALL_LIB_DIR} RENAME libswscale.so.5 COMPONENT mindspore)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(ENABLE_CPU)
|
if(ENABLE_CPU)
|
||||||
|
@ -198,6 +207,12 @@ if(NOT ENABLE_GE)
|
||||||
set(ASCEND_DRIVER_PATH ${ASCEND_PATH}/driver/lib64/common)
|
set(ASCEND_DRIVER_PATH ${ASCEND_PATH}/driver/lib64/common)
|
||||||
|
|
||||||
if(ENABLE_D)
|
if(ENABLE_D)
|
||||||
|
install(
|
||||||
|
TARGETS ms_profile
|
||||||
|
DESTINATION ${INSTALL_LIB_DIR}
|
||||||
|
COMPONENT mindspore
|
||||||
|
)
|
||||||
|
|
||||||
install(
|
install(
|
||||||
TARGETS hccl_plugin
|
TARGETS hccl_plugin
|
||||||
DESTINATION ${INSTALL_LIB_DIR}
|
DESTINATION ${INSTALL_LIB_DIR}
|
||||||
|
|
|
@ -330,6 +330,8 @@ elseif(WIN32)
|
||||||
DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME})
|
DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||||
install(FILES ${TOP_DIR}/mindspore/lite/tools/converter/model_parser.h
|
install(FILES ${TOP_DIR}/mindspore/lite/tools/converter/model_parser.h
|
||||||
DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME})
|
DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||||
|
install(FILES ${TOP_DIR}/mindspore/lite/tools/converter/dump_graph.h
|
||||||
|
DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||||
install(FILES ${TOP_DIR}/mindspore/lite/tools/converter/ops/ops_def.h
|
install(FILES ${TOP_DIR}/mindspore/lite/tools/converter/ops/ops_def.h
|
||||||
DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME})
|
DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||||
install(DIRECTORY ${TOP_DIR}/build/mindspore/schema/ DESTINATION ${CONVERTER_ROOT_DIR}/include/schema
|
install(DIRECTORY ${TOP_DIR}/build/mindspore/schema/ DESTINATION ${CONVERTER_ROOT_DIR}/include/schema
|
||||||
|
@ -460,6 +462,8 @@ else()
|
||||||
DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME})
|
DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||||
install(FILES ${TOP_DIR}/mindspore/lite/tools/converter/model_parser.h
|
install(FILES ${TOP_DIR}/mindspore/lite/tools/converter/model_parser.h
|
||||||
DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME})
|
DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||||
|
install(FILES ${TOP_DIR}/mindspore/lite/tools/converter/dump_graph.h
|
||||||
|
DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||||
install(FILES ${TOP_DIR}/mindspore/lite/tools/converter/ops/ops_def.h
|
install(FILES ${TOP_DIR}/mindspore/lite/tools/converter/ops/ops_def.h
|
||||||
DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME})
|
DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||||
install(DIRECTORY ${TOP_DIR}/mindspore/lite/build/schema/ DESTINATION ${CONVERTER_ROOT_DIR}/include/schema
|
install(DIRECTORY ${TOP_DIR}/mindspore/lite/build/schema/ DESTINATION ${CONVERTER_ROOT_DIR}/include/schema
|
||||||
|
|
|
@ -77,48 +77,6 @@ set(ENV{BACKEND_TARGET} ${CPACK_MS_TARGET})
|
||||||
set(ENV{MS_PACKAGE_NAME} ${CPACK_MS_PACKAGE_NAME})
|
set(ENV{MS_PACKAGE_NAME} ${CPACK_MS_PACKAGE_NAME})
|
||||||
set(ENV{COMMIT_ID} ${GIT_COMMIT_ID})
|
set(ENV{COMMIT_ID} ${GIT_COMMIT_ID})
|
||||||
|
|
||||||
file(GLOB DEBUG_SYM
|
|
||||||
${MS_PACK_ROOT_DIR}/mindspore/*.so
|
|
||||||
${MS_PACK_ROOT_DIR}/mindspore/lib/*.so
|
|
||||||
)
|
|
||||||
|
|
||||||
file(GLOB DEBUG_STRIP_SYM
|
|
||||||
${MS_PACK_ROOT_DIR}/mindspore/*.so
|
|
||||||
${MS_PACK_ROOT_DIR}/mindspore/lib/*.so*
|
|
||||||
)
|
|
||||||
|
|
||||||
set(CMAKE_OBJCOPY $ENV{CROSS_COMPILE}objcopy)
|
|
||||||
set(CMAKE_STRIP $ENV{CROSS_COMPILE}strip)
|
|
||||||
|
|
||||||
if(CPACK_ENABLE_SYM_FILE)
|
|
||||||
foreach(schema ${DEBUG_SYM})
|
|
||||||
execute_process(
|
|
||||||
COMMAND ${CMAKE_OBJCOPY} "--only-keep-debug" ${schema} ${schema}.sym
|
|
||||||
WORKING_DIRECTORY ${MS_PACK_ROOT_DIR}
|
|
||||||
)
|
|
||||||
endforeach()
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if("${CPACK_CMAKE_BUILD_TYPE}" STREQUAL "Release")
|
|
||||||
foreach(schema ${DEBUG_STRIP_SYM})
|
|
||||||
execute_process(
|
|
||||||
COMMAND ${CMAKE_STRIP} ${schema}
|
|
||||||
WORKING_DIRECTORY ${MS_PACK_ROOT_DIR}
|
|
||||||
)
|
|
||||||
endforeach()
|
|
||||||
endif()
|
|
||||||
|
|
||||||
file(GLOB DEBUG_SYM_FILE
|
|
||||||
${MS_PACK_ROOT_DIR}/mindspore/*.sym
|
|
||||||
${MS_PACK_ROOT_DIR}/mindspore/lib/*.sym
|
|
||||||
)
|
|
||||||
|
|
||||||
if(CPACK_ENABLE_SYM_FILE)
|
|
||||||
file(MAKE_DIRECTORY ${MS_ROOT_DIR}/debug_info)
|
|
||||||
file(COPY ${DEBUG_SYM_FILE} DESTINATION ${MS_ROOT_DIR}/debug_info/)
|
|
||||||
file(REMOVE_RECURSE ${DEBUG_SYM_FILE})
|
|
||||||
endif()
|
|
||||||
|
|
||||||
execute_process(
|
execute_process(
|
||||||
COMMAND ${PYTHON} ${MS_ROOT_DIR}/setup.py "bdist_wheel"
|
COMMAND ${PYTHON} ${MS_ROOT_DIR}/setup.py "bdist_wheel"
|
||||||
WORKING_DIRECTORY ${MS_PACK_ROOT_DIR}
|
WORKING_DIRECTORY ${MS_PACK_ROOT_DIR}
|
||||||
|
@ -146,16 +104,3 @@ file(COPY ${MS_PACK_ROOT_DIR}/${NEW_FILE_NAME} DESTINATION ${MS_ROOT_DIR}/output
|
||||||
|
|
||||||
file(SHA256 ${MS_ROOT_DIR}/output/${NEW_FILE_NAME} SHA256_VAR)
|
file(SHA256 ${MS_ROOT_DIR}/output/${NEW_FILE_NAME} SHA256_VAR)
|
||||||
file(WRITE ${MS_ROOT_DIR}/output/${NEW_FILE_NAME}.sha256 ${SHA256_VAR} " " ${NEW_FILE_NAME})
|
file(WRITE ${MS_ROOT_DIR}/output/${NEW_FILE_NAME}.sha256 ${SHA256_VAR} " " ${NEW_FILE_NAME})
|
||||||
set(CMAKE_TAR $ENV{CROSS_COMPILE}tar)
|
|
||||||
if(CPACK_ENABLE_SYM_FILE)
|
|
||||||
file(MAKE_DIRECTORY ${MS_ROOT_DIR}/output/${PACKAGE_NAME}-${VERSION}-${PY_TAGS}-${PLATFORM_TAG})
|
|
||||||
file(COPY ${MS_ROOT_DIR}/debug_info/ DESTINATION
|
|
||||||
${MS_ROOT_DIR}/output/${PACKAGE_NAME}-${VERSION}-${PY_TAGS}-${PLATFORM_TAG}/)
|
|
||||||
execute_process(COMMAND
|
|
||||||
${CMAKE_COMMAND} -E ${CMAKE_TAR} cfv
|
|
||||||
${MS_ROOT_DIR}/output/${PACKAGE_NAME}-${VERSION}-${PY_TAGS}-${PLATFORM_TAG}.zip
|
|
||||||
${MS_ROOT_DIR}/output/${PACKAGE_NAME}-${VERSION}-${PY_TAGS}-${PLATFORM_TAG}/ --format=zip
|
|
||||||
WORKING_DIRECTORY ${MS_ROOT_DIR})
|
|
||||||
file(REMOVE_RECURSE ${MS_ROOT_DIR}/debug_info)
|
|
||||||
file(REMOVE_RECURSE ${MS_ROOT_DIR}/output/${PACKAGE_NAME}-${VERSION}-${PY_TAGS}-${PLATFORM_TAG})
|
|
||||||
endif()
|
|
||||||
|
|
|
@ -91,6 +91,18 @@ if(ENABLE_MINDDATA)
|
||||||
DESTINATION ${INSTALL_LIB_DIR}
|
DESTINATION ${INSTALL_LIB_DIR}
|
||||||
COMPONENT mindspore
|
COMPONENT mindspore
|
||||||
)
|
)
|
||||||
|
file(GLOB_RECURSE FFMPEG_LIB_LIST
|
||||||
|
${ffmpeg_LIBPATH}/libavcodec*
|
||||||
|
${ffmpeg_LIBPATH}/libavformat*
|
||||||
|
${ffmpeg_LIBPATH}/libavutil*
|
||||||
|
${ffmpeg_LIBPATH}/libswresample*
|
||||||
|
${ffmpeg_LIBPATH}/libswscale*
|
||||||
|
)
|
||||||
|
install(
|
||||||
|
FILES ${FFMPEG_LIB_LIST}
|
||||||
|
DESTINATION ${INSTALL_LIB_DIR}
|
||||||
|
COMPONENT mindspore
|
||||||
|
)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# CPU mode
|
# CPU mode
|
||||||
|
|
|
@ -42,6 +42,7 @@ set(opencv_LIBPATH ${opencv_LIBPATH}/../bin/)
|
||||||
set(jpeg_turbo_LIBPATH ${jpeg_turbo_LIBPATH}/../bin/)
|
set(jpeg_turbo_LIBPATH ${jpeg_turbo_LIBPATH}/../bin/)
|
||||||
set(sqlite_LIBPATH ${sqlite_LIBPATH}/../bin/)
|
set(sqlite_LIBPATH ${sqlite_LIBPATH}/../bin/)
|
||||||
set(tinyxml2_LIBPATH ${tinyxml2_LIBPATH}/../bin/)
|
set(tinyxml2_LIBPATH ${tinyxml2_LIBPATH}/../bin/)
|
||||||
|
set(ffmpeg_LIBPATH ${ffmpeg_LIBPATH}/../bin/)
|
||||||
|
|
||||||
message("offline debugger does not support windows system temporarily")
|
message("offline debugger does not support windows system temporarily")
|
||||||
|
|
||||||
|
@ -97,6 +98,18 @@ if(ENABLE_MINDDATA)
|
||||||
DESTINATION ${INSTALL_LIB_DIR}
|
DESTINATION ${INSTALL_LIB_DIR}
|
||||||
COMPONENT mindspore
|
COMPONENT mindspore
|
||||||
)
|
)
|
||||||
|
file(GLOB_RECURSE FFMPEG_LIB_LIST
|
||||||
|
${ffmpeg_LIBPATH}/libavcodec*
|
||||||
|
${ffmpeg_LIBPATH}/libavformat*
|
||||||
|
${ffmpeg_LIBPATH}/libavutil*
|
||||||
|
${ffmpeg_LIBPATH}/libswresample*
|
||||||
|
${ffmpeg_LIBPATH}/libswscale*
|
||||||
|
)
|
||||||
|
install(
|
||||||
|
FILES ${FFMPEG_LIB_LIST}
|
||||||
|
DESTINATION ${INSTALL_LIB_DIR}
|
||||||
|
COMPONENT mindspore
|
||||||
|
)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(ENABLE_CPU)
|
if(ENABLE_CPU)
|
||||||
|
|
|
@ -1,4 +1,2 @@
|
||||||
approvers:
|
|
||||||
- zhoufeng54
|
|
||||||
reviewers:
|
reviewers:
|
||||||
- HW_KK
|
- HW_KK
|
|
@ -58,11 +58,8 @@ RUN apt install -y libffi-dev libssl-dev zlib1g-dev libbz2-dev libncurses5-dev \
|
||||||
&& make install -j4 \
|
&& make install -j4 \
|
||||||
&& rm -f /usr/local/bin/python \
|
&& rm -f /usr/local/bin/python \
|
||||||
&& rm -f /usr/local/bin/pip \
|
&& rm -f /usr/local/bin/pip \
|
||||||
&& rm -f /usr/local/lib/libpython3.7m.so.1.0 \
|
|
||||||
&& ln -s ${PYTHON_ROOT_PATH}/bin/python3.7 /usr/local/bin/python \
|
&& ln -s ${PYTHON_ROOT_PATH}/bin/python3.7 /usr/local/bin/python \
|
||||||
&& ln -s ${PYTHON_ROOT_PATH}/bin/pip3.7 /usr/local/bin/pip \
|
&& ln -s ${PYTHON_ROOT_PATH}/bin/pip3.7 /usr/local/bin/pip \
|
||||||
&& ln -s ${PYTHON_ROOT_PATH}/lib/libpython3.7m.so.1.0 /usr/local/lib/libpython3.7m.so.1.0 \
|
|
||||||
&& ldconfig \
|
|
||||||
&& rm -rf /tmp/cpython-3.7.5 \
|
&& rm -rf /tmp/cpython-3.7.5 \
|
||||||
&& rm -f /tmp/v3.7.5.tar.gz
|
&& rm -f /tmp/v3.7.5.tar.gz
|
||||||
|
|
||||||
|
|
|
@ -51,16 +51,13 @@ RUN apt install -y libffi-dev libssl-dev zlib1g-dev libbz2-dev libncurses5-dev \
|
||||||
&& tar -xvf v3.7.5.tar.gz \
|
&& tar -xvf v3.7.5.tar.gz \
|
||||||
&& cd /tmp/cpython-3.7.5 \
|
&& cd /tmp/cpython-3.7.5 \
|
||||||
&& mkdir -p ${PYTHON_ROOT_PATH} \
|
&& mkdir -p ${PYTHON_ROOT_PATH} \
|
||||||
&& ./configure --prefix=${PYTHON_ROOT_PATH} --enable-shared \
|
&& ./configure --prefix=${PYTHON_ROOT_PATH} \
|
||||||
&& make -j4 \
|
&& make -j4 \
|
||||||
&& make install -j4 \
|
&& make install -j4 \
|
||||||
&& rm -f /usr/local/bin/python \
|
&& rm -f /usr/local/bin/python \
|
||||||
&& rm -f /usr/local/bin/pip \
|
&& rm -f /usr/local/bin/pip \
|
||||||
&& rm -f /usr/local/lib/libpython3.7m.so.1.0 \
|
|
||||||
&& ln -s ${PYTHON_ROOT_PATH}/bin/python3.7 /usr/local/bin/python \
|
&& ln -s ${PYTHON_ROOT_PATH}/bin/python3.7 /usr/local/bin/python \
|
||||||
&& ln -s ${PYTHON_ROOT_PATH}/bin/pip3.7 /usr/local/bin/pip \
|
&& ln -s ${PYTHON_ROOT_PATH}/bin/pip3.7 /usr/local/bin/pip \
|
||||||
&& ln -s ${PYTHON_ROOT_PATH}/lib/libpython3.7m.so.1.0 /usr/local/lib/libpython3.7m.so.1.0 \
|
|
||||||
&& ldconfig \
|
|
||||||
&& rm -rf /tmp/cpython-3.7.5 \
|
&& rm -rf /tmp/cpython-3.7.5 \
|
||||||
&& rm -f /tmp/v3.7.5.tar.gz
|
&& rm -f /tmp/v3.7.5.tar.gz
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
FROM nvidia/cuda:11.1-cudnn8-devel-ubuntu18.04
|
FROM nvidia/cuda:10.1-cudnn7-devel-ubuntu18.04
|
||||||
|
|
||||||
MAINTAINER leonwanghui <leon.wanghui@huawei.com>
|
MAINTAINER leonwanghui <leon.wanghui@huawei.com>
|
||||||
|
|
||||||
|
@ -43,7 +43,7 @@ RUN DEBIAN_FRONTEND=noninteractive apt install -y \
|
||||||
libnuma-dev
|
libnuma-dev
|
||||||
|
|
||||||
# Configure cuDNN (v7.6.5)
|
# Configure cuDNN (v7.6.5)
|
||||||
RUN ln -s /usr/lib/x86_64-linux-gnu/libcudnn.so.8.0.5 /usr/local/cuda/lib64/libcudnn.so
|
RUN ln -s /usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.5 /usr/local/cuda/lib64/libcudnn.so
|
||||||
|
|
||||||
# Set bash
|
# Set bash
|
||||||
RUN echo "dash dash/sh boolean false" | debconf-set-selections
|
RUN echo "dash dash/sh boolean false" | debconf-set-selections
|
||||||
|
@ -62,11 +62,8 @@ RUN apt install -y libffi-dev libssl-dev zlib1g-dev libbz2-dev libncurses5-dev \
|
||||||
&& make install -j4 \
|
&& make install -j4 \
|
||||||
&& rm -f /usr/local/bin/python \
|
&& rm -f /usr/local/bin/python \
|
||||||
&& rm -f /usr/local/bin/pip \
|
&& rm -f /usr/local/bin/pip \
|
||||||
&& rm -f /usr/local/lib/libpython3.7m.so.1.0 \
|
|
||||||
&& ln -s ${PYTHON_ROOT_PATH}/bin/python3.7 /usr/local/bin/python \
|
&& ln -s ${PYTHON_ROOT_PATH}/bin/python3.7 /usr/local/bin/python \
|
||||||
&& ln -s ${PYTHON_ROOT_PATH}/bin/pip3.7 /usr/local/bin/pip \
|
&& ln -s ${PYTHON_ROOT_PATH}/bin/pip3.7 /usr/local/bin/pip \
|
||||||
&& ln -s ${PYTHON_ROOT_PATH}/lib/libpython3.7m.so.1.0 /usr/local/lib/libpython3.7m.so.1.0 \
|
|
||||||
&& ldconfig \
|
|
||||||
&& rm -rf /tmp/cpython-3.7.5 \
|
&& rm -rf /tmp/cpython-3.7.5 \
|
||||||
&& rm -f /tmp/v3.7.5.tar.gz
|
&& rm -f /tmp/v3.7.5.tar.gz
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
FROM nvidia/cuda:11.1-cudnn8-devel-ubuntu18.04
|
FROM nvidia/cuda:10.1-cudnn7-devel-ubuntu18.04
|
||||||
|
|
||||||
MAINTAINER leonwanghui <leon.wanghui@huawei.com>
|
MAINTAINER leonwanghui <leon.wanghui@huawei.com>
|
||||||
|
|
||||||
|
@ -53,16 +53,13 @@ RUN apt install -y libffi-dev libssl-dev zlib1g-dev libbz2-dev libncurses5-dev \
|
||||||
&& tar -xvf v3.7.5.tar.gz \
|
&& tar -xvf v3.7.5.tar.gz \
|
||||||
&& cd /tmp/cpython-3.7.5 \
|
&& cd /tmp/cpython-3.7.5 \
|
||||||
&& mkdir -p ${PYTHON_ROOT_PATH} \
|
&& mkdir -p ${PYTHON_ROOT_PATH} \
|
||||||
&& ./configure --prefix=${PYTHON_ROOT_PATH} --enable-shared \
|
&& ./configure --prefix=${PYTHON_ROOT_PATH} \
|
||||||
&& make -j4 \
|
&& make -j4 \
|
||||||
&& make install -j4 \
|
&& make install -j4 \
|
||||||
&& rm -f /usr/local/bin/python \
|
&& rm -f /usr/local/bin/python \
|
||||||
&& rm -f /usr/local/bin/pip \
|
&& rm -f /usr/local/bin/pip \
|
||||||
&& rm -f /usr/local/lib/libpython3.7m.so.1.0 \
|
|
||||||
&& ln -s ${PYTHON_ROOT_PATH}/bin/python3.7 /usr/local/bin/python \
|
&& ln -s ${PYTHON_ROOT_PATH}/bin/python3.7 /usr/local/bin/python \
|
||||||
&& ln -s ${PYTHON_ROOT_PATH}/bin/pip3.7 /usr/local/bin/pip \
|
&& ln -s ${PYTHON_ROOT_PATH}/bin/pip3.7 /usr/local/bin/pip \
|
||||||
&& ln -s ${PYTHON_ROOT_PATH}/lib/libpython3.7m.so.1.0 /usr/local/lib/libpython3.7m.so.1.0 \
|
|
||||||
&& ldconfig \
|
|
||||||
&& rm -rf /tmp/cpython-3.7.5 \
|
&& rm -rf /tmp/cpython-3.7.5 \
|
||||||
&& rm -f /tmp/v3.7.5.tar.gz
|
&& rm -f /tmp/v3.7.5.tar.gz
|
||||||
|
|
||||||
|
|
|
@ -23,6 +23,12 @@
|
||||||
#include "include/api/data_type.h"
|
#include "include/api/data_type.h"
|
||||||
#include "include/api/dual_abi_helper.h"
|
#include "include/api/dual_abi_helper.h"
|
||||||
|
|
||||||
|
#ifdef _WIN32
|
||||||
|
#define MS_API __declspec(dllexport)
|
||||||
|
#else
|
||||||
|
#define MS_API __attribute__((visibility("default")))
|
||||||
|
#endif
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
class Model;
|
class Model;
|
||||||
class ModelImpl;
|
class ModelImpl;
|
||||||
|
|
|
@ -22,6 +22,12 @@
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include "include/api/callback/callback.h"
|
#include "include/api/callback/callback.h"
|
||||||
|
|
||||||
|
#ifdef _WIN32
|
||||||
|
#define MS_API __declspec(dllexport)
|
||||||
|
#else
|
||||||
|
#define MS_API __attribute__((visibility("default")))
|
||||||
|
#endif
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
|
||||||
class CkptSaver: public TrainCallBack {
|
class CkptSaver: public TrainCallBack {
|
||||||
|
|
|
@ -21,6 +21,12 @@
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include "include/api/callback/callback.h"
|
#include "include/api/callback/callback.h"
|
||||||
|
|
||||||
|
#ifdef _WIN32
|
||||||
|
#define MS_API __declspec(dllexport)
|
||||||
|
#else
|
||||||
|
#define MS_API __attribute__((visibility("default")))
|
||||||
|
#endif
|
||||||
|
|
||||||
using GraphPoint = std::pair<int, float>;
|
using GraphPoint = std::pair<int, float>;
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
|
|
@ -22,6 +22,12 @@
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include "include/api/callback/callback.h"
|
#include "include/api/callback/callback.h"
|
||||||
|
|
||||||
|
#ifdef _WIN32
|
||||||
|
#define MS_API __declspec(dllexport)
|
||||||
|
#else
|
||||||
|
#define MS_API __attribute__((visibility("default")))
|
||||||
|
#endif
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
|
||||||
constexpr int DONT_UPDATE_LR = 0;
|
constexpr int DONT_UPDATE_LR = 0;
|
||||||
|
|
|
@ -22,6 +22,12 @@
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include "include/api/callback/callback.h"
|
#include "include/api/callback/callback.h"
|
||||||
|
|
||||||
|
#ifdef _WIN32
|
||||||
|
#define MS_API __declspec(dllexport)
|
||||||
|
#else
|
||||||
|
#define MS_API __attribute__((visibility("default")))
|
||||||
|
#endif
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
|
||||||
class TimeMonitor: public TrainCallBack {
|
class TimeMonitor: public TrainCallBack {
|
||||||
|
|
|
@ -24,6 +24,12 @@
|
||||||
#include "include/api/callback/callback.h"
|
#include "include/api/callback/callback.h"
|
||||||
#include "include/api/metrics/accuracy.h"
|
#include "include/api/metrics/accuracy.h"
|
||||||
|
|
||||||
|
#ifdef _WIN32
|
||||||
|
#define MS_API __declspec(dllexport)
|
||||||
|
#else
|
||||||
|
#define MS_API __attribute__((visibility("default")))
|
||||||
|
#endif
|
||||||
|
|
||||||
using GraphPoint = std::pair<int, float>;
|
using GraphPoint = std::pair<int, float>;
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
|
|
@ -23,6 +23,12 @@
|
||||||
#include "include/api/data_type.h"
|
#include "include/api/data_type.h"
|
||||||
#include "include/api/dual_abi_helper.h"
|
#include "include/api/dual_abi_helper.h"
|
||||||
|
|
||||||
|
#ifdef _WIN32
|
||||||
|
#define MS_API __declspec(dllexport)
|
||||||
|
#else
|
||||||
|
#define MS_API __attribute__((visibility("default")))
|
||||||
|
#endif
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
|
||||||
class MixPrecisionCfg {
|
class MixPrecisionCfg {
|
||||||
|
|
|
@ -38,19 +38,12 @@ class Allocator;
|
||||||
class Delegate;
|
class Delegate;
|
||||||
class DeviceInfoContext;
|
class DeviceInfoContext;
|
||||||
|
|
||||||
/// \brief Context is used to store environment variables during execution.
|
|
||||||
class MS_API Context {
|
class MS_API Context {
|
||||||
public:
|
public:
|
||||||
Context();
|
Context();
|
||||||
~Context() = default;
|
~Context() = default;
|
||||||
|
|
||||||
/// \brief Set the number of threads at runtime. This option is only valid for MindSpore Lite.
|
|
||||||
///
|
|
||||||
/// \param[in] thread_num the number of threads at runtime.
|
|
||||||
void SetThreadNum(int32_t thread_num);
|
void SetThreadNum(int32_t thread_num);
|
||||||
/// \brief Get the current thread number setting.
|
|
||||||
///
|
|
||||||
/// \return The current thread number setting.
|
|
||||||
int32_t GetThreadNum() const;
|
int32_t GetThreadNum() const;
|
||||||
|
|
||||||
/// \brief Set the thread affinity to CPU cores.
|
/// \brief Set the thread affinity to CPU cores.
|
||||||
|
@ -67,10 +60,6 @@ class MS_API Context {
|
||||||
void SetDelegate(const std::shared_ptr<Delegate> &delegate);
|
void SetDelegate(const std::shared_ptr<Delegate> &delegate);
|
||||||
std::shared_ptr<Delegate> GetDelegate() const;
|
std::shared_ptr<Delegate> GetDelegate() const;
|
||||||
|
|
||||||
/// \brief Get a mutable reference of DeviceInfoContext vector in this context. Only MindSpore Lite supports
|
|
||||||
/// heterogeneous scenarios with multiple members in the vector.
|
|
||||||
///
|
|
||||||
/// \return Mutable reference of DeviceInfoContext vector in this context.
|
|
||||||
std::vector<std::shared_ptr<DeviceInfoContext>> &MutableDeviceInfo();
|
std::vector<std::shared_ptr<DeviceInfoContext>> &MutableDeviceInfo();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -78,24 +67,14 @@ class MS_API Context {
|
||||||
std::shared_ptr<Data> data_;
|
std::shared_ptr<Data> data_;
|
||||||
};
|
};
|
||||||
|
|
||||||
/// \brief DeviceInfoContext defines different device contexts.
|
|
||||||
class MS_API DeviceInfoContext : public std::enable_shared_from_this<DeviceInfoContext> {
|
class MS_API DeviceInfoContext : public std::enable_shared_from_this<DeviceInfoContext> {
|
||||||
public:
|
public:
|
||||||
struct Data;
|
struct Data;
|
||||||
|
|
||||||
DeviceInfoContext();
|
DeviceInfoContext();
|
||||||
virtual ~DeviceInfoContext() = default;
|
virtual ~DeviceInfoContext() = default;
|
||||||
|
|
||||||
/// \brief Get the type of this DeviceInfoContext.
|
|
||||||
///
|
|
||||||
/// \return Type of this DeviceInfoContext.
|
|
||||||
virtual enum DeviceType GetDeviceType() const = 0;
|
virtual enum DeviceType GetDeviceType() const = 0;
|
||||||
|
|
||||||
/// \brief A similar function to RTTI is provided when the -fno-rtti compilation option is turned on, which converts
|
|
||||||
/// DeviceInfoContext to a shared pointer of type T, and returns nullptr if the conversion fails.
|
|
||||||
///
|
|
||||||
/// \param T Type
|
|
||||||
/// \return A pointer of type T after conversion. If the conversion fails, it will be nullptr.
|
|
||||||
template <class T>
|
template <class T>
|
||||||
std::shared_ptr<T> Cast() {
|
std::shared_ptr<T> Cast() {
|
||||||
static_assert(std::is_base_of<DeviceInfoContext, T>::value, "Wrong cast type.");
|
static_assert(std::is_base_of<DeviceInfoContext, T>::value, "Wrong cast type.");
|
||||||
|
@ -105,89 +84,41 @@ class MS_API DeviceInfoContext : public std::enable_shared_from_this<DeviceInfoC
|
||||||
|
|
||||||
return std::static_pointer_cast<T>(shared_from_this());
|
return std::static_pointer_cast<T>(shared_from_this());
|
||||||
}
|
}
|
||||||
/// \brief obtain provider's name
|
|
||||||
///
|
|
||||||
/// \return provider's name.
|
|
||||||
std::string GetProvider() const;
|
std::string GetProvider() const;
|
||||||
/// \brief set provider's name.
|
|
||||||
///
|
|
||||||
/// \param[in] provider define the provider's name.
|
|
||||||
void SetProvider(const std::string &provider);
|
void SetProvider(const std::string &provider);
|
||||||
/// \brief obtain provider's device type.
|
|
||||||
///
|
|
||||||
/// \return provider's device type.
|
|
||||||
std::string GetProviderDevice() const;
|
std::string GetProviderDevice() const;
|
||||||
/// \brief set provider's device type.
|
|
||||||
///
|
|
||||||
/// \param[in] device define the provider's device type.EG: CPU.
|
|
||||||
void SetProviderDevice(const std::string &device);
|
void SetProviderDevice(const std::string &device);
|
||||||
/// \brief set memory allocator.
|
|
||||||
///
|
|
||||||
/// \param[in] allocator define the memory allocator which can be defined by user.
|
|
||||||
void SetAllocator(const std::shared_ptr<Allocator> &allocator);
|
void SetAllocator(const std::shared_ptr<Allocator> &allocator);
|
||||||
/// \brief obtain memory allocator.
|
|
||||||
///
|
|
||||||
/// \return memory allocator.
|
|
||||||
std::shared_ptr<Allocator> GetAllocator() const;
|
std::shared_ptr<Allocator> GetAllocator() const;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
std::shared_ptr<Data> data_;
|
std::shared_ptr<Data> data_;
|
||||||
};
|
};
|
||||||
|
|
||||||
/// \brief Derived from DeviceInfoContext, The configuration of the model running on the CPU. This option is only valid
|
|
||||||
/// for MindSpore Lite.
|
|
||||||
class MS_API CPUDeviceInfo : public DeviceInfoContext {
|
class MS_API CPUDeviceInfo : public DeviceInfoContext {
|
||||||
public:
|
public:
|
||||||
/// \brief Get the type of this DeviceInfoContext.
|
|
||||||
///
|
|
||||||
/// \return Type of this DeviceInfoContext.
|
|
||||||
enum DeviceType GetDeviceType() const override { return DeviceType::kCPU; };
|
enum DeviceType GetDeviceType() const override { return DeviceType::kCPU; };
|
||||||
|
|
||||||
/// \brief Set enables to perform the float16 inference
|
|
||||||
///
|
|
||||||
/// \param[in] is_fp16 Enable float16 inference or not.
|
|
||||||
void SetEnableFP16(bool is_fp16);
|
void SetEnableFP16(bool is_fp16);
|
||||||
/// \brief Get enables to perform the float16 inference
|
|
||||||
///
|
|
||||||
/// \return Whether enable float16 inference.
|
|
||||||
bool GetEnableFP16() const;
|
bool GetEnableFP16() const;
|
||||||
};
|
};
|
||||||
|
|
||||||
/// \brief Derived from DeviceInfoContext, The configuration of the model running on the NPU. This option is only valid
|
|
||||||
/// for MindSpore Lite.
|
|
||||||
class MS_API KirinNPUDeviceInfo : public DeviceInfoContext {
|
class MS_API KirinNPUDeviceInfo : public DeviceInfoContext {
|
||||||
public:
|
public:
|
||||||
/// \brief Get the type of this DeviceInfoContext.
|
|
||||||
///
|
|
||||||
/// \return Type of this DeviceInfoContext.
|
|
||||||
enum DeviceType GetDeviceType() const override { return DeviceType::kKirinNPU; };
|
enum DeviceType GetDeviceType() const override { return DeviceType::kKirinNPU; };
|
||||||
|
|
||||||
/// \brief Set the NPU frequency.
|
|
||||||
///
|
|
||||||
/// \param[in] frequency Can be set to 1 (low power consumption), 2 (balanced), 3 (high performance), 4 (extreme
|
|
||||||
/// performance), default as 3.
|
|
||||||
void SetFrequency(int frequency);
|
void SetFrequency(int frequency);
|
||||||
/// \brief Get the NPU frequency.
|
|
||||||
///
|
|
||||||
/// \return NPU frequency
|
|
||||||
int GetFrequency() const;
|
int GetFrequency() const;
|
||||||
};
|
};
|
||||||
|
|
||||||
/// \brief Derived from DeviceInfoContext, The configuration of the model running on the GPU.
|
|
||||||
class MS_API GPUDeviceInfo : public DeviceInfoContext {
|
class MS_API GPUDeviceInfo : public DeviceInfoContext {
|
||||||
public:
|
public:
|
||||||
/// \brief Get the type of this DeviceInfoContext.
|
|
||||||
///
|
|
||||||
/// \return Type of this DeviceInfoContext.
|
|
||||||
enum DeviceType GetDeviceType() const override { return DeviceType::kGPU; };
|
enum DeviceType GetDeviceType() const override { return DeviceType::kGPU; };
|
||||||
|
|
||||||
/// \brief Set device id.
|
|
||||||
///
|
|
||||||
/// \param[in] device_id The device id.
|
|
||||||
void SetDeviceID(uint32_t device_id);
|
void SetDeviceID(uint32_t device_id);
|
||||||
/// \brief Get the device id.
|
|
||||||
///
|
|
||||||
/// \return The device id.
|
|
||||||
uint32_t GetDeviceID() const;
|
uint32_t GetDeviceID() const;
|
||||||
|
|
||||||
void SetGpuTrtInferMode(bool gpu_trt_infer_mode);
|
void SetGpuTrtInferMode(bool gpu_trt_infer_mode);
|
||||||
|
@ -196,15 +127,8 @@ class MS_API GPUDeviceInfo : public DeviceInfoContext {
|
||||||
inline void SetPrecisionMode(const std::string &precison_mode);
|
inline void SetPrecisionMode(const std::string &precison_mode);
|
||||||
inline std::string GetPrecisionMode() const;
|
inline std::string GetPrecisionMode() const;
|
||||||
|
|
||||||
/// \brief Set enables to perform the float16 inference
|
|
||||||
///
|
|
||||||
/// \param[in] is_fp16 Enable float16 inference or not.
|
|
||||||
void SetEnableFP16(bool is_fp16);
|
void SetEnableFP16(bool is_fp16);
|
||||||
/// \brief Get enables to perform the float16 inference
|
|
||||||
///
|
|
||||||
/// \return Whether enable float16 inference.
|
|
||||||
bool GetEnableFP16() const;
|
bool GetEnableFP16() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void SetPrecisionMode(const std::vector<char> &precision_mode);
|
void SetPrecisionMode(const std::vector<char> &precision_mode);
|
||||||
std::vector<char> GetPrecisionModeChar() const;
|
std::vector<char> GetPrecisionModeChar() const;
|
||||||
|
@ -215,113 +139,52 @@ void GPUDeviceInfo::SetPrecisionMode(const std::string &precision_mode) {
|
||||||
}
|
}
|
||||||
std::string GPUDeviceInfo::GetPrecisionMode() const { return CharToString(GetPrecisionModeChar()); }
|
std::string GPUDeviceInfo::GetPrecisionMode() const { return CharToString(GetPrecisionModeChar()); }
|
||||||
|
|
||||||
/// \brief Derived from DeviceInfoContext, The configuration of the model running on the Ascend910. This option is
|
|
||||||
/// invalid for MindSpore Lite.
|
|
||||||
class MS_API Ascend910DeviceInfo : public DeviceInfoContext {
|
class MS_API Ascend910DeviceInfo : public DeviceInfoContext {
|
||||||
public:
|
public:
|
||||||
/// \brief Get the type of this DeviceInfoContext.
|
|
||||||
///
|
|
||||||
/// \return Type of this DeviceInfoContext.
|
|
||||||
enum DeviceType GetDeviceType() const override { return DeviceType::kAscend910; };
|
enum DeviceType GetDeviceType() const override { return DeviceType::kAscend910; };
|
||||||
|
|
||||||
/// \brief Set device id.
|
|
||||||
///
|
|
||||||
/// \param[in] device_id The device id.
|
|
||||||
void SetDeviceID(uint32_t device_id);
|
void SetDeviceID(uint32_t device_id);
|
||||||
/// \brief Get the device id.
|
|
||||||
///
|
|
||||||
/// \return The device id.
|
|
||||||
uint32_t GetDeviceID() const;
|
uint32_t GetDeviceID() const;
|
||||||
};
|
};
|
||||||
|
|
||||||
/// \brief Derived from DeviceInfoContext, The configuration of the model running on the Ascend310. This option is
|
|
||||||
/// invalid for MindSpore Lite.
|
|
||||||
class MS_API Ascend310DeviceInfo : public DeviceInfoContext {
|
class MS_API Ascend310DeviceInfo : public DeviceInfoContext {
|
||||||
public:
|
public:
|
||||||
/// \brief Get the type of this DeviceInfoContext.
|
|
||||||
///
|
|
||||||
/// \return Type of this DeviceInfoContext.
|
|
||||||
enum DeviceType GetDeviceType() const override { return DeviceType::kAscend310; };
|
enum DeviceType GetDeviceType() const override { return DeviceType::kAscend310; };
|
||||||
|
|
||||||
/// \brief Set device id.
|
|
||||||
///
|
|
||||||
/// \param[in] device_id The device id.
|
|
||||||
void SetDeviceID(uint32_t device_id);
|
void SetDeviceID(uint32_t device_id);
|
||||||
/// \brief Get the device id.
|
|
||||||
///
|
|
||||||
/// \return The device id.
|
|
||||||
uint32_t GetDeviceID() const;
|
uint32_t GetDeviceID() const;
|
||||||
|
|
||||||
inline void SetDumpConfigPath(const std::string &cfg_path);
|
inline void SetDumpConfigPath(const std::string &cfg_path);
|
||||||
inline std::string GetDumpConfigPath() const;
|
inline std::string GetDumpConfigPath() const;
|
||||||
|
|
||||||
/// \brief Set AIPP configuration file path.
|
// aipp config file
|
||||||
///
|
|
||||||
/// \param[in] cfg_path AIPP configuration file path.
|
|
||||||
inline void SetInsertOpConfigPath(const std::string &cfg_path);
|
inline void SetInsertOpConfigPath(const std::string &cfg_path);
|
||||||
/// \brief Get AIPP configuration file path.
|
|
||||||
///
|
|
||||||
/// \return AIPP configuration file path.
|
|
||||||
inline std::string GetInsertOpConfigPath() const;
|
inline std::string GetInsertOpConfigPath() const;
|
||||||
|
|
||||||
/// \brief Set format of model inputs.
|
// nchw or nhwc
|
||||||
///
|
|
||||||
/// \param[in] format Optional "NCHW", "NHWC", etc.
|
|
||||||
inline void SetInputFormat(const std::string &format);
|
inline void SetInputFormat(const std::string &format);
|
||||||
/// \brief Get format of model inputs.
|
|
||||||
///
|
|
||||||
/// \return The format of model inputs.
|
|
||||||
inline std::string GetInputFormat() const;
|
inline std::string GetInputFormat() const;
|
||||||
|
|
||||||
/// \brief Set shape of model inputs.
|
// Mandatory while dynamic batch: e.g. "input_op_name1: 1,2,3,4;input_op_name2: 4,3,2,1"
|
||||||
///
|
|
||||||
/// \param[in] shape e.g. "input_op_name1: 1,2,3,4;input_op_name2: 4,3,2,1".
|
|
||||||
inline void SetInputShape(const std::string &shape);
|
inline void SetInputShape(const std::string &shape);
|
||||||
/// \brief Get shape of model inputs.
|
|
||||||
///
|
|
||||||
/// \return The shape of model inputs.
|
|
||||||
inline std::string GetInputShape() const;
|
inline std::string GetInputShape() const;
|
||||||
|
|
||||||
/// \brief Set shape of model inputs.
|
|
||||||
///
|
|
||||||
/// \param[in] shape e.g. {{1, {1,2,3,4}}, {2, {4,3,2,1}}} means the first input shape 1,2,3,4 and the second input
|
|
||||||
/// shape 4,3,2,1.
|
|
||||||
void SetInputShapeMap(const std::map<int, std::vector<int>> &shape);
|
void SetInputShapeMap(const std::map<int, std::vector<int>> &shape);
|
||||||
/// \brief Get shape of model inputs.
|
|
||||||
///
|
|
||||||
/// \return The shape of model inputs.
|
|
||||||
std::map<int, std::vector<int>> GetInputShapeMap() const;
|
std::map<int, std::vector<int>> GetInputShapeMap() const;
|
||||||
|
|
||||||
void SetDynamicBatchSize(const std::vector<size_t> &dynamic_batch_size);
|
void SetDynamicBatchSize(const std::vector<size_t> &dynamic_batch_size);
|
||||||
inline std::string GetDynamicBatchSize() const;
|
inline std::string GetDynamicBatchSize() const;
|
||||||
|
|
||||||
/// \brief Set type of model outputs.
|
// FP32, UINT8 or FP16, default as FP32
|
||||||
///
|
|
||||||
/// \param[in] output_type FP32, UINT8 or FP16, default as FP32.
|
|
||||||
void SetOutputType(enum DataType output_type);
|
void SetOutputType(enum DataType output_type);
|
||||||
/// \brief Get type of model outputs.
|
|
||||||
///
|
|
||||||
/// \return The set type of model outputs.
|
|
||||||
enum DataType GetOutputType() const;
|
enum DataType GetOutputType() const;
|
||||||
|
|
||||||
/// \brief Set precision mode of model.
|
// "force_fp16", "allow_fp32_to_fp16", "must_keep_origin_dtype" or "allow_mix_precision", default as "force_fp16"
|
||||||
///
|
|
||||||
/// \param[in] precision_mode Optional "force_fp16", "allow_fp32_to_fp16", "must_keep_origin_dtype" and
|
|
||||||
/// "allow_mix_precision", "force_fp16" is set as default
|
|
||||||
inline void SetPrecisionMode(const std::string &precision_mode);
|
inline void SetPrecisionMode(const std::string &precision_mode);
|
||||||
/// \brief Get precision mode of model.
|
|
||||||
///
|
|
||||||
/// \return The set type of model outputs
|
|
||||||
inline std::string GetPrecisionMode() const;
|
inline std::string GetPrecisionMode() const;
|
||||||
|
|
||||||
/// \brief Set op select implementation mode.
|
// Optional "high_performance" and "high_precision", "high_performance" is set as default
|
||||||
///
|
|
||||||
/// \param[in] op_select_impl_mode Optional "high_performance" and "high_precision", "high_performance" is set as
|
|
||||||
/// default.
|
|
||||||
inline void SetOpSelectImplMode(const std::string &op_select_impl_mode);
|
inline void SetOpSelectImplMode(const std::string &op_select_impl_mode);
|
||||||
/// \brief Get op select implementation mode.
|
|
||||||
///
|
|
||||||
/// \return The set op select implementation mode.
|
|
||||||
inline std::string GetOpSelectImplMode() const;
|
inline std::string GetOpSelectImplMode() const;
|
||||||
|
|
||||||
inline void SetFusionSwitchConfigPath(const std::string &cfg_path);
|
inline void SetFusionSwitchConfigPath(const std::string &cfg_path);
|
||||||
|
|
|
@ -24,16 +24,9 @@
|
||||||
#include "include/api/context.h"
|
#include "include/api/context.h"
|
||||||
|
|
||||||
namespace mindspore::kernel {
|
namespace mindspore::kernel {
|
||||||
/// \brief The Kernel class is used to define a MindSpore Kernel.
|
|
||||||
class Kernel {
|
class Kernel {
|
||||||
public:
|
public:
|
||||||
Kernel() = default;
|
Kernel() = default;
|
||||||
/// \brief Constructor.
|
|
||||||
///
|
|
||||||
/// \param[in] inputs define the input tensors for kernel.
|
|
||||||
/// \param[in] outputs define the output tensors for kernel.
|
|
||||||
/// \param[in] primitive define the primitive of kernel generated by flatbuffers.
|
|
||||||
/// \param[in] ctx define the context for kernel.
|
|
||||||
Kernel(const std::vector<mindspore::MSTensor> &inputs, const std::vector<mindspore::MSTensor> &outputs,
|
Kernel(const std::vector<mindspore::MSTensor> &inputs, const std::vector<mindspore::MSTensor> &outputs,
|
||||||
const schema::Primitive *primitive, const mindspore::Context *ctx)
|
const schema::Primitive *primitive, const mindspore::Context *ctx)
|
||||||
: context_(ctx), inputs_(std::move(inputs)), outputs_(std::move(outputs)), primitive_(primitive) {
|
: context_(ctx), inputs_(std::move(inputs)), outputs_(std::move(outputs)), primitive_(primitive) {
|
||||||
|
@ -41,65 +34,32 @@ class Kernel {
|
||||||
type_ = primitive->value_type();
|
type_ = primitive->value_type();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
/// \brief Destructor.
|
|
||||||
virtual ~Kernel() = default;
|
virtual ~Kernel() = default;
|
||||||
/// \brief prepare for executing kernel.
|
|
||||||
///
|
|
||||||
/// \return result code.
|
|
||||||
virtual int Prepare() = 0;
|
virtual int Prepare() = 0;
|
||||||
/// \brief execute the kernel.
|
|
||||||
///
|
|
||||||
/// \return result code.
|
|
||||||
virtual int Execute() = 0;
|
virtual int Execute() = 0;
|
||||||
/// \brief resize the kernel input shape, memory need to refresh.
|
|
||||||
///
|
|
||||||
/// \return result code.
|
|
||||||
virtual int ReSize() = 0;
|
virtual int ReSize() = 0;
|
||||||
/// \brief set kernel's input tensors.
|
|
||||||
///
|
|
||||||
/// \param[in] in_tensors define the input tensors.
|
|
||||||
virtual void set_inputs(const std::vector<mindspore::MSTensor> &in_tensors) { this->inputs_ = in_tensors; }
|
virtual void set_inputs(const std::vector<mindspore::MSTensor> &in_tensors) { this->inputs_ = in_tensors; }
|
||||||
/// \brief set kernel's input tensor.
|
|
||||||
///
|
|
||||||
/// \param[in] in_tensor define the input tensor.
|
|
||||||
/// \param[in] index define the index of the input tensor.
|
|
||||||
virtual void set_input(mindspore::MSTensor in_tensor, int index) { this->inputs_[index] = in_tensor; }
|
virtual void set_input(mindspore::MSTensor in_tensor, int index) { this->inputs_[index] = in_tensor; }
|
||||||
/// \brief set kernel's output tensors.
|
|
||||||
///
|
|
||||||
/// \param[in] out_tensors define the output tensors.
|
|
||||||
virtual void set_outputs(const std::vector<mindspore::MSTensor> &out_tensors) { this->outputs_ = out_tensors; }
|
virtual void set_outputs(const std::vector<mindspore::MSTensor> &out_tensors) { this->outputs_ = out_tensors; }
|
||||||
/// \brief set kernel's output tensor.
|
|
||||||
///
|
|
||||||
/// \param[in] out_tensor define the output tensor.
|
|
||||||
/// \param[in] index define the index of the output tensor.
|
|
||||||
virtual void set_output(mindspore::MSTensor out_tensor, int index) { this->outputs_[index] = out_tensor; }
|
virtual void set_output(mindspore::MSTensor out_tensor, int index) { this->outputs_[index] = out_tensor; }
|
||||||
/// \brief obtain kernel's input tensors.
|
|
||||||
///
|
|
||||||
/// \return input tensors.
|
|
||||||
virtual const std::vector<mindspore::MSTensor> &inputs() { return this->inputs_; }
|
virtual const std::vector<mindspore::MSTensor> &inputs() { return this->inputs_; }
|
||||||
/// \brief obtain kernel's output tensors.
|
|
||||||
///
|
|
||||||
/// \return output tensors.
|
|
||||||
virtual const std::vector<mindspore::MSTensor> &outputs() { return this->outputs_; }
|
virtual const std::vector<mindspore::MSTensor> &outputs() { return this->outputs_; }
|
||||||
/// \brief obtain kernel's name.
|
|
||||||
///
|
|
||||||
/// \return kernel's name.
|
|
||||||
std::string name() const { return this->name_; }
|
std::string name() const { return this->name_; }
|
||||||
/// \brief set kernel's name.
|
|
||||||
///
|
|
||||||
/// \param[in] name define the kernel's name.
|
|
||||||
void set_name(const std::string &name) { this->name_ = name; }
|
void set_name(const std::string &name) { this->name_ = name; }
|
||||||
/// \brief obtain kernel's context.
|
|
||||||
///
|
|
||||||
/// \return kernel's context.
|
|
||||||
const mindspore::Context *context() const { return this->context_; }
|
const mindspore::Context *context() const { return this->context_; }
|
||||||
/// \brief obtain kernel's type.
|
|
||||||
///
|
|
||||||
/// \return kernel's type.
|
|
||||||
virtual schema::PrimitiveType type() const { return type_; }
|
virtual schema::PrimitiveType type() const { return type_; }
|
||||||
/// \brief obtain the primitive of kernel generated by flatbuffers.
|
|
||||||
///
|
|
||||||
/// \return the primitive of kernel generated by flatbuffers.
|
|
||||||
const schema::Primitive *primitive() const { return this->primitive_; }
|
const schema::Primitive *primitive() const { return this->primitive_; }
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
|
|
@ -37,75 +37,32 @@ class Metrics;
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
class Dataset;
|
class Dataset;
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
/// \brief The Model class is used to define a MindSpore model, facilitating computational graph management.
|
|
||||||
class MS_API Model {
|
class MS_API Model {
|
||||||
public:
|
public:
|
||||||
Model();
|
Model();
|
||||||
~Model();
|
~Model();
|
||||||
Model(const Model &) = delete;
|
Model(const Model &) = delete;
|
||||||
void operator=(const Model &) = delete;
|
void operator=(const Model &) = delete;
|
||||||
/// \brief Builds a model so that it can run on a device.
|
|
||||||
///
|
|
||||||
/// \param[in] graph GraphCell is a derivative of Cell. Cell is not available currently. GraphCell can be constructed
|
|
||||||
/// from Graph, for example, model.Build(GraphCell(graph), context).
|
|
||||||
/// \param[in] model_context A context used to store options during execution.
|
|
||||||
/// \param[in] train_cfg A config used by training.
|
|
||||||
///
|
|
||||||
/// \return Status.
|
|
||||||
Status Build(GraphCell graph, const std::shared_ptr<Context> &model_context = nullptr,
|
Status Build(GraphCell graph, const std::shared_ptr<Context> &model_context = nullptr,
|
||||||
const std::shared_ptr<TrainCfg> &train_cfg = nullptr);
|
const std::shared_ptr<TrainCfg> &train_cfg = nullptr);
|
||||||
|
|
||||||
/// \brief Resizes the shapes of inputs.
|
|
||||||
///
|
|
||||||
/// \param[in] inputs A vector that includes all input tensors in order.
|
|
||||||
/// \param[in] dims Defines the new shapes of inputs, should be consistent with inputs.
|
|
||||||
///
|
|
||||||
/// \return Status.
|
|
||||||
Status Resize(const std::vector<MSTensor> &inputs, const std::vector<std::vector<int64_t>> &dims);
|
Status Resize(const std::vector<MSTensor> &inputs, const std::vector<std::vector<int64_t>> &dims);
|
||||||
|
|
||||||
/// \brief Inference model.
|
|
||||||
///
|
|
||||||
/// \param[in] inputs A vector where model inputs are arranged in sequence.
|
|
||||||
/// \param[out] outputs Which is a pointer to a vector. The model outputs are filled in the container in sequence.
|
|
||||||
/// \param[in] before CallBack before predict.
|
|
||||||
/// \param[in] after CallBack after predict.
|
|
||||||
///
|
|
||||||
/// \return Status.
|
|
||||||
Status Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
|
Status Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
|
||||||
const MSKernelCallBack &before = nullptr, const MSKernelCallBack &after = nullptr);
|
const MSKernelCallBack &before = nullptr, const MSKernelCallBack &after = nullptr);
|
||||||
|
|
||||||
/// \brief Obtains all input tensors of the model.
|
|
||||||
///
|
|
||||||
/// \return The vector that includes all input tensors.
|
|
||||||
std::vector<MSTensor> GetInputs();
|
std::vector<MSTensor> GetInputs();
|
||||||
/// \brief Obtains the input tensor of the model by name.
|
|
||||||
///
|
|
||||||
/// \return The input tensor with the given name, if the name is not found, an invalid tensor is returned.
|
|
||||||
inline MSTensor GetInputByTensorName(const std::string &tensor_name);
|
inline MSTensor GetInputByTensorName(const std::string &tensor_name);
|
||||||
|
|
||||||
Status InitMetrics(std::vector<Metrics *> metrics);
|
Status InitMetrics(std::vector<Metrics *> metrics);
|
||||||
std::vector<Metrics *> GetMetrics();
|
std::vector<Metrics *> GetMetrics();
|
||||||
|
|
||||||
/// \brief Obtains all output tensors of the model.
|
|
||||||
///
|
|
||||||
/// \return The vector that includes all output tensors.
|
|
||||||
std::vector<MSTensor> GetOutputs();
|
std::vector<MSTensor> GetOutputs();
|
||||||
/// \brief Obtains names of all output tensors of the model.
|
|
||||||
///
|
|
||||||
/// \return A vector that includes names of all output tensors.
|
|
||||||
inline std::vector<std::string> GetOutputTensorNames();
|
inline std::vector<std::string> GetOutputTensorNames();
|
||||||
/// \brief Obtains the output tensor of the model by name.
|
|
||||||
///
|
|
||||||
/// \return The output tensor with the given name, if the name is not found, an invalid tensor is returned.
|
|
||||||
inline MSTensor GetOutputByTensorName(const std::string &tensor_name);
|
inline MSTensor GetOutputByTensorName(const std::string &tensor_name);
|
||||||
inline std::vector<MSTensor> GetOutputsByNodeName(const std::string &tensor_name);
|
inline std::vector<MSTensor> GetOutputsByNodeName(const std::string &tensor_name);
|
||||||
|
|
||||||
/// \brief Inference model.
|
|
||||||
///
|
|
||||||
/// \param[in] device_type Device type,options are kGPU, kAscend910, etc.
|
|
||||||
/// \param[in] model_type The type of model file, options are ModelType::kMindIR, ModelType::kOM.
|
|
||||||
///
|
|
||||||
/// \return Is supported or not.
|
|
||||||
static bool CheckModelSupport(enum DeviceType device_type, ModelType model_type);
|
static bool CheckModelSupport(enum DeviceType device_type, ModelType model_type);
|
||||||
|
|
||||||
Status SetTrainMode(bool train);
|
Status SetTrainMode(bool train);
|
||||||
|
|
|
@ -27,43 +27,13 @@
|
||||||
#include "include/api/dual_abi_helper.h"
|
#include "include/api/dual_abi_helper.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
/// \brief The Serialization class is used to summarize methods for reading and writing model files.
|
|
||||||
class MS_API Serialization {
|
class MS_API Serialization {
|
||||||
public:
|
public:
|
||||||
/// \brief Loads a model file from memory buffer.
|
|
||||||
///
|
|
||||||
/// \param[in] model_data A buffer filled by model file.
|
|
||||||
/// \param[in] data_size The size of the buffer.
|
|
||||||
/// \param[in] model_type The Type of model file, options are ModelType::kMindIR, ModelType::kOM.
|
|
||||||
/// \param[out] graph The output parameter, an object saves graph data.
|
|
||||||
/// \param[in] dec_key The decryption key, key length is 16, 24, or 32.
|
|
||||||
/// \param[in] dec_mode The decryption mode, optional options are AES-GCM, AES-CBC.
|
|
||||||
///
|
|
||||||
/// \return Status.
|
|
||||||
inline static Status Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph,
|
inline static Status Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph,
|
||||||
const Key &dec_key = {}, const std::string &dec_mode = kDecModeAesGcm);
|
const Key &dec_key = {}, const std::string &dec_mode = kDecModeAesGcm);
|
||||||
|
|
||||||
/// \brief Loads a model file from path, is not supported on MindSpore Lite.
|
|
||||||
///
|
|
||||||
/// \param[in] file The path of model file.
|
|
||||||
/// \param[in] model_type The Type of model file, options are ModelType::kMindIR, ModelType::kOM.
|
|
||||||
/// \param[out] graph The output parameter, an object saves graph data.
|
|
||||||
/// \param[in] dec_key The decryption key, key length is 16, 24, or 32.
|
|
||||||
/// \param[in] dec_mode The decryption mode, optional options are AES-GCM, AES-CBC.
|
|
||||||
///
|
|
||||||
/// \return Status.
|
|
||||||
inline static Status Load(const std::string &file, ModelType model_type, Graph *graph, const Key &dec_key = {},
|
inline static Status Load(const std::string &file, ModelType model_type, Graph *graph, const Key &dec_key = {},
|
||||||
const std::string &dec_mode = kDecModeAesGcm);
|
const std::string &dec_mode = kDecModeAesGcm);
|
||||||
|
|
||||||
/// \brief Load multiple models from multiple files, MindSpore Lite does not provide this feature.
|
|
||||||
///
|
|
||||||
/// \param[in] files The path of model files.
|
|
||||||
/// \param[in] model_type The Type of model file, options are ModelType::kMindIR, ModelType::kOM.
|
|
||||||
/// \param[out] graph The output parameter, an object saves graph data.
|
|
||||||
/// \param[in] dec_key The decryption key, key length is 16, 24, or 32.
|
|
||||||
/// \param[in] dec_mode The decryption mode, optional options are AES-GCM, AES-CBC.
|
|
||||||
///
|
|
||||||
/// \return Status.
|
|
||||||
inline static Status Load(const std::vector<std::string> &files, ModelType model_type, std::vector<Graph> *graphs,
|
inline static Status Load(const std::vector<std::string> &files, ModelType model_type, std::vector<Graph> *graphs,
|
||||||
const Key &dec_key = {}, const std::string &dec_mode = kDecModeAesGcm);
|
const Key &dec_key = {}, const std::string &dec_mode = kDecModeAesGcm);
|
||||||
static Status SetParameters(const std::map<std::string, Buffer> ¶meters, Model *model);
|
static Status SetParameters(const std::map<std::string, Buffer> ¶meters, Model *model);
|
||||||
|
|
|
@ -25,21 +25,11 @@
|
||||||
#include "include/api/dual_abi_helper.h"
|
#include "include/api/dual_abi_helper.h"
|
||||||
#include "include/api/format.h"
|
#include "include/api/format.h"
|
||||||
|
|
||||||
#ifndef MS_API
|
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
#ifdef _MSC_VER
|
|
||||||
#ifdef BUILDING_DLL
|
|
||||||
#define MS_API __declspec(dllexport)
|
#define MS_API __declspec(dllexport)
|
||||||
#else
|
#else
|
||||||
#define MS_API __declspec(dllimport)
|
|
||||||
#endif
|
|
||||||
#else
|
|
||||||
#define MS_API __declspec(dllexport)
|
|
||||||
#endif
|
|
||||||
#else
|
|
||||||
#define MS_API __attribute__((visibility("default")))
|
#define MS_API __attribute__((visibility("default")))
|
||||||
#endif
|
#endif
|
||||||
#endif
|
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
enum ModelType : uint32_t {
|
enum ModelType : uint32_t {
|
||||||
|
@ -74,64 +64,18 @@ struct QuantParam {
|
||||||
};
|
};
|
||||||
|
|
||||||
class Allocator;
|
class Allocator;
|
||||||
/// \brief The MSTensor class defines a tensor in MindSpore.
|
|
||||||
class MS_API MSTensor {
|
class MS_API MSTensor {
|
||||||
public:
|
public:
|
||||||
class Impl;
|
class Impl;
|
||||||
/// \brief Creates a MSTensor object, whose data need to be copied before accessed by Model, must be used in pairs
|
|
||||||
/// with DestroyTensorPtr.
|
|
||||||
///
|
|
||||||
/// \param[in] name The name of the MSTensor.
|
|
||||||
/// \param[in] type The data type of the MSTensor.
|
|
||||||
/// \param[in] shape The shape of the MSTensor.
|
|
||||||
/// \param[in] data The data pointer that points to allocated memory.
|
|
||||||
/// \param[in] data_len The length of the memory, in bytes.
|
|
||||||
///
|
|
||||||
/// \return A pointer of MSTensor.
|
|
||||||
static inline MSTensor *CreateTensor(const std::string &name, DataType type, const std::vector<int64_t> &shape,
|
static inline MSTensor *CreateTensor(const std::string &name, DataType type, const std::vector<int64_t> &shape,
|
||||||
const void *data, size_t data_len) noexcept;
|
const void *data, size_t data_len) noexcept;
|
||||||
/// \brief Creates a MSTensor object, whose data can be directly accessed by Model, must be used in pairs with
|
|
||||||
/// DestroyTensorPtr.
|
|
||||||
///
|
|
||||||
/// \param[in] name The name of the MSTensor.
|
|
||||||
/// \param[in] type The data type of the MSTensor.
|
|
||||||
/// \param[in] shape The shape of the MSTensor.
|
|
||||||
/// \param[in] data The data pointer that points to allocated memory.
|
|
||||||
/// \param[in] data_len The length of the memory, in bytes.
|
|
||||||
///
|
|
||||||
/// \return A pointer of MSTensor.
|
|
||||||
static inline MSTensor *CreateRefTensor(const std::string &name, DataType type, const std::vector<int64_t> &shape,
|
static inline MSTensor *CreateRefTensor(const std::string &name, DataType type, const std::vector<int64_t> &shape,
|
||||||
const void *data, size_t data_len) noexcept;
|
const void *data, size_t data_len) noexcept;
|
||||||
/// \brief Creates a MSTensor object, whose device data can be directly accessed by Model, must be used in pairs with
|
|
||||||
/// DestroyTensorPtr.
|
|
||||||
///
|
|
||||||
/// \param[in] name The name of the MSTensor.
|
|
||||||
/// \param[in] type The data type of the MSTensor.
|
|
||||||
/// \param[in] shape The shape of the MSTensor.
|
|
||||||
/// \param[in] data The data pointer that points to device memory.
|
|
||||||
/// \param[in] data_len The length of the memory, in bytes.
|
|
||||||
///
|
|
||||||
/// \return A pointer of MSTensor.
|
|
||||||
static inline MSTensor *CreateDevTensor(const std::string &name, DataType type, const std::vector<int64_t> &shape,
|
static inline MSTensor *CreateDevTensor(const std::string &name, DataType type, const std::vector<int64_t> &shape,
|
||||||
const void *data, size_t data_len) noexcept;
|
const void *data, size_t data_len) noexcept;
|
||||||
/// \brief Create a string type MSTensor object whose data can be accessed by Model only after being copied, must be
|
|
||||||
/// used in pair with DestroyTensorPtr.
|
|
||||||
///
|
|
||||||
/// \param[in] name The name of the MSTensor.
|
|
||||||
/// \param[in] str A vector container containing several strings.
|
|
||||||
///
|
|
||||||
/// \return A pointer of MSTensor.
|
|
||||||
static inline MSTensor *StringsToTensor(const std::string &name, const std::vector<std::string> &str);
|
static inline MSTensor *StringsToTensor(const std::string &name, const std::vector<std::string> &str);
|
||||||
/// \brief Parse the string type MSTensor object into strings.
|
|
||||||
///
|
|
||||||
/// \param[in] tensor A MSTensor object.
|
|
||||||
///
|
|
||||||
/// \return A vector container containing several strings.
|
|
||||||
static inline std::vector<std::string> TensorToStrings(const MSTensor &tensor);
|
static inline std::vector<std::string> TensorToStrings(const MSTensor &tensor);
|
||||||
/// \brief Destroy an object created by Clone, StringsToTensor, CreateRefTensor, CreateDevTensor or CreateTensor. Do
|
|
||||||
/// not use it to destroy MSTensor from other sources.
|
|
||||||
///
|
|
||||||
/// \param[in] tensor A MSTensor object.
|
|
||||||
static void DestroyTensorPtr(MSTensor *tensor) noexcept;
|
static void DestroyTensorPtr(MSTensor *tensor) noexcept;
|
||||||
|
|
||||||
MSTensor();
|
MSTensor();
|
||||||
|
@ -141,51 +85,19 @@ class MS_API MSTensor {
|
||||||
explicit MSTensor(std::nullptr_t);
|
explicit MSTensor(std::nullptr_t);
|
||||||
~MSTensor();
|
~MSTensor();
|
||||||
|
|
||||||
/// \brief Obtains the name of the MSTensor.
|
|
||||||
///
|
|
||||||
/// \return The name of the MSTensor.
|
|
||||||
inline std::string Name() const;
|
inline std::string Name() const;
|
||||||
/// \brief Obtains the data type of the MSTensor.
|
|
||||||
///
|
|
||||||
/// \return The data type of the MSTensor.
|
|
||||||
enum DataType DataType() const;
|
enum DataType DataType() const;
|
||||||
/// \brief Obtains the shape of the MSTensor.
|
|
||||||
///
|
|
||||||
/// \return The shape of the MSTensor.
|
|
||||||
const std::vector<int64_t> &Shape() const;
|
const std::vector<int64_t> &Shape() const;
|
||||||
/// \brief Obtains the number of elements of the MSTensor.
|
|
||||||
///
|
|
||||||
/// \return The number of elements of the MSTensor.
|
|
||||||
int64_t ElementNum() const;
|
int64_t ElementNum() const;
|
||||||
|
|
||||||
/// \brief Obtains a shared pointer to the copy of data of the MSTensor. The data can be read on host.
|
|
||||||
///
|
|
||||||
/// \return A shared pointer to the copy of data of the MSTensor.
|
|
||||||
std::shared_ptr<const void> Data() const;
|
std::shared_ptr<const void> Data() const;
|
||||||
/// \brief Obtains the pointer to the data of the MSTensor. If the MSTensor is a device tensor, the data cannot be
|
|
||||||
/// accessed directly on host.
|
|
||||||
///
|
|
||||||
/// \return A pointer to the data of the MSTensor.
|
|
||||||
void *MutableData();
|
void *MutableData();
|
||||||
/// \brief Obtains the length of the data of the MSTensor, in bytes.
|
|
||||||
///
|
|
||||||
/// \return The length of the data of the MSTensor, in bytes.
|
|
||||||
size_t DataSize() const;
|
size_t DataSize() const;
|
||||||
/// \brief Gets the boolean value that indicates whether the memory of MSTensor is on device.
|
|
||||||
///
|
|
||||||
/// \return The boolean value that indicates whether the memory of MSTensor is on device.
|
|
||||||
bool IsDevice() const;
|
bool IsDevice() const;
|
||||||
/// \brief Gets a deep copy of the MSTensor, must be used in pair with DestroyTensorPtr.
|
|
||||||
///
|
|
||||||
/// \return A pointer points to a deep copy of the MSTensor.
|
|
||||||
MSTensor *Clone() const;
|
MSTensor *Clone() const;
|
||||||
/// \brief Gets the boolean value that indicates whether the MSTensor is valid.
|
|
||||||
///
|
|
||||||
/// \return The boolean value that indicates whether the MSTensor is valid.
|
|
||||||
bool operator==(std::nullptr_t) const;
|
bool operator==(std::nullptr_t) const;
|
||||||
/// \brief Gets the boolean value that indicates whether the MSTensor is valid.
|
|
||||||
///
|
|
||||||
/// \return The boolean value that indicates whether the MSTensor is valid.
|
|
||||||
bool operator!=(std::nullptr_t) const;
|
bool operator!=(std::nullptr_t) const;
|
||||||
bool operator==(const MSTensor &tensor) const;
|
bool operator==(const MSTensor &tensor) const;
|
||||||
|
|
||||||
|
|
|
@ -23,7 +23,6 @@ from itertools import repeat, zip_longest
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from mindspore import context
|
|
||||||
from mindspore import log as logger
|
from mindspore import log as logger
|
||||||
from mindspore.common import dtype as mstype
|
from mindspore.common import dtype as mstype
|
||||||
from mindspore._c_expression import Tensor as Tensor_
|
from mindspore._c_expression import Tensor as Tensor_
|
||||||
|
@ -148,7 +147,7 @@ def check_number(arg_value, value, rel, arg_type=int, arg_name=None, prim_name=N
|
||||||
Check argument integer.
|
Check argument integer.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
- number = check_number(number, 0, Rel.GE, "number", None) # number >= 0
|
- number = check_int(number, 0, Rel.GE, "number", None) # number >= 0
|
||||||
"""
|
"""
|
||||||
rel_fn = Rel.get_fns(rel)
|
rel_fn = Rel.get_fns(rel)
|
||||||
prim_name = f'in `{prim_name}`' if prim_name else ''
|
prim_name = f'in `{prim_name}`' if prim_name else ''
|
||||||
|
@ -847,10 +846,6 @@ class Validator:
|
||||||
"""Returns an empty Tensor."""
|
"""Returns an empty Tensor."""
|
||||||
return Tensor_(dtype, shape)
|
return Tensor_(dtype, shape)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def check_type_support(dtype, device, supported_dtypes):
|
|
||||||
return dtype in supported_dtypes or not context.get_context('device_target') == device
|
|
||||||
|
|
||||||
|
|
||||||
def check_input_format(input_param):
|
def check_input_format(input_param):
|
||||||
"""Judge input format."""
|
"""Judge input format."""
|
||||||
|
|
|
@ -18,6 +18,7 @@ from .addn import AddN
|
||||||
from .assign_add import AssignAdd
|
from .assign_add import AssignAdd
|
||||||
from .batchnorm import BatchNorm
|
from .batchnorm import BatchNorm
|
||||||
from .batchnorm_grad import BatchNormGrad
|
from .batchnorm_grad import BatchNormGrad
|
||||||
|
from .bias_add import BiasAdd
|
||||||
from .bias_add_grad import BiasAddGrad
|
from .bias_add_grad import BiasAddGrad
|
||||||
from .clip_by_norm_no_div_sum import ClipByNormNoDivSum
|
from .clip_by_norm_no_div_sum import ClipByNormNoDivSum
|
||||||
from .conv2d import Conv2D
|
from .conv2d import Conv2D
|
||||||
|
@ -25,6 +26,7 @@ from .complex import CAbs, CAdd, CDiv, CMul, CSub
|
||||||
from .dropout_grad import DropoutGrad
|
from .dropout_grad import DropoutGrad
|
||||||
from .equal_count import EqualCount
|
from .equal_count import EqualCount
|
||||||
from .erfc import Erfc
|
from .erfc import Erfc
|
||||||
|
from .expand_dims import ExpandDims
|
||||||
from .fused_adam import FusedAdam
|
from .fused_adam import FusedAdam
|
||||||
from .fused_adam_weight_decay import FusedAdamWeightDecay
|
from .fused_adam_weight_decay import FusedAdamWeightDecay
|
||||||
from .fused_mul_add import FusedMulAdd
|
from .fused_mul_add import FusedMulAdd
|
||||||
|
@ -49,7 +51,6 @@ from .sigmoid import Sigmoid
|
||||||
from .sigmoid_cross_entropy_with_logits import SigmoidCrossEntropyWithLogits
|
from .sigmoid_cross_entropy_with_logits import SigmoidCrossEntropyWithLogits
|
||||||
from .sigmoid_cross_entropy_with_logits_grad import SigmoidCrossEntropyWithLogitsGrad
|
from .sigmoid_cross_entropy_with_logits_grad import SigmoidCrossEntropyWithLogitsGrad
|
||||||
from .sigmoid_grad import SigmoidGrad
|
from .sigmoid_grad import SigmoidGrad
|
||||||
from .slice import Slice
|
|
||||||
from .softmax import Softmax
|
from .softmax import Softmax
|
||||||
from .softmax_cross_entropy_with_logits import SoftmaxCrossEntropyWithLogits
|
from .softmax_cross_entropy_with_logits import SoftmaxCrossEntropyWithLogits
|
||||||
from .softmax_grad_ext import SoftmaxGradExt
|
from .softmax_grad_ext import SoftmaxGradExt
|
||||||
|
|
|
@ -80,9 +80,6 @@ class Expander:
|
||||||
|
|
||||||
class ExpanderInfoValidator:
|
class ExpanderInfoValidator:
|
||||||
"""ExpanderInfoValidator is the utility class which defines the validator decorator for expanders"""
|
"""ExpanderInfoValidator is the utility class which defines the validator decorator for expanders"""
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
"""Init"""
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _add_check_function(kls, func):
|
def _add_check_function(kls, func):
|
||||||
"""
|
"""
|
||||||
|
@ -201,8 +198,8 @@ def to_frac_z_axis(ori_shape, ori_axis):
|
||||||
return frac_z_axis
|
return frac_z_axis
|
||||||
|
|
||||||
|
|
||||||
def infer_shape_from_fractalnz(fractal):
|
def infer_shape_from_fractalNz(fractal):
|
||||||
"get original shape from fractalnz shape"
|
"get original shape from fractalNz shape"
|
||||||
shape = []
|
shape = []
|
||||||
dims = len(fractal)
|
dims = len(fractal)
|
||||||
batch = dims - 4
|
batch = dims - 4
|
||||||
|
|
|
@ -24,7 +24,6 @@ from .expand_dims import ExpandDims
|
||||||
@VLD.check_attrs('is_training', 'momentum', 'epsilon')
|
@VLD.check_attrs('is_training', 'momentum', 'epsilon')
|
||||||
class BatchNorm(Expander):
|
class BatchNorm(Expander):
|
||||||
"""BatchNorm expander"""
|
"""BatchNorm expander"""
|
||||||
|
|
||||||
def _expand(self, graph_builder):
|
def _expand(self, graph_builder):
|
||||||
# get op info
|
# get op info
|
||||||
input_x = self.inputs[0]
|
input_x = self.inputs[0]
|
||||||
|
@ -43,40 +42,6 @@ class BatchNorm(Expander):
|
||||||
input_x = graph_builder.emit('Cast', [input_x], attrs={'dst_type': input_x_new_type})
|
input_x = graph_builder.emit('Cast', [input_x], attrs={'dst_type': input_x_new_type})
|
||||||
|
|
||||||
if self.attrs['is_training']:
|
if self.attrs['is_training']:
|
||||||
self.inputs[0] = input_x
|
|
||||||
res_y, mean_res, variance_res, mean_muls, y_sqrt_rec = self._bn_train(graph_builder)
|
|
||||||
if input_x_new_type != input_x_ori_type:
|
|
||||||
res_y = graph_builder.emit('Cast', [res_y], attrs={'dst_type': input_x_ori_type})
|
|
||||||
return res_y, mean_res, variance_res, mean_muls, y_sqrt_rec
|
|
||||||
# infer mode
|
|
||||||
if input_x.data_format in (DF.DEFAULT, DF.NCHW):
|
|
||||||
input_mean = graph_builder.emit(
|
|
||||||
'Reshape', [input_mean], attrs={'shape': ExpandDims.infer_shape(input_mean.shape, [-1, -1])})
|
|
||||||
input_scale = graph_builder.emit(
|
|
||||||
'Reshape', [input_scale], attrs={'shape': ExpandDims.infer_shape(input_scale.shape, [-1, -1])})
|
|
||||||
input_offset = graph_builder.emit(
|
|
||||||
'Reshape', [input_offset], attrs={'shape': ExpandDims.infer_shape(input_offset.shape, [-1, -1])})
|
|
||||||
x_sub = graph_builder.emit('Sub', [input_x, input_mean])
|
|
||||||
x_sub_mul = graph_builder.emit('Mul', [input_scale, x_sub])
|
|
||||||
var_add = graph_builder.emit('Add', [epsilon_v, input_variance])
|
|
||||||
var_add_sqrt = graph_builder.emit('Sqrt', [var_add])
|
|
||||||
if input_x.data_format in (DF.DEFAULT, DF.NCHW):
|
|
||||||
var_add_sqrt = graph_builder.emit(
|
|
||||||
'Reshape', [var_add_sqrt], attrs={'shape': ExpandDims.infer_shape(var_add_sqrt.shape, [-1, -1])})
|
|
||||||
x_div = graph_builder.emit('RealDiv', [x_sub_mul, var_add_sqrt])
|
|
||||||
res_y = graph_builder.emit('Add', [input_offset, x_div])
|
|
||||||
if input_x_new_type != input_x_ori_type:
|
|
||||||
res_y = graph_builder.emit('Cast', [res_y], attrs={'dst_type': input_x_ori_type})
|
|
||||||
return res_y, var_add, var_add, var_add, var_add
|
|
||||||
|
|
||||||
def _bn_train(self, graph_builder):
|
|
||||||
"""expand BatchNorm for training mode"""
|
|
||||||
input_x = self.inputs[0]
|
|
||||||
input_scale = self.inputs[1]
|
|
||||||
input_offset = self.inputs[2]
|
|
||||||
input_mean = self.inputs[3]
|
|
||||||
input_variance = self.inputs[4]
|
|
||||||
epsilon_v = graph_builder.value(input_scale.dtype, self.attrs['epsilon'])
|
|
||||||
reduce_axis = ()
|
reduce_axis = ()
|
||||||
shape_x = input_x.shape
|
shape_x = input_x.shape
|
||||||
if input_x.data_format == DF.NHWC:
|
if input_x.data_format == DF.NHWC:
|
||||||
|
@ -152,4 +117,26 @@ class BatchNorm(Expander):
|
||||||
variance_res = graph_builder.emit(
|
variance_res = graph_builder.emit(
|
||||||
'InplaceAssign', [input_variance, updated_moving_variance, updated_moving_variance],
|
'InplaceAssign', [input_variance, updated_moving_variance, updated_moving_variance],
|
||||||
attrs={'fake_output': True})
|
attrs={'fake_output': True})
|
||||||
|
if input_x_new_type != input_x_ori_type:
|
||||||
|
res_y = graph_builder.emit('Cast', [res_y], attrs={'dst_type': input_x_ori_type})
|
||||||
return res_y, mean_res, variance_res, mean_muls, y_sqrt_rec
|
return res_y, mean_res, variance_res, mean_muls, y_sqrt_rec
|
||||||
|
# infer mode
|
||||||
|
if input_x.data_format in (DF.DEFAULT, DF.NCHW):
|
||||||
|
input_mean = graph_builder.emit(
|
||||||
|
'Reshape', [input_mean], attrs={'shape': ExpandDims.infer_shape(input_mean.shape, [-1, -1])})
|
||||||
|
input_scale = graph_builder.emit(
|
||||||
|
'Reshape', [input_scale], attrs={'shape': ExpandDims.infer_shape(input_scale.shape, [-1, -1])})
|
||||||
|
input_offset = graph_builder.emit(
|
||||||
|
'Reshape', [input_offset], attrs={'shape': ExpandDims.infer_shape(input_offset.shape, [-1, -1])})
|
||||||
|
x_sub = graph_builder.emit('Sub', [input_x, input_mean])
|
||||||
|
x_sub_mul = graph_builder.emit('Mul', [input_scale, x_sub])
|
||||||
|
var_add = graph_builder.emit('Add', [epsilon_v, input_variance])
|
||||||
|
var_add_sqrt = graph_builder.emit('Sqrt', [var_add])
|
||||||
|
if input_x.data_format in (DF.DEFAULT, DF.NCHW):
|
||||||
|
var_add_sqrt = graph_builder.emit(
|
||||||
|
'Reshape', [var_add_sqrt], attrs={'shape': ExpandDims.infer_shape(var_add_sqrt.shape, [-1, -1])})
|
||||||
|
x_div = graph_builder.emit('RealDiv', [x_sub_mul, var_add_sqrt])
|
||||||
|
res_y = graph_builder.emit('Add', [input_offset, x_div])
|
||||||
|
if input_x_new_type != input_x_ori_type:
|
||||||
|
res_y = graph_builder.emit('Cast', [res_y], attrs={'dst_type': input_x_ori_type})
|
||||||
|
return res_y, var_add, var_add, var_add, var_add
|
||||||
|
|
|
@ -17,14 +17,12 @@ from mindspore._extends.graph_kernel.model.model import DataFormat as DF
|
||||||
from ._utils import Expander, ExpanderInfoValidator as VLD
|
from ._utils import Expander, ExpanderInfoValidator as VLD
|
||||||
from .expand_dims import ExpandDims
|
from .expand_dims import ExpandDims
|
||||||
|
|
||||||
|
|
||||||
@VLD.add_format(DF.NHWC, DF.NHWC, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
|
@VLD.add_format(DF.NHWC, DF.NHWC, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
|
||||||
@VLD.add_format(DF.NCHW, DF.NCHW, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
|
@VLD.add_format(DF.NCHW, DF.NCHW, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
|
||||||
@VLD.add_format(DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
|
@VLD.add_format(DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
|
||||||
@VLD.check_attrs('is_training', 'epsilon')
|
@VLD.check_attrs('is_training', 'epsilon')
|
||||||
class BatchNormGrad(Expander):
|
class BatchNormGrad(Expander):
|
||||||
"""BatchNormGrad expander"""
|
"""BatchNormGrad expander"""
|
||||||
|
|
||||||
def _expand(self, graph_builder):
|
def _expand(self, graph_builder):
|
||||||
# get op info
|
# get op info
|
||||||
input_dy = self.inputs[0]
|
input_dy = self.inputs[0]
|
||||||
|
|
|
@ -0,0 +1,48 @@
|
||||||
|
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ===========================================================================
|
||||||
|
"""generate json desc for bias_add"""
|
||||||
|
from mindspore._extends.graph_kernel.model.model import DataFormat as DF
|
||||||
|
from ._utils import Expander, ExpanderInfoValidator as VLD
|
||||||
|
from .expand_dims import ExpandDims
|
||||||
|
|
||||||
|
|
||||||
|
@VLD.add_format(DF.DEFAULT, DF.DEFAULT)
|
||||||
|
@VLD.add_format(DF.NCHW, DF.DEFAULT)
|
||||||
|
@VLD.add_format(DF.NHWC, DF.DEFAULT)
|
||||||
|
class BiasAdd(Expander):
|
||||||
|
"""BiasAdd expander"""
|
||||||
|
|
||||||
|
def _expand(self, graph_builder):
|
||||||
|
input_x, input_y = self.inputs
|
||||||
|
|
||||||
|
if input_x.data_format == DF.NCHW:
|
||||||
|
input_y_expand = graph_builder.emit(
|
||||||
|
'Reshape', [input_y], attrs={'shape': ExpandDims.infer_shape(input_y.shape, [1, 2])})
|
||||||
|
result = graph_builder.emit('Add', [input_x, input_y_expand])
|
||||||
|
elif input_x.data_format == DF.DEFAULT:
|
||||||
|
if len(input_x.shape) == 2:
|
||||||
|
result = graph_builder.emit('Add', [input_x, input_y])
|
||||||
|
elif len(input_x.shape) == 3:
|
||||||
|
input_y_expand = graph_builder.emit(
|
||||||
|
'Reshape', [input_y], attrs={'shape': ExpandDims.infer_shape(input_y.shape, 1)})
|
||||||
|
result = graph_builder.emit('Add', [input_x, input_y_expand])
|
||||||
|
else: # len == 4
|
||||||
|
input_y_expand = graph_builder.emit(
|
||||||
|
'Reshape', [input_y], attrs={'shape': ExpandDims.infer_shape(input_y.shape, [1, 2])})
|
||||||
|
result = graph_builder.emit('Add', [input_x, input_y_expand])
|
||||||
|
else: # NHWC
|
||||||
|
result = graph_builder.emit('Add', [input_x, input_y])
|
||||||
|
|
||||||
|
return result
|
|
@ -15,7 +15,6 @@
|
||||||
"""generate json desc for FusedMulAdd"""
|
"""generate json desc for FusedMulAdd"""
|
||||||
from ._utils import Expander
|
from ._utils import Expander
|
||||||
|
|
||||||
|
|
||||||
class FusedMulAdd(Expander):
|
class FusedMulAdd(Expander):
|
||||||
"""FusedMulAdd expander"""
|
"""FusedMulAdd expander"""
|
||||||
|
|
||||||
|
|
|
@ -15,15 +15,13 @@
|
||||||
"""generate json desc for LayerNorm"""
|
"""generate json desc for LayerNorm"""
|
||||||
from mindspore._extends.graph_kernel.model.model import DataFormat as DF
|
from mindspore._extends.graph_kernel.model.model import DataFormat as DF
|
||||||
from ._utils import Expander, ExpanderInfoValidator as VLD
|
from ._utils import Expander, ExpanderInfoValidator as VLD
|
||||||
from ._utils import infer_shape_from_fractalnz, get_reduced_ori_shape, to_frac_z_axis
|
from ._utils import infer_shape_from_fractalNz, get_reduced_ori_shape, to_frac_z_axis
|
||||||
|
|
||||||
|
|
||||||
@VLD.add_format(DF.FRAC_NZ, DF.DEFAULT, DF.DEFAULT)
|
@VLD.add_format(DF.FRAC_NZ, DF.DEFAULT, DF.DEFAULT)
|
||||||
@VLD.add_format(DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
|
@VLD.add_format(DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
|
||||||
@VLD.check_attrs('begin_norm_axis', 'begin_params_axis', 'epsilon')
|
@VLD.check_attrs('begin_norm_axis', 'begin_params_axis', 'epsilon')
|
||||||
class LayerNorm(Expander):
|
class LayerNorm(Expander):
|
||||||
"""LayerNorm expander"""
|
"""LayerNorm expander"""
|
||||||
|
|
||||||
def _expand(self, graph_builder):
|
def _expand(self, graph_builder):
|
||||||
input_x, input_gamma, input_beta = self.inputs
|
input_x, input_gamma, input_beta = self.inputs
|
||||||
processor = self.processor
|
processor = self.processor
|
||||||
|
@ -38,7 +36,7 @@ class LayerNorm(Expander):
|
||||||
|
|
||||||
ori_shape_x = input_x.shape
|
ori_shape_x = input_x.shape
|
||||||
if input_x.data_format == DF.FRAC_NZ:
|
if input_x.data_format == DF.FRAC_NZ:
|
||||||
ori_shape_x = infer_shape_from_fractalnz(ori_shape_x)
|
ori_shape_x = infer_shape_from_fractalNz(ori_shape_x)
|
||||||
|
|
||||||
# Calculate the scaling ratio of the average
|
# Calculate the scaling ratio of the average
|
||||||
if begin_norm_axis < 0:
|
if begin_norm_axis < 0:
|
||||||
|
|
|
@ -17,7 +17,6 @@ from mindspore._extends.graph_kernel.model.model import DataFormat as DF
|
||||||
from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException
|
from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException
|
||||||
from ._utils import Expander, ExpanderInfoValidator as VLD
|
from ._utils import Expander, ExpanderInfoValidator as VLD
|
||||||
|
|
||||||
|
|
||||||
@VLD.check_attrs('transpose_a', 'transpose_b', 'left_format', 'right_format')
|
@VLD.check_attrs('transpose_a', 'transpose_b', 'left_format', 'right_format')
|
||||||
class MatMul(Expander):
|
class MatMul(Expander):
|
||||||
"""
|
"""
|
||||||
|
@ -25,7 +24,7 @@ class MatMul(Expander):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, expand_info):
|
def __init__(self, expand_info):
|
||||||
super(MatMul, self).__init__(expand_info)
|
super().__init__(expand_info)
|
||||||
self.transpose_a = self.attrs['transpose_a']
|
self.transpose_a = self.attrs['transpose_a']
|
||||||
self.transpose_b = self.attrs['transpose_b']
|
self.transpose_b = self.attrs['transpose_b']
|
||||||
self.left_format = self.attrs['left_format']
|
self.left_format = self.attrs['left_format']
|
||||||
|
@ -48,28 +47,28 @@ class MatMul(Expander):
|
||||||
if input_num < 2:
|
if input_num < 2:
|
||||||
raise GKException("matul inputs number should bigger than 1, but got {}.".format(input_num))
|
raise GKException("matul inputs number should bigger than 1, but got {}.".format(input_num))
|
||||||
|
|
||||||
def _expand(self, graph_builder):
|
def _trans_shape(self, shape):
|
||||||
def transpose(shape):
|
|
||||||
trans_shape = list(shape)
|
trans_shape = list(shape)
|
||||||
trans_shape[-2] = shape[-1]
|
trans_shape[-2] = shape[-1]
|
||||||
trans_shape[-1] = shape[-2]
|
trans_shape[-1] = shape[-2]
|
||||||
return trans_shape
|
return trans_shape
|
||||||
|
|
||||||
|
def _expand(self, graph_builder):
|
||||||
if not self._optimize_to_mul():
|
if not self._optimize_to_mul():
|
||||||
raise GKException("MatMul/BatchMatMul do not need to be replaced by Mul")
|
raise GKException("MatMul/BatchMatMul do not need to be replaced by Mul")
|
||||||
#Matmul is replaced by Mul([b m k], [b k n]) when k==1
|
#Matmul is replaced by Mul([b m k], [b k n]) when k==1
|
||||||
input_a = self.inputs[0]
|
input_a = self.inputs[0]
|
||||||
input_b = self.inputs[1]
|
input_b = self.inputs[1]
|
||||||
if self.transpose_a:
|
if self.transpose_a:
|
||||||
shape_a_trans = transpose(self.shape_a)
|
shape_a_trans = self._trans_shape(self.shape_a)
|
||||||
input_a = graph_builder.emit('Reshape', [input_a], attrs={'shape': shape_a_trans})
|
input_a = graph_builder.emit('Reshape', [input_a], attrs={'shape': shape_a_trans})
|
||||||
if self.transpose_b:
|
if self.transpose_b:
|
||||||
shape_b_trans = transpose(self.shape_b)
|
shape_b_trans = self._trans_shape(self.shape_b)
|
||||||
input_b = graph_builder.emit('Reshape', [input_b], attrs={'shape': shape_b_trans})
|
input_b = graph_builder.emit('Reshape', [input_b], attrs={'shape': shape_b_trans})
|
||||||
result = graph_builder.emit('Mul', [input_a, input_b])
|
result = graph_builder.emit('Mul', [input_a, input_b])
|
||||||
if 'dst_type' in self.attrs and self.inputs[0].dtype != self.attrs['dst_type']:
|
if 'dst_type' in self.attrs and self.inputs[0].dtype != self.attrs['dst_type']:
|
||||||
result = graph_builder.emit('Cast', [result], attrs={'dst_type': self.attrs['dst_type']})
|
result = graph_builder.emit('Cast', [result], attrs={'dst_type': self.attrs['dst_type']})
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
class BatchMatMul(MatMul):
|
class BatchMatMul(MatMul):
|
||||||
"""BatchMatMul expander"""
|
"""BatchMatMul expander"""
|
||||||
|
|
|
@ -24,7 +24,7 @@ class MinimumGrad(Expander):
|
||||||
def _check(self):
|
def _check(self):
|
||||||
if not self.attrs.get('grad_x', True) and not self.attrs.get('grad_y', True):
|
if not self.attrs.get('grad_x', True) and not self.attrs.get('grad_y', True):
|
||||||
raise GKException("both grad_x and grad_y are False.")
|
raise GKException("both grad_x and grad_y are False.")
|
||||||
return super(MinimumGrad, self)._check()
|
return super()._check()
|
||||||
|
|
||||||
def _expand(self, graph_builder):
|
def _expand(self, graph_builder):
|
||||||
input_x, input_y, input_dout = self.inputs
|
input_x, input_y, input_dout = self.inputs
|
||||||
|
@ -34,8 +34,7 @@ class MinimumGrad(Expander):
|
||||||
dx = graph_builder.emit('Mul', [le_result, input_dout])
|
dx = graph_builder.emit('Mul', [le_result, input_dout])
|
||||||
dy = graph_builder.emit('Sub', [input_dout, dx])
|
dy = graph_builder.emit('Sub', [input_dout, dx])
|
||||||
|
|
||||||
# for minimumgrad op, output_shape should be equal to input_shape,
|
# for minimumgrad op, output_shape should be equal to input_shape, but some elementwise operating may broadcast input_shape
|
||||||
# but some elementwise operating may broadcast input_shape
|
|
||||||
# then output_shape not equal to original input_shape, so need to reduce output to let them equal
|
# then output_shape not equal to original input_shape, so need to reduce output to let them equal
|
||||||
reduce_axis_x = self.get_reduce_axis(input_x.shape, dx.shape)
|
reduce_axis_x = self.get_reduce_axis(input_x.shape, dx.shape)
|
||||||
reduce_axis_y = self.get_reduce_axis(input_y.shape, dy.shape)
|
reduce_axis_y = self.get_reduce_axis(input_y.shape, dy.shape)
|
||||||
|
|
|
@ -15,8 +15,7 @@
|
||||||
"""generate json desc for softmax"""
|
"""generate json desc for softmax"""
|
||||||
from mindspore._extends.graph_kernel.model.model import DataFormat as DF
|
from mindspore._extends.graph_kernel.model.model import DataFormat as DF
|
||||||
from ._utils import Expander, ExpanderInfoValidator as VLD
|
from ._utils import Expander, ExpanderInfoValidator as VLD
|
||||||
from ._utils import infer_shape_from_fractalnz, get_reduced_ori_shape, to_frac_z_axis
|
from ._utils import infer_shape_from_fractalNz, get_reduced_ori_shape, to_frac_z_axis
|
||||||
|
|
||||||
|
|
||||||
@VLD.add_format(DF.FRAC_NZ)
|
@VLD.add_format(DF.FRAC_NZ)
|
||||||
@VLD.add_format(DF.DEFAULT)
|
@VLD.add_format(DF.DEFAULT)
|
||||||
|
@ -31,7 +30,7 @@ class Softmax(Expander):
|
||||||
|
|
||||||
ori_shape = input_x.shape
|
ori_shape = input_x.shape
|
||||||
if input_x.data_format == DF.FRAC_NZ:
|
if input_x.data_format == DF.FRAC_NZ:
|
||||||
ori_shape = infer_shape_from_fractalnz(input_x.shape)
|
ori_shape = infer_shape_from_fractalNz(input_x.shape)
|
||||||
|
|
||||||
for i, _ in enumerate(list(axis)):
|
for i, _ in enumerate(list(axis)):
|
||||||
if axis[i] < 0:
|
if axis[i] < 0:
|
||||||
|
|
|
@ -15,8 +15,7 @@
|
||||||
"""generate json desc for SoftmaxGradExt"""
|
"""generate json desc for SoftmaxGradExt"""
|
||||||
from mindspore._extends.graph_kernel.model.model import DataFormat as DF
|
from mindspore._extends.graph_kernel.model.model import DataFormat as DF
|
||||||
from ._utils import Expander, ExpanderInfoValidator as VLD
|
from ._utils import Expander, ExpanderInfoValidator as VLD
|
||||||
from ._utils import infer_shape_from_fractalnz, get_reduced_ori_shape, to_frac_z_axis
|
from ._utils import infer_shape_from_fractalNz, get_reduced_ori_shape, to_frac_z_axis
|
||||||
|
|
||||||
|
|
||||||
@VLD.add_format(DF.FRAC_NZ, DF.FRAC_NZ, DF.DEFAULT)
|
@VLD.add_format(DF.FRAC_NZ, DF.FRAC_NZ, DF.DEFAULT)
|
||||||
@VLD.add_format(DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
|
@VLD.add_format(DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
|
||||||
|
@ -30,7 +29,7 @@ class SoftmaxGradExt(Expander):
|
||||||
|
|
||||||
ori_shape = x.shape
|
ori_shape = x.shape
|
||||||
if x.data_format == DF.FRAC_NZ:
|
if x.data_format == DF.FRAC_NZ:
|
||||||
ori_shape = infer_shape_from_fractalnz(ori_shape)
|
ori_shape = infer_shape_from_fractalNz(ori_shape)
|
||||||
if not axis:
|
if not axis:
|
||||||
axis = []
|
axis = []
|
||||||
for i, _ in enumerate(ori_shape):
|
for i, _ in enumerate(ori_shape):
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
"""generate json desc for SquareSumV1"""
|
"""generate json desc for SquareSumV1"""
|
||||||
from mindspore._extends.graph_kernel.model.model import DataFormat as DF
|
from mindspore._extends.graph_kernel.model.model import DataFormat as DF
|
||||||
from ._utils import Expander, ExpanderInfoValidator as VLD
|
from ._utils import Expander, ExpanderInfoValidator as VLD
|
||||||
from ._utils import infer_shape_from_fractalnz, get_reduced_ori_shape, to_frac_z_axis
|
from ._utils import infer_shape_from_fractalNz, get_reduced_ori_shape, to_frac_z_axis
|
||||||
|
|
||||||
|
|
||||||
@VLD.add_format(DF.FRAC_NZ)
|
@VLD.add_format(DF.FRAC_NZ)
|
||||||
|
@ -30,7 +30,7 @@ class SquareSumV1(Expander):
|
||||||
|
|
||||||
ori_shape = x.shape
|
ori_shape = x.shape
|
||||||
if x.data_format == DF.FRAC_NZ:
|
if x.data_format == DF.FRAC_NZ:
|
||||||
ori_shape = infer_shape_from_fractalnz(ori_shape)
|
ori_shape = infer_shape_from_fractalNz(ori_shape)
|
||||||
if not axis:
|
if not axis:
|
||||||
axis = []
|
axis = []
|
||||||
for i, _ in enumerate(ori_shape):
|
for i, _ in enumerate(ori_shape):
|
||||||
|
|
|
@ -17,8 +17,6 @@ from .model import PrimLib
|
||||||
|
|
||||||
|
|
||||||
class ParalGain:
|
class ParalGain:
|
||||||
"""Paral Gain"""
|
|
||||||
|
|
||||||
def __init__(self, fusion_type, bottleneck, gain, block_assign, type_info):
|
def __init__(self, fusion_type, bottleneck, gain, block_assign, type_info):
|
||||||
self.fusion_type = fusion_type
|
self.fusion_type = fusion_type
|
||||||
self.bottleneck = bottleneck
|
self.bottleneck = bottleneck
|
||||||
|
@ -43,9 +41,7 @@ class ScheduleAnalyzer:
|
||||||
self.ops = graph.ops
|
self.ops = graph.ops
|
||||||
self.dom_op = [out.op for out in outputs]
|
self.dom_op = [out.op for out in outputs]
|
||||||
|
|
||||||
@staticmethod
|
def prod(self, shape):
|
||||||
def prod(shape):
|
|
||||||
"""Compute shape product"""
|
|
||||||
res = shape[0]
|
res = shape[0]
|
||||||
for i in range(1, len(shape)):
|
for i in range(1, len(shape)):
|
||||||
res = res * shape[i]
|
res = res * shape[i]
|
||||||
|
@ -291,5 +287,4 @@ def block_parallel_estimate(graphs):
|
||||||
|
|
||||||
|
|
||||||
def parallel_estimate(graphs):
|
def parallel_estimate(graphs):
|
||||||
"""Estimate parallel gain"""
|
|
||||||
return block_parallel_estimate(graphs)
|
return block_parallel_estimate(graphs)
|
||||||
|
|
|
@ -13,6 +13,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ===========================================================================
|
# ===========================================================================
|
||||||
"""Cost model splitter"""
|
"""Cost model splitter"""
|
||||||
|
import os
|
||||||
from functools import reduce as prod_reduce
|
from functools import reduce as prod_reduce
|
||||||
from mindspore import log as logger
|
from mindspore import log as logger
|
||||||
from .model import PrimLib, Graph, Tensor, Operator
|
from .model import PrimLib, Graph, Tensor, Operator
|
||||||
|
@ -38,24 +39,20 @@ class GraphSplitByPattern:
|
||||||
def sync(self, x, y):
|
def sync(self, x, y):
|
||||||
"""sync from y to x"""
|
"""sync from y to x"""
|
||||||
for i in self.alive:
|
for i in self.alive:
|
||||||
self._link(self.map[y][i], x, i)
|
if self.map[y][i] and not self.map[x][i]:
|
||||||
|
self.map[x][i] = True
|
||||||
def _link(self, cond, f, t):
|
|
||||||
"""link from `f` to `t`"""
|
|
||||||
if cond:
|
|
||||||
self.map[f][t] = True
|
|
||||||
|
|
||||||
def fuse(self, x, y):
|
def fuse(self, x, y):
|
||||||
"""fuse y to x"""
|
"""fuse y to x"""
|
||||||
for i in self.alive:
|
for i in self.alive:
|
||||||
# i is the succeeding node of y, links the x's previous nodes to i
|
|
||||||
if self.map[y][i] and not self.map[x][i]:
|
if self.map[y][i] and not self.map[x][i]:
|
||||||
for pre in self.alive:
|
for pre in self.alive:
|
||||||
self._link(self.map[pre][x], pre, i)
|
if self.map[pre][x] and not self.map[pre][i]:
|
||||||
# i is the previous node of y, link i to x's succeeding nodes
|
self.map[pre][i] = True
|
||||||
if self.map[i][y] and not self.map[i][x]:
|
if self.map[i][y] and not self.map[i][x]:
|
||||||
for suc in self.alive:
|
for suc in self.alive:
|
||||||
self._link(self.map[x][suc], i, suc)
|
if self.map[x][suc] and not self.map[i][suc]:
|
||||||
|
self.map[i][suc] = True
|
||||||
self.alive.remove(y)
|
self.alive.remove(y)
|
||||||
|
|
||||||
class Area:
|
class Area:
|
||||||
|
@ -70,10 +67,6 @@ class GraphSplitByPattern:
|
||||||
self.stitch_ops = set()
|
self.stitch_ops = set()
|
||||||
self.stitch_atomic_ops = set()
|
self.stitch_atomic_ops = set()
|
||||||
|
|
||||||
def has_stitch_op(self):
|
|
||||||
"""check stitch_op exists"""
|
|
||||||
return self.stitch_ops or self.stitch_atomic_ops
|
|
||||||
|
|
||||||
def __init__(self, init_op, is_output, unique_id, reach_tab, recompute_ops=None):
|
def __init__(self, init_op, is_output, unique_id, reach_tab, recompute_ops=None):
|
||||||
self.pattern = PrimLib.iter_type(init_op) if init_op is not None else PrimLib.UNKNOWN
|
self.pattern = PrimLib.iter_type(init_op) if init_op is not None else PrimLib.UNKNOWN
|
||||||
self.ops = [] if init_op is None else [init_op]
|
self.ops = [] if init_op is None else [init_op]
|
||||||
|
@ -293,11 +286,11 @@ class GraphSplitByPattern:
|
||||||
|
|
||||||
def fuse(self, selector):
|
def fuse(self, selector):
|
||||||
"""Fuse areas"""
|
"""Fuse areas"""
|
||||||
def _fuse_area():
|
changed = False
|
||||||
|
while True:
|
||||||
for dominant in self.areas:
|
for dominant in self.areas:
|
||||||
result = selector(dominant)
|
result = selector(dominant)
|
||||||
if result is None or not result[0]:
|
if result is not None and result[0]:
|
||||||
continue
|
|
||||||
fuse_areas, is_forward = result
|
fuse_areas, is_forward = result
|
||||||
fuse_areas = self.limit_area_size(dominant, fuse_areas)
|
fuse_areas = self.limit_area_size(dominant, fuse_areas)
|
||||||
if not fuse_areas:
|
if not fuse_areas:
|
||||||
|
@ -314,13 +307,9 @@ class GraphSplitByPattern:
|
||||||
self.set_area_map(forward_area.ops, area)
|
self.set_area_map(forward_area.ops, area)
|
||||||
self.areas.remove(forward_area)
|
self.areas.remove(forward_area)
|
||||||
forward_area = area
|
forward_area = area
|
||||||
return True
|
changed = True
|
||||||
return False
|
break
|
||||||
|
else:
|
||||||
changed, do_again = False, True
|
|
||||||
while do_again:
|
|
||||||
do_again = _fuse_area()
|
|
||||||
changed = changed or do_again
|
|
||||||
return changed
|
return changed
|
||||||
|
|
||||||
def fuse_recom(self, selector):
|
def fuse_recom(self, selector):
|
||||||
|
@ -359,6 +348,21 @@ class GraphSplitByPattern:
|
||||||
graphmodes.append("basic" if area.mode == self.Area.MODE_BASIC else "composite")
|
graphmodes.append("basic" if area.mode == self.Area.MODE_BASIC else "composite")
|
||||||
return subgraphs, graphmodes
|
return subgraphs, graphmodes
|
||||||
|
|
||||||
|
def dump_subgraphs(self, subgraphs):
|
||||||
|
"""Dump subgraphs"""
|
||||||
|
if os.environ.get("ENABLE_SUBGRAPHS", "off") == "on":
|
||||||
|
subgraphs_str = "subgraphs:\nlen: " + str(len(subgraphs)) + "\n"
|
||||||
|
for i, sub in enumerate(subgraphs):
|
||||||
|
subgraphs_str += str("============") + str(i) + "\n"
|
||||||
|
subgraphs_str += str(sub)
|
||||||
|
dirname = 'subgraphs'
|
||||||
|
if not os.path.exists(dirname):
|
||||||
|
os.makedirs(dirname)
|
||||||
|
graphname = self.graph.name
|
||||||
|
filename = dirname + '/' + graphname + '.log'
|
||||||
|
with os.fdopen(os.open(filename, os.O_RDWR | os.O_CREAT), 'w+') as f:
|
||||||
|
f.write(subgraphs_str)
|
||||||
|
|
||||||
def pattern_fuse(self, fuse_func=None):
|
def pattern_fuse(self, fuse_func=None):
|
||||||
"""fuse Areas by pattern repeatedly"""
|
"""fuse Areas by pattern repeatedly"""
|
||||||
del fuse_func
|
del fuse_func
|
||||||
|
@ -372,37 +376,33 @@ class GraphSplitByPattern:
|
||||||
# Note: after this function, the input output relation is not maintained.
|
# Note: after this function, the input output relation is not maintained.
|
||||||
self.split_output_reshapes()
|
self.split_output_reshapes()
|
||||||
subgraphs, graphmodes = self.to_subgraphs()
|
subgraphs, graphmodes = self.to_subgraphs()
|
||||||
|
self.dump_subgraphs(subgraphs)
|
||||||
return subgraphs, graphmodes
|
return subgraphs, graphmodes
|
||||||
|
|
||||||
def split_output_reshapes(self):
|
def split_output_reshapes(self):
|
||||||
"""Force split the output Reshapes into other new area"""
|
"""Force split the output reshapes into other new """
|
||||||
def _remove_output_reshape(reshape_ops, other_ops):
|
|
||||||
def _run():
|
|
||||||
for op in reshape_ops:
|
|
||||||
if any([to_op in other_ops for to_op in op.output.to_ops]):
|
|
||||||
reshape_ops.remove(op)
|
|
||||||
other_ops.append(op)
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
while _run():
|
|
||||||
pass
|
|
||||||
|
|
||||||
new_areas = []
|
new_areas = []
|
||||||
for area in self.areas:
|
for area in self.areas:
|
||||||
reshape_ops = [op for op in area.ops if PrimLib.iter_type(op) == PrimLib.RESHAPE]
|
out_reshape_ops = [op for op in area.ops if PrimLib.iter_type(op) == PrimLib.RESHAPE]
|
||||||
other_ops = [op for op in area.ops if op not in reshape_ops]
|
remain_ops = [op for op in area.ops if op not in out_reshape_ops]
|
||||||
if not other_ops or not reshape_ops:
|
if not remain_ops or not out_reshape_ops:
|
||||||
continue
|
continue
|
||||||
# remove the output reshape from "reshape_ops" and add it into "other_ops"
|
changed = True
|
||||||
_remove_output_reshape(reshape_ops, other_ops)
|
while changed:
|
||||||
if not reshape_ops:
|
changed = False
|
||||||
continue
|
for op in out_reshape_ops:
|
||||||
for op in reshape_ops:
|
if any([to_op in remain_ops for to_op in op.output.to_ops]):
|
||||||
|
out_reshape_ops.remove(op)
|
||||||
|
remain_ops.append(op)
|
||||||
|
changed = True
|
||||||
|
break
|
||||||
|
if out_reshape_ops:
|
||||||
|
for op in out_reshape_ops:
|
||||||
a = self.Area(op, False, 0, self.reach_tab)
|
a = self.Area(op, False, 0, self.reach_tab)
|
||||||
self.set_default_mode(a)
|
self.set_default_mode(a)
|
||||||
new_areas.append(a)
|
new_areas.append(a)
|
||||||
area.ops = other_ops
|
area.ops = remain_ops
|
||||||
if len(other_ops) == 1:
|
if len(remain_ops) == 1:
|
||||||
self.set_default_mode(area)
|
self.set_default_mode(area)
|
||||||
if new_areas:
|
if new_areas:
|
||||||
self.areas += new_areas
|
self.areas += new_areas
|
||||||
|
@ -472,8 +472,8 @@ class GraphSplitByPattern:
|
||||||
region_ops.append(op)
|
region_ops.append(op)
|
||||||
return False, None, weight, True
|
return False, None, weight, True
|
||||||
# region fails to grow
|
# region fails to grow
|
||||||
max_weight = 20
|
MAX_WEIGHT = 20
|
||||||
if weight > max_weight or len(op.inputs) > 1 or PrimLib.iter_type(op) > PrimLib.BROADCAST:
|
if weight > MAX_WEIGHT or len(op.inputs) > 1 or PrimLib.iter_type(op) > PrimLib.BROADCAST:
|
||||||
return False, None, weight, False
|
return False, None, weight, False
|
||||||
# region grows successfully
|
# region grows successfully
|
||||||
weight = weight + 1
|
weight = weight + 1
|
||||||
|
@ -486,7 +486,7 @@ class GraphSplitByPattern:
|
||||||
cheap_regions = []
|
cheap_regions = []
|
||||||
for output in outputs:
|
for output in outputs:
|
||||||
# tensor should have user other than user_area to be fused
|
# tensor should have user other than user_area to be fused
|
||||||
if len(output.to_ops) < 2:
|
if output.para_type != Tensor.PARA_OUTPUT and len(output.to_ops) < 2:
|
||||||
continue
|
continue
|
||||||
region_ops = []
|
region_ops = []
|
||||||
grow = True
|
grow = True
|
||||||
|
@ -533,7 +533,14 @@ class GraphSplitByPattern:
|
||||||
"""find recompute regions and copy them out to new Areas"""
|
"""find recompute regions and copy them out to new Areas"""
|
||||||
def do_recompute_fuse():
|
def do_recompute_fuse():
|
||||||
"""split the unfusing pattern by add recompute area"""
|
"""split the unfusing pattern by add recompute area"""
|
||||||
def recompute_cheap_region(dom):
|
recompute_suc = False
|
||||||
|
orig_areas = []
|
||||||
|
orig_areas.extend(self.areas)
|
||||||
|
for dom in orig_areas:
|
||||||
|
if dom not in self.areas or not dom.out_relations:
|
||||||
|
continue
|
||||||
|
cheap_regions = self.find_cheap_regions(dom)
|
||||||
|
dom_changed = False
|
||||||
for cheap_region in cheap_regions:
|
for cheap_region in cheap_regions:
|
||||||
user_areas = self.select_user_area(cheap_region[-1].output)
|
user_areas = self.select_user_area(cheap_region[-1].output)
|
||||||
if not user_areas:
|
if not user_areas:
|
||||||
|
@ -543,17 +550,12 @@ class GraphSplitByPattern:
|
||||||
self.pattern_fuse(self.fuse_recom)
|
self.pattern_fuse(self.fuse_recom)
|
||||||
self.clear_recompute()
|
self.clear_recompute()
|
||||||
if self.recom_res:
|
if self.recom_res:
|
||||||
return True
|
|
||||||
return False
|
|
||||||
recompute_suc = False
|
|
||||||
orig_areas = []
|
|
||||||
orig_areas.extend(self.areas)
|
|
||||||
for dom in orig_areas:
|
|
||||||
if dom not in self.areas or not dom.out_relations:
|
|
||||||
continue
|
|
||||||
cheap_regions = self.find_cheap_regions(dom)
|
|
||||||
if recompute_cheap_region(dom):
|
|
||||||
recompute_suc = True
|
recompute_suc = True
|
||||||
|
# Copy region at most once for this dom
|
||||||
|
dom_changed = True
|
||||||
|
break
|
||||||
|
if dom_changed:
|
||||||
|
break
|
||||||
return recompute_suc
|
return recompute_suc
|
||||||
|
|
||||||
if self.enable_recompute:
|
if self.enable_recompute:
|
||||||
|
@ -561,6 +563,9 @@ class GraphSplitByPattern:
|
||||||
self.pattern_fuse()
|
self.pattern_fuse()
|
||||||
|
|
||||||
|
|
||||||
|
use_poly_reduce = True
|
||||||
|
|
||||||
|
|
||||||
class GraphSplitGpu(GraphSplitByPattern):
|
class GraphSplitGpu(GraphSplitByPattern):
|
||||||
"""Graph splitter"""
|
"""Graph splitter"""
|
||||||
BORADCAST_FUSE_DEPTH = 20
|
BORADCAST_FUSE_DEPTH = 20
|
||||||
|
@ -611,7 +616,7 @@ class GraphSplitGpu(GraphSplitByPattern):
|
||||||
return fused, True
|
return fused, True
|
||||||
|
|
||||||
def _broadcast_pat_exclude(dom, a, r):
|
def _broadcast_pat_exclude(dom, a, r):
|
||||||
if a.pattern == PrimLib.REDUCE:
|
if use_poly_reduce and a.pattern == PrimLib.REDUCE:
|
||||||
return dom.pattern > PrimLib.ELEMWISE or r > PrimLib.ELEMWISE
|
return dom.pattern > PrimLib.ELEMWISE or r > PrimLib.ELEMWISE
|
||||||
return a.pattern > PrimLib.REDUCE or r > PrimLib.BROADCAST
|
return a.pattern > PrimLib.REDUCE or r > PrimLib.BROADCAST
|
||||||
|
|
||||||
|
@ -636,14 +641,34 @@ class GraphSplitGpu(GraphSplitByPattern):
|
||||||
fused.append(a)
|
fused.append(a)
|
||||||
return fused, False
|
return fused, False
|
||||||
|
|
||||||
|
def _check_reduce_exclude(dom):
|
||||||
|
if use_poly_reduce:
|
||||||
|
return False
|
||||||
|
# exclude large all-reduce
|
||||||
|
if len(dom.ops[0].inputs[0].shape) == len(dom.ops[0].attrs["reduce_axis"]) and \
|
||||||
|
dom.ops[0].inputs[0].get_size() > 10000:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# exclude multi output
|
||||||
|
for a in dom.in_relations.keys():
|
||||||
|
if len(a.out_relations) > 1:
|
||||||
|
return True
|
||||||
|
if any([op.output.para_type == Tensor.PARA_OUTPUT for op in a.ops]):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
def _reduce_pat_exclude(_, a, r):
|
def _reduce_pat_exclude(_, a, r):
|
||||||
if len(a.ops) > self.REDUCE_FUSE_DEPTH:
|
if len(a.ops) > self.REDUCE_FUSE_DEPTH:
|
||||||
return True
|
return True
|
||||||
|
if use_poly_reduce:
|
||||||
return a.pattern > PrimLib.ELEMWISE or r > PrimLib.REDUCE or r == PrimLib.BROADCAST
|
return a.pattern > PrimLib.ELEMWISE or r > PrimLib.REDUCE or r == PrimLib.BROADCAST
|
||||||
|
return a.pattern > PrimLib.BROADCAST or r > PrimLib.REDUCE
|
||||||
|
|
||||||
def _reduce_depth(dom):
|
def _reduce_depth(dom):
|
||||||
if dom.pattern != PrimLib.REDUCE or len(dom.in_relations) != 1:
|
if dom.pattern != PrimLib.REDUCE or len(dom.in_relations) != 1:
|
||||||
return None
|
return None
|
||||||
|
if _check_reduce_exclude(dom):
|
||||||
|
return None
|
||||||
a, r = list(dom.in_relations.items())[0]
|
a, r = list(dom.in_relations.items())[0]
|
||||||
if dom.ops[0].inputs[0].dtype == "float16" and a.is_output and len(a.ops) >= 10 and \
|
if dom.ops[0].inputs[0].dtype == "float16" and a.is_output and len(a.ops) >= 10 and \
|
||||||
_is_atomic_add_available(dom):
|
_is_atomic_add_available(dom):
|
||||||
|
@ -656,6 +681,8 @@ class GraphSplitGpu(GraphSplitByPattern):
|
||||||
def _reduce_width(dom):
|
def _reduce_width(dom):
|
||||||
if dom.pattern != PrimLib.REDUCE:
|
if dom.pattern != PrimLib.REDUCE:
|
||||||
return None
|
return None
|
||||||
|
if _check_reduce_exclude(dom):
|
||||||
|
return None
|
||||||
fused = []
|
fused = []
|
||||||
for a, r in dom.in_relations.items():
|
for a, r in dom.in_relations.items():
|
||||||
if dom.ops[0].inputs[0].dtype == "float16" and a.is_output and len(a.ops) >= 10 and \
|
if dom.ops[0].inputs[0].dtype == "float16" and a.is_output and len(a.ops) >= 10 and \
|
||||||
|
@ -736,16 +763,16 @@ class GraphSplitGpu(GraphSplitByPattern):
|
||||||
|
|
||||||
def _may_stitch(dom, a, r):
|
def _may_stitch(dom, a, r):
|
||||||
if a.pattern <= PrimLib.REDUCE and r <= PrimLib.BROADCAST and dom.check_acyclic(a):
|
if a.pattern <= PrimLib.REDUCE and r <= PrimLib.BROADCAST and dom.check_acyclic(a):
|
||||||
if _reduce_nums(a.ops) >= 2:
|
if _reduce_nums(a.ops) < 2:
|
||||||
return False
|
|
||||||
dom_outs = [op.output for op in dom.ops]
|
dom_outs = [op.output for op in dom.ops]
|
||||||
a_ins = [op_input for op in a.ops for op_input in op.inputs]
|
a_ins = [op_input for op in a.ops for op_input in op.inputs]
|
||||||
a_outs = [op.output for op in a.ops]
|
a_outs = [op.output for op in a.ops]
|
||||||
a_final_outs = [tensor for tensor in a_outs if tensor not in a_ins]
|
a_final_outs = [tensor for tensor in a_outs if tensor not in a_ins]
|
||||||
stitch_tensors = [tensor for tensor in dom_outs if tensor in a_ins]
|
stitch_tensors = [tensor for tensor in dom_outs if tensor in a_ins]
|
||||||
if not _same_stitch_axis(stitch_tensors, a_final_outs):
|
if _same_stitch_axis(stitch_tensors, a_final_outs):
|
||||||
return False
|
for tensor in stitch_tensors:
|
||||||
return any([_tensor_size(tensor) >= 1024 * 1024 for tensor in stitch_tensors])
|
if _tensor_size(tensor) >= 1024 * 1024:
|
||||||
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _reduce_stitch(dom):
|
def _reduce_stitch(dom):
|
||||||
|
@ -758,8 +785,7 @@ class GraphSplitGpu(GraphSplitByPattern):
|
||||||
|
|
||||||
fused = []
|
fused = []
|
||||||
for a, r in dom.out_relations.items():
|
for a, r in dom.out_relations.items():
|
||||||
if not _may_stitch(dom, a, r):
|
if _may_stitch(dom, a, r):
|
||||||
continue
|
|
||||||
if a.pattern == PrimLib.REDUCE:
|
if a.pattern == PrimLib.REDUCE:
|
||||||
if a.ops[0].attrs['reduce_axis'] == dom.ops[0].attrs['reduce_axis']:
|
if a.ops[0].attrs['reduce_axis'] == dom.ops[0].attrs['reduce_axis']:
|
||||||
dom.stitch_info.stitch_ops.add(dom.ops[0].output.name)
|
dom.stitch_info.stitch_ops.add(dom.ops[0].output.name)
|
||||||
|
@ -778,16 +804,6 @@ class GraphSplitGpu(GraphSplitByPattern):
|
||||||
fused.append(a)
|
fused.append(a)
|
||||||
return fused, True
|
return fused, True
|
||||||
|
|
||||||
def _strided_slice(dom):
|
|
||||||
if dom.dom_op().prim != "StridedSlice":
|
|
||||||
return None
|
|
||||||
fused = []
|
|
||||||
for a, _ in dom.in_relations.items():
|
|
||||||
if a.pattern <= PrimLib.BROADCAST and a.check_acyclic(dom) and \
|
|
||||||
len(a.out_relations) == 1 and not a.is_output:
|
|
||||||
fused.append(a)
|
|
||||||
return fused, True
|
|
||||||
|
|
||||||
def _fuse_loop():
|
def _fuse_loop():
|
||||||
changed = True
|
changed = True
|
||||||
while changed:
|
while changed:
|
||||||
|
@ -798,7 +814,7 @@ class GraphSplitGpu(GraphSplitByPattern):
|
||||||
changed = self.fuse(_reduce_width) or changed
|
changed = self.fuse(_reduce_width) or changed
|
||||||
changed = self.fuse(_broadcast_depth) or changed
|
changed = self.fuse(_broadcast_depth) or changed
|
||||||
changed = self.fuse(_broadcast_width) or changed
|
changed = self.fuse(_broadcast_width) or changed
|
||||||
changed = self.fuse(_strided_slice) or changed
|
if use_poly_reduce:
|
||||||
changed = self.fuse(_reduce_output) or changed
|
changed = self.fuse(_reduce_output) or changed
|
||||||
if enable_stitch_fusion:
|
if enable_stitch_fusion:
|
||||||
changed = self.fuse(_reduce_stitch) or changed
|
changed = self.fuse(_reduce_stitch) or changed
|
||||||
|
@ -809,6 +825,7 @@ class GraphSplitGpu(GraphSplitByPattern):
|
||||||
fuse_func(_reduce_depth) or fuse_func(_reduce_width) or fuse_func(_broadcast_depth) or \
|
fuse_func(_reduce_depth) or fuse_func(_reduce_width) or fuse_func(_broadcast_depth) or \
|
||||||
fuse_func(_broadcast_width):
|
fuse_func(_broadcast_width):
|
||||||
return
|
return
|
||||||
|
if use_poly_reduce:
|
||||||
if fuse_func(_reduce_output) or (enable_stitch_fusion and fuse_func(_reduce_stitch)):
|
if fuse_func(_reduce_output) or (enable_stitch_fusion and fuse_func(_reduce_stitch)):
|
||||||
return
|
return
|
||||||
fuse_func(_transpose)
|
fuse_func(_transpose)
|
||||||
|
|
|
@ -216,7 +216,6 @@ class PrimLib:
|
||||||
'Transpose': Prim(OPAQUE),
|
'Transpose': Prim(OPAQUE),
|
||||||
'Tile': Prim(BROADCAST),
|
'Tile': Prim(BROADCAST),
|
||||||
'BroadcastTo': Prim(BROADCAST),
|
'BroadcastTo': Prim(BROADCAST),
|
||||||
'StridedSlice': Prim(OPAQUE),
|
|
||||||
'MatMul': Prim(OPAQUE),
|
'MatMul': Prim(OPAQUE),
|
||||||
'TransData': Prim(OPAQUE),
|
'TransData': Prim(OPAQUE),
|
||||||
'BatchMatMul': Prim(OPAQUE),
|
'BatchMatMul': Prim(OPAQUE),
|
||||||
|
@ -422,13 +421,14 @@ class Graph:
|
||||||
for t in op.inputs:
|
for t in op.inputs:
|
||||||
if t not in inputs and t.op not in self.ops:
|
if t not in inputs and t.op not in self.ops:
|
||||||
inputs.append(t)
|
inputs.append(t)
|
||||||
if op.output in outputs:
|
if op.output not in outputs:
|
||||||
continue
|
|
||||||
if op.output.para_type == Tensor.PARA_OUTPUT or not op.output.to_ops:
|
if op.output.para_type == Tensor.PARA_OUTPUT or not op.output.to_ops:
|
||||||
outputs.append(op.output)
|
outputs.append(op.output)
|
||||||
continue
|
else:
|
||||||
if any([succ not in self.ops for succ in op.output.to_ops]):
|
for d in op.output.to_ops:
|
||||||
|
if d not in self.ops:
|
||||||
outputs.append(op.output)
|
outputs.append(op.output)
|
||||||
|
break
|
||||||
if self.inputs:
|
if self.inputs:
|
||||||
inputs = self.inputs
|
inputs = self.inputs
|
||||||
|
|
||||||
|
|
|
@ -28,13 +28,11 @@ class GraphBuilder:
|
||||||
self.graph = Graph(name, [])
|
self.graph = Graph(name, [])
|
||||||
|
|
||||||
def set_input(self, *para):
|
def set_input(self, *para):
|
||||||
"""set input to graph inputs"""
|
|
||||||
for t in para:
|
for t in para:
|
||||||
t.para_type = Tensor.PARA_INPUT
|
t.para_type = Tensor.PARA_INPUT
|
||||||
self.graph.inputs.append(t)
|
self.graph.inputs.append(t)
|
||||||
|
|
||||||
def set_output(self, *para):
|
def set_output(self, *para):
|
||||||
"""set output to graph inputs"""
|
|
||||||
for t in para:
|
for t in para:
|
||||||
t.para_type = Tensor.PARA_OUTPUT
|
t.para_type = Tensor.PARA_OUTPUT
|
||||||
self.graph.outputs.append(t)
|
self.graph.outputs.append(t)
|
||||||
|
@ -52,8 +50,6 @@ class GraphBuilder:
|
||||||
def graph_scope(self, name):
|
def graph_scope(self, name):
|
||||||
"""The graph scope to be processed"""
|
"""The graph scope to be processed"""
|
||||||
class GraphScope:
|
class GraphScope:
|
||||||
"""Graph Scope"""
|
|
||||||
|
|
||||||
def __init__(self, gb):
|
def __init__(self, gb):
|
||||||
self.gb = gb
|
self.gb = gb
|
||||||
|
|
||||||
|
@ -81,6 +77,7 @@ class GraphBuilder:
|
||||||
"""Create a new Value"""
|
"""Create a new Value"""
|
||||||
if name in (None, ''):
|
if name in (None, ''):
|
||||||
name = self._alloc_tensor_name()
|
name = self._alloc_tensor_name()
|
||||||
|
|
||||||
v = Value(name, dtype, value)
|
v = Value(name, dtype, value)
|
||||||
return v
|
return v
|
||||||
|
|
||||||
|
@ -108,7 +105,6 @@ class GraphBuilder:
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def get(self):
|
def get(self):
|
||||||
"""Get graphs"""
|
|
||||||
return self.graphs
|
return self.graphs
|
||||||
|
|
||||||
|
|
||||||
|
@ -127,14 +123,34 @@ class CompositeGraph:
|
||||||
|
|
||||||
def load(self, desc):
|
def load(self, desc):
|
||||||
"""Load Graph from json"""
|
"""Load Graph from json"""
|
||||||
def _attr_of(op):
|
def _attr_of(op, inputs, output):
|
||||||
if not op['attr']:
|
def _get_axis_while_none(input_shape, output_shape):
|
||||||
return dict()
|
red_axis = []
|
||||||
attr = {}
|
if len(output_shape) == len(input_shape):
|
||||||
for a in op['attr']:
|
for i, s in enumerate(output_shape):
|
||||||
if a['name'] == 'axis' and op['name'] in ('ReduceSum', 'ReduceMax', 'ReduceMin'):
|
if s == 1 and input_shape[i] > 1:
|
||||||
attr['reduce_axis'] = a['value']
|
red_axis.append(i)
|
||||||
else:
|
else:
|
||||||
|
red_axis = list(range(len(output_shape)))
|
||||||
|
return red_axis
|
||||||
|
|
||||||
|
attr = {}
|
||||||
|
if op['name'] in ('ReduceSum', 'ReduceMax', 'ReduceMin'):
|
||||||
|
for a in op['attr']:
|
||||||
|
if a['name'] == 'axis':
|
||||||
|
red_axis, dim_size = [], len(inputs[0].shape)
|
||||||
|
if not a['value']:
|
||||||
|
red_axis = _get_axis_while_none(inputs[0].shape, output.shape)
|
||||||
|
else:
|
||||||
|
if isinstance(a['value'], int):
|
||||||
|
a['value'] = [a['value']]
|
||||||
|
for i in a['value']:
|
||||||
|
red_axis.append(i if i >= 0 else dim_size + i)
|
||||||
|
attr['reduce_axis'] = red_axis
|
||||||
|
if a['name'] == "reduce_output_fuse":
|
||||||
|
attr['reduce_output_fuse'] = a['value']
|
||||||
|
elif op['attr']:
|
||||||
|
for a in op['attr']:
|
||||||
attr[a['name']] = a['value']
|
attr[a['name']] = a['value']
|
||||||
return attr
|
return attr
|
||||||
|
|
||||||
|
@ -150,6 +166,7 @@ class CompositeGraph:
|
||||||
'shape'], out_desc['data_type'], out_desc['format']
|
'shape'], out_desc['data_type'], out_desc['format']
|
||||||
self.tensors[name] = builder.tensor(
|
self.tensors[name] = builder.tensor(
|
||||||
shape, dtype, data_format, name=name, para_type=Tensor.PARA_OUTPUT)
|
shape, dtype, data_format, name=name, para_type=Tensor.PARA_OUTPUT)
|
||||||
|
cur_fusion = None
|
||||||
for op in desc['op_desc']:
|
for op in desc['op_desc']:
|
||||||
inputs = [self.tensors[d['tensor_name']] for x in op['input_desc'] for d in x if 'value' not in d]
|
inputs = [self.tensors[d['tensor_name']] for x in op['input_desc'] for d in x if 'value' not in d]
|
||||||
out_desc = op['output_desc']
|
out_desc = op['output_desc']
|
||||||
|
@ -160,17 +177,25 @@ class CompositeGraph:
|
||||||
inputs[1].para_type = Tensor.PARA_OUTPUT
|
inputs[1].para_type = Tensor.PARA_OUTPUT
|
||||||
output = inputs[2]
|
output = inputs[2]
|
||||||
self.tensors[name] = output
|
self.tensors[name] = output
|
||||||
continue
|
else:
|
||||||
output = self.tensors.get(name, None)
|
output = self.tensors.get(name, None)
|
||||||
if not output:
|
if not output:
|
||||||
output = builder.tensor(shape, dtype, data_format, name=name)
|
output = builder.tensor(
|
||||||
|
shape, dtype, data_format, name=name)
|
||||||
self.tensors[name] = output
|
self.tensors[name] = output
|
||||||
builder.op(op['name'], output, inputs, attrs=_attr_of(op))
|
builder.op(op['name'], output, inputs,
|
||||||
|
attrs=_attr_of(op, inputs, output))
|
||||||
|
if 'fusion' in op:
|
||||||
|
if cur_fusion is None:
|
||||||
|
cur_fusion = output
|
||||||
|
else:
|
||||||
|
cur_fusion.add_buddy(output)
|
||||||
|
if op['fusion'].endswith('_end'):
|
||||||
|
cur_fusion = None
|
||||||
self.graph = builder.get()[0]
|
self.graph = builder.get()[0]
|
||||||
self.desc = desc
|
self.desc = desc
|
||||||
|
|
||||||
def add_stitch_info(self, subgraph, desc):
|
def add_stitch_info(self, subgraph, desc):
|
||||||
"""add stitch info to desc"""
|
|
||||||
if subgraph.stitch_info and subgraph.stitch_info.stitch_ops:
|
if subgraph.stitch_info and subgraph.stitch_info.stitch_ops:
|
||||||
buffer_stitch = {'stitch_op': list(subgraph.stitch_info.stitch_ops)}
|
buffer_stitch = {'stitch_op': list(subgraph.stitch_info.stitch_ops)}
|
||||||
if subgraph.stitch_info.stitch_atomic_ops:
|
if subgraph.stitch_info.stitch_atomic_ops:
|
||||||
|
@ -179,7 +204,6 @@ class CompositeGraph:
|
||||||
return desc
|
return desc
|
||||||
|
|
||||||
def add_recompute_ops(self, subgraph, desc):
|
def add_recompute_ops(self, subgraph, desc):
|
||||||
"""add recompute ops to desc"""
|
|
||||||
if subgraph.recompute_ops:
|
if subgraph.recompute_ops:
|
||||||
desc['recompute_ops'] = [op.output.name for op in subgraph.recompute_ops]
|
desc['recompute_ops'] = [op.output.name for op in subgraph.recompute_ops]
|
||||||
return desc
|
return desc
|
||||||
|
@ -203,18 +227,29 @@ class CompositeGraph:
|
||||||
inputs, outputs = subgraph.deduce_parameters()
|
inputs, outputs = subgraph.deduce_parameters()
|
||||||
graph_ops = set(subgraph.ops)
|
graph_ops = set(subgraph.ops)
|
||||||
inplace_assign, inplace_assign_z = self._pre_dump(outputs)
|
inplace_assign, inplace_assign_z = self._pre_dump(outputs)
|
||||||
|
for key in self.desc:
|
||||||
def dump_output(t):
|
if key == 'input_desc':
|
||||||
|
desc[key] = [
|
||||||
|
[{'data_type': t.dtype, 'shape': t.shape, 'tensor_name': t.name}] for t in inputs]
|
||||||
|
elif key == 'output_desc':
|
||||||
|
out_desc = []
|
||||||
|
for t in outputs:
|
||||||
if t.name in inplace_assign:
|
if t.name in inplace_assign:
|
||||||
z = inplace_assign_z if inplace_assign_z is not None else self.tensors[t.name]
|
z = inplace_assign_z if inplace_assign_z is not None else self.tensors[t.name]
|
||||||
return {'data_type': z.dtype, 'shape': z.shape, 'tensor_name': inplace_assign[t.name]}
|
out_desc.append(
|
||||||
return {'data_type': t.dtype, 'shape': t.shape, 'tensor_name': t.name}
|
{'data_type': z.dtype, 'shape': z.shape, 'tensor_name': inplace_assign[t.name]})
|
||||||
|
else:
|
||||||
def dump_op_desc(d):
|
out_desc.append(
|
||||||
|
{'data_type': t.dtype, 'shape': t.shape, 'tensor_name': t.name})
|
||||||
|
desc[key] = out_desc
|
||||||
|
elif key == 'op_desc':
|
||||||
|
op_desc = []
|
||||||
|
for d in self.desc[key]:
|
||||||
if d['name'] == 'InplaceAssign':
|
if d['name'] == 'InplaceAssign':
|
||||||
y = d['input_desc'][1][0]['tensor_name']
|
y = d['input_desc'][1][0]['tensor_name']
|
||||||
if self.tensors[y].op in graph_ops:
|
if self.tensors[y].op in graph_ops:
|
||||||
z, fake = (inplace_assign_z, False) if inplace_assign_z is not None else (self.tensors[y], True)
|
z, fake = (inplace_assign_z, False) if inplace_assign_z is not None else (
|
||||||
|
self.tensors[y], True)
|
||||||
inplace_desc = copy.deepcopy(d)
|
inplace_desc = copy.deepcopy(d)
|
||||||
inplace_desc['attr'] = {'name': 'fake_output', 'value': fake}
|
inplace_desc['attr'] = {'name': 'fake_output', 'value': fake}
|
||||||
z_desc, out_desc = inplace_desc['input_desc'][2][0], inplace_desc['output_desc'][0]
|
z_desc, out_desc = inplace_desc['input_desc'][2][0], inplace_desc['output_desc'][0]
|
||||||
|
@ -223,20 +258,12 @@ class CompositeGraph:
|
||||||
z_desc['tensor_name'] = z.name
|
z_desc['tensor_name'] = z.name
|
||||||
out_desc['shape'] = z.shape
|
out_desc['shape'] = z.shape
|
||||||
out_desc['data_type'] = z.dtype
|
out_desc['data_type'] = z.dtype
|
||||||
return inplace_desc
|
op_desc.append(inplace_desc)
|
||||||
|
else:
|
||||||
op = self.tensors[d['output_desc'][0]['tensor_name']].op
|
op = self.tensors[d['output_desc'][0]['tensor_name']].op
|
||||||
if op in graph_ops or op in subgraph.recompute_ops:
|
if op in graph_ops or op in subgraph.recompute_ops:
|
||||||
return d
|
op_desc.append(d)
|
||||||
return None
|
desc[key] = op_desc
|
||||||
|
|
||||||
for key in self.desc.keys():
|
|
||||||
if key == 'input_desc':
|
|
||||||
desc[key] = [[{'data_type': t.dtype, 'shape': t.shape, 'tensor_name': t.name}] for t in inputs]
|
|
||||||
elif key == 'output_desc':
|
|
||||||
desc[key] = list(map(dump_output, outputs))
|
|
||||||
elif key == 'op_desc':
|
|
||||||
op_desc = map(dump_op_desc, self.desc[key])
|
|
||||||
desc[key] = [d for d in op_desc if d is not None]
|
|
||||||
elif key == 'op':
|
elif key == 'op':
|
||||||
desc[key] = subgraph.name
|
desc[key] = subgraph.name
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
import sys
|
import sys
|
||||||
from functools import reduce as prod_reduce
|
from functools import reduce
|
||||||
from .model import GraphKernelUnsupportedException as GKException
|
from .model import GraphKernelUnsupportedException as GKException
|
||||||
from .model import PrimLib, DataFormat as DF
|
from .model import PrimLib, DataFormat as DF
|
||||||
|
|
||||||
|
@ -101,24 +101,22 @@ class OpInfer:
|
||||||
|
|
||||||
class _Elemwise(OpInfer):
|
class _Elemwise(OpInfer):
|
||||||
"""Common infer for elementwise operators"""
|
"""Common infer for elementwise operators"""
|
||||||
@staticmethod
|
|
||||||
def broadcast_shape(shapes):
|
def _broadcast_shape(self, shapes):
|
||||||
"""deduce broadcast shape using same rules as numpy"""
|
"""deduce broadcast shape using same rules as numpy"""
|
||||||
dim_size = max([len(shape) for shape in shapes])
|
dim_size = max([len(shape) for shape in shapes])
|
||||||
align_shapes = [[1] * (dim_size - len(shape)) + shape for shape in shapes]
|
align_shapes = [[1] * (dim_size - len(shape)) + shape for shape in shapes]
|
||||||
out_shape = [1] * dim_size
|
out_shape = [1] * dim_size
|
||||||
for i in range(dim_size):
|
for i in range(dim_size):
|
||||||
for align_shape in align_shapes:
|
for align_shape in align_shapes:
|
||||||
if align_shape[i] == 1:
|
if align_shape[i] > 1:
|
||||||
continue
|
|
||||||
if out_shape[i] == 1:
|
if out_shape[i] == 1:
|
||||||
out_shape[i] = align_shape[i]
|
out_shape[i] = align_shape[i]
|
||||||
elif out_shape[i] != align_shape[i]:
|
if out_shape[i] != align_shape[i]:
|
||||||
raise GKException("shape broadcast failed!")
|
raise GKException("shape broadcast failed!")
|
||||||
return out_shape
|
return out_shape
|
||||||
|
|
||||||
@staticmethod
|
def _to_nz(self, default_shape):
|
||||||
def defaultformat_to_nz(default_shape):
|
|
||||||
"""default format shape to fractal_Nz format shape"""
|
"""default format shape to fractal_Nz format shape"""
|
||||||
if len(default_shape) not in (1, 2):
|
if len(default_shape) not in (1, 2):
|
||||||
raise GKException("shape is too long!")
|
raise GKException("shape is too long!")
|
||||||
|
@ -144,17 +142,17 @@ class _Elemwise(OpInfer):
|
||||||
"""returns the output shape with broadcast"""
|
"""returns the output shape with broadcast"""
|
||||||
|
|
||||||
# in case all inputs are default format/NHWC/NCHW
|
# in case all inputs are default format/NHWC/NCHW
|
||||||
is_default = [op_input.data_format in (DF.DEFAULT, DF.NHWC, DF.NCHW) for op_input in self.inputs]
|
is_default = [input.data_format in (DF.DEFAULT, DF.NHWC, DF.NCHW) for input in self.inputs]
|
||||||
if all(is_default):
|
if all(is_default):
|
||||||
return self.broadcast_shape([op_input.shape for op_input in self.inputs])
|
return self._broadcast_shape([input.shape for input in self.inputs])
|
||||||
|
|
||||||
# in case formats are fractal_nz, default_fromat/NHWC/HCHW(optional)
|
# in case formats are fractal_nz, default_fromat/NHWC/HCHW(optional)
|
||||||
is_default_frac_nz = [op_input.data_format in (DF.DEFAULT, DF.NHWC, DF.NCHW, DF.FRAC_NZ)
|
is_default_frac_nz = [input.data_format in (DF.DEFAULT, DF.NHWC, DF.NCHW, DF.FRAC_NZ)
|
||||||
for op_input in self.inputs]
|
for input in self.inputs]
|
||||||
if all(is_default_frac_nz):
|
if all(is_default_frac_nz):
|
||||||
nz_shapes = [self.defaultformat_to_nz(op_input.shape) if op_input.data_format != DF.FRAC_NZ
|
nz_shapes = [self._to_nz(input.shape) if input.data_format != DF.FRAC_NZ else input.shape
|
||||||
else op_input.shape for op_input in self.inputs]
|
for input in self.inputs]
|
||||||
return self.broadcast_shape(nz_shapes)
|
return self._broadcast_shape(nz_shapes)
|
||||||
|
|
||||||
raise GKException("Only support default and fractal_nz")
|
raise GKException("Only support default and fractal_nz")
|
||||||
|
|
||||||
|
@ -216,11 +214,9 @@ class _Reshape(OpInfer):
|
||||||
|
|
||||||
|
|
||||||
class Reshape(_Reshape):
|
class Reshape(_Reshape):
|
||||||
"""Reshape op infer"""
|
|
||||||
|
|
||||||
def _check_shape(self):
|
def _check_shape(self):
|
||||||
size_before_reshape = prod_reduce(lambda x, y: x * y, self.inputs[0].shape)
|
size_before_reshape = reduce(lambda x, y: x * y, self.inputs[0].shape)
|
||||||
size_after_reshape = prod_reduce(lambda x, y: x * y, self.attrs["shape"])
|
size_after_reshape = reduce(lambda x, y: x * y, self.attrs["shape"])
|
||||||
if size_before_reshape != size_after_reshape:
|
if size_before_reshape != size_after_reshape:
|
||||||
raise GKException("The shape product before and after reshaping should be equal")
|
raise GKException("The shape product before and after reshaping should be equal")
|
||||||
|
|
||||||
|
@ -229,15 +225,11 @@ class Reshape(_Reshape):
|
||||||
|
|
||||||
|
|
||||||
class Cast(_Elemwise):
|
class Cast(_Elemwise):
|
||||||
"""Cast op infer"""
|
|
||||||
|
|
||||||
def _infer_type(self):
|
def _infer_type(self):
|
||||||
return self.attrs["dst_type"]
|
return self.attrs["dst_type"]
|
||||||
|
|
||||||
|
|
||||||
class InplaceAssign(_Elemwise):
|
class InplaceAssign(_Elemwise):
|
||||||
"""InplaceAssign op infer"""
|
|
||||||
|
|
||||||
def _infer_shape(self):
|
def _infer_shape(self):
|
||||||
return self.inputs[2].shape
|
return self.inputs[2].shape
|
||||||
|
|
||||||
|
@ -249,8 +241,6 @@ class InplaceAssign(_Elemwise):
|
||||||
|
|
||||||
|
|
||||||
class BroadcastTo(OpInfer):
|
class BroadcastTo(OpInfer):
|
||||||
"""BroadcastTo op infer"""
|
|
||||||
|
|
||||||
def _infer_shape(self):
|
def _infer_shape(self):
|
||||||
return self.attrs["shape"]
|
return self.attrs["shape"]
|
||||||
|
|
||||||
|
@ -266,8 +256,6 @@ class _CompareOp(_Elemwise):
|
||||||
|
|
||||||
|
|
||||||
class CImag(OpInfer):
|
class CImag(OpInfer):
|
||||||
"""CImag op infer"""
|
|
||||||
|
|
||||||
def _check_type(self):
|
def _check_type(self):
|
||||||
if self.inputs[0].dtype != "complex64":
|
if self.inputs[0].dtype != "complex64":
|
||||||
raise GKException(
|
raise GKException(
|
||||||
|
@ -278,8 +266,6 @@ class CImag(OpInfer):
|
||||||
|
|
||||||
|
|
||||||
class CReal(OpInfer):
|
class CReal(OpInfer):
|
||||||
"""CReal op infer"""
|
|
||||||
|
|
||||||
def _check_type(self):
|
def _check_type(self):
|
||||||
if self.inputs[0].dtype != "complex64":
|
if self.inputs[0].dtype != "complex64":
|
||||||
raise GKException(
|
raise GKException(
|
||||||
|
@ -290,8 +276,6 @@ class CReal(OpInfer):
|
||||||
|
|
||||||
|
|
||||||
class Complex(OpInfer):
|
class Complex(OpInfer):
|
||||||
"""Complex op infer"""
|
|
||||||
|
|
||||||
def _check_type(self):
|
def _check_type(self):
|
||||||
if self.inputs[0].dtype != "float32":
|
if self.inputs[0].dtype != "float32":
|
||||||
raise GKException(
|
raise GKException(
|
||||||
|
@ -304,28 +288,26 @@ class Complex(OpInfer):
|
||||||
|
|
||||||
|
|
||||||
class Less(_CompareOp):
|
class Less(_CompareOp):
|
||||||
"""Less op infer"""
|
pass
|
||||||
|
|
||||||
|
|
||||||
class LessEqual(_CompareOp):
|
class LessEqual(_CompareOp):
|
||||||
"""LessEqual op infer"""
|
pass
|
||||||
|
|
||||||
|
|
||||||
class Equal(_CompareOp):
|
class Equal(_CompareOp):
|
||||||
"""Equal op infer"""
|
pass
|
||||||
|
|
||||||
|
|
||||||
class Greater(_CompareOp):
|
class Greater(_CompareOp):
|
||||||
"""Greater op infer"""
|
pass
|
||||||
|
|
||||||
|
|
||||||
class GreaterEqual(_CompareOp):
|
class GreaterEqual(_CompareOp):
|
||||||
"""GreaterEqual op infer"""
|
pass
|
||||||
|
|
||||||
|
|
||||||
class Select(_Elemwise):
|
class Select(_Elemwise):
|
||||||
"""Select op infer"""
|
|
||||||
|
|
||||||
def _check_type(self):
|
def _check_type(self):
|
||||||
if self.inputs[0].dtype != "bool":
|
if self.inputs[0].dtype != "bool":
|
||||||
raise GKException("Select's input[0] should be a bool condition but got {}".format(self.inputs[0].dtype))
|
raise GKException("Select's input[0] should be a bool condition but got {}".format(self.inputs[0].dtype))
|
||||||
|
@ -337,7 +319,6 @@ class Select(_Elemwise):
|
||||||
|
|
||||||
|
|
||||||
def check_format_any(formats, checked_format):
|
def check_format_any(formats, checked_format):
|
||||||
"""Check whether input format in formats list"""
|
|
||||||
if not isinstance(formats, (list, tuple)):
|
if not isinstance(formats, (list, tuple)):
|
||||||
raise GKException("formats {} should be list or tuple, but got {}.".format(formats, type(formats)))
|
raise GKException("formats {} should be list or tuple, but got {}.".format(formats, type(formats)))
|
||||||
if checked_format not in formats:
|
if checked_format not in formats:
|
||||||
|
@ -345,13 +326,11 @@ def check_format_any(formats, checked_format):
|
||||||
|
|
||||||
|
|
||||||
def check_nd(data, nd):
|
def check_nd(data, nd):
|
||||||
"""Check whether data are nd format"""
|
|
||||||
if not isinstance(data, (list, tuple)) or len(data) != nd:
|
if not isinstance(data, (list, tuple)) or len(data) != nd:
|
||||||
raise GKException("input should be {}D list or tuple, but got {}.".format(nd, data))
|
raise GKException("input should be {}D list or tuple, but got {}.".format(nd, data))
|
||||||
|
|
||||||
|
|
||||||
def conv_had_pad(pad_list, pad_mode):
|
def conv_had_pad(pad_list, pad_mode):
|
||||||
"""Check whether conv need to add pad"""
|
|
||||||
if not isinstance(pad_list, (list, tuple)) or len(pad_list) != 4:
|
if not isinstance(pad_list, (list, tuple)) or len(pad_list) != 4:
|
||||||
raise GKException("pad_list should be 4D list or tuple, but got {}".format(pad_list))
|
raise GKException("pad_list should be 4D list or tuple, but got {}".format(pad_list))
|
||||||
if pad_list[0] != pad_list[1] or pad_list[2] != pad_list[3]:
|
if pad_list[0] != pad_list[1] or pad_list[2] != pad_list[3]:
|
||||||
|
|
|
@ -57,11 +57,11 @@ def _dump_split_info(flags, graph_json, graph_desc, subgraphs, graph_mode):
|
||||||
return
|
return
|
||||||
utils.create_dir(utils.GRAPH_KERNEL_DUMP_PATH)
|
utils.create_dir(utils.GRAPH_KERNEL_DUMP_PATH)
|
||||||
filename = os.path.join(utils.GRAPH_KERNEL_DUMP_PATH, "graph_kernel_split_mode.txt")
|
filename = os.path.join(utils.GRAPH_KERNEL_DUMP_PATH, "graph_kernel_split_mode.txt")
|
||||||
with os.fdopen(os.open(filename, os.O_WRONLY | os.O_CREAT), "a+") as f:
|
with open(filename, "a+") as f:
|
||||||
f.write("********** main graph: {} **********\n".format(graph_desc.name))
|
f.write("********** main graph: {} **********\n".format(graph_desc.name))
|
||||||
f.write("input json:\n{}\n".format(graph_json))
|
f.write("input json:\n{}\n".format(graph_json))
|
||||||
f.write("graph desc:\n{}\n".format(str(graph_desc)))
|
f.write("graph desc:\n{}\n".format(str(graph_desc)))
|
||||||
if len(subgraphs) > 1 or subgraphs[0].stitch_info.has_stitch_op():
|
if len(subgraphs) > 1:
|
||||||
for i, g in enumerate(subgraphs):
|
for i, g in enumerate(subgraphs):
|
||||||
f.write("-------- subgraph {}, mode: {} --------\n".format(i, graph_mode[i]))
|
f.write("-------- subgraph {}, mode: {} --------\n".format(i, graph_mode[i]))
|
||||||
f.write("{}\n".format(str(g)))
|
f.write("{}\n".format(str(g)))
|
||||||
|
|
|
@ -26,5 +26,3 @@ def create_dir(pathname):
|
||||||
os.mkdir(pathname)
|
os.mkdir(pathname)
|
||||||
except OSError:
|
except OSError:
|
||||||
pass
|
pass
|
||||||
finally:
|
|
||||||
pass
|
|
||||||
|
|
|
@ -50,6 +50,11 @@ def _compile_akg_task_gpu(json_strs, attrs):
|
||||||
if not res:
|
if not res:
|
||||||
raise ValueError("Compile error, args: {}! build attrs: {}".format(json_str, attrs))
|
raise ValueError("Compile error, args: {}! build attrs: {}".format(json_str, attrs))
|
||||||
|
|
||||||
|
pid_path = os.path.realpath("./cuda_meta_" + str(os.getpid()))
|
||||||
|
if os.path.exists(pid_path):
|
||||||
|
copy_json(pid_path, os.path.realpath("./cuda_meta_" + str(os.getppid())))
|
||||||
|
shutil.rmtree(pid_path)
|
||||||
|
|
||||||
|
|
||||||
def _compile_akg_task_ascend(json_strs, attrs):
|
def _compile_akg_task_ascend(json_strs, attrs):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -32,7 +32,7 @@ from te_fusion.parallel_compilation import init_multi_process_env, start_ga_mult
|
||||||
get_finished_compilation_task
|
get_finished_compilation_task
|
||||||
|
|
||||||
from .tbe_helper import get_soc_info, assemble_op_args, get_compute_op_list, get_options_info, get_fuzz_build_info, \
|
from .tbe_helper import get_soc_info, assemble_op_args, get_compute_op_list, get_options_info, get_fuzz_build_info, \
|
||||||
BuildType, adjust_custom_op_info, pack_op_args, get_module_name
|
BuildType, adjust_custom_op_info, pack_op_args
|
||||||
from .tbe_job import TbeJob, JobStatus
|
from .tbe_job import TbeJob, JobStatus
|
||||||
|
|
||||||
PLATFORM_FLAG = ["Ascend310", "Ascend910", "Hi3796CV300ES", "Ascend710", "Ascend610", "Hi3796CV300CS", "SD3403"]
|
PLATFORM_FLAG = ["Ascend310", "Ascend910", "Hi3796CV300ES", "Ascend710", "Ascend610", "Hi3796CV300CS", "SD3403"]
|
||||||
|
@ -242,7 +242,7 @@ def check_support(job: TbeJob):
|
||||||
op_func_name = compute_op_info["func_name"]
|
op_func_name = compute_op_info["func_name"]
|
||||||
if op_func_name in ("resize_nearest_neighbor_v2_grad_d", "resize_bilinear_v2_grad"):
|
if op_func_name in ("resize_nearest_neighbor_v2_grad_d", "resize_bilinear_v2_grad"):
|
||||||
attrs.pop(-2)
|
attrs.pop(-2)
|
||||||
op_module_name = get_module_name(compute_op_info)
|
op_module_name = compute_op_info["module_name"]
|
||||||
py_module_path = compute_op_info["py_module_path"]
|
py_module_path = compute_op_info["py_module_path"]
|
||||||
_normalize_module_name(op_module_name, py_module_path)
|
_normalize_module_name(op_module_name, py_module_path)
|
||||||
func_name = "check_supported"
|
func_name = "check_supported"
|
||||||
|
@ -281,7 +281,7 @@ def select_op_format(job: TbeJob):
|
||||||
compute_op_info = compute_op_info_list[0]
|
compute_op_info = compute_op_info_list[0]
|
||||||
adjust_custom_op_info(compute_op_info)
|
adjust_custom_op_info(compute_op_info)
|
||||||
inputs, outputs, attrs = assemble_op_args(compute_op_info)
|
inputs, outputs, attrs = assemble_op_args(compute_op_info)
|
||||||
op_module_name = get_module_name(compute_op_info)
|
op_module_name = compute_op_info["module_name"]
|
||||||
py_module_path = compute_op_info["py_module_path"]
|
py_module_path = compute_op_info["py_module_path"]
|
||||||
_normalize_module_name(op_module_name, py_module_path)
|
_normalize_module_name(op_module_name, py_module_path)
|
||||||
op_func_name = "op_select_format"
|
op_func_name = "op_select_format"
|
||||||
|
@ -317,7 +317,7 @@ def _pre_build_compute_op_info(compute_op, job):
|
||||||
if l1_size != -1:
|
if l1_size != -1:
|
||||||
set_L1_info("op_L1_space", -1)
|
set_L1_info("op_L1_space", -1)
|
||||||
inputs, outputs, attrs = assemble_op_args(compute_op)
|
inputs, outputs, attrs = assemble_op_args(compute_op)
|
||||||
op_module_name = get_module_name(compute_op)
|
op_module_name = compute_op["module_name"]
|
||||||
py_module_path = compute_op["py_module_path"]
|
py_module_path = compute_op["py_module_path"]
|
||||||
op_func_name = compute_op["func_name"]
|
op_func_name = compute_op["func_name"]
|
||||||
op_type = compute_op["type"]
|
op_type = compute_op["type"]
|
||||||
|
@ -340,8 +340,8 @@ def _pre_build_compute_op_info(compute_op, job):
|
||||||
job.info("OpType {} support op_impl_mode, current op_impl_mode:{}".format(op_type, op_impl_mode))
|
job.info("OpType {} support op_impl_mode, current op_impl_mode:{}".format(op_type, op_impl_mode))
|
||||||
options = get_options_info(job.content)
|
options = get_options_info(job.content)
|
||||||
dispatch_prebuild_task(job.source_id, job.id, l1_size, op_module_name, op_type, op_func_name, unknown_shape,
|
dispatch_prebuild_task(job.source_id, job.id, l1_size, op_module_name, op_type, op_func_name, unknown_shape,
|
||||||
(inputs, outputs, attrs, options), int64_mode, dynamic_compile_static, unknown_shape,
|
(inputs, outputs, attrs, options), int64_mode, dynamic_compile_static, job.rl_tune_switch,
|
||||||
job.rl_tune_switch, job.rl_tune_list, job.pass_list, job.op_tune_switch, job.op_tune_list)
|
job.rl_tune_list, job.pass_list, job.op_tune_switch, job.op_tune_list)
|
||||||
|
|
||||||
|
|
||||||
def get_prebuild_output(op_name):
|
def get_prebuild_output(op_name):
|
||||||
|
@ -391,7 +391,7 @@ def build_single_pre_op(job: TbeJob):
|
||||||
inputs, outputs, attrs = assemble_op_args(compute_op_info)
|
inputs, outputs, attrs = assemble_op_args(compute_op_info)
|
||||||
op_type = compute_op_info["type"]
|
op_type = compute_op_info["type"]
|
||||||
l1_size = job.content["l1_size"]
|
l1_size = job.content["l1_size"]
|
||||||
op_module_name = get_module_name(compute_op_info)
|
op_module_name = compute_op_info["module_name"]
|
||||||
op_kernel_name = compute_op_info["op_name"]
|
op_kernel_name = compute_op_info["op_name"]
|
||||||
py_module_path = compute_op_info["py_module_path"]
|
py_module_path = compute_op_info["py_module_path"]
|
||||||
op_func_name = compute_op_info["func_name"]
|
op_func_name = compute_op_info["func_name"]
|
||||||
|
@ -404,9 +404,9 @@ def build_single_pre_op(job: TbeJob):
|
||||||
fuzz_build_info = get_fuzz_build_info(job.content)
|
fuzz_build_info = get_fuzz_build_info(job.content)
|
||||||
dispatch_single_op_compile_task(job.source_id, job.id, l1_size, op_module_name, op_type, op_func_name,
|
dispatch_single_op_compile_task(job.source_id, job.id, l1_size, op_module_name, op_type, op_func_name,
|
||||||
op_kernel_name, unknown_shape, (inputs, outputs, attrs, options), int64_mode,
|
op_kernel_name, unknown_shape, (inputs, outputs, attrs, options), int64_mode,
|
||||||
None, None, dynamic_compile_static, unknown_shape, op_pattern,
|
None, None, dynamic_compile_static, op_pattern, json.dumps(fuzz_build_info),
|
||||||
json.dumps(fuzz_build_info), job.rl_tune_switch, job.rl_tune_list, job.pass_list,
|
job.rl_tune_switch, job.rl_tune_list, job.pass_list, job.op_tune_switch,
|
||||||
job.op_tune_switch, job.op_tune_list)
|
job.op_tune_list)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
@ -487,7 +487,7 @@ def rl_tune_single_op(job: TbeJob):
|
||||||
inputs, outputs, attrs = assemble_op_args(compute_op_info)
|
inputs, outputs, attrs = assemble_op_args(compute_op_info)
|
||||||
op_type = compute_op_info["type"]
|
op_type = compute_op_info["type"]
|
||||||
l1_size = job.content["l1_size"]
|
l1_size = job.content["l1_size"]
|
||||||
op_module_name = get_module_name(compute_op_info)
|
op_module_name = compute_op_info["module_name"]
|
||||||
op_kernel_name = compute_op_info["op_name"]
|
op_kernel_name = compute_op_info["op_name"]
|
||||||
full_name = compute_op_info["name"]
|
full_name = compute_op_info["name"]
|
||||||
py_module_path = compute_op_info["py_module_path"]
|
py_module_path = compute_op_info["py_module_path"]
|
||||||
|
@ -503,7 +503,7 @@ def rl_tune_single_op(job: TbeJob):
|
||||||
device_id = job.content["SocInfo"]["deviceId"]
|
device_id = job.content["SocInfo"]["deviceId"]
|
||||||
try:
|
try:
|
||||||
build_single_op_from_c(op_module_name, op_func_name, op_type, "build", unknown_shape,
|
build_single_op_from_c(op_module_name, op_func_name, op_type, "build", unknown_shape,
|
||||||
(inputs, outputs, attrs), int64_mode, dynamic_compile_static, unknown_shape, op_pattern,
|
(inputs, outputs, attrs), int64_mode, dynamic_compile_static, op_pattern,
|
||||||
auto_tiling_mode, device_id, json.dumps(fuzz_build_info))
|
auto_tiling_mode, device_id, json.dumps(fuzz_build_info))
|
||||||
# pylint: disable=broad-except
|
# pylint: disable=broad-except
|
||||||
except Exception:
|
except Exception:
|
||||||
|
@ -547,7 +547,7 @@ def rl_tune_fusion_op(job: TbeJob):
|
||||||
compute_op_list = get_compute_op_list(job.content)
|
compute_op_list = get_compute_op_list(job.content)
|
||||||
op_module_names_str = ""
|
op_module_names_str = ""
|
||||||
for op in compute_op_list:
|
for op in compute_op_list:
|
||||||
op_module_names_str = op_module_names_str + "," + get_module_name(op)
|
op_module_names_str = op_module_names_str + "," + op["module_name"]
|
||||||
op_module_names_str = op_module_names_str[1:]
|
op_module_names_str = op_module_names_str[1:]
|
||||||
from schedule_search.rl_online_tune import dispatch_fusion_tune_task
|
from schedule_search.rl_online_tune import dispatch_fusion_tune_task
|
||||||
res = dispatch_fusion_tune_task(job.source_id, job.id, l1_size, base_kernel, op_kernel_name, op_module_names_str,
|
res = dispatch_fusion_tune_task(job.source_id, job.id, l1_size, base_kernel, op_kernel_name, op_module_names_str,
|
||||||
|
|
|
@ -179,6 +179,8 @@ def get_options_info(job_content):
|
||||||
options["op_debug_level"] = job_content["SocInfo"]["op_debug_level"]
|
options["op_debug_level"] = job_content["SocInfo"]["op_debug_level"]
|
||||||
options["op_impl_mode"] = job_content["SocInfo"]["op_impl_mode"]
|
options["op_impl_mode"] = job_content["SocInfo"]["op_impl_mode"]
|
||||||
options["op_debug_dir"] = job_content["SocInfo"]["op_debug_dir"]
|
options["op_debug_dir"] = job_content["SocInfo"]["op_debug_dir"]
|
||||||
|
options["op_compiler_cache_dir"] = job_content["SocInfo"]["op_compiler_cache_dir"]
|
||||||
|
options["op_compiler_cache_mode"] = job_content["SocInfo"]["op_compiler_cache_mode"]
|
||||||
options["mdl_bank_path"] = job_content["SocInfo"]["op_debug_level"]
|
options["mdl_bank_path"] = job_content["SocInfo"]["op_debug_level"]
|
||||||
options["op_bank_path"] = job_content["SocInfo"]["op_bank_path"]
|
options["op_bank_path"] = job_content["SocInfo"]["op_bank_path"]
|
||||||
options["deviceId"] = job_content["SocInfo"]["deviceId"]
|
options["deviceId"] = job_content["SocInfo"]["deviceId"]
|
||||||
|
@ -218,19 +220,6 @@ def get_func_names(job_content):
|
||||||
return func_names
|
return func_names
|
||||||
|
|
||||||
|
|
||||||
def get_module_name(compute_op_info):
|
|
||||||
"""
|
|
||||||
get compute_op_info
|
|
||||||
:param compute_op_info:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
unknown_shape = compute_op_info["unknown_shape"]
|
|
||||||
op_module_name = compute_op_info["module_name"]
|
|
||||||
if unknown_shape:
|
|
||||||
op_module_name = op_module_name.split(".")[0] + ".dynamic." + op_module_name.split(".")[-1]
|
|
||||||
return op_module_name
|
|
||||||
|
|
||||||
|
|
||||||
def adjust_custom_op_info(compute_op_info):
|
def adjust_custom_op_info(compute_op_info):
|
||||||
"""
|
"""
|
||||||
adjust custom op info
|
adjust custom op info
|
||||||
|
|
|
@ -71,13 +71,12 @@ def _get_message(msg, args):
|
||||||
class TbeJob:
|
class TbeJob:
|
||||||
""" Tbe compilation job """
|
""" Tbe compilation job """
|
||||||
|
|
||||||
def __init__(self, source_id, job_id, job_type, content, fusion_op_name, json_str, sys_info):
|
def __init__(self, source_id, job_id, job_type, content, json_str, sys_info):
|
||||||
self.source_id = source_id
|
self.source_id = source_id
|
||||||
self.id = job_id
|
self.id = job_id
|
||||||
self.type = JobType(job_type)
|
self.type = JobType(job_type)
|
||||||
self.status = JobStatus.JOB_INITIAL
|
self.status = JobStatus.JOB_INITIAL
|
||||||
self.content = content
|
self.content = content
|
||||||
self.fusion_op_name = fusion_op_name
|
|
||||||
self.result = ""
|
self.result = ""
|
||||||
self.process_info = []
|
self.process_info = []
|
||||||
self.json_string = json_str
|
self.json_string = json_str
|
||||||
|
@ -150,8 +149,8 @@ class TbeJob:
|
||||||
result["source_id"] = self.source_id
|
result["source_id"] = self.source_id
|
||||||
result["job_id"] = self.id
|
result["job_id"] = self.id
|
||||||
result["job_type"] = self.type.value
|
result["job_type"] = self.type.value
|
||||||
result["fusion_op_name"] = self.fusion_op_name
|
|
||||||
result["result"] = self.result
|
result["result"] = self.result
|
||||||
|
self.debug("Resp result:{}".format(json.dumps(result)))
|
||||||
process_info = []
|
process_info = []
|
||||||
for info in self.process_info:
|
for info in self.process_info:
|
||||||
msg = {"index": info.index, "level": info.level.value, "message": info.info}
|
msg = {"index": info.index, "level": info.level.value, "message": info.info}
|
||||||
|
|
|
@ -102,9 +102,8 @@ class TbeJobManager:
|
||||||
source_id = job_json["source_id"]
|
source_id = job_json["source_id"]
|
||||||
job_type = job_json["job_type"]
|
job_type = job_json["job_type"]
|
||||||
sys_info = self._get_job_sys_info()
|
sys_info = self._get_job_sys_info()
|
||||||
fusion_op_name = "NA" if "fusion_op_name" not in job_json["job_content"] else job_json["job_content"][
|
job = TbeJob(source_id, job_id, job_type, job_json["job_content"], job_str, sys_info)
|
||||||
"fusion_op_name"]
|
job.debug("Req job string: {}".format(job_str))
|
||||||
job = TbeJob(source_id, job_id, job_type, job_json["job_content"], fusion_op_name, job_str, sys_info)
|
|
||||||
post_job(self._all_jobs, job)
|
post_job(self._all_jobs, job)
|
||||||
if not self.tbe_initialize and job.type != JobType.INITIALIZE_JOB:
|
if not self.tbe_initialize and job.type != JobType.INITIALIZE_JOB:
|
||||||
job.error(
|
job.error(
|
||||||
|
@ -116,7 +115,6 @@ class TbeJobManager:
|
||||||
return res
|
return res
|
||||||
# pylint: disable=broad-except
|
# pylint: disable=broad-except
|
||||||
except Exception:
|
except Exception:
|
||||||
# pylint: disable=no-value-for-parameter
|
|
||||||
sys_info = self._get_job_sys_info()
|
sys_info = self._get_job_sys_info()
|
||||||
job = TbeJob(-1, -1, "", None, job_str, sys_info) if job is None else job
|
job = TbeJob(-1, -1, "", None, job_str, sys_info) if job is None else job
|
||||||
job.status = JobStatus.JOB_FAILED
|
job.status = JobStatus.JOB_FAILED
|
||||||
|
@ -263,6 +261,9 @@ class TbeJobManager:
|
||||||
return self.add_to_finished_jobs(query_job, JobStatus.JOB_SUCCESS)
|
return self.add_to_finished_jobs(query_job, JobStatus.JOB_SUCCESS)
|
||||||
target_job = get_job(self._running_jobs, target_source_id, target_job_id)
|
target_job = get_job(self._running_jobs, target_source_id, target_job_id)
|
||||||
if target_job:
|
if target_job:
|
||||||
|
query_job.debug("Found job in Running jobs, source_id:{}, job_id:{}".format(target_source_id,
|
||||||
|
target_job_id))
|
||||||
|
target_job.debug("Be Queried")
|
||||||
query_job.result = target_job.get_result()
|
query_job.result = target_job.get_result()
|
||||||
return self.add_to_finished_jobs(query_job, JobStatus.JOB_SUCCESS)
|
return self.add_to_finished_jobs(query_job, JobStatus.JOB_SUCCESS)
|
||||||
target_job = get_job(self._all_jobs, target_source_id, target_job_id)
|
target_job = get_job(self._all_jobs, target_source_id, target_job_id)
|
||||||
|
|
|
@ -159,17 +159,12 @@ def resolve_symbol(namespace, symbol):
|
||||||
if getattr(resolve_, "__hash__") is None:
|
if getattr(resolve_, "__hash__") is None:
|
||||||
return resolve_
|
return resolve_
|
||||||
|
|
||||||
# Raise NotImplementedError when parsing the numpy methods, but not the numpy constant.
|
|
||||||
if namespace.name == "numpy" and isinstance(resolve_, (types.FunctionType, types.MethodType, types.ModuleType)):
|
|
||||||
raise NotImplementedError(
|
|
||||||
f"MindSpore does not support to use the numpy methods in the function construct with the graph mode.")
|
|
||||||
|
|
||||||
# If need trope the obj
|
# If need trope the obj
|
||||||
if resolve_ in convert_object_map:
|
if resolve_ in convert_object_map:
|
||||||
resolve_ = convert_object_map.get(resolve_)
|
resolve_ = convert_object_map.get(resolve_)
|
||||||
logger.debug("convert resolve = %r", resolve_)
|
logger.debug("convert resolve = %r", resolve_)
|
||||||
if resolve_ == NO_IMPLEMENT:
|
if resolve_ == NO_IMPLEMENT:
|
||||||
raise NotImplementedError(f"Not support for `{symbol}`.")
|
raise NotImplementedError(f"Not support for `{symbol}`")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if isinstance(e, NotImplementedError):
|
if isinstance(e, NotImplementedError):
|
||||||
raise e
|
raise e
|
||||||
|
|
|
@ -1312,8 +1312,7 @@ def sum(x, axis=None, dtype=None, keepdims=False, initial=None): # pylint: disab
|
||||||
>>> print(input_x.sum(axis=1))
|
>>> print(input_x.sum(axis=1))
|
||||||
[10. 35.]
|
[10. 35.]
|
||||||
"""
|
"""
|
||||||
input_x = x.astype(mstype.int32) if x.dtype == mstype.bool_ else x
|
dtype = x.dtype if dtype is None else dtype
|
||||||
dtype = input_x.dtype if dtype is None else dtype
|
|
||||||
if not isinstance(keepdims, int):
|
if not isinstance(keepdims, int):
|
||||||
const_utils.raise_type_error("integer argument expected")
|
const_utils.raise_type_error("integer argument expected")
|
||||||
if initial is not None and not isinstance(initial, (int, float, bool)):
|
if initial is not None and not isinstance(initial, (int, float, bool)):
|
||||||
|
@ -1323,14 +1322,14 @@ def sum(x, axis=None, dtype=None, keepdims=False, initial=None): # pylint: disab
|
||||||
else:
|
else:
|
||||||
axis = check_and_canonicalize_axes(axis, x.ndim)
|
axis = check_and_canonicalize_axes(axis, x.ndim)
|
||||||
|
|
||||||
if not check_type_support(input_x.dtype, 'GPU', (mstype.float64, mstype.float32, mstype.float16)):
|
if x.dtype == mstype.bool_:
|
||||||
input_x = input_x.astype(mstype.float32)
|
x = x.astype("int32")
|
||||||
if 0 in x.shape:
|
if 0 in x.shape:
|
||||||
x = const_utils.make_tensor([0], x.dtype)
|
x = const_utils.make_tensor([0], x.dtype)
|
||||||
if keepdims:
|
if keepdims:
|
||||||
res = _reduce_sum_keepdims(input_x, axis)
|
res = _reduce_sum_keepdims(x, axis)
|
||||||
else:
|
else:
|
||||||
res = _reduce_sum_default(input_x, axis)
|
res = _reduce_sum_default(x, axis)
|
||||||
if initial is not None:
|
if initial is not None:
|
||||||
res += initial
|
res += initial
|
||||||
return res.astype(dtype)
|
return res.astype(dtype)
|
||||||
|
@ -1649,7 +1648,6 @@ get_log2_size = constexpr(validator.get_log2_size)
|
||||||
check_axis_type = constexpr(validator.check_axis_type)
|
check_axis_type = constexpr(validator.check_axis_type)
|
||||||
check_and_canonicalize_axes = constexpr(validator.check_and_canonicalize_axes)
|
check_and_canonicalize_axes = constexpr(validator.check_and_canonicalize_axes)
|
||||||
empty_compile = constexpr(validator.empty_compile)
|
empty_compile = constexpr(validator.empty_compile)
|
||||||
check_type_support = constexpr(validator.check_type_support)
|
|
||||||
|
|
||||||
|
|
||||||
def tensor_bool(x):
|
def tensor_bool(x):
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
import os
|
import os
|
||||||
from mindspore import log as logger
|
from mindspore import log as logger
|
||||||
from mindspore._extends.parallel_compile.akg_compiler.akg_process import create_akg_parallel_process
|
from mindspore._extends.parallel_compile.akg_compiler.akg_process import create_akg_parallel_process
|
||||||
|
from mindspore._extends.parallel_compile.akg_compiler.compiler import run_compiler as akg_compile_single
|
||||||
|
|
||||||
|
|
||||||
class Messager:
|
class Messager:
|
||||||
|
@ -145,7 +146,9 @@ class AkgBuilder():
|
||||||
|
|
||||||
def handle(self, messager, arg):
|
def handle(self, messager, arg):
|
||||||
"""Handle message about akg"""
|
"""Handle message about akg"""
|
||||||
if arg == 'AKG/START':
|
if arg == 'AKG/PID':
|
||||||
|
messager.send_res(os.getpid())
|
||||||
|
elif arg == 'AKG/START':
|
||||||
messager.send_ack()
|
messager.send_ack()
|
||||||
process_num_str = messager.get_message()
|
process_num_str = messager.get_message()
|
||||||
messager.send_ack()
|
messager.send_ack()
|
||||||
|
@ -170,8 +173,17 @@ class AkgBuilder():
|
||||||
else:
|
else:
|
||||||
messager.send_ack(False)
|
messager.send_ack(False)
|
||||||
break
|
break
|
||||||
else:
|
elif arg == 'AKG/COMPILE':
|
||||||
raise RuntimeError("Unknown message type: %s" % arg)
|
messager.send_ack()
|
||||||
|
json = messager.get_message()
|
||||||
|
try:
|
||||||
|
akg_compile_single(json, self.attrs)
|
||||||
|
except ValueError:
|
||||||
|
messager.send_ack(False)
|
||||||
|
messager.exit()
|
||||||
|
finally:
|
||||||
|
pass
|
||||||
|
messager.send_ack()
|
||||||
|
|
||||||
|
|
||||||
def get_logger():
|
def get_logger():
|
||||||
|
|
|
@ -297,14 +297,20 @@ if(MODE_ASCEND_ALL)
|
||||||
${ASCEND_DRIVER_BACK_PATH})
|
${ASCEND_DRIVER_BACK_PATH})
|
||||||
find_library(DATATRANSFER datatransfer HINTS ${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}
|
find_library(DATATRANSFER datatransfer HINTS ${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}
|
||||||
${ASCEND_DRIVER_BACK_PATH})
|
${ASCEND_DRIVER_BACK_PATH})
|
||||||
find_library(PROFILING msprofiler ${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
|
find_library(PROFILING msprofiler_fwkacl ${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
|
||||||
find_library(ACL ascendcl ${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
|
find_library(ACL ascendcl ${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
|
||||||
find_library(PLATFORM platform ${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
|
find_library(PLATFORM platform ${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
|
||||||
find_library(OPTILING optiling ${ASCEND_OPP_PATH} ${ASCEND_TOOLKIT_OPP_PATH})
|
find_library(OPTILING optiling ${ASCEND_OPP_PATH} ${ASCEND_TOOLKIT_OPP_PATH})
|
||||||
find_library(OPT_FEATURE opt_feature ${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
|
find_library(OPT_FEATURE opt_feature ${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
|
||||||
|
|
||||||
|
add_library(ms_profile SHARED
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/runtime/device/ascend/profiling/profiling_callback_register.cc)
|
||||||
|
set_target_properties(ms_profile PROPERTIES LINKER_LANGUAGE CXX)
|
||||||
|
target_link_options(ms_profile PRIVATE -Wl,-init,common_log_init)
|
||||||
|
target_link_libraries(ms_profile -Wl,--start-group -Wl,--whole-archive ${PROFILING} -Wl,--no-whole-archive
|
||||||
|
mindspore::protobuf -Wl,--end-group)
|
||||||
target_link_libraries(mindspore ${RUNTIME_LIB} ${TSDCLIENT} ${DATATRANSFER} ${ERROR_MANAGER} -Wl,--no-as-needed
|
target_link_libraries(mindspore ${RUNTIME_LIB} ${TSDCLIENT} ${DATATRANSFER} ${ERROR_MANAGER} -Wl,--no-as-needed
|
||||||
${OPTILING} ${PLATFORM} ${ACL} ${OPT_FEATURE} ${PROFILING})
|
${OPTILING} ${PLATFORM} ${ACL} ${OPT_FEATURE})
|
||||||
target_link_libraries(mindspore -Wl,--start-group proto_input mindspore::protobuf -Wl,--end-group)
|
target_link_libraries(mindspore -Wl,--start-group proto_input mindspore::protobuf -Wl,--end-group)
|
||||||
elseif(CMAKE_SYSTEM_NAME MATCHES "Windows")
|
elseif(CMAKE_SYSTEM_NAME MATCHES "Windows")
|
||||||
target_link_libraries(mindspore -Wl,--start-group proto_input mindspore::protobuf mindspore::sentencepiece
|
target_link_libraries(mindspore -Wl,--start-group proto_input mindspore::protobuf mindspore::sentencepiece
|
||||||
|
@ -319,7 +325,7 @@ endif()
|
||||||
set(CMAKE_BUILD_WITH_INSTALL_RPATH TRUE)
|
set(CMAKE_BUILD_WITH_INSTALL_RPATH TRUE)
|
||||||
set_property(SOURCE "pipeline/jit/init.cc" PROPERTY
|
set_property(SOURCE "pipeline/jit/init.cc" PROPERTY
|
||||||
COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PIPELINE)
|
COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PIPELINE)
|
||||||
pybind11_add_module(_c_expression NO_EXTRAS "pipeline/jit/init.cc" NO_EXTRAS)
|
pybind11_add_module(_c_expression "pipeline/jit/init.cc")
|
||||||
|
|
||||||
MESSAGE(STATUS "operation system is ${CMAKE_SYSTEM}")
|
MESSAGE(STATUS "operation system is ${CMAKE_SYSTEM}")
|
||||||
if(CMAKE_SYSTEM_NAME MATCHES "Linux")
|
if(CMAKE_SYSTEM_NAME MATCHES "Linux")
|
||||||
|
@ -369,6 +375,9 @@ else()
|
||||||
proto_input -Wl,--no-whole-archive)
|
proto_input -Wl,--no-whole-archive)
|
||||||
target_link_libraries(_c_expression PRIVATE mindspore::pybind11_module)
|
target_link_libraries(_c_expression PRIVATE mindspore::pybind11_module)
|
||||||
target_link_libraries(_c_expression PRIVATE mindspore_gvar)
|
target_link_libraries(_c_expression PRIVATE mindspore_gvar)
|
||||||
|
if(MODE_ASCEND_ALL)
|
||||||
|
target_link_libraries(_c_expression PRIVATE -Wl,--no-as-needed ms_profile)
|
||||||
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(USE_GLOG)
|
if(USE_GLOG)
|
||||||
|
|
|
@ -35,8 +35,6 @@ if(ENABLE_CPU)
|
||||||
"cpu/fl/*.cc"
|
"cpu/fl/*.cc"
|
||||||
"cpu/ps/*.cc"
|
"cpu/ps/*.cc"
|
||||||
"cpu/quantum/*.cc"
|
"cpu/quantum/*.cc"
|
||||||
"cpu/pyfunc/*.cc"
|
|
||||||
"cpu/rl/*.cc"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if(NOT ENABLE_MPI)
|
if(NOT ENABLE_MPI)
|
||||||
|
@ -85,7 +83,6 @@ if(NOT ENABLE_CPU OR WIN32)
|
||||||
list(REMOVE_ITEM CPU_SRC_LIST "cpu/fl/get_model_kernel.cc")
|
list(REMOVE_ITEM CPU_SRC_LIST "cpu/fl/get_model_kernel.cc")
|
||||||
list(REMOVE_ITEM CPU_SRC_LIST "cpu/fl/start_fl_job_kernel.cc")
|
list(REMOVE_ITEM CPU_SRC_LIST "cpu/fl/start_fl_job_kernel.cc")
|
||||||
list(REMOVE_ITEM CPU_SRC_LIST "cpu/fl/update_model_kernel.cc")
|
list(REMOVE_ITEM CPU_SRC_LIST "cpu/fl/update_model_kernel.cc")
|
||||||
list(REMOVE_ITEM CPU_SRC_LIST "cpu/fl/push_metrics_kernel.cc")
|
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(ENABLE_GPU)
|
if(ENABLE_GPU)
|
||||||
|
|
|
@ -16,11 +16,6 @@
|
||||||
|
|
||||||
#include "backend/kernel_compiler/akg/akg_kernel_build.h"
|
#include "backend/kernel_compiler/akg/akg_kernel_build.h"
|
||||||
|
|
||||||
#include <stdio.h>
|
|
||||||
#include <errno.h>
|
|
||||||
#include <fcntl.h>
|
|
||||||
#include <unistd.h>
|
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
@ -28,7 +23,6 @@
|
||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <iostream>
|
|
||||||
#include "nlohmann/json.hpp"
|
#include "nlohmann/json.hpp"
|
||||||
#include "ir/dtype.h"
|
#include "ir/dtype.h"
|
||||||
#include "ir/func_graph.h"
|
#include "ir/func_graph.h"
|
||||||
|
@ -40,346 +34,17 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
|
|
||||||
#define INIT_SET_FROM_2D_ARRAY(set_var, list_idx) \
|
|
||||||
std::set<size_t> set_var(kernel_lists_[list_idx], kernel_lists_[list_idx] + kernel_lists_[list_idx][kMaxKernelNum_]);
|
|
||||||
|
|
||||||
#define LIST_BEGIN(list_idx) kernel_lists_[list_idx]
|
|
||||||
#define LIST_END(list_idx) (kernel_lists_[list_idx] + kernel_lists_[list_idx][kMaxKernelNum_])
|
|
||||||
#define RESET_LIST_SIZE(list_idx, val) kernel_lists_[list_idx][kMaxKernelNum_] = val
|
|
||||||
|
|
||||||
#define INCREASE_LIST_SIZE(list_idx, val) kernel_lists_[list_idx][kMaxKernelNum_] += val
|
|
||||||
|
|
||||||
constexpr int32_t PROCESS_NUM = 16;
|
constexpr int32_t PROCESS_NUM = 16;
|
||||||
constexpr int32_t TIME_OUT = 300;
|
constexpr int32_t TIME_OUT = 300;
|
||||||
|
|
||||||
bool AkgKernelPool::LockMng::TryLock() {
|
std::vector<std::string> AkgKernelBuilder::GetNotCachedKernelJsons(const std::vector<JsonNodePair> &build_args) {
|
||||||
// Try to lock 100 times. Return errno if lock unsuccessfully
|
// Remove cached nodes, gether unique nodes, and collect repeated nodes which need postprecess.
|
||||||
uint32_t trial = 100;
|
std::vector<std::string> jsons;
|
||||||
|
|
||||||
int32_t ret = -1;
|
|
||||||
while (trial > 0) {
|
|
||||||
ret = lockf(fd_, F_TLOCK, 0);
|
|
||||||
if (ret == 0 || (errno != EACCES && errno != EAGAIN)) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
trial--;
|
|
||||||
usleep(5000);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (ret == -1) {
|
|
||||||
MS_LOG(ERROR) << "Failed to acquire the lock, errno:" << strerror(errno) << ".";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
void AkgKernelPool::LockMng::Unlock() {
|
|
||||||
auto ret = lockf(fd_, F_ULOCK, 0);
|
|
||||||
if (ret == -1) {
|
|
||||||
MS_LOG(ERROR) << "Failed to release the lock, errno:" << strerror(errno);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string AkgKernelPool::GetCurrentPath() {
|
|
||||||
char cwd[PATH_MAX];
|
|
||||||
char *ret = getcwd(cwd, sizeof(cwd));
|
|
||||||
if (ret == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "Get current work directory failed, errno:" << strerror(errno);
|
|
||||||
return "";
|
|
||||||
}
|
|
||||||
|
|
||||||
char abspath[PATH_MAX];
|
|
||||||
char *res = realpath(cwd, abspath);
|
|
||||||
if (res == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "Change to realpath failed, errno:" << strerror(errno);
|
|
||||||
return "";
|
|
||||||
}
|
|
||||||
|
|
||||||
return std::string(abspath);
|
|
||||||
}
|
|
||||||
|
|
||||||
void *AkgKernelPool::CreateSharedMem(const std::string &path) {
|
|
||||||
is_creator_ = false;
|
|
||||||
|
|
||||||
auto hash_id = std::hash<std::string>()(path);
|
|
||||||
auto key_id = static_cast<key_t>(hash_id);
|
|
||||||
auto mem_size = sizeof(size_t) * kListNum_ * (kMaxKernelNum_ + 1) + 512;
|
|
||||||
|
|
||||||
{
|
|
||||||
LockMng lock(fd_);
|
|
||||||
if (!lock.locked_) {
|
|
||||||
MS_LOG(ERROR) << "Failed to acquire lock.";
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
// check if the shared memory exists or not.
|
|
||||||
// remove shared memory if exists and the nattach is 0
|
|
||||||
struct shmid_ds buf;
|
|
||||||
auto id = shmget(key_id, mem_size, 0);
|
|
||||||
if (id != -1) {
|
|
||||||
auto ret = shmctl(id, IPC_STAT, &buf);
|
|
||||||
if (ret == -1) {
|
|
||||||
MS_LOG(ERROR) << "Failed to get the info of shared memory, errno:" << strerror(errno);
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (buf.shm_nattch == 0) {
|
|
||||||
ret = shmctl(id, IPC_RMID, nullptr);
|
|
||||||
if (ret < 0) {
|
|
||||||
MS_LOG(EXCEPTION) << "Realse shared_mem failed, errno:" << strerror(errno);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
LockMng lock(fd_);
|
|
||||||
if (!lock.locked_) {
|
|
||||||
MS_LOG(ERROR) << "Failed to acquire lock.";
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
shm_id_ = shmget(key_id, mem_size, IPC_CREAT | IPC_EXCL | 0600);
|
|
||||||
if (shm_id_ == -1) {
|
|
||||||
if (errno == EEXIST) {
|
|
||||||
shm_id_ = shmget(key_id, mem_size, 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (shm_id_ == -1) {
|
|
||||||
MS_LOG(ERROR) << "Create shared_mem failed, error no:" << strerror(errno);
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
is_creator_ = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto local_addr = shmat(shm_id_, nullptr, 0);
|
|
||||||
if (local_addr == reinterpret_cast<void *>(-1)) {
|
|
||||||
MS_LOG(ERROR) << "Attach to shared_mem failed, error no:" << strerror(errno);
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (is_creator_) {
|
|
||||||
(void)memset(local_addr, 0, mem_size);
|
|
||||||
}
|
|
||||||
|
|
||||||
return local_addr;
|
|
||||||
}
|
|
||||||
|
|
||||||
int32_t AkgKernelPool::Init(const std::vector<JsonNodePair> &build_args) {
|
|
||||||
auto cp = GetCurrentPath();
|
|
||||||
if (cp.empty()) {
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
fd_ = open(kKeyName_, O_CREAT | O_RDWR, S_IRUSR | S_IWUSR);
|
|
||||||
if (fd_ == -1) {
|
|
||||||
MS_LOG(ERROR) << "open file <" << kKeyName_ << "> failed, errno:" << strerror(errno);
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto addr = CreateSharedMem(cp);
|
|
||||||
if (addr == nullptr) {
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
InitKernelLists(addr);
|
|
||||||
|
|
||||||
auto ret = AddKernels(build_args);
|
|
||||||
if (ret != 0) {
|
|
||||||
MS_LOG(ERROR) << "AkgKernelPool AddKernels failed.";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
AkgKernelPool::~AkgKernelPool() {
|
|
||||||
{
|
|
||||||
LockMng lock(fd_);
|
|
||||||
if (!lock.locked_) {
|
|
||||||
MS_LOG(EXCEPTION) << "Failed to acquire lock.";
|
|
||||||
}
|
|
||||||
|
|
||||||
struct shmid_ds buf;
|
|
||||||
auto ret = shmctl(shm_id_, IPC_STAT, &buf);
|
|
||||||
if (ret == -1) {
|
|
||||||
MS_LOG(EXCEPTION) << "Failed to get the info of shared memory, errno:" << strerror(errno);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool need_delete_by_last = false;
|
|
||||||
|
|
||||||
// if the creator exits unexpectedly and fails to delete the shm, the last process will try to delete the shm
|
|
||||||
if (((buf.shm_perm.mode & SHM_DEST) == 0) && (buf.shm_nattch == 1)) {
|
|
||||||
need_delete_by_last = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Detach shared memory
|
|
||||||
ret = shmdt(reinterpret_cast<void *>(kernel_lists_[0]));
|
|
||||||
if (ret < 0) {
|
|
||||||
MS_LOG(EXCEPTION) << "Shared_mem detach failed, errno:" << strerror(errno);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Realse shared_memroy
|
|
||||||
if (is_creator_ || need_delete_by_last) {
|
|
||||||
ret = shmctl(shm_id_, IPC_RMID, nullptr);
|
|
||||||
if (ret < 0) {
|
|
||||||
MS_LOG(EXCEPTION) << "Realse shared_mem failed, errno:" << strerror(errno);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close key file
|
|
||||||
if (fd_ != -1) {
|
|
||||||
(void)close(fd_);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
int32_t AkgKernelPool::AddKernels(const std::vector<JsonNodePair> &build_args) {
|
|
||||||
LockMng lock(fd_);
|
|
||||||
if (!lock.locked_) {
|
|
||||||
MS_LOG(ERROR) << "Failed to acquire lock.";
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
INIT_SET_FROM_2D_ARRAY(todo_list, kToDoIdx_);
|
|
||||||
INIT_SET_FROM_2D_ARRAY(doing_list, kDoingIdx_);
|
|
||||||
INIT_SET_FROM_2D_ARRAY(done_list, kDoneIdx_);
|
|
||||||
|
|
||||||
for (const auto &[json_generator, anf_node] : build_args) {
|
|
||||||
MS_EXCEPTION_IF_NULL(anf_node);
|
|
||||||
auto kernel_name = json_generator.kernel_name();
|
|
||||||
|
|
||||||
auto hash_id = std::hash<std::string>()(kernel_name);
|
|
||||||
if (self_kernel_ids_.count(hash_id) != 0) {
|
|
||||||
MS_LOG(ERROR) << "Duplicated hash_id in list.";
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
self_kernel_ids_.emplace(hash_id);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::set<size_t> diff_from_todo;
|
|
||||||
std::set<size_t> diff_from_doing;
|
|
||||||
std::set<size_t> diff_from_done;
|
|
||||||
|
|
||||||
// add the unique kernel only once, so need to check if it exists in todo_list, doing_list, or done_list
|
|
||||||
std::set_difference(self_kernel_ids_.begin(), self_kernel_ids_.end(), todo_list.begin(), todo_list.end(),
|
|
||||||
std::inserter(diff_from_todo, diff_from_todo.begin()));
|
|
||||||
std::set_difference(diff_from_todo.begin(), diff_from_todo.end(), doing_list.begin(), doing_list.end(),
|
|
||||||
std::inserter(diff_from_doing, diff_from_doing.begin()));
|
|
||||||
std::set_difference(diff_from_doing.begin(), diff_from_doing.end(), done_list.begin(), done_list.end(),
|
|
||||||
std::inserter(diff_from_done, diff_from_done.begin()));
|
|
||||||
|
|
||||||
auto new_kernel_size = diff_from_done.size();
|
|
||||||
if (new_kernel_size + todo_list.size() > static_cast<size_t>(kMaxKernelNum_)) {
|
|
||||||
MS_LOG(ERROR) << "The size of kernels is " << new_kernel_size << ", while the left space of the pool is "
|
|
||||||
<< kMaxKernelNum_ - todo_list.size();
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::copy(diff_from_done.begin(), diff_from_done.end(), LIST_END(kToDoIdx_));
|
|
||||||
INCREASE_LIST_SIZE(kToDoIdx_, new_kernel_size);
|
|
||||||
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
int32_t AkgKernelPool::FetchKernels(std::set<size_t> *out) {
|
|
||||||
LockMng lock(fd_);
|
|
||||||
if (!lock.locked_) {
|
|
||||||
MS_LOG(ERROR) << "Failed to acquire lock.";
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::set<size_t> left_in_todo_list;
|
|
||||||
|
|
||||||
// filter out kernels which belongs to other processes
|
|
||||||
auto FilterBySelfList = [&left_in_todo_list, &out, this](size_t id) {
|
|
||||||
if (this->self_kernel_ids_.count(id) != 0) {
|
|
||||||
out->emplace(id);
|
|
||||||
} else {
|
|
||||||
left_in_todo_list.emplace(id);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
std::for_each(LIST_BEGIN(kToDoIdx_), LIST_END(kToDoIdx_), FilterBySelfList);
|
|
||||||
|
|
||||||
std::copy(out->begin(), out->end(), LIST_END(kDoingIdx_));
|
|
||||||
INCREASE_LIST_SIZE(kDoingIdx_, out->size());
|
|
||||||
|
|
||||||
std::copy(left_in_todo_list.begin(), left_in_todo_list.end(), LIST_BEGIN(kToDoIdx_));
|
|
||||||
RESET_LIST_SIZE(kToDoIdx_, left_in_todo_list.size());
|
|
||||||
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
int32_t AkgKernelPool::UpdateAndWait(const std::set<size_t> &ids) {
|
|
||||||
if (!ids.empty()) {
|
|
||||||
LockMng lock(fd_);
|
|
||||||
if (!lock.locked_) {
|
|
||||||
MS_LOG(ERROR) << "Failed to acquire lock.";
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
// update the state of finished kernels to `done`
|
|
||||||
std::copy(ids.begin(), ids.end(), LIST_END(kDoneIdx_));
|
|
||||||
INCREASE_LIST_SIZE(kDoneIdx_, ids.size());
|
|
||||||
|
|
||||||
// delete the finished kernels from doing_list
|
|
||||||
std::vector<size_t> left_in_doing_list;
|
|
||||||
INIT_SET_FROM_2D_ARRAY(doing_list, kDoingIdx_);
|
|
||||||
std::set_difference(doing_list.begin(), doing_list.end(), ids.begin(), ids.end(),
|
|
||||||
std::inserter(left_in_doing_list, left_in_doing_list.begin()));
|
|
||||||
|
|
||||||
std::copy(left_in_doing_list.begin(), left_in_doing_list.end(), LIST_BEGIN(kDoingIdx_));
|
|
||||||
RESET_LIST_SIZE(kDoingIdx_, left_in_doing_list.size());
|
|
||||||
}
|
|
||||||
|
|
||||||
auto ret = Wait();
|
|
||||||
if (ret != 0) {
|
|
||||||
MS_LOG(ERROR) << "AkgKernelPool Wait failed.";
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
int32_t AkgKernelPool::Wait() {
|
|
||||||
// wait until all the kernels which belong to this process finish compiling
|
|
||||||
uint32_t trials = 1000;
|
|
||||||
|
|
||||||
while (trials > 0) {
|
|
||||||
{
|
|
||||||
LockMng lock(fd_);
|
|
||||||
if (!lock.locked_) {
|
|
||||||
MS_LOG(ERROR) << "Failed to acquire lock.";
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
INIT_SET_FROM_2D_ARRAY(done_list, kDoneIdx_);
|
|
||||||
|
|
||||||
if (std::all_of(self_kernel_ids_.begin(), self_kernel_ids_.end(),
|
|
||||||
[&done_list](size_t id) { return done_list.count(id) != 0; })) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
usleep(1000000);
|
|
||||||
trials--;
|
|
||||||
}
|
|
||||||
|
|
||||||
MS_LOG(ERROR) << "Time out while wait kernel compiling";
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<JsonNodePair> AkgKernelBuilder::GetNotCachedKernels(const std::vector<JsonNodePair> &build_args) {
|
|
||||||
std::unordered_set<std::string> kernel_name_set;
|
std::unordered_set<std::string> kernel_name_set;
|
||||||
std::vector<JsonNodePair> new_build_args;
|
|
||||||
for (const auto &[json_generator, anf_node] : build_args) {
|
for (const auto &[json_generator, anf_node] : build_args) {
|
||||||
MS_EXCEPTION_IF_NULL(anf_node);
|
MS_EXCEPTION_IF_NULL(anf_node);
|
||||||
auto kernel_name = json_generator.kernel_name();
|
auto kernel_name = json_generator.kernel_name();
|
||||||
|
MS_LOG(DEBUG) << "Akg start compile op: " << kernel_name;
|
||||||
|
|
||||||
auto cached_kernel_pack = AkgSearchCache(kernel_name);
|
auto cached_kernel_pack = AkgSearchCache(kernel_name);
|
||||||
if (cached_kernel_pack != nullptr) {
|
if (cached_kernel_pack != nullptr) {
|
||||||
|
@ -394,9 +59,11 @@ std::vector<JsonNodePair> AkgKernelBuilder::GetNotCachedKernels(const std::vecto
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
kernel_name_set.insert(kernel_name);
|
kernel_name_set.insert(kernel_name);
|
||||||
new_build_args.push_back({json_generator, anf_node});
|
auto kernel_json = json_generator.kernel_json_str();
|
||||||
|
AkgSaveJsonInfo(kernel_name, kernel_json);
|
||||||
|
jsons.push_back(kernel_json);
|
||||||
}
|
}
|
||||||
return new_build_args;
|
return jsons;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool AkgKernelBuilder::InsertToCache(const std::vector<JsonNodePair> &build_args) {
|
bool AkgKernelBuilder::InsertToCache(const std::vector<JsonNodePair> &build_args) {
|
||||||
|
@ -423,57 +90,20 @@ bool AkgKernelBuilder::HandleRepeatNodes() {
|
||||||
<< anf_node->fullname_with_scope() << "].";
|
<< anf_node->fullname_with_scope() << "].";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
MS_LOG(DEBUG) << "Use just compiled kernel, kernel_name[" << kernel_name << "], fullname_with_scope["
|
MS_LOG(INFO) << "Use just compiled kernel, kernel_name[" << kernel_name << "], fullname_with_scope["
|
||||||
<< anf_node->fullname_with_scope() << "].";
|
<< anf_node->fullname_with_scope() << "].";
|
||||||
AkgSetKernelMod(cached_kernel_pack, json_generator, anf_node);
|
AkgSetKernelMod(cached_kernel_pack, json_generator, anf_node);
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::string> AkgKernelBuilder::GetKernelJsonsByHashId(const std::vector<JsonNodePair> &build_args,
|
|
||||||
std::set<size_t> fetched_ids) {
|
|
||||||
std::vector<std::string> jsons;
|
|
||||||
for (const auto &[json_generator, anf_node] : build_args) {
|
|
||||||
MS_EXCEPTION_IF_NULL(anf_node);
|
|
||||||
auto kernel_name = json_generator.kernel_name();
|
|
||||||
|
|
||||||
auto hash_id = std::hash<std::string>()(kernel_name);
|
|
||||||
|
|
||||||
if (fetched_ids.count(hash_id) == 0) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto kernel_json = json_generator.kernel_json_str();
|
|
||||||
AkgSaveJsonInfo(kernel_name, kernel_json);
|
|
||||||
jsons.push_back(kernel_json);
|
|
||||||
}
|
|
||||||
return jsons;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool AkgKernelBuilder::AkgOpParallelBuild(const std::vector<JsonNodePair> &build_args) {
|
bool AkgKernelBuilder::AkgOpParallelBuild(const std::vector<JsonNodePair> &build_args) {
|
||||||
repeat_nodes_.clear();
|
repeat_nodes_.clear();
|
||||||
auto new_build_args = GetNotCachedKernels(build_args);
|
auto jsons = GetNotCachedKernelJsons(build_args);
|
||||||
if (new_build_args.empty()) {
|
if (jsons.empty()) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
AkgKernelPool kp;
|
|
||||||
auto ret = kp.Init(new_build_args);
|
|
||||||
if (ret != 0) {
|
|
||||||
MS_LOG(ERROR) << "AkgKernelPool init failed.";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::set<size_t> fetched_ids;
|
|
||||||
ret = kp.FetchKernels(&fetched_ids);
|
|
||||||
if (ret != 0) {
|
|
||||||
MS_LOG(ERROR) << "AkgKernelPool FetchKernels failed.";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!fetched_ids.empty()) {
|
|
||||||
auto jsons = GetKernelJsonsByHashId(new_build_args, fetched_ids);
|
|
||||||
|
|
||||||
auto client = GetClient();
|
auto client = GetClient();
|
||||||
MS_EXCEPTION_IF_NULL(client);
|
MS_EXCEPTION_IF_NULL(client);
|
||||||
if (!client->AkgStart(PROCESS_NUM, TIME_OUT)) {
|
if (!client->AkgStart(PROCESS_NUM, TIME_OUT)) {
|
||||||
|
@ -493,14 +123,6 @@ bool AkgKernelBuilder::AkgOpParallelBuild(const std::vector<JsonNodePair> &build
|
||||||
MS_LOG(ERROR) << "Akg compile failed.";
|
MS_LOG(ERROR) << "Akg compile failed.";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
ret = kp.UpdateAndWait(fetched_ids);
|
|
||||||
if (ret != 0) {
|
|
||||||
MS_LOG(ERROR) << "AkgKernelPool UpdateAndWait failed.";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// All unique done here, cache them and set kernel.
|
// All unique done here, cache them and set kernel.
|
||||||
if (!InsertToCache(build_args)) {
|
if (!InsertToCache(build_args)) {
|
||||||
MS_LOG(ERROR) << "Insert cache failed.";
|
MS_LOG(ERROR) << "Insert cache failed.";
|
||||||
|
@ -546,7 +168,7 @@ bool AkgKernelBuilder::AkgKernelParallelBuild(const std::vector<AnfNodePtr> &anf
|
||||||
}
|
}
|
||||||
|
|
||||||
if (json_and_node.empty()) {
|
if (json_and_node.empty()) {
|
||||||
MS_LOG(INFO) << "There is no akg kernel to be compiled.";
|
MS_LOG(DEBUG) << "There is no kernel needed to be compiled.";
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -17,13 +17,10 @@
|
||||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_AKG_KERNEL_BUILD_H_
|
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_AKG_KERNEL_BUILD_H_
|
||||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_AKG_KERNEL_BUILD_H_
|
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AKG_AKG_KERNEL_BUILD_H_
|
||||||
|
|
||||||
#include <sys/shm.h>
|
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <set>
|
|
||||||
#include "ir/anf.h"
|
#include "ir/anf.h"
|
||||||
#include "backend/kernel_compiler/kernel.h"
|
#include "backend/kernel_compiler/kernel.h"
|
||||||
#include "backend/session/kernel_build_client.h"
|
#include "backend/session/kernel_build_client.h"
|
||||||
|
@ -47,83 +44,13 @@ class AkgKernelBuilder {
|
||||||
bool AkgKernelParallelBuild(const std::vector<AnfNodePtr> &anf_nodes);
|
bool AkgKernelParallelBuild(const std::vector<AnfNodePtr> &anf_nodes);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::vector<JsonNodePair> GetNotCachedKernels(const std::vector<JsonNodePair> &build_args);
|
std::vector<std::string> GetNotCachedKernelJsons(const std::vector<JsonNodePair> &build_args);
|
||||||
std::vector<std::string> GetKernelJsonsByHashId(const std::vector<JsonNodePair> &build_args,
|
|
||||||
std::set<size_t> fetched_ids);
|
|
||||||
bool InsertToCache(const std::vector<JsonNodePair> &build_args);
|
bool InsertToCache(const std::vector<JsonNodePair> &build_args);
|
||||||
bool HandleRepeatNodes();
|
bool HandleRepeatNodes();
|
||||||
bool AkgOpParallelBuild(const std::vector<JsonNodePair> &build_args);
|
bool AkgOpParallelBuild(const std::vector<JsonNodePair> &build_args);
|
||||||
std::vector<JsonNodePair> repeat_nodes_;
|
std::vector<JsonNodePair> repeat_nodes_;
|
||||||
std::string CollectBuildAttrs();
|
std::string CollectBuildAttrs();
|
||||||
};
|
};
|
||||||
|
|
||||||
class AkgKernelPool {
|
|
||||||
public:
|
|
||||||
class LockMng {
|
|
||||||
public:
|
|
||||||
explicit LockMng(int32_t fd) {
|
|
||||||
fd_ = fd;
|
|
||||||
locked_ = TryLock();
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual ~LockMng() {
|
|
||||||
if (locked_) {
|
|
||||||
Unlock();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
bool locked_{false};
|
|
||||||
|
|
||||||
private:
|
|
||||||
bool TryLock();
|
|
||||||
void Unlock();
|
|
||||||
|
|
||||||
int32_t fd_{-1};
|
|
||||||
};
|
|
||||||
|
|
||||||
public:
|
|
||||||
AkgKernelPool() = default;
|
|
||||||
virtual ~AkgKernelPool();
|
|
||||||
|
|
||||||
int32_t Init(const std::vector<JsonNodePair> &build_args);
|
|
||||||
int32_t FetchKernels(std::set<size_t> *out);
|
|
||||||
int32_t UpdateAndWait(const std::set<size_t> &ids);
|
|
||||||
|
|
||||||
constexpr inline static size_t kMaxKernelNum_{1000};
|
|
||||||
|
|
||||||
// allocate memory for todo_list, doing_list, done_list
|
|
||||||
constexpr inline static size_t kListNum_{3};
|
|
||||||
|
|
||||||
constexpr inline static auto kKeyName_ = "./akg_build_tmp.key";
|
|
||||||
|
|
||||||
constexpr inline static int32_t kToDoIdx_ = 0;
|
|
||||||
constexpr inline static int32_t kDoingIdx_ = 1;
|
|
||||||
constexpr inline static int32_t kDoneIdx_ = 2;
|
|
||||||
|
|
||||||
private:
|
|
||||||
void *CreateSharedMem(const std::string &path);
|
|
||||||
std::string GetCurrentPath();
|
|
||||||
|
|
||||||
inline void InitKernelLists(void *addr) {
|
|
||||||
kernel_lists_[kToDoIdx_] = reinterpret_cast<size_t *>(addr);
|
|
||||||
kernel_lists_[kDoingIdx_] = kernel_lists_[kToDoIdx_] + kMaxKernelNum_ + 1;
|
|
||||||
kernel_lists_[kDoneIdx_] = kernel_lists_[kDoingIdx_] + kMaxKernelNum_ + 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
int32_t AddKernels(const std::vector<JsonNodePair> &kernel_jsons);
|
|
||||||
int32_t Wait();
|
|
||||||
|
|
||||||
int32_t shm_id_{-1};
|
|
||||||
bool is_creator_{false};
|
|
||||||
int32_t fd_{-1};
|
|
||||||
|
|
||||||
// includes 3 lists: todo_list, doing_list, done_list.
|
|
||||||
// each list has kMaxKernelNum_ + 1 elements and, the count of elements in each list
|
|
||||||
// is stored in kernel_lists_[xx][kMaxKernelNum_]
|
|
||||||
size_t *kernel_lists_[kListNum_]{nullptr, nullptr, nullptr};
|
|
||||||
|
|
||||||
std::set<size_t> self_kernel_ids_;
|
|
||||||
};
|
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -15,6 +15,12 @@
|
||||||
*/
|
*/
|
||||||
#include "backend/kernel_compiler/akg/akg_kernel_json_decoder.h"
|
#include "backend/kernel_compiler/akg/akg_kernel_json_decoder.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <memory>
|
||||||
|
#include <sstream>
|
||||||
|
#include <string>
|
||||||
|
#include <map>
|
||||||
|
#include <vector>
|
||||||
#include "backend/kernel_compiler/akg/akg_kernel_json_generator.h"
|
#include "backend/kernel_compiler/akg/akg_kernel_json_generator.h"
|
||||||
#include "backend/kernel_compiler/common_utils.h"
|
#include "backend/kernel_compiler/common_utils.h"
|
||||||
#include "backend/session/anf_runtime_algorithm.h"
|
#include "backend/session/anf_runtime_algorithm.h"
|
||||||
|
|
|
@ -16,6 +16,12 @@
|
||||||
|
|
||||||
#include "backend/kernel_compiler/akg/akg_kernel_json_generator.h"
|
#include "backend/kernel_compiler/akg/akg_kernel_json_generator.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <functional>
|
||||||
|
#include <map>
|
||||||
|
#include <set>
|
||||||
|
#include <sstream>
|
||||||
|
#include <tuple>
|
||||||
#if ENABLE_GPU
|
#if ENABLE_GPU
|
||||||
#include <cuda.h>
|
#include <cuda.h>
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "backend/kernel_compiler/akg/akg_kernel_metadata.h"
|
#include "backend/kernel_compiler/akg/akg_kernel_metadata.h"
|
||||||
|
#include <memory>
|
||||||
#include "backend/session/anf_runtime_algorithm.h"
|
#include "backend/session/anf_runtime_algorithm.h"
|
||||||
#include "backend/kernel_compiler/oplib/oplib.h"
|
#include "backend/kernel_compiler/oplib/oplib.h"
|
||||||
#include "backend/kernel_compiler/common_utils.h"
|
#include "backend/kernel_compiler/common_utils.h"
|
||||||
|
|
|
@ -16,6 +16,13 @@
|
||||||
|
|
||||||
#include "backend/kernel_compiler/akg/ascend/akg_ascend_kernel_build.h"
|
#include "backend/kernel_compiler/akg/ascend/akg_ascend_kernel_build.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <map>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <unordered_set>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
#include "ir/dtype.h"
|
#include "ir/dtype.h"
|
||||||
#include "ir/func_graph.h"
|
#include "ir/func_graph.h"
|
||||||
#include "backend/kernel_compiler/common_utils.h"
|
#include "backend/kernel_compiler/common_utils.h"
|
||||||
|
@ -27,20 +34,18 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
KernelPackPtr AkgAscendKernelBuilder::AkgSearchCache(const std::string &kernel_name) {
|
KernelPackPtr AkgAscendKernelBuilder::AkgSearchCache(const std::string &kernel_name) {
|
||||||
return tbe::TbeUtils::SearchCache(kernel_name, true);
|
return tbe::TbeUtils::SearchCache(kernel_name, kProcessorAiCore);
|
||||||
}
|
}
|
||||||
|
|
||||||
KernelPackPtr AkgAscendKernelBuilder::AkgInsertCache(const std::string &kernel_name) {
|
KernelPackPtr AkgAscendKernelBuilder::AkgInsertCache(const std::string &kernel_name) {
|
||||||
return tbe::TbeUtils::InsertCache(kernel_name, kProcessorAiCore, true);
|
return tbe::TbeUtils::InsertCache(kernel_name, kProcessorAiCore);
|
||||||
}
|
}
|
||||||
|
|
||||||
void AkgAscendKernelBuilder::AkgSetKernelMod(const KernelPackPtr &kernel_pack,
|
void AkgAscendKernelBuilder::AkgSetKernelMod(const KernelPackPtr &kernel_pack,
|
||||||
const AkgKernelJsonGenerator &json_generator, const AnfNodePtr &anf_node) {
|
const AkgKernelJsonGenerator &json_generator, const AnfNodePtr &anf_node) {
|
||||||
auto kernel_mod_ptr = std::make_shared<AkgKernelMod>(kernel_pack);
|
auto kernel_mod_ptr = std::make_shared<AkgKernelMod>(kernel_pack);
|
||||||
auto kernel_json_info = kernel_pack->kernel_json_info();
|
|
||||||
kernel_mod_ptr->SetInputSizeList(json_generator.input_size_list());
|
kernel_mod_ptr->SetInputSizeList(json_generator.input_size_list());
|
||||||
kernel_mod_ptr->SetOutputSizeList(json_generator.output_size_list());
|
kernel_mod_ptr->SetOutputSizeList(json_generator.output_size_list());
|
||||||
kernel_mod_ptr->SetWorkspaceSizeList(kernel_json_info.workspaces);
|
|
||||||
AnfAlgo::SetKernelMod(kernel_mod_ptr, anf_node.get());
|
AnfAlgo::SetKernelMod(kernel_mod_ptr, anf_node.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -49,7 +49,7 @@ const std::vector<size_t> &AkgKernelMod::GetOutputSizeList() const { return outp
|
||||||
|
|
||||||
const std::vector<size_t> &AkgKernelMod::GetWorkspaceSizeList() const { return workspace_size_list_; }
|
const std::vector<size_t> &AkgKernelMod::GetWorkspaceSizeList() const { return workspace_size_list_; }
|
||||||
|
|
||||||
bool AkgKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
bool AkgKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
|
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
|
||||||
if (stream_ptr == nullptr) {
|
if (stream_ptr == nullptr) {
|
||||||
MS_LOG(ERROR) << "stream_ptr should not be nullptr.";
|
MS_LOG(ERROR) << "stream_ptr should not be nullptr.";
|
||||||
|
@ -74,10 +74,6 @@ bool AkgKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vect
|
||||||
[](const AddressPtr &input) -> void * { return input->addr; });
|
[](const AddressPtr &input) -> void * { return input->addr; });
|
||||||
(void)std::transform(std::begin(outputs), std::end(outputs), std::back_inserter(runtime_args),
|
(void)std::transform(std::begin(outputs), std::end(outputs), std::back_inserter(runtime_args),
|
||||||
[](const AddressPtr &output) -> void * { return output->addr; });
|
[](const AddressPtr &output) -> void * { return output->addr; });
|
||||||
if (!workspace.empty()) {
|
|
||||||
(void)std::transform(std::begin(workspace), std::end(workspace), std::back_inserter(runtime_args),
|
|
||||||
[](const AddressPtr &addr) -> void * { return addr->addr; });
|
|
||||||
}
|
|
||||||
|
|
||||||
rtL2Ctrl_t *l2ctrl = nullptr;
|
rtL2Ctrl_t *l2ctrl = nullptr;
|
||||||
auto stream = static_cast<rtStream_t *>(stream_ptr);
|
auto stream = static_cast<rtStream_t *>(stream_ptr);
|
||||||
|
@ -90,8 +86,7 @@ bool AkgKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vect
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<TaskInfoPtr> AkgKernelMod::GenTask(const std::vector<AddressPtr> &inputs,
|
std::vector<TaskInfoPtr> AkgKernelMod::GenTask(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||||
const std::vector<AddressPtr> &workspace,
|
|
||||||
const std::vector<AddressPtr> &outputs, uint32_t stream_id) {
|
const std::vector<AddressPtr> &outputs, uint32_t stream_id) {
|
||||||
if (kernel_pack_ == nullptr) {
|
if (kernel_pack_ == nullptr) {
|
||||||
MS_LOG(EXCEPTION) << "kernel pack should not be nullptr.";
|
MS_LOG(EXCEPTION) << "kernel pack should not be nullptr.";
|
||||||
|
@ -112,10 +107,6 @@ std::vector<TaskInfoPtr> AkgKernelMod::GenTask(const std::vector<AddressPtr> &in
|
||||||
[](const AddressPtr &input) -> void * { return input->addr; });
|
[](const AddressPtr &input) -> void * { return input->addr; });
|
||||||
(void)std::transform(std::begin(outputs), std::end(outputs), std::back_inserter(output_data_addrs),
|
(void)std::transform(std::begin(outputs), std::end(outputs), std::back_inserter(output_data_addrs),
|
||||||
[](const AddressPtr &output) -> void * { return output->addr; });
|
[](const AddressPtr &output) -> void * { return output->addr; });
|
||||||
if (!workspace.empty()) {
|
|
||||||
(void)std::transform(std::begin(workspace), std::end(workspace), std::back_inserter(workspace_addrs),
|
|
||||||
[](const AddressPtr &workspace) -> void * { return workspace->addr; });
|
|
||||||
}
|
|
||||||
|
|
||||||
uint32_t block_dim = DEFAULT_BLOCK_DIM; // default blockdim equal to 1.
|
uint32_t block_dim = DEFAULT_BLOCK_DIM; // default blockdim equal to 1.
|
||||||
auto func_stub = KernelManager::GenFuncStub(*kernel_pack_, false, &block_dim);
|
auto func_stub = KernelManager::GenFuncStub(*kernel_pack_, false, &block_dim);
|
||||||
|
|
|
@ -39,15 +39,14 @@ KernelPackPtr AkgGpuKernelBuilder::AkgInsertCache(const std::string &kernel_name
|
||||||
void AkgGpuKernelBuilder::AkgSetKernelMod(const KernelPackPtr &kernel_pack,
|
void AkgGpuKernelBuilder::AkgSetKernelMod(const KernelPackPtr &kernel_pack,
|
||||||
const AkgKernelJsonGenerator &json_generator, const AnfNodePtr &anf_node) {
|
const AkgKernelJsonGenerator &json_generator, const AnfNodePtr &anf_node) {
|
||||||
auto kernel_mod_ptr = std::make_shared<GpuKernelMod>(kernel_pack);
|
auto kernel_mod_ptr = std::make_shared<GpuKernelMod>(kernel_pack);
|
||||||
auto kernel_json_info = kernel_pack->kernel_json_info();
|
|
||||||
kernel_mod_ptr->SetInputSizeList(json_generator.input_size_list());
|
kernel_mod_ptr->SetInputSizeList(json_generator.input_size_list());
|
||||||
kernel_mod_ptr->SetOutputSizeList(json_generator.output_size_list());
|
kernel_mod_ptr->SetOutputSizeList(json_generator.output_size_list());
|
||||||
kernel_mod_ptr->SetWorkspaceSizeList(kernel_json_info.workspaces);
|
|
||||||
AnfAlgo::SetKernelMod(kernel_mod_ptr, anf_node.get());
|
AnfAlgo::SetKernelMod(kernel_mod_ptr, anf_node.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
void AkgGpuKernelBuilder::AkgSaveJsonInfo(const string &kernel_name, const string &kernel_json) {
|
void AkgGpuKernelBuilder::AkgSaveJsonInfo(const string &kernel_name, const string &kernel_json) {
|
||||||
kernel::SaveJsonInfo(kernel_name, kernel_json, kernel::KernelMeta::GetInstance()->kernel_meta_path());
|
kernel::SaveJsonInfo(kernel_name, kernel_json, kernel::KernelMeta::GetInstance()->kernel_meta_path());
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -15,7 +15,8 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "backend/kernel_compiler/akg/gpu/akg_gpu_kernel_mod.h"
|
#include "backend/kernel_compiler/akg/gpu/akg_gpu_kernel_mod.h"
|
||||||
|
#include <fstream>
|
||||||
|
#include <algorithm>
|
||||||
#include "nlohmann/json.hpp"
|
#include "nlohmann/json.hpp"
|
||||||
#include "utils/ms_utils.h"
|
#include "utils/ms_utils.h"
|
||||||
|
|
||||||
|
@ -91,15 +92,13 @@ void GpuKernelMod::SetInputSizeList(const std::vector<size_t> &size_list) { inpu
|
||||||
|
|
||||||
void GpuKernelMod::SetOutputSizeList(const std::vector<size_t> &size_list) { output_size_list_ = size_list; }
|
void GpuKernelMod::SetOutputSizeList(const std::vector<size_t> &size_list) { output_size_list_ = size_list; }
|
||||||
|
|
||||||
void GpuKernelMod::SetWorkspaceSizeList(const std::vector<size_t> &size_list) { workspace_size_list_ = size_list; }
|
|
||||||
|
|
||||||
const std::vector<size_t> &GpuKernelMod::GetInputSizeList() const { return input_size_list_; }
|
const std::vector<size_t> &GpuKernelMod::GetInputSizeList() const { return input_size_list_; }
|
||||||
|
|
||||||
const std::vector<size_t> &GpuKernelMod::GetOutputSizeList() const { return output_size_list_; }
|
const std::vector<size_t> &GpuKernelMod::GetOutputSizeList() const { return output_size_list_; }
|
||||||
|
|
||||||
const std::vector<size_t> &GpuKernelMod::GetWorkspaceSizeList() const { return workspace_size_list_; }
|
const std::vector<size_t> &GpuKernelMod::GetWorkspaceSizeList() const { return workspace_size_list_; }
|
||||||
|
|
||||||
bool GpuKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
bool GpuKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
|
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
|
||||||
if (stream_ptr == 0) {
|
if (stream_ptr == 0) {
|
||||||
MS_LOG(ERROR) << "stream_ptr should not be nullptr.";
|
MS_LOG(ERROR) << "stream_ptr should not be nullptr.";
|
||||||
|
@ -123,10 +122,6 @@ bool GpuKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vect
|
||||||
[](const AddressPtr &input) -> void * { return reinterpret_cast<void *>(&(input->addr)); });
|
[](const AddressPtr &input) -> void * { return reinterpret_cast<void *>(&(input->addr)); });
|
||||||
(void)std::transform(std::begin(outputs), std::end(outputs), std::back_inserter(runtimeargs),
|
(void)std::transform(std::begin(outputs), std::end(outputs), std::back_inserter(runtimeargs),
|
||||||
[](const AddressPtr &output) -> void * { return reinterpret_cast<void *>(&(output->addr)); });
|
[](const AddressPtr &output) -> void * { return reinterpret_cast<void *>(&(output->addr)); });
|
||||||
if (!workspace.empty()) {
|
|
||||||
(void)std::transform(std::begin(workspace), std::end(workspace), std::back_inserter(runtimeargs),
|
|
||||||
[](const AddressPtr &addr) -> void * { return reinterpret_cast<void *>(&(addr->addr)); });
|
|
||||||
}
|
|
||||||
result = cuLaunchKernel(kernel_addr, thread_info[0], thread_info[1], thread_info[2], thread_info[3], thread_info[4],
|
result = cuLaunchKernel(kernel_addr, thread_info[0], thread_info[1], thread_info[2], thread_info[3], thread_info[4],
|
||||||
thread_info[5], 0, reinterpret_cast<CUstream>(stream_ptr),
|
thread_info[5], 0, reinterpret_cast<CUstream>(stream_ptr),
|
||||||
reinterpret_cast<void **>(&runtimeargs[0]), 0);
|
reinterpret_cast<void **>(&runtimeargs[0]), 0);
|
||||||
|
|
|
@ -60,7 +60,6 @@ class GpuKernelMod : public KernelMod {
|
||||||
|
|
||||||
void SetInputSizeList(const std::vector<size_t> &size_list);
|
void SetInputSizeList(const std::vector<size_t> &size_list);
|
||||||
void SetOutputSizeList(const std::vector<size_t> &size_list);
|
void SetOutputSizeList(const std::vector<size_t> &size_list);
|
||||||
void SetWorkspaceSizeList(const std::vector<size_t> &size_list);
|
|
||||||
const std::vector<size_t> &GetInputSizeList() const override;
|
const std::vector<size_t> &GetInputSizeList() const override;
|
||||||
const std::vector<size_t> &GetOutputSizeList() const override;
|
const std::vector<size_t> &GetOutputSizeList() const override;
|
||||||
const std::vector<size_t> &GetWorkspaceSizeList() const override;
|
const std::vector<size_t> &GetWorkspaceSizeList() const override;
|
||||||
|
|
|
@ -141,8 +141,14 @@ FusionType GetFusionTypeByName(const std::string &name) {
|
||||||
return iter->first;
|
return iter->first;
|
||||||
}
|
}
|
||||||
|
|
||||||
void KernelMeta::Initialize() {
|
void KernelMeta::Initialize(int pid) {
|
||||||
kernel_meta_path_ = std::string(kGpuKernelMeta) + "/";
|
if (pid == -1) {
|
||||||
|
kernel_meta_path_ = std::string(kGpuKernelMeta) + "_" + std::to_string(getpid()) + "/";
|
||||||
|
} else {
|
||||||
|
kernel_meta_path_ = std::string(kGpuKernelMeta) + "_" + std::to_string(pid) + "/";
|
||||||
|
}
|
||||||
|
// remove old kernel cache
|
||||||
|
RemoveKernelCache();
|
||||||
|
|
||||||
#if defined(_WIN32) || defined(_WIN64)
|
#if defined(_WIN32) || defined(_WIN64)
|
||||||
auto ret = mkdir(kernel_meta_path_.c_str());
|
auto ret = mkdir(kernel_meta_path_.c_str());
|
||||||
|
@ -155,6 +161,21 @@ void KernelMeta::Initialize() {
|
||||||
initialized_ = true;
|
initialized_ = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void KernelMeta::RemoveKernelCache() {
|
||||||
|
DIR *dir = opendir(kernel_meta_path_.c_str());
|
||||||
|
if (dir == nullptr) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
struct dirent *entry;
|
||||||
|
while ((entry = readdir(dir)) != nullptr) {
|
||||||
|
std::string kernel_file = entry->d_name;
|
||||||
|
std::string kernel_file_realpath = kernel_meta_path_ + kernel_file;
|
||||||
|
(void)remove(kernel_file_realpath.c_str());
|
||||||
|
}
|
||||||
|
(void)closedir(dir);
|
||||||
|
(void)rmdir(kernel_meta_path_.c_str());
|
||||||
|
}
|
||||||
|
|
||||||
std::string KernelMeta::Search(const std::string &kernel_name) const {
|
std::string KernelMeta::Search(const std::string &kernel_name) const {
|
||||||
if (!initialized_) {
|
if (!initialized_) {
|
||||||
return "";
|
return "";
|
||||||
|
@ -206,7 +227,7 @@ KernelPackPtr SearchCache(const std::string &kernel_name, const std::string &pro
|
||||||
KernelPackPtr kernel_pack = std::make_shared<KernelPack>();
|
KernelPackPtr kernel_pack = std::make_shared<KernelPack>();
|
||||||
// just a tmp solution.
|
// just a tmp solution.
|
||||||
if (!kernel_pack->ReadFromJsonFile(kernel_json, processor)) {
|
if (!kernel_pack->ReadFromJsonFile(kernel_json, processor)) {
|
||||||
MS_LOG(ERROR) << "Read cache json and bin file failed[" << kernel_json << "].";
|
MS_LOG(DEBUG) << "Read cache json and bin file failed[" << kernel_json << "].";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
} else {
|
} else {
|
||||||
return kernel_pack;
|
return kernel_pack;
|
||||||
|
@ -229,7 +250,7 @@ KernelPackPtr InsertCache(const std::string &kernel_name, const std::string &pro
|
||||||
(void)kernel_json.append(kernel_name).append(kJsonSuffix);
|
(void)kernel_json.append(kernel_name).append(kJsonSuffix);
|
||||||
KernelPackPtr kernel_pack = std::make_shared<KernelPack>();
|
KernelPackPtr kernel_pack = std::make_shared<KernelPack>();
|
||||||
if (!kernel_pack->ReadFromJsonFile(kernel_json, processor)) {
|
if (!kernel_pack->ReadFromJsonFile(kernel_json, processor)) {
|
||||||
MS_LOG(ERROR) << "Read json and bin file failed[" << kernel_json << "].";
|
MS_LOG(DEBUG) << "Read json and bin file failed[" << kernel_json << "].";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -693,9 +714,6 @@ void GetFuncGraphOutputNodes(const FuncGraphPtr &func_graph, std::vector<AnfNode
|
||||||
for (size_t input_idx = 1; input_idx < cnode->inputs().size(); ++input_idx) {
|
for (size_t input_idx = 1; input_idx < cnode->inputs().size(); ++input_idx) {
|
||||||
auto input_node = cnode->input(input_idx);
|
auto input_node = cnode->input(input_idx);
|
||||||
MS_EXCEPTION_IF_NULL(input_node);
|
MS_EXCEPTION_IF_NULL(input_node);
|
||||||
if (input_node->isa<CNode>() && AnfAlgo::GetInputTensorNum(input_node) == 0) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
output_list->push_back(AnfAlgo::VisitKernel(input_node, 0).first);
|
output_list->push_back(AnfAlgo::VisitKernel(input_node, 0).first);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
@ -970,39 +988,5 @@ size_t CalOffset(const std::vector<int64_t> &start, const std::vector<int64_t> &
|
||||||
}
|
}
|
||||||
return offset;
|
return offset;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t UnitSizeInBytes(const mindspore::TypeId &t) {
|
|
||||||
size_t bytes = 0;
|
|
||||||
switch (t) {
|
|
||||||
case kNumberTypeBool:
|
|
||||||
case kNumberTypeInt8:
|
|
||||||
case kNumberTypeUInt8:
|
|
||||||
bytes = sizeof(int8_t);
|
|
||||||
break;
|
|
||||||
case kNumberTypeInt16:
|
|
||||||
case kNumberTypeUInt16:
|
|
||||||
case kNumberTypeFloat16:
|
|
||||||
bytes = sizeof(int16_t);
|
|
||||||
break;
|
|
||||||
case kNumberTypeInt:
|
|
||||||
case kNumberTypeUInt:
|
|
||||||
case kNumberTypeInt32:
|
|
||||||
case kNumberTypeUInt32:
|
|
||||||
case kNumberTypeFloat:
|
|
||||||
case kNumberTypeFloat32:
|
|
||||||
bytes = sizeof(int32_t);
|
|
||||||
break;
|
|
||||||
case kNumberTypeUInt64:
|
|
||||||
case kNumberTypeInt64:
|
|
||||||
case kNumberTypeFloat64:
|
|
||||||
bytes = sizeof(int64_t);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
MS_LOG(EXCEPTION) << "Invalid types " << t;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
return bytes;
|
|
||||||
}
|
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -55,7 +55,8 @@ using KernelMetaPtr = std::shared_ptr<KernelMetaInfo>;
|
||||||
class KernelMeta {
|
class KernelMeta {
|
||||||
public:
|
public:
|
||||||
KernelMeta() = default;
|
KernelMeta() = default;
|
||||||
void Initialize();
|
void Initialize(int pid);
|
||||||
|
void RemoveKernelCache();
|
||||||
std::string Search(const std::string &kernel_name) const;
|
std::string Search(const std::string &kernel_name) const;
|
||||||
bool Insert(const std::string &kernel_name, const std::string &kernel_json);
|
bool Insert(const std::string &kernel_name, const std::string &kernel_json);
|
||||||
std::string kernel_meta_path() const { return kernel_meta_path_; }
|
std::string kernel_meta_path() const { return kernel_meta_path_; }
|
||||||
|
@ -143,7 +144,6 @@ size_t CalOffset(const std::vector<int64_t> &start, const std::vector<int64_t> &
|
||||||
std::vector<int64_t> CalDimOffset(const std::vector<int64_t> &input_shape);
|
std::vector<int64_t> CalDimOffset(const std::vector<int64_t> &input_shape);
|
||||||
size_t GetCopySize(const std::vector<int64_t> &dim_offset, const std::vector<int64_t> &start,
|
size_t GetCopySize(const std::vector<int64_t> &dim_offset, const std::vector<int64_t> &start,
|
||||||
const std::vector<int64_t> &stop);
|
const std::vector<int64_t> &stop);
|
||||||
size_t UnitSizeInBytes(const mindspore::TypeId &t);
|
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -83,7 +83,7 @@ void AdamCPUKernel::LaunchAdamNnacl(const std::vector<kernel::AddressPtr> &input
|
||||||
MS_LOG(EXCEPTION) << "AdamFp32 failed.";
|
MS_LOG(EXCEPTION) << "AdamFp32 failed.";
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
ParallelLaunchAutoSearch(task, lens, this, ¶llel_search_info_);
|
CPUKernelUtils::ParallelForAutoSearch(task, lens, ¶llel_search_info_);
|
||||||
}
|
}
|
||||||
|
|
||||||
void AdamCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
void AdamCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||||
|
|
|
@ -26,26 +26,46 @@ namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
constexpr size_t kSizeFloat16 = sizeof(float16);
|
constexpr size_t kSizeFloat16 = sizeof(float16);
|
||||||
constexpr size_t kSizeFloat32 = sizeof(float);
|
constexpr size_t kSizeFloat32 = sizeof(float);
|
||||||
constexpr size_t kScalarIndex = 0;
|
|
||||||
constexpr size_t kAdamWeightDecayInputSize = 9;
|
constexpr size_t kAdamWeightDecayInputSize = 9;
|
||||||
constexpr size_t kAdamWeightDecayOutputSize = 3;
|
constexpr size_t kAdamWeightDecayOutputSize = 3;
|
||||||
|
|
||||||
|
void AdamWeightDecayCPUKernel::ParallelForAdam(const CTask &task, size_t count) {
|
||||||
|
auto max_thread_num = common::ThreadPool::GetInstance().GetSyncRunThreadNum();
|
||||||
|
const float block_size = 128.0;
|
||||||
|
const float align_size = 16.0;
|
||||||
|
size_t thread_num = count < block_size * max_thread_num ? std::ceil(count / block_size) : max_thread_num;
|
||||||
|
std::vector<common::Task> tasks;
|
||||||
|
size_t start = 0;
|
||||||
|
size_t once_compute_size = align_size * std::ceil(count / (align_size * thread_num));
|
||||||
|
while (start < count) {
|
||||||
|
size_t end = (start + once_compute_size) > count ? count : (start + once_compute_size);
|
||||||
|
auto block = [&, start, end]() {
|
||||||
|
task(start, end);
|
||||||
|
return common::SUCCESS;
|
||||||
|
};
|
||||||
|
tasks.emplace_back(block);
|
||||||
|
start += once_compute_size;
|
||||||
|
}
|
||||||
|
common::ThreadPool::GetInstance().SyncRun(tasks);
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T, typename S>
|
template <typename T, typename S>
|
||||||
void AdamWeightDecayCPUKernel::LaunchFusedAdam(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &) {
|
void AdamWeightDecayCPUKernel::LaunchFusedAdam(const std::vector<AddressPtr> &inputs,
|
||||||
auto var = reinterpret_cast<T *>(inputs[VAR]->addr);
|
const std::vector<AddressPtr> &outputs) {
|
||||||
auto m = reinterpret_cast<T *>(inputs[M]->addr);
|
auto var = reinterpret_cast<T *>(inputs[0]->addr);
|
||||||
auto v = reinterpret_cast<T *>(inputs[V]->addr);
|
auto m = reinterpret_cast<T *>(inputs[1]->addr);
|
||||||
auto lr = reinterpret_cast<T *>(inputs[LR]->addr)[kScalarIndex];
|
auto v = reinterpret_cast<T *>(inputs[2]->addr);
|
||||||
auto beta1 = reinterpret_cast<T *>(inputs[BETA1]->addr)[kScalarIndex];
|
auto lr = reinterpret_cast<T *>(inputs[3]->addr)[0];
|
||||||
auto beta2 = reinterpret_cast<T *>(inputs[BETA2]->addr)[kScalarIndex];
|
auto beta1 = reinterpret_cast<T *>(inputs[4]->addr)[0];
|
||||||
auto epsilon = reinterpret_cast<T *>(inputs[EPSILON]->addr)[kScalarIndex];
|
auto beta2 = reinterpret_cast<T *>(inputs[5]->addr)[0];
|
||||||
auto decay = reinterpret_cast<T *>(inputs[DECAY]->addr)[kScalarIndex];
|
auto epsilon = reinterpret_cast<T *>(inputs[6]->addr)[0];
|
||||||
auto gradient16 = reinterpret_cast<S *>(inputs[GRAD]->addr);
|
auto decay = reinterpret_cast<T *>(inputs[7]->addr)[0];
|
||||||
|
auto gradient16 = reinterpret_cast<S *>(inputs[8]->addr);
|
||||||
const auto beta1_minus = 1 - beta1;
|
const auto beta1_minus = 1 - beta1;
|
||||||
const auto beta2_minus = 1 - beta2;
|
const auto beta2_minus = 1 - beta2;
|
||||||
|
|
||||||
// multithreading
|
// multithreading
|
||||||
size_t lens = inputs[VAR]->size > 0 ? static_cast<size_t>(inputs[VAR]->size / sizeof(float)) : 1;
|
size_t lens = inputs[0]->size > 0 ? static_cast<size_t>(inputs[0]->size / sizeof(float)) : 1;
|
||||||
std::function<void(size_t, size_t)> task;
|
std::function<void(size_t, size_t)> task;
|
||||||
|
|
||||||
task = [&](size_t start, size_t end) {
|
task = [&](size_t start, size_t end) {
|
||||||
|
@ -61,27 +81,28 @@ void AdamWeightDecayCPUKernel::LaunchFusedAdam(const std::vector<AddressPtr> &in
|
||||||
var[i] -= lr * update;
|
var[i] -= lr * update;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
CPUKernelUtils::ParallelFor(task, lens);
|
ParallelForAdam(task, lens);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void AdamWeightDecayCPUKernel::LaunchAdamWeightDecay(const std::vector<AddressPtr> &inputs,
|
void AdamWeightDecayCPUKernel::LaunchAdamWeightDecay(const std::vector<AddressPtr> &inputs,
|
||||||
const std::vector<AddressPtr> &) {
|
const std::vector<AddressPtr> &outputs) {
|
||||||
auto var = reinterpret_cast<T *>(inputs[VAR]->addr);
|
auto var = reinterpret_cast<T *>(inputs[0]->addr);
|
||||||
auto m = reinterpret_cast<T *>(inputs[M]->addr);
|
auto m = reinterpret_cast<T *>(inputs[1]->addr);
|
||||||
auto v = reinterpret_cast<T *>(inputs[V]->addr);
|
auto v = reinterpret_cast<T *>(inputs[2]->addr);
|
||||||
auto lr = reinterpret_cast<T *>(inputs[LR]->addr)[kScalarIndex];
|
auto lr = reinterpret_cast<T *>(inputs[3]->addr)[0];
|
||||||
auto beta1 = reinterpret_cast<T *>(inputs[BETA1]->addr)[kScalarIndex];
|
auto beta1 = reinterpret_cast<T *>(inputs[4]->addr)[0];
|
||||||
auto beta2 = reinterpret_cast<T *>(inputs[BETA2]->addr)[kScalarIndex];
|
auto beta2 = reinterpret_cast<T *>(inputs[5]->addr)[0];
|
||||||
auto epsilon = reinterpret_cast<T *>(inputs[EPSILON]->addr)[kScalarIndex];
|
auto epsilon = reinterpret_cast<T *>(inputs[6]->addr)[0];
|
||||||
auto decay = reinterpret_cast<T *>(inputs[DECAY]->addr)[kScalarIndex];
|
auto decay = reinterpret_cast<T *>(inputs[7]->addr)[0];
|
||||||
auto gradient = reinterpret_cast<T *>(inputs[GRAD]->addr);
|
auto gradient = reinterpret_cast<T *>(inputs[8]->addr);
|
||||||
const auto beta1_minus = 1 - beta1;
|
const auto beta1_minus = 1 - beta1;
|
||||||
const auto beta2_minus = 1 - beta2;
|
const auto beta2_minus = 1 - beta2;
|
||||||
|
|
||||||
// multithreading
|
// multithreading
|
||||||
size_t lens = inputs[VAR]->size > 0 ? static_cast<size_t>(inputs[VAR]->size / sizeof(float)) : 1;
|
size_t lens = inputs[0]->size > 0 ? static_cast<size_t>(inputs[0]->size / sizeof(float)) : 1;
|
||||||
std::function<void(size_t, size_t)> task;
|
std::function<void(size_t, size_t)> task;
|
||||||
|
|
||||||
task = [&](size_t start, size_t end) {
|
task = [&](size_t start, size_t end) {
|
||||||
size_t i = AdamWeightDecayFp32(var, m, v, lr, beta1, beta2, epsilon, decay, gradient, start, end);
|
size_t i = AdamWeightDecayFp32(var, m, v, lr, beta1, beta2, epsilon, decay, gradient, start, end);
|
||||||
// remaining
|
// remaining
|
||||||
|
@ -93,14 +114,14 @@ void AdamWeightDecayCPUKernel::LaunchAdamWeightDecay(const std::vector<AddressPt
|
||||||
var[i] -= lr * update;
|
var[i] -= lr * update;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
CPUKernelUtils::ParallelFor(task, lens);
|
ParallelForAdam(task, lens);
|
||||||
}
|
}
|
||||||
|
|
||||||
void AdamWeightDecayCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
void AdamWeightDecayCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||||
std::vector<size_t> var_shape = AnfAlgo::GetInputDeviceShape(kernel_node, VAR);
|
std::vector<size_t> var_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||||
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, VAR);
|
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
|
||||||
gradient_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, GRAD);
|
gradient_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 8);
|
||||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||||
if (input_num != kAdamWeightDecayInputSize) {
|
if (input_num != kAdamWeightDecayInputSize) {
|
||||||
MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but AdamWeightDecay needs 9 inputs.";
|
MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but AdamWeightDecay needs 9 inputs.";
|
||||||
|
@ -134,12 +155,12 @@ void AdamWeightDecayCPUKernel::CheckParam(const std::vector<kernel::AddressPtr>
|
||||||
}
|
}
|
||||||
size_t elem1_size = elem_num_ * kSizeFloat32;
|
size_t elem1_size = elem_num_ * kSizeFloat32;
|
||||||
size_t elem2_size = gradient_dtype_ == kNumberTypeFloat16 ? elem_num_ * kSizeFloat16 : elem1_size;
|
size_t elem2_size = gradient_dtype_ == kNumberTypeFloat16 ? elem_num_ * kSizeFloat16 : elem1_size;
|
||||||
if (inputs[VAR]->size != elem1_size || inputs[M]->size != elem1_size || inputs[V]->size != elem1_size ||
|
if (inputs[0]->size != elem1_size || inputs[1]->size != elem1_size || inputs[2]->size != elem1_size ||
|
||||||
inputs[GRAD]->size != elem2_size) {
|
inputs[8]->size != elem2_size) {
|
||||||
MS_LOG(EXCEPTION) << "Error input data size!";
|
MS_LOG(EXCEPTION) << "Error input data size!";
|
||||||
}
|
}
|
||||||
if (inputs[LR]->size != kSizeFloat32 || inputs[BETA1]->size != kSizeFloat32 || inputs[BETA2]->size != kSizeFloat32 ||
|
if (inputs[3]->size != kSizeFloat32 || inputs[4]->size != kSizeFloat32 || inputs[5]->size != kSizeFloat32 ||
|
||||||
inputs[EPSILON]->size != kSizeFloat32 || inputs[DECAY]->size != kSizeFloat32) {
|
inputs[6]->size != kSizeFloat32 || inputs[7]->size != kSizeFloat32) {
|
||||||
MS_LOG(EXCEPTION) << "The attribute beta, lr, epsilon and weight decay must be float!";
|
MS_LOG(EXCEPTION) << "The attribute beta, lr, epsilon and weight decay must be float!";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -32,6 +32,7 @@ class AdamWeightDecayCPUKernel : public CPUKernel {
|
||||||
const std::vector<AddressPtr> &outputs) override;
|
const std::vector<AddressPtr> &outputs) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
void ParallelForAdam(const CTask &task, size_t count);
|
||||||
void CheckParam(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
void CheckParam(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||||
template <typename T, typename S>
|
template <typename T, typename S>
|
||||||
void LaunchFusedAdam(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
void LaunchFusedAdam(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||||
|
@ -40,7 +41,6 @@ class AdamWeightDecayCPUKernel : public CPUKernel {
|
||||||
size_t elem_num_{0};
|
size_t elem_num_{0};
|
||||||
TypeId dtype_{kTypeUnknown};
|
TypeId dtype_{kTypeUnknown};
|
||||||
TypeId gradient_dtype_{kTypeUnknown};
|
TypeId gradient_dtype_{kTypeUnknown};
|
||||||
enum input_list_ { VAR, M, V, LR, BETA1, BETA2, EPSILON, DECAY, GRAD };
|
|
||||||
};
|
};
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL(AdamWeightDecay,
|
MS_REG_CPU_KERNEL(AdamWeightDecay,
|
||||||
|
|
|
@ -76,10 +76,27 @@ void ApplyAdagradCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
||||||
|
|
||||||
// multithreading
|
// multithreading
|
||||||
size_t length = inputs[0]->size / sizeof(T);
|
size_t length = inputs[0]->size / sizeof(T);
|
||||||
auto task = [this, &var, &accum, lr, gradient](size_t start, size_t end) {
|
size_t max_thread_num = std::thread::hardware_concurrency();
|
||||||
LaunchApplyAdagrad(var, accum, lr, gradient, start, end);
|
size_t use_thread_num = length < 128 * max_thread_num ? std::ceil(length / 128.0) : max_thread_num;
|
||||||
};
|
std::vector<std::thread> threads;
|
||||||
CPUKernelUtils::ParallelForAutoSearch(task, length, ¶llel_search_info_);
|
threads.reserve(use_thread_num);
|
||||||
|
size_t start = 0;
|
||||||
|
const size_t batch_size = (length + use_thread_num - 1) / use_thread_num;
|
||||||
|
|
||||||
|
if (batch_size == 0) {
|
||||||
|
MS_LOG(EXCEPTION) << "Error occur in launch kernel";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
while (start < length) {
|
||||||
|
size_t end = (start + batch_size) > length ? length : (start + batch_size);
|
||||||
|
threads.emplace_back(
|
||||||
|
std::thread(&ApplyAdagradCPUKernel::LaunchApplyAdagrad<T *>, this, var, accum, lr, gradient, start, end));
|
||||||
|
start += batch_size;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto &it : threads) {
|
||||||
|
it.join();
|
||||||
|
}
|
||||||
|
|
||||||
// Copy result to output tensor
|
// Copy result to output tensor
|
||||||
auto output_var = reinterpret_cast<T *>(outputs[0]->addr);
|
auto output_var = reinterpret_cast<T *>(outputs[0]->addr);
|
||||||
|
|
|
@ -19,7 +19,6 @@
|
||||||
#include "runtime/device/cpu/cpu_device_address.h"
|
#include "runtime/device/cpu/cpu_device_address.h"
|
||||||
#include "nnacl/fp32/power_fp32.h"
|
#include "nnacl/fp32/power_fp32.h"
|
||||||
#include "nnacl/fp32/sub_fp32.h"
|
#include "nnacl/fp32/sub_fp32.h"
|
||||||
#include "nnacl/fp32/mul_fp32.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
|
@ -55,7 +54,7 @@ void ArithmeticCPUKernel<T>::Sub(const T *input1, const T *input2, T *out) {
|
||||||
auto task = [&](size_t start, size_t end) {
|
auto task = [&](size_t start, size_t end) {
|
||||||
ElementSub(input1 + start, input2 + start, out + start, end - start);
|
ElementSub(input1 + start, input2 + start, out + start, end - start);
|
||||||
};
|
};
|
||||||
ParallelLaunchAutoSearch(task, output_size_, this, ¶llel_search_info_);
|
CPUKernelUtils::ParallelFor(task, output_size_, MAX_SUB_SERIAL_SIZE);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (op_para.in_elements_num0_ == 1 || op_para.in_elements_num1_ == 1) {
|
if (op_para.in_elements_num0_ == 1 || op_para.in_elements_num1_ == 1) {
|
||||||
|
@ -66,7 +65,7 @@ void ArithmeticCPUKernel<T>::Sub(const T *input1, const T *input2, T *out) {
|
||||||
ElementOptSub(input1 + start, input2, out + start, end - start, &op_para);
|
ElementOptSub(input1 + start, input2, out + start, end - start, &op_para);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
ParallelLaunchAutoSearch(task, output_size_, this, ¶llel_search_info_);
|
CPUKernelUtils::ParallelFor(task, output_size_, MAX_SUB_SERIAL_SIZE);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -85,26 +84,6 @@ void ArithmeticCPUKernel<T>::Sub(const T *input1, const T *input2, T *out) {
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void ArithmeticCPUKernel<T>::Mul(const T *input1, const T *input2, T *out) {
|
void ArithmeticCPUKernel<T>::Mul(const T *input1, const T *input2, T *out) {
|
||||||
if constexpr (std::is_same_v<T, float>) {
|
|
||||||
if (input_shape1_ == input_shape2_) {
|
|
||||||
auto task = [&](size_t start, size_t end) {
|
|
||||||
ElementMul(input1 + start, input2 + start, out + start, end - start);
|
|
||||||
};
|
|
||||||
ParallelLaunchAutoSearch(task, output_size_, this, ¶llel_search_info_);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (op_para.in_elements_num0_ == 1 || op_para.in_elements_num1_ == 1) {
|
|
||||||
auto task = [&](size_t start, size_t end) {
|
|
||||||
if (op_para.in_elements_num0_ == 1) {
|
|
||||||
ElementOptMul(input1, input2 + start, out + start, end - start, &op_para);
|
|
||||||
} else {
|
|
||||||
ElementOptMul(input1 + start, input2, out + start, end - start, &op_para);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
ParallelLaunchAutoSearch(task, output_size_, this, ¶llel_search_info_);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
BroadcastIterator base_iter(input_shape1_, input_shape2_, output_shape_);
|
BroadcastIterator base_iter(input_shape1_, input_shape2_, output_shape_);
|
||||||
auto task = [&input1, &input2, &out, &base_iter](size_t start, size_t end) {
|
auto task = [&input1, &input2, &out, &base_iter](size_t start, size_t end) {
|
||||||
auto iter = base_iter;
|
auto iter = base_iter;
|
||||||
|
@ -149,21 +128,21 @@ void ArithmeticCPUKernel<T>::RealDiv(const T *input1, const T *input2, T *out) {
|
||||||
auto task = [&](size_t start, size_t end) {
|
auto task = [&](size_t start, size_t end) {
|
||||||
ElementRealDiv<T>(input1 + start, input2 + start, out + start, end - start, 1, 1);
|
ElementRealDiv<T>(input1 + start, input2 + start, out + start, end - start, 1, 1);
|
||||||
};
|
};
|
||||||
ParallelLaunchAutoSearch(task, output_size_, this, ¶llel_search_info_);
|
CPUKernelUtils::ParallelFor(task, output_size_, MAX_DIV_SERIAL_SIZE);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (op_para.in_elements_num0_ == 1) {
|
if (op_para.in_elements_num0_ == 1) {
|
||||||
auto task = [&](size_t start, size_t end) {
|
auto task = [&](size_t start, size_t end) {
|
||||||
ElementRealDiv<T>(input1, input2 + start, out + start, end - start, 0, 1);
|
ElementRealDiv<T>(input1, input2 + start, out + start, end - start, 0, 1);
|
||||||
};
|
};
|
||||||
ParallelLaunchAutoSearch(task, output_size_, this, ¶llel_search_info_);
|
CPUKernelUtils::ParallelFor(task, output_size_, MAX_DIV_SERIAL_SIZE);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (op_para.in_elements_num1_ == 1) {
|
if (op_para.in_elements_num1_ == 1) {
|
||||||
auto task = [&](size_t start, size_t end) {
|
auto task = [&](size_t start, size_t end) {
|
||||||
ElementRealDiv<T>(input1 + start, input2, out + start, end - start, 1, 0);
|
ElementRealDiv<T>(input1 + start, input2, out + start, end - start, 1, 0);
|
||||||
};
|
};
|
||||||
ParallelLaunchAutoSearch(task, output_size_, this, ¶llel_search_info_);
|
CPUKernelUtils::ParallelFor(task, output_size_, MAX_DIV_SERIAL_SIZE);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -360,7 +339,7 @@ void ArithmeticCPUKernel<T>::SquaredDifference(const T *input1, const T *input2,
|
||||||
iter.GenNextPos();
|
iter.GenNextPos();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
ParallelLaunchAutoSearch(task, output_size_, this, ¶llel_search_info_);
|
CPUKernelUtils::ParallelFor(task, output_size_);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
|
|
@ -77,8 +77,6 @@ MS_REG_CPU_KERNEL_T(RealDiv, KernelAttr(), ArithmeticCPUKernel, int64_t);
|
||||||
MS_REG_CPU_KERNEL_T(Div, KernelAttr(), ArithmeticCPUKernel, int32_t);
|
MS_REG_CPU_KERNEL_T(Div, KernelAttr(), ArithmeticCPUKernel, int32_t);
|
||||||
MS_REG_CPU_KERNEL_T(Div, KernelAttr(), ArithmeticCPUKernel, float);
|
MS_REG_CPU_KERNEL_T(Div, KernelAttr(), ArithmeticCPUKernel, float);
|
||||||
MS_REG_CPU_KERNEL_T(Div, KernelAttr(), ArithmeticCPUKernel, int64_t);
|
MS_REG_CPU_KERNEL_T(Div, KernelAttr(), ArithmeticCPUKernel, int64_t);
|
||||||
MS_REG_CPU_KERNEL_T(Mul, KernelAttr(), ArithmeticCPUKernel, float);
|
|
||||||
MS_REG_CPU_KERNEL_T(Mul, KernelAttr(), ArithmeticCPUKernel, int32_t);
|
|
||||||
MS_REG_CPU_KERNEL_T(
|
MS_REG_CPU_KERNEL_T(
|
||||||
FloorDiv, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
FloorDiv, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||||
ArithmeticCPUKernel, int64_t);
|
ArithmeticCPUKernel, int64_t);
|
||||||
|
|
|
@ -13,12 +13,10 @@
|
||||||
* See the License for the specific language governing permissions and
|
* See the License for the specific language governing permissions and
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "backend/kernel_compiler/cpu/arithmetic_logic_cpu_kernel.h"
|
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <functional>
|
#include "backend/kernel_compiler/cpu/arithmetic_logic_cpu_kernel.h"
|
||||||
#include "runtime/device/cpu/cpu_device_address.h"
|
#include "runtime/device/cpu/cpu_device_address.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
@ -31,9 +29,7 @@ void ArithmeticLogicCPUKernel<T>::Less(const T *input1, const T *input2, bool *o
|
||||||
auto iter = base_iter;
|
auto iter = base_iter;
|
||||||
iter.SetPos(start);
|
iter.SetPos(start);
|
||||||
for (size_t i = start; i < end; i++) {
|
for (size_t i = start; i < end; i++) {
|
||||||
auto x = input1[iter.GetInputPosA()];
|
out[i] = input1[iter.GetInputPosA()] < input2[iter.GetInputPosB()];
|
||||||
auto y = input2[iter.GetInputPosB()];
|
|
||||||
out[i] = std::less<T>()(x, y);
|
|
||||||
iter.GenNextPos();
|
iter.GenNextPos();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -41,9 +37,7 @@ void ArithmeticLogicCPUKernel<T>::Less(const T *input1, const T *input2, bool *o
|
||||||
} else {
|
} else {
|
||||||
base_iter.SetPos(0);
|
base_iter.SetPos(0);
|
||||||
for (size_t i = 0; i < output_size_; i++) {
|
for (size_t i = 0; i < output_size_; i++) {
|
||||||
auto x = input1[base_iter.GetInputPosA()];
|
out[i] = input1[base_iter.GetInputPosA()] < input2[base_iter.GetInputPosB()];
|
||||||
auto y = input2[base_iter.GetInputPosB()];
|
|
||||||
out[i] = std::less<T>()(x, y);
|
|
||||||
base_iter.GenNextPos();
|
base_iter.GenNextPos();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -56,9 +50,7 @@ void ArithmeticLogicCPUKernel<T>::Equal(const T *input1, const T *input2, bool *
|
||||||
auto iter = base_iter;
|
auto iter = base_iter;
|
||||||
iter.SetPos(start);
|
iter.SetPos(start);
|
||||||
for (size_t i = start; i < end; i++) {
|
for (size_t i = start; i < end; i++) {
|
||||||
auto x = input1[iter.GetInputPosA()];
|
out[i] = input1[iter.GetInputPosA()] == input2[iter.GetInputPosB()];
|
||||||
auto y = input2[iter.GetInputPosB()];
|
|
||||||
out[i] = std::equal_to<T>()(x, y);
|
|
||||||
iter.GenNextPos();
|
iter.GenNextPos();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -72,9 +64,7 @@ void ArithmeticLogicCPUKernel<T>::NotEqual(const T *input1, const T *input2, boo
|
||||||
auto iter = base_iter;
|
auto iter = base_iter;
|
||||||
iter.SetPos(start);
|
iter.SetPos(start);
|
||||||
for (size_t i = start; i < end; i++) {
|
for (size_t i = start; i < end; i++) {
|
||||||
auto x = input1[iter.GetInputPosA()];
|
out[i] = input1[iter.GetInputPosA()] != input2[iter.GetInputPosB()];
|
||||||
auto y = input2[iter.GetInputPosB()];
|
|
||||||
out[i] = std::not_equal_to<T>()(x, y);
|
|
||||||
iter.GenNextPos();
|
iter.GenNextPos();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -116,9 +106,7 @@ void ArithmeticLogicCPUKernel<T>::Greater(const T *input1, const T *input2, bool
|
||||||
auto iter = base_iter;
|
auto iter = base_iter;
|
||||||
iter.SetPos(start);
|
iter.SetPos(start);
|
||||||
for (size_t i = start; i < end; i++) {
|
for (size_t i = start; i < end; i++) {
|
||||||
auto x = input1[iter.GetInputPosA()];
|
out[i] = input1[iter.GetInputPosA()] > input2[iter.GetInputPosB()];
|
||||||
auto y = input2[iter.GetInputPosB()];
|
|
||||||
out[i] = std::greater<T>()(x, y);
|
|
||||||
iter.GenNextPos();
|
iter.GenNextPos();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -132,9 +120,7 @@ void ArithmeticLogicCPUKernel<T>::GreaterEqual(const T *input1, const T *input2,
|
||||||
auto iter = base_iter;
|
auto iter = base_iter;
|
||||||
iter.SetPos(start);
|
iter.SetPos(start);
|
||||||
for (size_t i = start; i < end; i++) {
|
for (size_t i = start; i < end; i++) {
|
||||||
auto x = input1[iter.GetInputPosA()];
|
out[i] = input1[iter.GetInputPosA()] >= input2[iter.GetInputPosB()];
|
||||||
auto y = input2[iter.GetInputPosB()];
|
|
||||||
out[i] = std::greater_equal<T>()(x, y);
|
|
||||||
iter.GenNextPos();
|
iter.GenNextPos();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -148,9 +134,7 @@ void ArithmeticLogicCPUKernel<T>::LessEqual(const T *input1, const T *input2, bo
|
||||||
auto iter = base_iter;
|
auto iter = base_iter;
|
||||||
iter.SetPos(start);
|
iter.SetPos(start);
|
||||||
for (size_t i = start; i < end; i++) {
|
for (size_t i = start; i < end; i++) {
|
||||||
auto x = input1[iter.GetInputPosA()];
|
out[i] = input1[iter.GetInputPosA()] <= input2[iter.GetInputPosB()];
|
||||||
auto y = input2[iter.GetInputPosB()];
|
|
||||||
out[i] = std::less_equal<T>()(x, y);
|
|
||||||
iter.GenNextPos();
|
iter.GenNextPos();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -20,7 +20,6 @@
|
||||||
#include <map>
|
#include <map>
|
||||||
#include "backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.h"
|
#include "backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.h"
|
||||||
#include "runtime/device/cpu/cpu_device_address.h"
|
#include "runtime/device/cpu/cpu_device_address.h"
|
||||||
#include "nnacl/fp32/exp_fp32.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
|
@ -32,15 +31,7 @@ void Square(const T *in, T *out, size_t size) {
|
||||||
out[i] = in[i] * in[i];
|
out[i] = in[i] * in[i];
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
ParallelLaunch(task, size, MAX_SQUARE_SERIAL_SIZE);
|
CPUKernelUtils::ParallelFor(task, size, MAX_SQUARE_SERIAL_SIZE);
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
void Exp(const T *in, T *out, size_t size) {
|
|
||||||
if constexpr (std::is_same_v<T, float>) {
|
|
||||||
auto task = [&in, &out](size_t start, size_t end) { ExpFp32(in + start, out + start, end - start); };
|
|
||||||
ParallelLaunch(task, size, MAX_EXP_SERIAL_SIZE);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
@ -66,7 +57,7 @@ void Neg(const T *in, T *out, size_t size) {
|
||||||
out[i] = -in[i];
|
out[i] = -in[i];
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
ParallelLaunch(task, size, MAX_NEG_SERIAL_SIZE);
|
CPUKernelUtils::ParallelFor(task, size, MAX_NEG_SERIAL_SIZE);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
@ -271,7 +262,6 @@ void Identity(const T *in, T *out, size_t size) {
|
||||||
static const std::map<std::string, OperateType> kArithmeticOpTypeMap = {{prim::kPrimNeg->name(), NEG},
|
static const std::map<std::string, OperateType> kArithmeticOpTypeMap = {{prim::kPrimNeg->name(), NEG},
|
||||||
{prim::kPrimSquare->name(), SQUARE},
|
{prim::kPrimSquare->name(), SQUARE},
|
||||||
{prim::kPrimOnesLike->name(), ONESLIKE},
|
{prim::kPrimOnesLike->name(), ONESLIKE},
|
||||||
{prim::kPrimExp->name(), EXP},
|
|
||||||
{prim::kPrimZerosLike->name(), ZEROSLIKE},
|
{prim::kPrimZerosLike->name(), ZEROSLIKE},
|
||||||
{prim::kPrimLogicalNot->name(), LOGICALNOT},
|
{prim::kPrimLogicalNot->name(), LOGICALNOT},
|
||||||
{prim::kPrimSign->name(), SIGN},
|
{prim::kPrimSign->name(), SIGN},
|
||||||
|
@ -334,29 +324,17 @@ void ArithmeticSelfCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs
|
||||||
T *output = reinterpret_cast<T *>(outputs[0]->addr);
|
T *output = reinterpret_cast<T *>(outputs[0]->addr);
|
||||||
size_t lens = outputs[0]->size > 0 ? static_cast<size_t>(outputs[0]->size / sizeof(T)) : 1;
|
size_t lens = outputs[0]->size > 0 ? static_cast<size_t>(outputs[0]->size / sizeof(T)) : 1;
|
||||||
static const std::map<OperateType, std::function<void(const T *in, T *out, size_t size)>> kArithmeticOpFuncMap = {
|
static const std::map<OperateType, std::function<void(const T *in, T *out, size_t size)>> kArithmeticOpFuncMap = {
|
||||||
{SQUARE, Square<T>},
|
{SQUARE, Square<T>}, {SIGN, Sign<T>},
|
||||||
{SIGN, Sign<T>},
|
{NEG, Neg<T>}, {LOGICALNOT, LogicalNot<T>},
|
||||||
{NEG, Neg<T>},
|
{ONESLIKE, OnesLike<T>}, {ZEROSLIKE, ZerosLike<T>},
|
||||||
{LOGICALNOT, LogicalNot<T>},
|
{FLOOR, Floor<T>}, {RECIPROCAL, Reciprocal<T>},
|
||||||
{ONESLIKE, OnesLike<T>},
|
{GELU, Gelu<T>}, {SIN, Sin<T>},
|
||||||
{ZEROSLIKE, ZerosLike<T>},
|
{COS, Cos<T>}, {TAN, Tan<T>},
|
||||||
{FLOOR, Floor<T>},
|
{ASIN, Asin<T>}, {ACOS, ACos<T>},
|
||||||
{RECIPROCAL, Reciprocal<T>},
|
{ATAN, Atan<T>}, {SINH, Sinh<T>},
|
||||||
{GELU, Gelu<T>},
|
{COSH, Cosh<T>}, {ASINH, Asinh<T>},
|
||||||
{SIN, Sin<T>},
|
{ACOSH, Acosh<T>}, {ATANH, Atanh<T>},
|
||||||
{COS, Cos<T>},
|
{RINT, Rint<T>}, {ROUND, Round<T>}};
|
||||||
{TAN, Tan<T>},
|
|
||||||
{ASIN, Asin<T>},
|
|
||||||
{ACOS, ACos<T>},
|
|
||||||
{ATAN, Atan<T>},
|
|
||||||
{SINH, Sinh<T>},
|
|
||||||
{COSH, Cosh<T>},
|
|
||||||
{ASINH, Asinh<T>},
|
|
||||||
{ACOSH, Acosh<T>},
|
|
||||||
{ATANH, Atanh<T>},
|
|
||||||
{RINT, Rint<T>},
|
|
||||||
{ROUND, Round<T>},
|
|
||||||
{EXP, Exp<T>}};
|
|
||||||
if (kArithmeticOpFuncMap.find(operate_type_) != kArithmeticOpFuncMap.end()) {
|
if (kArithmeticOpFuncMap.find(operate_type_) != kArithmeticOpFuncMap.end()) {
|
||||||
kArithmeticOpFuncMap.at(operate_type_)(input, output, lens);
|
kArithmeticOpFuncMap.at(operate_type_)(input, output, lens);
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -20,9 +20,8 @@
|
||||||
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
|
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
|
||||||
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
|
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
|
||||||
|
|
||||||
const float MAX_NEG_SERIAL_SIZE = 5000;
|
const float MAX_NEG_SERIAL_SIZE = 20000;
|
||||||
const float MAX_SQUARE_SERIAL_SIZE = 5000;
|
const float MAX_SQUARE_SERIAL_SIZE = 20000;
|
||||||
const float MAX_EXP_SERIAL_SIZE = 15000;
|
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
|
@ -59,10 +58,6 @@ class IdentityCPUKernel : public ArithmeticSelfCPUKernel {
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL(Square, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
MS_REG_CPU_KERNEL(Square, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||||
ArithmeticSelfCPUKernel);
|
ArithmeticSelfCPUKernel);
|
||||||
MS_REG_CPU_KERNEL(Square, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
|
||||||
ArithmeticSelfCPUKernel);
|
|
||||||
MS_REG_CPU_KERNEL(Exp, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
|
||||||
ArithmeticSelfCPUKernel);
|
|
||||||
MS_REG_CPU_KERNEL(Neg, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
MS_REG_CPU_KERNEL(Neg, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||||
ArithmeticSelfCPUKernel);
|
ArithmeticSelfCPUKernel);
|
||||||
MS_REG_CPU_KERNEL(Neg, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
MS_REG_CPU_KERNEL(Neg, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||||
|
|
|
@ -90,7 +90,7 @@ bool BiasAddCPUKernel::Launch(const std::vector<AddressPtr> &inputs, const std::
|
||||||
ElementAdd(src_addr + n_offset, bias_addr, output_addr + n_offset, input_shape_[1]);
|
ElementAdd(src_addr + n_offset, bias_addr, output_addr + n_offset, input_shape_[1]);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
ParallelLaunchAutoSearch(task, input_shape_[0], this, ¶llel_search_info_);
|
CPUKernelUtils::ParallelForAutoSearch(task, input_shape_[0], ¶llel_search_info_);
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
|
@ -55,7 +55,7 @@ bool BiasAddGradCPUKernel::Launch(const std::vector<AddressPtr> &inputs, const s
|
||||||
auto task = [&](size_t start, size_t end) {
|
auto task = [&](size_t start, size_t end) {
|
||||||
ReduceSumDim2Axis0(end - start, input_shape_[1], input_shape_[0], input_addr + start, output_addr + start);
|
ReduceSumDim2Axis0(end - start, input_shape_[1], input_shape_[0], input_addr + start, output_addr + start);
|
||||||
};
|
};
|
||||||
ParallelLaunchAutoSearch(task, input_shape_[1], this, ¶llel_search_info_);
|
CPUKernelUtils::ParallelForAutoSearch(task, input_shape_[1], ¶llel_search_info_);
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
|
@ -74,7 +74,7 @@ bool ConcatCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &inputs, c
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
ParallelLaunchAutoSearch(task, before_axis, this, ¶llel_search_info_);
|
CPUKernelUtils::ParallelForAutoSearch(task, before_axis, ¶llel_search_info_);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -138,77 +138,6 @@ void CPUKernelUtils::ParallelForAutoSearch(const CTask &task, size_t count, Para
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ActorThreadPool *GetActorMgrInnerThreadPool() {
|
|
||||||
auto actor_manager = ActorMgr::GetActorMgrRef();
|
|
||||||
auto thread_pool = actor_manager->GetActorThreadPool();
|
|
||||||
// Init thread_pool if env is windows or ascend, in case that it won't be init in graph_scheduler.
|
|
||||||
if (thread_pool == nullptr) {
|
|
||||||
const size_t kMaxThreadNum = 23;
|
|
||||||
size_t max_thread_num = std::thread::hardware_concurrency() - 1;
|
|
||||||
if (max_thread_num < 1) {
|
|
||||||
max_thread_num = 1;
|
|
||||||
}
|
|
||||||
max_thread_num = max_thread_num < kMaxThreadNum ? max_thread_num : kMaxThreadNum;
|
|
||||||
actor_manager->Initialize(true, 0, max_thread_num);
|
|
||||||
thread_pool = actor_manager->GetActorThreadPool();
|
|
||||||
MS_EXCEPTION_IF_NULL(thread_pool);
|
|
||||||
}
|
|
||||||
return thread_pool;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Use threadpool of mindrt
|
|
||||||
void ParallelLaunch(const CTask &task, size_t count, float block_size, Content content) {
|
|
||||||
auto thread_pool = GetActorMgrInnerThreadPool();
|
|
||||||
size_t kernel_thread_num = thread_pool->GetKernelThreadNum();
|
|
||||||
if (kernel_thread_num == 0) {
|
|
||||||
MS_LOG(EXCEPTION) << "Actor inner pool has been init, but kernel thread is 0!";
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t thread_num = count < block_size * kernel_thread_num ? std::ceil(count / block_size) : kernel_thread_num;
|
|
||||||
size_t once_compute_size = (count + thread_num - 1) / thread_num;
|
|
||||||
size_t task_num = count / once_compute_size;
|
|
||||||
if (count % once_compute_size != 0) {
|
|
||||||
task_num += 1;
|
|
||||||
}
|
|
||||||
auto func = [&](void *, int task_id, float, float) {
|
|
||||||
size_t start = task_id * once_compute_size;
|
|
||||||
size_t end = (start + once_compute_size) > count ? count : (start + once_compute_size);
|
|
||||||
task(start, end);
|
|
||||||
return common::SUCCESS;
|
|
||||||
};
|
|
||||||
thread_pool->ParallelLaunch(func, content, task_num);
|
|
||||||
}
|
|
||||||
|
|
||||||
void ParallelLaunchAutoSearch(const CTask &task, size_t count, Content content,
|
|
||||||
ParallelSearchInfo *parallel_search_info) {
|
|
||||||
const size_t MAX_POW = 6;
|
|
||||||
const size_t AVG_COUNT = 5;
|
|
||||||
size_t current_pow = parallel_search_info->search_count / AVG_COUNT;
|
|
||||||
if (current_pow < MAX_POW) {
|
|
||||||
if (parallel_search_info->search_count % AVG_COUNT == 0) {
|
|
||||||
parallel_search_info->tmp_sum_cost_time = 0;
|
|
||||||
}
|
|
||||||
float block_size = static_cast<float>(count) / std::pow(2.0f, current_pow);
|
|
||||||
double start_time = GetTime();
|
|
||||||
ParallelLaunch(task, count, block_size, content);
|
|
||||||
double cost_time = GetTime() - start_time;
|
|
||||||
parallel_search_info->tmp_sum_cost_time += cost_time;
|
|
||||||
parallel_search_info->search_count++;
|
|
||||||
if (parallel_search_info->search_count % AVG_COUNT == 0) {
|
|
||||||
double avg_time = parallel_search_info->tmp_sum_cost_time / AVG_COUNT;
|
|
||||||
if (parallel_search_info->min_cost_time > avg_time) {
|
|
||||||
parallel_search_info->min_cost_time = avg_time;
|
|
||||||
parallel_search_info->best_block_size = block_size;
|
|
||||||
parallel_search_info->best_pow = current_pow;
|
|
||||||
} else if (current_pow - parallel_search_info->best_pow >= 2) {
|
|
||||||
parallel_search_info->search_count = AVG_COUNT * MAX_POW;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
ParallelLaunch(task, count, parallel_search_info->best_block_size, content);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<size_t> CPUKernelUtils::FlatShapeByAxis(const std::vector<size_t> &shape, int axis) {
|
std::vector<size_t> CPUKernelUtils::FlatShapeByAxis(const std::vector<size_t> &shape, int axis) {
|
||||||
if (axis < 0) {
|
if (axis < 0) {
|
||||||
axis = axis + SizeToInt(shape.size());
|
axis = axis + SizeToInt(shape.size());
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -25,8 +25,6 @@
|
||||||
#include "backend/session/anf_runtime_algorithm.h"
|
#include "backend/session/anf_runtime_algorithm.h"
|
||||||
#include "backend/kernel_compiler/common_utils.h"
|
#include "backend/kernel_compiler/common_utils.h"
|
||||||
#include "ir/anf.h"
|
#include "ir/anf.h"
|
||||||
#include "runtime/framework/graph_scheduler.h"
|
|
||||||
#include "actor/actormgr.h"
|
|
||||||
|
|
||||||
using mindspore::kernel::Address;
|
using mindspore::kernel::Address;
|
||||||
using mindspore::kernel::AddressPtr;
|
using mindspore::kernel::AddressPtr;
|
||||||
|
@ -64,7 +62,6 @@ const char DELTA[] = "delta";
|
||||||
const char SORTED[] = "sorted";
|
const char SORTED[] = "sorted";
|
||||||
const char ADJ_ST[] = "adjoint_st";
|
const char ADJ_ST[] = "adjoint_st";
|
||||||
const char ADJ_dT[] = "adjoint_dt";
|
const char ADJ_dT[] = "adjoint_dt";
|
||||||
const char PERIODS[] = "periods";
|
|
||||||
|
|
||||||
enum OperateType {
|
enum OperateType {
|
||||||
ADD = 0,
|
ADD = 0,
|
||||||
|
@ -122,7 +119,6 @@ enum OperateType {
|
||||||
ATAN2,
|
ATAN2,
|
||||||
RINT,
|
RINT,
|
||||||
ROUND,
|
ROUND,
|
||||||
EXP,
|
|
||||||
IDENTITY,
|
IDENTITY,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -156,19 +152,6 @@ class CPUKernel : public kernel::KernelMod {
|
||||||
std::vector<size_t> output_size_list_;
|
std::vector<size_t> output_size_list_;
|
||||||
std::vector<size_t> workspace_size_list_;
|
std::vector<size_t> workspace_size_list_;
|
||||||
ParallelSearchInfo parallel_search_info_;
|
ParallelSearchInfo parallel_search_info_;
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
inline T *GetDeviceAddress(const std::vector<AddressPtr> &addr_list, size_t index) {
|
|
||||||
if (index >= addr_list.size()) {
|
|
||||||
MS_LOG(EXCEPTION) << "Address index(" << index << ") out of range(" << addr_list.size() << ")";
|
|
||||||
}
|
|
||||||
|
|
||||||
if ((addr_list[index] == nullptr) || (addr_list[index]->addr == nullptr) || (addr_list[index]->size == 0)) {
|
|
||||||
MS_LOG(EXCEPTION) << "The device address is empty, address index: " << index;
|
|
||||||
}
|
|
||||||
|
|
||||||
return reinterpret_cast<T *>(addr_list[index]->addr);
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
class CPUKernelUtils {
|
class CPUKernelUtils {
|
||||||
|
@ -226,12 +209,6 @@ class TransposeIterator {
|
||||||
std::vector<size_t> axes_;
|
std::vector<size_t> axes_;
|
||||||
size_t pos_{0};
|
size_t pos_{0};
|
||||||
};
|
};
|
||||||
|
|
||||||
ActorThreadPool *GetActorMgrInnerThreadPool();
|
|
||||||
void ParallelLaunch(const CTask &task, size_t count, float block_size = 128.0, Content content = nullptr);
|
|
||||||
void ParallelLaunchAutoSearch(const CTask &task, size_t count, Content content,
|
|
||||||
ParallelSearchInfo *parallel_search_info);
|
|
||||||
|
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -21,7 +21,6 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
#include "runtime/device/kernel_info.h"
|
#include "runtime/device/kernel_info.h"
|
||||||
#include "runtime/device/cpu/kernel_select_cpu.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
|
@ -112,11 +111,6 @@ std::pair<bool, size_t> CPUKernelFactory::CPUKernelAttrCheck(const std::string &
|
||||||
MS_LOG(INFO) << "Not registered CPU kernel: op[" << kernel_name << "]!";
|
MS_LOG(INFO) << "Not registered CPU kernel: op[" << kernel_name << "]!";
|
||||||
return std::make_pair(false, 0);
|
return std::make_pair(false, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (device::cpu::IsDynamicParamKernel(kernel_name)) {
|
|
||||||
return std::make_pair(true, 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto kernel_attrs = GetSupportedKernelAttrList(kernel_name);
|
auto kernel_attrs = GetSupportedKernelAttrList(kernel_name);
|
||||||
if (kernel_attrs[0].GetInputSize() == 0 && kernel_attrs[0].GetOutputSize() == 0) {
|
if (kernel_attrs[0].GetInputSize() == 0 && kernel_attrs[0].GetOutputSize() == 0) {
|
||||||
auto op_info_ptr = mindspore::kernel::OpLib::FindOp(kernel_name, kernel::OpImplyType::kCPU);
|
auto op_info_ptr = mindspore::kernel::OpLib::FindOp(kernel_name, kernel::OpImplyType::kCPU);
|
||||||
|
|
|
@ -144,7 +144,8 @@ bool CropAndResizeCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &in
|
||||||
const int bottom_y_index = ceilf(target_y);
|
const int bottom_y_index = ceilf(target_y);
|
||||||
const int left_x_index = floorf(target_x);
|
const int left_x_index = floorf(target_x);
|
||||||
const int right_x_index = ceilf(target_x);
|
const int right_x_index = ceilf(target_x);
|
||||||
|
const float y_lerp = target_y - top_y_index;
|
||||||
|
const float x_lerp = target_x - left_x_index;
|
||||||
const float top_left = static_cast<float>(
|
const float top_left = static_cast<float>(
|
||||||
input_image[((box_index * input_height_ + top_y_index) * input_width_ + left_x_index) * channel_ +
|
input_image[((box_index * input_height_ + top_y_index) * input_width_ + left_x_index) * channel_ +
|
||||||
pos_channel]);
|
pos_channel]);
|
||||||
|
@ -157,9 +158,9 @@ bool CropAndResizeCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &in
|
||||||
const float bottom_right = static_cast<float>(
|
const float bottom_right = static_cast<float>(
|
||||||
input_image[((box_index * input_height_ + bottom_y_index) * input_width_ + right_x_index) * channel_ +
|
input_image[((box_index * input_height_ + bottom_y_index) * input_width_ + right_x_index) * channel_ +
|
||||||
pos_channel]);
|
pos_channel]);
|
||||||
const float top = top_left + (top_right - top_left) * (target_x - left_x_index);
|
const float top = top_left + (top_right - top_left) * x_lerp;
|
||||||
const float bottom = bottom_left + (bottom_right - bottom_left) * (target_x - left_x_index);
|
const float bottom = bottom_left + (bottom_right - bottom_left) * x_lerp;
|
||||||
output[pos] = top + (bottom - top) * (target_y - top_y_index);
|
output[pos] = top + (bottom - top) * y_lerp;
|
||||||
} else if (method_ == 3) {
|
} else if (method_ == 3) {
|
||||||
int y1h = static_cast<int>(y1 * input_height_);
|
int y1h = static_cast<int>(y1 * input_height_);
|
||||||
int x1w = static_cast<int>(x1 * input_width_);
|
int x1w = static_cast<int>(x1 * input_width_);
|
||||||
|
@ -169,37 +170,36 @@ bool CropAndResizeCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &in
|
||||||
int h = ((y2h - y1h + 1) > 1) ? y2h - y1h + 1 : 1;
|
int h = ((y2h - y1h + 1) > 1) ? y2h - y1h + 1 : 1;
|
||||||
|
|
||||||
float y_point = (pos_y + 0.5) * (h / static_cast<float>(final_height_)) - 0.5;
|
float y_point = (pos_y + 0.5) * (h / static_cast<float>(final_height_)) - 0.5;
|
||||||
int top_y_index = std::min(std::max(0, static_cast<int>(floorf(y_point))), h - 1);
|
int top_y_index = floorf(y_point);
|
||||||
int bottom_y_index = std::min(std::max(0, static_cast<int>(ceilf(y_point))), h - 1);
|
top_y_index = std::min(std::max(0, top_y_index), h - 1);
|
||||||
|
|
||||||
|
int bottom_y_index = ceilf(y_point);
|
||||||
|
bottom_y_index = std::min(std::max(0, bottom_y_index), h - 1);
|
||||||
|
|
||||||
float x_point = (pos_x + 0.5) * (w / static_cast<float>(final_width_)) - 0.5;
|
float x_point = (pos_x + 0.5) * (w / static_cast<float>(final_width_)) - 0.5;
|
||||||
int left_x_index = std::min(std::max(0, static_cast<int>(floorf(x_point))), w - 1);
|
int left_x_index = floorf(x_point);
|
||||||
int right_x_index = std::min(std::max(0, static_cast<int>(ceilf(x_point))), w - 1);
|
left_x_index = std::min(std::max(0, left_x_index), w - 1);
|
||||||
|
|
||||||
|
int right_x_index = ceilf(x_point);
|
||||||
|
right_x_index = std::min(std::max(0, right_x_index), w - 1);
|
||||||
|
|
||||||
const float y_lerp = y_point - top_y_index;
|
const float y_lerp = y_point - top_y_index;
|
||||||
const float x_lerp = x_point - left_x_index;
|
const float x_lerp = x_point - left_x_index;
|
||||||
|
const int y_top_index = box_index * input_height_ + y1h + top_y_index;
|
||||||
|
const int y_bottom_index = box_index * input_height_ + y1h + bottom_y_index;
|
||||||
|
|
||||||
const int y_top_index = std::max(0, y1h + top_y_index);
|
const float top_left =
|
||||||
const int y_bottom_index = std::max(0, y1h + bottom_y_index);
|
static_cast<float>(input_image[(y_top_index * input_width_ + x1w + left_x_index) * channel_ + pos_channel]);
|
||||||
const int x_left_index = std::max(0, x1w + left_x_index);
|
const float top_right =
|
||||||
const int x_right_index = std::max(0, x1w + right_x_index);
|
static_cast<float>(input_image[(y_top_index * input_width_ + x1w + right_x_index) * channel_ + pos_channel]);
|
||||||
|
|
||||||
const float top_left = static_cast<float>(
|
|
||||||
input_image[((box_index * input_height_ + y_top_index) * input_width_ + x_left_index) * channel_ +
|
|
||||||
pos_channel]);
|
|
||||||
const float top_right = static_cast<float>(
|
|
||||||
input_image[((box_index * input_height_ + y_top_index) * input_width_ + x_right_index) * channel_ +
|
|
||||||
pos_channel]);
|
|
||||||
const float bottom_left = static_cast<float>(
|
const float bottom_left = static_cast<float>(
|
||||||
input_image[((box_index * input_height_ + y_bottom_index) * input_width_ + x_left_index) * channel_ +
|
input_image[(y_bottom_index * input_width_ + x1w + left_x_index) * channel_ + pos_channel]);
|
||||||
pos_channel]);
|
|
||||||
const float bottom_right = static_cast<float>(
|
const float bottom_right = static_cast<float>(
|
||||||
input_image[((box_index * input_height_ + y_bottom_index) * input_width_ + x_right_index) * channel_ +
|
input_image[(y_bottom_index * input_width_ + x1w + right_x_index) * channel_ + pos_channel]);
|
||||||
pos_channel]);
|
|
||||||
|
|
||||||
output[pos] = top_left * (1 - y_lerp) * (1 - x_lerp) + bottom_right * y_lerp * x_lerp +
|
float ret = top_left * (1 - y_lerp) * (1 - x_lerp) + bottom_right * y_lerp * x_lerp +
|
||||||
top_right * (1 - y_lerp) * x_lerp + bottom_left * y_lerp * (1 - x_lerp);
|
top_right * (1 - y_lerp) * x_lerp + bottom_left * y_lerp * (1 - x_lerp);
|
||||||
|
output[pos] = ret;
|
||||||
} else {
|
} else {
|
||||||
// Nearest Neighbour
|
// Nearest Neighbour
|
||||||
const int closest_x_index = roundf(target_x);
|
const int closest_x_index = roundf(target_x);
|
||||||
|
|
|
@ -35,14 +35,15 @@ class CropAndResizeCPUKernel : public CPUKernel {
|
||||||
const std::vector<AddressPtr> &outputs) override;
|
const std::vector<AddressPtr> &outputs) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int method_{1};
|
int method_;
|
||||||
float extrapolation_value_{0.0};
|
float extrapolation_value_;
|
||||||
int output_size_{0};
|
int input_crop_size_;
|
||||||
int input_height_{0};
|
int output_size_;
|
||||||
int input_width_{0};
|
int input_height_;
|
||||||
int final_height_{0};
|
int input_width_;
|
||||||
int final_width_{0};
|
int final_height_;
|
||||||
int channel_{0};
|
int final_width_;
|
||||||
|
int channel_;
|
||||||
};
|
};
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL_T(CropAndResize,
|
MS_REG_CPU_KERNEL_T(CropAndResize,
|
||||||
|
|
|
@ -43,9 +43,9 @@ void DropoutGradCpuBwdKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||||
bool DropoutGradCpuBwdKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
bool DropoutGradCpuBwdKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||||
const std::vector<AddressPtr> &outputs) {
|
const std::vector<AddressPtr> &outputs) {
|
||||||
if (dtype_ == kNumberTypeFloat16) {
|
if (dtype_ == kNumberTypeFloat16) {
|
||||||
DropoutBackwardKernel<float16>(inputs, outputs, keep_prob_);
|
DropoutBackwardKernel<float16>(inputs, outputs, num_count_, keep_prob_);
|
||||||
} else if (dtype_ == kNumberTypeFloat32) {
|
} else if (dtype_ == kNumberTypeFloat32) {
|
||||||
DropoutBackwardKernel<float>(inputs, outputs, keep_prob_);
|
DropoutBackwardKernel<float>(inputs, outputs, num_count_, keep_prob_);
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(ERROR) << "Input data type: " << dtype_ << " is not supported for DropoutGrad kernel for CPU.";
|
MS_LOG(ERROR) << "Input data type: " << dtype_ << " is not supported for DropoutGrad kernel for CPU.";
|
||||||
}
|
}
|
||||||
|
@ -55,7 +55,8 @@ bool DropoutGradCpuBwdKernel::Launch(const std::vector<AddressPtr> &inputs, cons
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void DropoutGradCpuBwdKernel::DropoutBackwardKernel(const std::vector<AddressPtr> &inputs,
|
void DropoutGradCpuBwdKernel::DropoutBackwardKernel(const std::vector<AddressPtr> &inputs,
|
||||||
const std::vector<AddressPtr> &outputs, float keep_prob) {
|
const std::vector<AddressPtr> &outputs, size_t num_count,
|
||||||
|
float keep_prob) {
|
||||||
auto *output = reinterpret_cast<T *>(outputs[0]->addr);
|
auto *output = reinterpret_cast<T *>(outputs[0]->addr);
|
||||||
const auto *input = reinterpret_cast<T *>(inputs[0]->addr);
|
const auto *input = reinterpret_cast<T *>(inputs[0]->addr);
|
||||||
const auto *mask = reinterpret_cast<T *>(inputs[1]->addr);
|
const auto *mask = reinterpret_cast<T *>(inputs[1]->addr);
|
||||||
|
@ -69,7 +70,7 @@ void DropoutGradCpuBwdKernel::DropoutBackwardKernel(const std::vector<AddressPtr
|
||||||
input_tmp[i] = static_cast<float>(input[i]);
|
input_tmp[i] = static_cast<float>(input[i]);
|
||||||
mask_tmp[i] = static_cast<float>(mask[i]);
|
mask_tmp[i] = static_cast<float>(mask[i]);
|
||||||
}
|
}
|
||||||
DropoutGrad(input_tmp, mask_tmp, output_tmp, SizeToInt(num_count_), scale);
|
DropoutGrad(input_tmp, mask_tmp, output_tmp, num_count_, scale);
|
||||||
for (size_t i = 0; i < num_count_; ++i) {
|
for (size_t i = 0; i < num_count_; ++i) {
|
||||||
output[i] = static_cast<float16>(output_tmp[i]);
|
output[i] = static_cast<float16>(output_tmp[i]);
|
||||||
}
|
}
|
||||||
|
@ -77,7 +78,7 @@ void DropoutGradCpuBwdKernel::DropoutBackwardKernel(const std::vector<AddressPtr
|
||||||
delete[] output_tmp;
|
delete[] output_tmp;
|
||||||
delete[] mask_tmp;
|
delete[] mask_tmp;
|
||||||
} else if constexpr (std::is_same_v<T, float>) {
|
} else if constexpr (std::is_same_v<T, float>) {
|
||||||
DropoutGrad(input, mask, output, SizeToInt(num_count_), scale);
|
DropoutGrad(input, mask, output, num_count_, scale);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
|
|
|
@ -40,7 +40,7 @@ class DropoutGradCpuBwdKernel : public CPUKernel {
|
||||||
TypeId dtype_{kTypeUnknown};
|
TypeId dtype_{kTypeUnknown};
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void DropoutBackwardKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs,
|
void DropoutBackwardKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs,
|
||||||
float keep_prob);
|
size_t num_count, float keep_prob);
|
||||||
};
|
};
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL(DropoutGrad, KernelAttr(), DropoutGradCpuBwdKernel);
|
MS_REG_CPU_KERNEL(DropoutGrad, KernelAttr(), DropoutGradCpuBwdKernel);
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -13,10 +13,8 @@
|
||||||
* See the License for the specific language governing permissions and
|
* See the License for the specific language governing permissions and
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "backend/kernel_compiler/cpu/eltwise_grad_cpu_kernel.h"
|
|
||||||
#include <string>
|
|
||||||
#include <map>
|
#include <map>
|
||||||
|
#include "backend/kernel_compiler/cpu/eltwise_grad_cpu_kernel.h"
|
||||||
#include "common/thread_pool.h"
|
#include "common/thread_pool.h"
|
||||||
#include "runtime/device/cpu/cpu_device_address.h"
|
#include "runtime/device/cpu/cpu_device_address.h"
|
||||||
#include "nnacl/fp32_grad/activation_grad.h"
|
#include "nnacl/fp32_grad/activation_grad.h"
|
||||||
|
@ -27,49 +25,49 @@ namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void EltWiseGradCPUKernel<T>::ReluGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const {
|
void EltWiseGradCPUKernel<T>::ReluGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const {
|
||||||
if constexpr (!std::is_same<T, float>::value) {
|
if constexpr (std::is_same_v<T, float>) {
|
||||||
MS_LOG(EXCEPTION) << "ReLUGrad only support float";
|
|
||||||
}
|
|
||||||
|
|
||||||
int ret = ::ReluGrad(input1 + start, input2 + start, end - start, out + start);
|
int ret = ::ReluGrad(input1 + start, input2 + start, end - start, out + start);
|
||||||
if (ret == NNACL_ERR) {
|
if (ret == NNACL_ERR) {
|
||||||
MS_LOG(EXCEPTION) << "ReLUGrad execute failed.";
|
MS_LOG(EXCEPTION) << "ReLUGrad failed.";
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
MS_LOG(EXCEPTION) << "ReLUGrad only support float";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void EltWiseGradCPUKernel<T>::ReLU6Grad(const T *input1, const T *input2, T *out, size_t start, size_t end) const {
|
void EltWiseGradCPUKernel<T>::ReLU6Grad(const T *input1, const T *input2, T *out, size_t start, size_t end) const {
|
||||||
if constexpr (!std::is_same<T, float>::value) {
|
if constexpr (std::is_same_v<T, float>) {
|
||||||
MS_LOG(EXCEPTION) << "ReLU6Grad only support float";
|
|
||||||
}
|
|
||||||
|
|
||||||
int ret = ::Relu6Grad(input1 + start, input2 + start, end - start, out + start);
|
int ret = ::Relu6Grad(input1 + start, input2 + start, end - start, out + start);
|
||||||
if (ret == NNACL_ERR) {
|
if (ret == NNACL_ERR) {
|
||||||
MS_LOG(EXCEPTION) << "ReLU6Grad execute failed.";
|
MS_LOG(EXCEPTION) << "ReLU6Grad failed.";
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
MS_LOG(EXCEPTION) << "ReLU6Grad only support float";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void EltWiseGradCPUKernel<T>::AbsGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const {
|
void EltWiseGradCPUKernel<T>::AbsGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const {
|
||||||
if constexpr (!std::is_same<T, float>::value) {
|
if constexpr (std::is_same_v<T, float>) {
|
||||||
MS_LOG(EXCEPTION) << "AbsGrad only support float";
|
|
||||||
}
|
|
||||||
|
|
||||||
int ret = ::ElementAbsGrad(input1 + start, input2 + start, out + start, end - start);
|
int ret = ::ElementAbsGrad(input1 + start, input2 + start, out + start, end - start);
|
||||||
if (ret == NNACL_ERR) {
|
if (ret == NNACL_ERR) {
|
||||||
MS_LOG(EXCEPTION) << "AbsGrad execute failed.";
|
MS_LOG(EXCEPTION) << "AbsGrad failed.";
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
MS_LOG(EXCEPTION) << "AbsGrad only support float";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void EltWiseGradCPUKernel<T>::SigmoidGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const {
|
void EltWiseGradCPUKernel<T>::SigmoidGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const {
|
||||||
if constexpr (!std::is_same<T, float>::value) {
|
if constexpr (std::is_same_v<T, float>) {
|
||||||
MS_LOG(EXCEPTION) << "SigmoidGrad only support float";
|
|
||||||
}
|
|
||||||
|
|
||||||
int ret = ::SigmoidGrad(input2 + start, input1 + start, end - start, out + start);
|
int ret = ::SigmoidGrad(input2 + start, input1 + start, end - start, out + start);
|
||||||
if (ret == NNACL_ERR) {
|
if (ret == NNACL_ERR) {
|
||||||
MS_LOG(EXCEPTION) << "SigmoidGrad execute failed.";
|
MS_LOG(EXCEPTION) << "SigmoidGrad failed.";
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
MS_LOG(EXCEPTION) << "SigmoidGrad only support float";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -82,13 +80,13 @@ void EltWiseGradCPUKernel<T>::SqrtGrad(const T *input1, const T *input2, T *out,
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void EltWiseGradCPUKernel<T>::TanhGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const {
|
void EltWiseGradCPUKernel<T>::TanhGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const {
|
||||||
if constexpr (!std::is_same<T, float>::value) {
|
if constexpr (std::is_same_v<T, float>) {
|
||||||
MS_LOG(EXCEPTION) << "TanhGrad only support float";
|
|
||||||
}
|
|
||||||
|
|
||||||
int ret = ::TanhGrad(input2 + start, input1 + start, end - start, out + start);
|
int ret = ::TanhGrad(input2 + start, input1 + start, end - start, out + start);
|
||||||
if (ret == NNACL_ERR) {
|
if (ret == NNACL_ERR) {
|
||||||
MS_LOG(EXCEPTION) << "TanhGrad execute failed.";
|
MS_LOG(EXCEPTION) << "TanhGrad failed.";
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
MS_LOG(EXCEPTION) << "TanhGrad only support float";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -209,18 +207,6 @@ void EltWiseGradCPUKernel<T>::AcoshGrad(const T *input1, const T *input2, T *out
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
void EltWiseGradCPUKernel<T>::SoftplusGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const {
|
|
||||||
if constexpr (!std::is_same<T, float>::value) {
|
|
||||||
MS_LOG(EXCEPTION) << "SoftplusGrad only support float";
|
|
||||||
}
|
|
||||||
|
|
||||||
int ret = ::SoftplusGrad(input1 + start, input2 + start, end - start, out + start);
|
|
||||||
if (ret == NNACL_ERR) {
|
|
||||||
MS_LOG(EXCEPTION) << "SoftplusGrad execute failed.";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void EltWiseGradCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) {
|
void EltWiseGradCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) {
|
||||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||||
|
@ -233,19 +219,12 @@ bool EltWiseGradCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &inpu
|
||||||
const std::vector<kernel::AddressPtr> &outputs) {
|
const std::vector<kernel::AddressPtr> &outputs) {
|
||||||
static const std::map<std::string,
|
static const std::map<std::string,
|
||||||
std::function<void(EltWiseGradCPUKernel *, const T *, const T *, T *, size_t, size_t)>>
|
std::function<void(EltWiseGradCPUKernel *, const T *, const T *, T *, size_t, size_t)>>
|
||||||
elt_map{{prim::kPrimReluGrad->name(), &EltWiseGradCPUKernel<T>::ReluGrad},
|
elt_map{{"ReluGrad", &EltWiseGradCPUKernel<T>::ReluGrad}, {"ReLU6Grad", &EltWiseGradCPUKernel<T>::ReLU6Grad},
|
||||||
{prim::kPrimRelu6Grad->name(), &EltWiseGradCPUKernel<T>::ReLU6Grad},
|
{"SigmoidGrad", &EltWiseGradCPUKernel<T>::SigmoidGrad}, {"AbsGrad", &EltWiseGradCPUKernel<T>::AbsGrad},
|
||||||
{prim::kPrimSigmoidGrad->name(), &EltWiseGradCPUKernel<T>::SigmoidGrad},
|
{"TanhGrad", &EltWiseGradCPUKernel<T>::TanhGrad}, {"SqrtGrad", &EltWiseGradCPUKernel<T>::SqrtGrad},
|
||||||
{prim::kPrimAbsGrad->name(), &EltWiseGradCPUKernel<T>::AbsGrad},
|
{"GeLUGrad", &EltWiseGradCPUKernel<T>::GeluGrad}, {"AsinGrad", &EltWiseGradCPUKernel<T>::AsinGrad},
|
||||||
{prim::kPrimTanhGrad->name(), &EltWiseGradCPUKernel<T>::TanhGrad},
|
{"ACosGrad", &EltWiseGradCPUKernel<T>::ACosGrad}, {"AtanGrad", &EltWiseGradCPUKernel<T>::AtanGrad},
|
||||||
{prim::kPrimSqrtGrad->name(), &EltWiseGradCPUKernel<T>::SqrtGrad},
|
{"AsinhGrad", &EltWiseGradCPUKernel<T>::AsinhGrad}, {"AcoshGrad", &EltWiseGradCPUKernel<T>::AcoshGrad}};
|
||||||
{prim::kPrimGeLUGrad->name(), &EltWiseGradCPUKernel<T>::GeluGrad},
|
|
||||||
{prim::kPrimAsinGrad->name(), &EltWiseGradCPUKernel<T>::AsinGrad},
|
|
||||||
{prim::kPrimACosGrad->name(), &EltWiseGradCPUKernel<T>::ACosGrad},
|
|
||||||
{prim::kPrimAtanGrad->name(), &EltWiseGradCPUKernel<T>::AtanGrad},
|
|
||||||
{prim::kPrimAsinhGrad->name(), &EltWiseGradCPUKernel<T>::AsinhGrad},
|
|
||||||
{prim::kPrimAcoshGrad->name(), &EltWiseGradCPUKernel<T>::AcoshGrad},
|
|
||||||
{prim::kPrimSoftplusGrad->name(), &EltWiseGradCPUKernel<T>::SoftplusGrad}};
|
|
||||||
if (inputs.size() < 2 || outputs.size() != 1) {
|
if (inputs.size() < 2 || outputs.size() != 1) {
|
||||||
MS_LOG(ERROR) << kernel_name_ << " requires at least 2 inputs and 1 output, but got " << inputs.size()
|
MS_LOG(ERROR) << kernel_name_ << " requires at least 2 inputs and 1 output, but got " << inputs.size()
|
||||||
<< " inputs and " << outputs.size() << " output.";
|
<< " inputs and " << outputs.size() << " output.";
|
||||||
|
@ -259,9 +238,9 @@ bool EltWiseGradCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &inpu
|
||||||
const auto input1 = reinterpret_cast<T *>(inputs[1]->addr);
|
const auto input1 = reinterpret_cast<T *>(inputs[1]->addr);
|
||||||
auto output = reinterpret_cast<T *>(outputs[0]->addr);
|
auto output = reinterpret_cast<T *>(outputs[0]->addr);
|
||||||
|
|
||||||
ParallelLaunchAutoSearch(
|
CPUKernelUtils::ParallelForAutoSearch(
|
||||||
std::bind(elt_map.at(kernel_name_), this, input0, input1, output, std::placeholders::_1, std::placeholders::_2),
|
std::bind(elt_map.at(kernel_name_), this, input0, input1, output, std::placeholders::_1, std::placeholders::_2),
|
||||||
outputs[0]->size / sizeof(T), this, ¶llel_search_info_);
|
outputs[0]->size / sizeof(T), ¶llel_search_info_);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -48,7 +48,6 @@ class EltWiseGradCPUKernel : public CPUKernel {
|
||||||
void AtanGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const;
|
void AtanGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const;
|
||||||
void AsinhGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const;
|
void AsinhGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const;
|
||||||
void AcoshGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const;
|
void AcoshGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const;
|
||||||
void SoftplusGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const;
|
|
||||||
|
|
||||||
std::string kernel_name_ = "";
|
std::string kernel_name_ = "";
|
||||||
};
|
};
|
||||||
|
@ -104,10 +103,6 @@ MS_REG_CPU_KERNEL_T(
|
||||||
AcoshGrad,
|
AcoshGrad,
|
||||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||||
EltWiseGradCPUKernel, float);
|
EltWiseGradCPUKernel, float);
|
||||||
MS_REG_CPU_KERNEL_T(
|
|
||||||
SoftplusGrad,
|
|
||||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
|
||||||
EltWiseGradCPUKernel, float);
|
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -30,7 +30,7 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
// The duration between two PullWeights requests when return code is ResponseCode_SucNotReady.
|
// The duration between two downloading requests when return code is ResponseCode_SucNotReady.
|
||||||
constexpr int kRetryDurationOfPullWeights = 200;
|
constexpr int kRetryDurationOfPullWeights = 200;
|
||||||
template <typename T>
|
template <typename T>
|
||||||
class FusedPullWeightKernel : public CPUKernel {
|
class FusedPullWeightKernel : public CPUKernel {
|
||||||
|
@ -51,17 +51,19 @@ class FusedPullWeightKernel : public CPUKernel {
|
||||||
MS_EXCEPTION_IF_NULL(fbb);
|
MS_EXCEPTION_IF_NULL(fbb);
|
||||||
|
|
||||||
total_iteration_++;
|
total_iteration_++;
|
||||||
uint64_t step_num_per_iteration = fl::worker::FLWorker::GetInstance().worker_step_num_per_iteration();
|
|
||||||
// The worker has to train kWorkerTrainStepNum standalone iterations before it communicates with server.
|
// The worker has to train kWorkerTrainStepNum standalone iterations before it communicates with server.
|
||||||
MS_LOG(INFO) << "Try to pull weights. Local step number: " << total_iteration_
|
if (total_iteration_ % fl::worker::FLWorker::GetInstance().worker_step_num_per_iteration() !=
|
||||||
<< ", step number needs to run per iteration: " << step_num_per_iteration;
|
fl::kTrainBeginStepNum) {
|
||||||
if (step_num_per_iteration != fl::kOneStepPerIteration &&
|
|
||||||
total_iteration_ % step_num_per_iteration != fl::kTrainBeginStepNum) {
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
fl_iteration_++;
|
fl_iteration_++;
|
||||||
MS_LOG(INFO) << "Launching pulling weight for federated learning iteration " << fl_iteration_;
|
if (fl_iteration_ > ps::PSContext::instance()->fl_iteration_num()) {
|
||||||
|
MS_LOG(INFO) << ps::PSContext::instance()->fl_iteration_num() << " iterations are completed.";
|
||||||
|
fl_iteration_ = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
MS_LOG(INFO) << "Start pulling weight for federated learning iteration " << fl_iteration_;
|
||||||
if (!BuildPullWeightReq(fbb)) {
|
if (!BuildPullWeightReq(fbb)) {
|
||||||
MS_LOG(EXCEPTION) << "Building request for FusedPullWeight failed.";
|
MS_LOG(EXCEPTION) << "Building request for FusedPullWeight failed.";
|
||||||
return false;
|
return false;
|
||||||
|
@ -71,16 +73,11 @@ class FusedPullWeightKernel : public CPUKernel {
|
||||||
const schema::ResponsePullWeight *pull_weight_rsp = nullptr;
|
const schema::ResponsePullWeight *pull_weight_rsp = nullptr;
|
||||||
int retcode = schema::ResponseCode_SucNotReady;
|
int retcode = schema::ResponseCode_SucNotReady;
|
||||||
while (retcode == schema::ResponseCode_SucNotReady) {
|
while (retcode == schema::ResponseCode_SucNotReady) {
|
||||||
if (!fl::worker::FLWorker::GetInstance().running()) {
|
|
||||||
MS_LOG(WARNING) << "Worker has finished.";
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
if (!fl::worker::FLWorker::GetInstance().SendToServer(
|
if (!fl::worker::FLWorker::GetInstance().SendToServer(
|
||||||
0, fbb->GetBufferPointer(), fbb->GetSize(), ps::core::TcpUserCommand::kPullWeight, &pull_weight_rsp_msg)) {
|
0, fbb->GetBufferPointer(), fbb->GetSize(), ps::core::TcpUserCommand::kPullWeight, &pull_weight_rsp_msg)) {
|
||||||
MS_LOG(WARNING) << "Sending request for FusedPullWeight to server 0 failed. Retry later.";
|
MS_LOG(WARNING) << "Sending request for FusedPullWeight to server 0 failed. This iteration is dropped.";
|
||||||
retcode = schema::ResponseCode_SucNotReady;
|
fl::worker::FLWorker::GetInstance().SetIterationRunning();
|
||||||
std::this_thread::sleep_for(std::chrono::milliseconds(kRetryDurationOfPullWeights));
|
return true;
|
||||||
continue;
|
|
||||||
}
|
}
|
||||||
MS_EXCEPTION_IF_NULL(pull_weight_rsp_msg);
|
MS_EXCEPTION_IF_NULL(pull_weight_rsp_msg);
|
||||||
|
|
||||||
|
@ -91,8 +88,6 @@ class FusedPullWeightKernel : public CPUKernel {
|
||||||
fl_iteration_ = pull_weight_rsp->iteration();
|
fl_iteration_ = pull_weight_rsp->iteration();
|
||||||
MS_LOG(DEBUG) << "Server is not ready for downloading yet. Reason: " << pull_weight_rsp->reason()->str()
|
MS_LOG(DEBUG) << "Server is not ready for downloading yet. Reason: " << pull_weight_rsp->reason()->str()
|
||||||
<< ". Retry later.";
|
<< ". Retry later.";
|
||||||
// Recreate fbb to avoid memory leak of FlatBuffers.
|
|
||||||
fbb = std::make_shared<fl::FBBuilder>();
|
|
||||||
if (!BuildPullWeightReq(fbb)) {
|
if (!BuildPullWeightReq(fbb)) {
|
||||||
MS_LOG(EXCEPTION) << "Building request for FusedDownloadWeightsByKeys failed.";
|
MS_LOG(EXCEPTION) << "Building request for FusedDownloadWeightsByKeys failed.";
|
||||||
return false;
|
return false;
|
||||||
|
@ -121,7 +116,7 @@ class FusedPullWeightKernel : public CPUKernel {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
MS_LOG(INFO) << "Pull weights for " << weight_full_names_ << " success. Iteration: " << fl_iteration_;
|
MS_LOG(INFO) << "Pull weights for " << weight_full_names_ << " succeed. Iteration: " << fl_iteration_;
|
||||||
fl::worker::FLWorker::GetInstance().SetIterationRunning();
|
fl::worker::FLWorker::GetInstance().SetIterationRunning();
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
|
@ -28,7 +28,7 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
// The duration between two PushWeights requests when return code is ResponseCode_SucNotReady.
|
// The duration between two uploading requests when return code is ResponseCode_SucNotReady.
|
||||||
constexpr int kRetryDurationOfPushWeights = 200;
|
constexpr int kRetryDurationOfPushWeights = 200;
|
||||||
template <typename T>
|
template <typename T>
|
||||||
class FusedPushWeightKernel : public CPUKernel {
|
class FusedPushWeightKernel : public CPUKernel {
|
||||||
|
@ -49,17 +49,19 @@ class FusedPushWeightKernel : public CPUKernel {
|
||||||
MS_EXCEPTION_IF_NULL(fbb);
|
MS_EXCEPTION_IF_NULL(fbb);
|
||||||
|
|
||||||
total_iteration_++;
|
total_iteration_++;
|
||||||
uint64_t step_num_per_iteration = fl::worker::FLWorker::GetInstance().worker_step_num_per_iteration();
|
|
||||||
// The worker has to train kWorkerTrainStepNum standalone iterations before it communicates with server.
|
// The worker has to train kWorkerTrainStepNum standalone iterations before it communicates with server.
|
||||||
MS_LOG(INFO) << "Try to push weights. Local step number: " << total_iteration_
|
if (total_iteration_ % fl::worker::FLWorker::GetInstance().worker_step_num_per_iteration() !=
|
||||||
<< ", step number needs to run per iteration: " << step_num_per_iteration;
|
fl::kTrainBeginStepNum) {
|
||||||
if (step_num_per_iteration != fl::kOneStepPerIteration &&
|
|
||||||
total_iteration_ % step_num_per_iteration != fl::kTrainEndStepNum) {
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
fl_iteration_++;
|
fl_iteration_++;
|
||||||
MS_LOG(INFO) << "Launching pushing weight for federated learning iteration " << fl_iteration_;
|
if (fl_iteration_ > ps::PSContext::instance()->fl_iteration_num()) {
|
||||||
|
MS_LOG(INFO) << ps::PSContext::instance()->fl_iteration_num() << " iterations are completed.";
|
||||||
|
fl_iteration_ = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
MS_LOG(INFO) << "Start pushing weight for federated learning iteration " << fl_iteration_;
|
||||||
if (!BuildPushWeightReq(fbb, inputs)) {
|
if (!BuildPushWeightReq(fbb, inputs)) {
|
||||||
MS_LOG(EXCEPTION) << "Building request for FusedPushWeight failed.";
|
MS_LOG(EXCEPTION) << "Building request for FusedPushWeight failed.";
|
||||||
return false;
|
return false;
|
||||||
|
@ -71,17 +73,13 @@ class FusedPushWeightKernel : public CPUKernel {
|
||||||
const schema::ResponsePushWeight *push_weight_rsp = nullptr;
|
const schema::ResponsePushWeight *push_weight_rsp = nullptr;
|
||||||
int retcode = schema::ResponseCode_SucNotReady;
|
int retcode = schema::ResponseCode_SucNotReady;
|
||||||
while (retcode == schema::ResponseCode_SucNotReady) {
|
while (retcode == schema::ResponseCode_SucNotReady) {
|
||||||
if (!fl::worker::FLWorker::GetInstance().running()) {
|
|
||||||
MS_LOG(WARNING) << "Worker has finished.";
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
if (!fl::worker::FLWorker::GetInstance().SendToServer(i, fbb->GetBufferPointer(), fbb->GetSize(),
|
if (!fl::worker::FLWorker::GetInstance().SendToServer(i, fbb->GetBufferPointer(), fbb->GetSize(),
|
||||||
ps::core::TcpUserCommand::kPushWeight,
|
ps::core::TcpUserCommand::kPushWeight,
|
||||||
&push_weight_rsp_msg)) {
|
&push_weight_rsp_msg)) {
|
||||||
MS_LOG(WARNING) << "Sending request for FusedPushWeight to server " << i << " failed.";
|
MS_LOG(WARNING) << "Sending request for FusedPushWeight to server " << i
|
||||||
retcode = schema::ResponseCode_SucNotReady;
|
<< " failed. This iteration is dropped.";
|
||||||
std::this_thread::sleep_for(std::chrono::milliseconds(kRetryDurationOfPushWeights));
|
fl::worker::FLWorker::GetInstance().SetIterationCompleted();
|
||||||
continue;
|
return true;
|
||||||
}
|
}
|
||||||
MS_EXCEPTION_IF_NULL(push_weight_rsp_msg);
|
MS_EXCEPTION_IF_NULL(push_weight_rsp_msg);
|
||||||
|
|
||||||
|
@ -107,7 +105,8 @@ class FusedPushWeightKernel : public CPUKernel {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
MS_LOG(INFO) << "Push weights for " << weight_full_names_ << " success. Iteration: " << fl_iteration_;
|
MS_LOG(INFO) << "Push weights for " << weight_full_names_ << " succeed. Iteration: " << fl_iteration_;
|
||||||
|
fl::worker::FLWorker::GetInstance().SetIterationCompleted();
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -52,26 +52,6 @@ MS_REG_CPU_KERNEL_T(
|
||||||
MaskedSelect,
|
MaskedSelect,
|
||||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt32),
|
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt32),
|
||||||
MaskedSelectCPUKernel, int);
|
MaskedSelectCPUKernel, int);
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL_T(
|
|
||||||
MaskedSelect,
|
|
||||||
KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt16),
|
|
||||||
MaskedSelectCPUKernel, int16_t);
|
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL_T(
|
|
||||||
MaskedSelect,
|
|
||||||
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt64),
|
|
||||||
MaskedSelectCPUKernel, int64_t);
|
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL_T(
|
|
||||||
MaskedSelect,
|
|
||||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeFloat16),
|
|
||||||
MaskedSelectCPUKernel, float16);
|
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL_T(
|
|
||||||
MaskedSelect,
|
|
||||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeFloat64),
|
|
||||||
MaskedSelectCPUKernel, double);
|
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MASKED_SELECTED_CPU_KERNEL_H_
|
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MASKED_SELECTED_CPU_KERNEL_H_
|
||||||
|
|
|
@ -58,38 +58,6 @@ MS_REG_CPU_KERNEL_T(MaskedSelectGrad,
|
||||||
.AddInputAttr(kNumberTypeInt32)
|
.AddInputAttr(kNumberTypeInt32)
|
||||||
.AddOutputAttr(kNumberTypeInt32),
|
.AddOutputAttr(kNumberTypeInt32),
|
||||||
MaskedSelectGradCPUKernel, int);
|
MaskedSelectGradCPUKernel, int);
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL_T(MaskedSelectGrad,
|
|
||||||
KernelAttr()
|
|
||||||
.AddInputAttr(kNumberTypeFloat16)
|
|
||||||
.AddInputAttr(kNumberTypeBool)
|
|
||||||
.AddInputAttr(kNumberTypeFloat16)
|
|
||||||
.AddOutputAttr(kNumberTypeFloat16),
|
|
||||||
MaskedSelectGradCPUKernel, float16);
|
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL_T(MaskedSelectGrad,
|
|
||||||
KernelAttr()
|
|
||||||
.AddInputAttr(kNumberTypeFloat64)
|
|
||||||
.AddInputAttr(kNumberTypeBool)
|
|
||||||
.AddInputAttr(kNumberTypeFloat64)
|
|
||||||
.AddOutputAttr(kNumberTypeFloat64),
|
|
||||||
MaskedSelectGradCPUKernel, double);
|
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL_T(MaskedSelectGrad,
|
|
||||||
KernelAttr()
|
|
||||||
.AddInputAttr(kNumberTypeInt16)
|
|
||||||
.AddInputAttr(kNumberTypeBool)
|
|
||||||
.AddInputAttr(kNumberTypeInt16)
|
|
||||||
.AddOutputAttr(kNumberTypeInt16),
|
|
||||||
MaskedSelectGradCPUKernel, int16_t);
|
|
||||||
|
|
||||||
MS_REG_CPU_KERNEL_T(MaskedSelectGrad,
|
|
||||||
KernelAttr()
|
|
||||||
.AddInputAttr(kNumberTypeInt64)
|
|
||||||
.AddInputAttr(kNumberTypeBool)
|
|
||||||
.AddInputAttr(kNumberTypeInt64)
|
|
||||||
.AddOutputAttr(kNumberTypeInt64),
|
|
||||||
MaskedSelectGradCPUKernel, int64_t);
|
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MASKED_SELECTED_GRAD_CPU_KERNEL_H_
|
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MASKED_SELECTED_GRAD_CPU_KERNEL_H_
|
||||||
|
|
|
@ -86,8 +86,6 @@ bool MirrorPadCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, c
|
||||||
LaunchKernel<float16>(inputs, outputs);
|
LaunchKernel<float16>(inputs, outputs);
|
||||||
} else if (dtype_ == kNumberTypeFloat32) {
|
} else if (dtype_ == kNumberTypeFloat32) {
|
||||||
LaunchKernel<float>(inputs, outputs);
|
LaunchKernel<float>(inputs, outputs);
|
||||||
} else if (dtype_ == kNumberTypeFloat64) {
|
|
||||||
LaunchKernel<double>(inputs, outputs);
|
|
||||||
} else if (dtype_ == kNumberTypeInt32) {
|
} else if (dtype_ == kNumberTypeInt32) {
|
||||||
LaunchKernel<int>(inputs, outputs);
|
LaunchKernel<int>(inputs, outputs);
|
||||||
} else {
|
} else {
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue