forked from mindspore-Ecosystem/mindspore
!35051 [MS][Lite][Task] add kernel graph related files to cloud infer framework
Merge pull request !35051 from 刘力力/feature_cloud_infer_runtime_develop_kernel_graph_merge
This commit is contained in:
commit
221f51e2f5
|
@ -683,6 +683,10 @@ if(MSLITE_ENABLE_CLOUD_FUSION_INFERENCE)
|
|||
endif()
|
||||
endif()
|
||||
|
||||
if(MSLITE_ENABLE_MODEL_ENCRYPTION AND NOT ENABLE_CLOUD_AND_LITE)
|
||||
set(MSLITE_DEPS_OPENSSL on)
|
||||
endif()
|
||||
|
||||
include(${LITE_DIR}/cmake/lite_dependences.cmake)
|
||||
|
||||
if(MSLITE_GPU_BACKEND STREQUAL opencl)
|
||||
|
@ -725,11 +729,11 @@ if(MSLITE_ENABLE_FP16)
|
|||
endif()
|
||||
endif()
|
||||
|
||||
if(MSLITE_ENABLE_MODEL_ENCRYPTION AND NOT ENABLE_CLOUD_AND_LITE)
|
||||
find_required_package(Patch)
|
||||
include(${TOP_DIR}/cmake/external_libs/openssl.cmake)
|
||||
add_compile_definitions(ENABLE_OPENSSL)
|
||||
endif()
|
||||
# if(MSLITE_ENABLE_MODEL_ENCRYPTION AND NOT ENABLE_CLOUD_AND_LITE)
|
||||
# find_required_package(Patch)
|
||||
# include(${TOP_DIR}/cmake/external_libs/openssl.cmake)
|
||||
# add_compile_definitions(ENABLE_OPENSSL)
|
||||
# endif()
|
||||
|
||||
if(MSLITE_ENABLE_MINDRT)
|
||||
add_compile_definitions(ENABLE_MINDRT)
|
||||
|
|
|
@ -0,0 +1,45 @@
|
|||
add_compile_definitions(BUILD_LITE)
|
||||
|
||||
if(ENABLE_CLOUD_AND_LITE)
|
||||
remove_definitions(-DUSE_GLOG)
|
||||
add_compile_definitions(ENABLE_CLOUD_AND_LITE)
|
||||
endif()
|
||||
|
||||
add_definitions(-DVERSION_STR=\"${VERSION_STR}\")
|
||||
|
||||
if(MACHINE_LINUX_ARM64)
|
||||
add_compile_definitions(MACHINE_LINUX_ARM64)
|
||||
add_compile_definitions(LINUX_RUNTIME)
|
||||
endif()
|
||||
if(PLATFORM_X86_64)
|
||||
add_compile_definitions(LINUX_RUNTIME)
|
||||
endif()
|
||||
if(TOOLCHAIN_NAME STREQUAL "himix200")
|
||||
add_compile_definitions(SUPPORT_NNIE)
|
||||
elseif(TOOLCHAIN_NAME STREQUAL "himix100")
|
||||
add_compile_definitions(SUPPORT_NNIE)
|
||||
elseif(TOOLCHAIN_NAME STREQUAL "mix210")
|
||||
add_compile_definitions(SUPPORT_34XX)
|
||||
elseif(TOOLCHAIN_NAME STREQUAL "ohos-lite")
|
||||
SET_PROPERTY(GLOBAL PROPERTY TARGET_SUPPORTS_SHARED_LIBS TRUE)
|
||||
endif()
|
||||
|
||||
if(MSLITE_ENABLE_DYNAMIC_THREAD_DISTRIBUTE)
|
||||
add_compile_definitions(DYNAMIC_THREAD_DISTRIBUTE)
|
||||
endif()
|
||||
|
||||
if(MSLITE_ENABLE_BFC_MEMORY)
|
||||
add_compile_definitions(BFC_MEMORY)
|
||||
endif()
|
||||
|
||||
if(MSLITE_ENABLE_PARALLEL_INFERENCE)
|
||||
add_compile_definitions(PARALLEL_INFERENCE)
|
||||
endif()
|
||||
|
||||
if(MSLITE_ENABLE_SHARING_MODEL_WEIGHT)
|
||||
add_compile_definitions(SHARING_MODEL_WEIGHT)
|
||||
endif()
|
||||
|
||||
if(MSLITE_ENABLE_CONVERTER)
|
||||
add_compile_definitions(ENABLE_CONVERTER)
|
||||
endif()
|
|
@ -50,4 +50,9 @@ if(MSLITE_DEPS_PYBIND11)
|
|||
include(${TOP_DIR}/cmake/external_libs/pybind11.cmake)
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(MSLITE_DEPS_OPENSSL)
|
||||
include(${TOP_DIR}/cmake/external_libs/openssl.cmake)
|
||||
add_compile_definitions(ENABLE_OPENSSL)
|
||||
endif()
|
|
@ -0,0 +1,393 @@
|
|||
set(BUILD_LITE "on")
|
||||
|
||||
include(${CMAKE_CURRENT_SOURCE_DIR}/secure_option.cmake)
|
||||
include(${CMAKE_CURRENT_SOURCE_DIR}/compile_link_option.cmake)
|
||||
|
||||
#Options that can be configured through environment variables or manually
|
||||
set(MSLITE_GPU_BACKEND "" CACHE STRING "enable gpu backend, \
|
||||
opencl only support arm64 and x86_64 , tensorrt only support x86_64, opencl/cuda/tensorrt/off")
|
||||
set(MSLITE_REGISTRY_DEVICE "off" CACHE STRING "Compile Mindspore Lite that supports specific devices, \
|
||||
currently supported devices: Hi3516D/Hi3519A/Hi3559A/SD3403")
|
||||
set(MSLITE_MICRO_PLATFORM "auto" CACHE STRING "Platform of micro static library micro static, \
|
||||
currently supported : cortex-m7/auto")
|
||||
if(NOT ENABLE_CLOUD_AND_LITE)
|
||||
set(MSLITE_MINDDATA_IMPLEMENT "lite_cv" CACHE STRING "off, lite_cv, cloud, or full")
|
||||
else()
|
||||
if(${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "aarch64")
|
||||
set(PLATFORM_ARM64 "on")
|
||||
set(MACHINE_LINUX_ARM64 "on")
|
||||
endif()
|
||||
endif()
|
||||
option(MSLITE_ENABLE_NPU "enable npu, only arm64 or arm32 support" off)
|
||||
option(MSLITE_ENABLE_TRAIN "enable train" on)
|
||||
option(MSLITE_ENABLE_SSE "enable SSE instruction set, only x86_64 support" off)
|
||||
option(MSLITE_ENABLE_AVX "enable AVX instruction set, only x86_64 support" off)
|
||||
option(MSLITE_ENABLE_AVX512 "enable AVX512 instruction set, only x86_64 support" off)
|
||||
option(MSLITE_ENABLE_CONVERTER "enable converter" on)
|
||||
option(MSLITE_ENABLE_TOOLS "enable tools" on)
|
||||
option(MSLITE_ENABLE_TESTCASES "enable testcase" off)
|
||||
option(MSLITE_ENABLE_RUNTIME_PASS "enable runtime pass" on)
|
||||
option(MSLITE_ENABLE_HIGH_PERFORMANCE "enable high performance" off)
|
||||
option(MSLITE_ENABLE_STRING_KERNEL "enable string kernel" on)
|
||||
option(MSLITE_ENABLE_CONTROLFLOW "enable control and tensorlist" on)
|
||||
option(MSLITE_ENABLE_AUTO_PARALLEL "enable automatic parallelism" on)
|
||||
option(MSLITE_ENABLE_WEIGHT_DECODE "enable weight decode" on)
|
||||
option(MSLITE_ENABLE_CUSTOM_KERNEL "enable extend kernel registry" on)
|
||||
option(MSLITE_ENABLE_MINDRT "enable mindrt use" on)
|
||||
option(MSLITE_ENABLE_DELEGATE "enable delegate use" on)
|
||||
option(MSLITE_ENABLE_FP16 "Whether to compile Fp16 operator" off)
|
||||
option(MSLITE_ENABLE_INT8 "Whether to compile Int8 operator" on)
|
||||
option(MSLITE_ENABLE_ACL "enable ACL" off)
|
||||
option(MSLITE_ENABLE_MODEL_ENCRYPTION "enable model encryption" off)
|
||||
option(MSLITE_ENABLE_SPARSE_COMPUTE "enable sparse kernel" off)
|
||||
option(MSLITE_ENABLE_RUNTIME_CONVERT "enable runtime convert" off)
|
||||
option(MSLITE_ENABLE_RUNTIME_GLOG "enable runtime glog" off)
|
||||
option(MSLITE_ENABLE_COVERAGE "enable code coverage" off)
|
||||
option(MSLITE_ENABLE_SERVER_INFERENCE "enable inference on server" off)
|
||||
option(MSLITE_ENABLE_DYNAMIC_THREAD_DISTRIBUTE "enable distribute thread dynamically" off)
|
||||
option(MSLITE_ENABLE_BFC_MEMORY "enable distribute BFC memory" off)
|
||||
option(MSLITE_ENABLE_PARALLEL_INFERENCE "enable parallel inference interface" off)
|
||||
option(MSLITE_ENABLE_SHARING_MODEL_WEIGHT "enable sharing model weight" off)
|
||||
option(MSLITE_ENABLE_EXPERIMENTAL_KERNEL "enable experimental kernel" on)
|
||||
option(MSLITE_ENABLE_GRAPH_KERNEL "enable graph kernel" off)
|
||||
option(MSLITE_ENABLE_CONVERT_PYTORCH_MODEL "enable to convert pytorch model" off)
|
||||
option(MSLITE_ENABLE_KERNEL_EXECUTOR "enable kernel executor" off)
|
||||
option(MSLITE_ENABLE_GITEE_MIRROR "enable download third_party from gitee mirror" off)
|
||||
option(MSLITE_ENABLE_CLOUD_FUSION_INFERENCE "enable cloud and device fusion inference architecture" off)
|
||||
|
||||
#Option that can be configured through manually
|
||||
option(ENABLE_VERBOSE "" off)
|
||||
option(ENABLE_MODEL_OBF "if support model obfuscation" off)
|
||||
set(VERSION_STR "1.7.0" CACHE STRING "get from version")
|
||||
|
||||
if(MACHINE_LINUX_ARM64)
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=armv8-a+fp16")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=armv8-a+fp16")
|
||||
endif()
|
||||
|
||||
if(DEFINED ENV{MSLITE_ENABLE_EXPERIMENTAL_KERNEL})
|
||||
set(MSLITE_ENABLE_EXPERIMENTAL_KERNEL $ENV{MSLITE_ENABLE_EXPERIMENTAL_KERNEL})
|
||||
endif()
|
||||
|
||||
if(DEFINED ENV{MSLITE_GPU_BACKEND})
|
||||
set(MSLITE_GPU_BACKEND $ENV{MSLITE_GPU_BACKEND})
|
||||
endif()
|
||||
if(DEFINED ENV{MSLITE_REGISTRY_DEVICE})
|
||||
set(MSLITE_REGISTRY_DEVICE $ENV{MSLITE_REGISTRY_DEVICE})
|
||||
endif()
|
||||
if(DEFINED ENV{MSLITE_MICRO_PLATFORM})
|
||||
set(MSLITE_MICRO_PLATFORM $ENV{MSLITE_MICRO_PLATFORM})
|
||||
endif()
|
||||
if(DEFINED ENV{MSLITE_ENABLE_NPU})
|
||||
set(MSLITE_ENABLE_NPU $ENV{MSLITE_ENABLE_NPU})
|
||||
endif()
|
||||
if(DEFINED ENV{MSLITE_ENABLE_TRAIN})
|
||||
set(MSLITE_ENABLE_TRAIN $ENV{MSLITE_ENABLE_TRAIN})
|
||||
endif()
|
||||
|
||||
if(DEFINED ENV{MSLITE_ENABLE_SERVER_INFERENCE})
|
||||
set(MSLITE_ENABLE_SERVER_INFERENCE $ENV{MSLITE_ENABLE_SERVER_INFERENCE})
|
||||
endif()
|
||||
if(MSLITE_ENABLE_SERVER_INFERENCE)
|
||||
set(MSLITE_ENABLE_DYNAMIC_THREAD_DISTRIBUTE on)
|
||||
set(MSLITE_ENABLE_BFC_MEMORY on)
|
||||
set(MSLITE_ENABLE_PARALLEL_INFERENCE on)
|
||||
set(MSLITE_ENABLE_SHARING_MODEL_WEIGHT on)
|
||||
set(MSLITE_ENABLE_RUNTIME_GLOG on)
|
||||
set(MSLITE_ENABLE_AVX512 on)
|
||||
endif()
|
||||
if(DEFINED ENV{MSLITE_ENABLE_SSE})
|
||||
set(MSLITE_ENABLE_SSE $ENV{MSLITE_ENABLE_SSE})
|
||||
endif()
|
||||
if(DEFINED ENV{MSLITE_ENABLE_AVX})
|
||||
set(MSLITE_ENABLE_AVX $ENV{MSLITE_ENABLE_AVX})
|
||||
endif()
|
||||
if(DEFINED ENV{MSLITE_ENABLE_AVX512})
|
||||
set(MSLITE_ENABLE_AVX512 $ENV{MSLITE_ENABLE_AVX512})
|
||||
endif()
|
||||
if(DEFINED ENV{MSLITE_ENABLE_CONVERTER})
|
||||
set(MSLITE_ENABLE_CONVERTER $ENV{MSLITE_ENABLE_CONVERTER})
|
||||
endif()
|
||||
if(DEFINED ENV{MSLITE_ENABLE_RUNTIME_CONVERT})
|
||||
set(MSLITE_ENABLE_RUNTIME_CONVERT $ENV{MSLITE_ENABLE_RUNTIME_CONVERT})
|
||||
endif()
|
||||
if(DEFINED ENV{ENABLE_AKG} AND NOT MSLITE_ENABLE_RUNTIME_CONVERT)
|
||||
set(MSLITE_ENABLE_GRAPH_KERNEL $ENV{ENABLE_AKG})
|
||||
endif()
|
||||
if(DEFINED ENV{MSLITE_ENABLE_TOOLS})
|
||||
set(MSLITE_ENABLE_TOOLS $ENV{MSLITE_ENABLE_TOOLS})
|
||||
endif()
|
||||
if(DEFINED ENV{MSLITE_ENABLE_TESTCASES})
|
||||
set(MSLITE_ENABLE_TESTCASES $ENV{MSLITE_ENABLE_TESTCASES})
|
||||
endif()
|
||||
if(DEFINED ENV{MSLITE_ENABLE_RUNTIME_PASS})
|
||||
set(MSLITE_ENABLE_RUNTIME_PASS $ENV{MSLITE_ENABLE_RUNTIME_PASS})
|
||||
endif()
|
||||
if(DEFINED ENV{MSLITE_ENABLE_HIGH_PERFORMANCE})
|
||||
set(MSLITE_ENABLE_HIGH_PERFORMANCE $ENV{MSLITE_ENABLE_HIGH_PERFORMANCE})
|
||||
endif()
|
||||
if(DEFINED ENV{MSLITE_ENABLE_STRING_KERNEL})
|
||||
set(MSLITE_ENABLE_STRING_KERNEL $ENV{MSLITE_ENABLE_STRING_KERNEL})
|
||||
endif()
|
||||
if(DEFINED ENV{MSLITE_ENABLE_CONTROLFLOW})
|
||||
set(MSLITE_ENABLE_CONTROLFLOW $ENV{MSLITE_ENABLE_CONTROLFLOW})
|
||||
endif()
|
||||
if(DEFINED ENV{MSLITE_ENABLE_AUTO_PARALLEL})
|
||||
set(MSLITE_ENABLE_AUTO_PARALLEL $ENV{MSLITE_ENABLE_AUTO_PARALLEL})
|
||||
endif()
|
||||
if(DEFINED ENV{MSLITE_ENABLE_WEIGHT_DECODE})
|
||||
set(MSLITE_ENABLE_WEIGHT_DECODE $ENV{MSLITE_ENABLE_WEIGHT_DECODE})
|
||||
endif()
|
||||
if(DEFINED ENV{MSLITE_ENABLE_CUSTOM_KERNEL})
|
||||
set(MSLITE_ENABLE_CUSTOM_KERNEL $ENV{MSLITE_ENABLE_CUSTOM_KERNEL})
|
||||
endif()
|
||||
if(DEFINED ENV{MSLITE_ENABLE_MINDRT})
|
||||
set(MSLITE_ENABLE_MINDRT $ENV{MSLITE_ENABLE_MINDRT})
|
||||
endif()
|
||||
if(DEFINED ENV{MSLITE_ENABLE_DELEGATE})
|
||||
set(MSLITE_ENABLE_DELEGATE $ENV{MSLITE_ENABLE_DELEGATE})
|
||||
endif()
|
||||
if(DEFINED ENV{MSLITE_ENABLE_FP16})
|
||||
set(MSLITE_ENABLE_FP16 $ENV{MSLITE_ENABLE_FP16})
|
||||
endif()
|
||||
if(DEFINED ENV{MSLITE_ENABLE_INT8})
|
||||
set(MSLITE_ENABLE_INT8 $ENV{MSLITE_ENABLE_INT8})
|
||||
endif()
|
||||
if(DEFINED ENV{MSLITE_ENABLE_SPARSE_COMPUTE})
|
||||
set(MSLITE_ENABLE_SPARSE_COMPUTE $ENV{MSLITE_ENABLE_SPARSE_COMPUTE})
|
||||
endif()
|
||||
if(DEFINED ENV{MSLITE_ENABLE_ACL})
|
||||
set(MSLITE_ENABLE_ACL $ENV{MSLITE_ENABLE_ACL})
|
||||
endif()
|
||||
if(DEFINED ENV{MSLITE_MINDDATA_IMPLEMENT})
|
||||
set(MSLITE_MINDDATA_IMPLEMENT $ENV{MSLITE_MINDDATA_IMPLEMENT})
|
||||
endif()
|
||||
if(DEFINED ENV{MSLITE_ENABLE_MODEL_ENCRYPTION})
|
||||
if((${CMAKE_SYSTEM_NAME} MATCHES "Linux" AND PLATFORM_X86_64)
|
||||
OR((PLATFORM_ARM64 OR PLATFORM_ARM32) AND ANDROID_NDK_TOOLCHAIN_INCLUDED))
|
||||
set(MSLITE_ENABLE_MODEL_ENCRYPTION $ENV{MSLITE_ENABLE_MODEL_ENCRYPTION})
|
||||
else()
|
||||
set(MSLITE_ENABLE_MODEL_ENCRYPTION OFF)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(DEFINED ENV{MSLITE_ENABLE_COVERAGE})
|
||||
set(MSLITE_ENABLE_COVERAGE $ENV{MSLITE_ENABLE_COVERAGE})
|
||||
endif()
|
||||
|
||||
if(DEFINED ENV{MSLITE_ENABLE_SERVING})
|
||||
set(MSLITE_ENABLE_SERVING $ENV{MSLITE_ENABLE_SERVING})
|
||||
endif()
|
||||
if(DEFINED ENV{MSLITE_ENABLE_KERNEL_EXECUTOR})
|
||||
set(MSLITE_ENABLE_KERNEL_EXECUTOR $ENV{MSLITE_ENABLE_KERNEL_EXECUTOR})
|
||||
endif()
|
||||
|
||||
if(DEFINED ENV{MSLITE_ENABLE_CONVERT_PYTORCH_MODEL} AND DEFINED ENV{LIB_TORCH_PATH})
|
||||
set(ENABLE_CONVERT_PYTORCH_MODEL $ENV{MSLITE_ENABLE_CONVERT_PYTORCH_MODEL})
|
||||
set(LIB_TORCH_PATH $ENV{LIB_TORCH_PATH})
|
||||
endif()
|
||||
|
||||
if(DEFINED ENV{MSLITE_ENABLE_GITEE_MIRROR})
|
||||
set(MSLITE_ENABLE_GITEE_MIRROR $ENV{MSLITE_ENABLE_GITEE_MIRROR})
|
||||
endif()
|
||||
|
||||
if(MSLITE_ENABLE_GITEE_MIRROR)
|
||||
set(ENABLE_GITEE ON)
|
||||
endif()
|
||||
|
||||
if(DEFINED ENV{MSLITE_ENABLE_CLOUD_FUSION_INFERENCE})
|
||||
set(MSLITE_ENABLE_CLOUD_FUSION_INFERENCE $ENV{MSLITE_ENABLE_CLOUD_FUSION_INFERENCE})
|
||||
endif()
|
||||
|
||||
if(TOOLCHAIN_NAME STREQUAL "himix200")
|
||||
set(TARGET_HIMIX on)
|
||||
set(TARGET_HIMIX200 on)
|
||||
elseif(TOOLCHAIN_NAME STREQUAL "himix100")
|
||||
set(TARGET_HIMIX on)
|
||||
set(TARGET_HIMIX100 on)
|
||||
elseif(TOOLCHAIN_NAME STREQUAL "mix210")
|
||||
set(TARGET_MIX210 on)
|
||||
elseif(TOOLCHAIN_NAME STREQUAL "ohos-lite")
|
||||
set(TARGET_OHOS_LITE on)
|
||||
endif()
|
||||
|
||||
if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7.3.0
|
||||
AND NOT TARGET_HIMIX AND NOT TARGET_MIX210)
|
||||
message(FATAL_ERROR "GCC version ${CMAKE_CXX_COMPILER_VERSION} must not be less than 7.3.0")
|
||||
endif()
|
||||
|
||||
if(NOT MSLITE_ENABLE_ACL)
|
||||
set(ENABLE_GLIBCXX ON)
|
||||
else()
|
||||
set(MSLITE_ENABLE_TRAIN off)
|
||||
endif()
|
||||
|
||||
if(PLATFORM_ARM64)
|
||||
if(MSLITE_GPU_BACKEND STREQUAL "")
|
||||
set(MSLITE_GPU_BACKEND "opencl")
|
||||
endif()
|
||||
if((NOT MSLITE_GPU_BACKEND STREQUAL "opencl") AND (NOT MSLITE_GPU_BACKEND STREQUAL "off"))
|
||||
message("invalid MSLITE_GPU_BACKEND value ${MSLITE_GPU_BACKEND} for arm64, MSLITE_GPU_BACKEND is set to off.")
|
||||
set(MSLITE_GPU_BACKEND "off")
|
||||
endif()
|
||||
elseif(PLATFORM_ARM32)
|
||||
if((NOT MSLITE_GPU_BACKEND STREQUAL "opencl") AND (NOT MSLITE_GPU_BACKEND STREQUAL "off") AND
|
||||
(NOT MSLITE_GPU_BACKEND STREQUAL ""))
|
||||
message("invalid MSLITE_GPU_BACKEND value ${MSLITE_GPU_BACKEND} for arm32, MSLITE_GPU_BACKEND is set to off.")
|
||||
set(MSLITE_GPU_BACKEND "off")
|
||||
endif()
|
||||
elseif(WIN32)
|
||||
set(MSLITE_GPU_BACKEND "off")
|
||||
else()
|
||||
if(${MSLITE_REGISTRY_DEVICE} STREQUAL "SD3403")
|
||||
set(MSLITE_ENABLE_DPICO_ATC_ADAPTER on)
|
||||
endif()
|
||||
if(MSLITE_GPU_BACKEND STREQUAL "")
|
||||
set(MSLITE_GPU_BACKEND "off")
|
||||
endif()
|
||||
if((NOT MSLITE_GPU_BACKEND STREQUAL "tensorrt") AND (NOT MSLITE_GPU_BACKEND STREQUAL "off") AND
|
||||
(NOT MSLITE_GPU_BACKEND STREQUAL "cuda") AND (NOT MSLITE_GPU_BACKEND STREQUAL "opencl"))
|
||||
message("invalid MSLITE_GPU_BACKEND value ${MSLITE_GPU_BACKEND} for x86_64, MSLITE_GPU_BACKEND is set to off.")
|
||||
set(MSLITE_GPU_BACKEND "off")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(PLATFORM_ARM64 OR PLATFORM_ARM32)
|
||||
set(PLATFORM_ARM "on")
|
||||
set(MSLITE_ENABLE_SSE off)
|
||||
set(MSLITE_ENABLE_AVX off)
|
||||
set(MSLITE_ENABLE_AVX512 off)
|
||||
if(NOT MACHINE_LINUX_ARM64)
|
||||
set(MSLITE_ENABLE_CONVERTER off)
|
||||
endif()
|
||||
set(MSLITE_ENABLE_RUNTIME_GLOG off)
|
||||
set(MSLITE_ENABLE_RUNTIME_CONVERT off)
|
||||
#set for cross - compiling toolchain
|
||||
set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY BOTH)
|
||||
set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE BOTH)
|
||||
set(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE BOTH)
|
||||
else()
|
||||
set(MSLITE_ENABLE_NPU off)
|
||||
endif()
|
||||
|
||||
if(DEFINED ENV{MSLITE_ENABLE_RUNTIME_GLOG})
|
||||
set(MSLITE_ENABLE_RUNTIME_GLOG $ENV{MSLITE_ENABLE_RUNTIME_GLOG})
|
||||
endif()
|
||||
|
||||
if(DEFINED ENV{MSLITE_ENABLE_DYNAMIC_THREAD_DISTRIBUTE})
|
||||
set(MSLITE_ENABLE_DYNAMIC_THREAD_DISTRIBUTE $ENV{MSLITE_ENABLE_DYNAMIC_THREAD_DISTRIBUTE})
|
||||
endif()
|
||||
|
||||
if(DEFINED ENV{MSLITE_ENABLE_BFC_MEMORY})
|
||||
set(MSLITE_ENABLE_BFC_MEMORY $ENV{MSLITE_ENABLE_BFC_MEMORY})
|
||||
endif()
|
||||
|
||||
if(DEFINED ENV{MSLITE_ENABLE_PARALLEL_INFERENCE})
|
||||
set(MSLITE_ENABLE_PARALLEL_INFERENCE $ENV{MSLITE_ENABLE_PARALLEL_INFERENCE})
|
||||
endif()
|
||||
|
||||
if(DEFINED ENV{MSLITE_ENABLE_SHARING_MODEL_WEIGHT})
|
||||
set(MSLITE_ENABLE_SHARING_MODEL_WEIGHT $ENV{MSLITE_ENABLE_SHARING_MODEL_WEIGHT})
|
||||
endif()
|
||||
|
||||
if(MSLITE_ENABLE_SSE OR MSLITE_ENABLE_AVX OR MSLITE_ENABLE_AVX512 OR WIN32)
|
||||
set(MSLITE_ENABLE_RUNTIME_CONVERT off)
|
||||
endif()
|
||||
|
||||
if(MSLITE_ENABLE_TRAIN AND NOT MSLITE_ENABLE_WEIGHT_DECODE)
|
||||
message(FATAL_ERROR "If MSLITE_ENABLE_WEIGHT_DECODE use if configured as off, "
|
||||
"MSLITE_ENABLE_TRAIN must also be configured as off")
|
||||
endif()
|
||||
|
||||
if(MSLITE_ENABLE_CONTROLFLOW AND NOT MSLITE_ENABLE_MINDRT)
|
||||
message(FATAL_ERROR "If MSLITE_ENABLE_MINDRT use if configured as off, "
|
||||
"MSLITE_ENABLE_CONTROLFLOW must also be configured as off")
|
||||
endif()
|
||||
|
||||
if(MSLITE_ENABLE_RUNTIME_CONVERT)
|
||||
set(MSLITE_ENABLE_RUNTIME_GLOG on)
|
||||
set(MSLITE_ENABLE_CONVERTER on)
|
||||
endif()
|
||||
|
||||
if(MSLITE_ENABLE_CLOUD_FUSION_INFERENCE)
|
||||
set(MSLITE_ENABLE_RUNTIME_GLOG on)
|
||||
endif()
|
||||
|
||||
if(MSLITE_ENABLE_TRAIN)
|
||||
set(SUPPORT_TRAIN on)
|
||||
if(NOT MSLITE_MINDDATA_IMPLEMENT STREQUAL "off" OR NOT PLATFORM_ARM)
|
||||
set(MSLITE_MINDDATA_IMPLEMENT full)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(MSLITE_ENABLE_NPU)
|
||||
set(SUPPORT_NPU on)
|
||||
if(NOT PLATFORM_ARM)
|
||||
message(FATAL_ERROR "NPU only support platform arm.")
|
||||
endif()
|
||||
if(DEFINED ENV{HWHIAI_DDK})
|
||||
message("HWHIAI_DDK=$ENV{HWHIAI_DDK}")
|
||||
else()
|
||||
message(FATAL_ERROR "please set HWHIAI_DDK, example: export HWHIAI_DDK=/root/usr/hwhiai-ddk-100.510.010.010/")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(TARGET_HIMIX OR TARGET_OHOS_LITE)
|
||||
set(MSLITE_ENABLE_MINDRT off)
|
||||
endif()
|
||||
|
||||
if(MSVC)
|
||||
set(MSLITE_ENABLE_CONVERTER off)
|
||||
endif()
|
||||
|
||||
if(MSLITE_GPU_BACKEND STREQUAL cuda)
|
||||
set(MSLITE_ENABLE_CONVERTER on)
|
||||
set(MSLITE_ENABLE_RUNTIME_GLOG on)
|
||||
endif()
|
||||
|
||||
if(MSLITE_ENABLE_FP16 AND PLATFORM_ARM32 AND CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
|
||||
if(CMAKE_CXX_COMPILER_VERSION VERSION_LESS 9.0 OR CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 12.0)
|
||||
message(STATUS "If you want to build fp16 in arm82_a32, please use android nkd r21e or r22b!")
|
||||
set(MSLITE_ENABLE_FP16 off)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
message(STATUS "************MindSpore Lite Build Option:************")
|
||||
message(STATUS "\tMSLITE_GPU_BACKEND = \t${MSLITE_GPU_BACKEND}")
|
||||
message(STATUS "\tMSLITE_REGISTRY_DEVICE = \t${MSLITE_REGISTRY_DEVICE}")
|
||||
message(STATUS "\tMSLITE_ENABLE_NPU = \t${MSLITE_ENABLE_NPU}")
|
||||
message(STATUS "\tMSLITE_ENABLE_TRAIN = \t${MSLITE_ENABLE_TRAIN}")
|
||||
message(STATUS "\tMSLITE_MICRO_PLATFORM = \t${MSLITE_MICRO_PLATFORM}")
|
||||
message(STATUS "\tMSLITE_ENABLE_SSE = \t${MSLITE_ENABLE_SSE}")
|
||||
message(STATUS "\tMSLITE_ENABLE_AVX = \t${MSLITE_ENABLE_AVX}")
|
||||
message(STATUS "\tMSLITE_ENABLE_AVX512 = \t${MSLITE_ENABLE_AVX512}")
|
||||
message(STATUS "\tMSLITE_ENABLE_CONVERTER = \t${MSLITE_ENABLE_CONVERTER}")
|
||||
message(STATUS "\tMSLITE_ENABLE_TOOLS = \t${MSLITE_ENABLE_TOOLS}")
|
||||
message(STATUS "\tMSLITE_ENABLE_TESTCASES = \t${MSLITE_ENABLE_TESTCASES}")
|
||||
message(STATUS "\tMSLITE_ENABLE_HIGH_PERFORMANCE = \t${MSLITE_ENABLE_HIGH_PERFORMANCE}")
|
||||
message(STATUS "\tMSLITE_ENABLE_RUNTIME_PASS = \t${MSLITE_ENABLE_RUNTIME_PASS}")
|
||||
message(STATUS "\tMSLITE_ENABLE_STRING_KERNEL = \t${MSLITE_ENABLE_STRING_KERNEL}")
|
||||
message(STATUS "\tMSLITE_ENABLE_CONTROLFLOW = \t${MSLITE_ENABLE_CONTROLFLOW}")
|
||||
message(STATUS "\tMSLITE_ENABLE_AUTO_PARALLEL = \t${MSLITE_ENABLE_AUTO_PARALLEL}")
|
||||
message(STATUS "\tMSLITE_ENABLE_WEIGHT_DECODE = \t${MSLITE_ENABLE_WEIGHT_DECODE}")
|
||||
message(STATUS "\tMSLITE_ENABLE_CUSTOM_KERNEL = \t${MSLITE_ENABLE_CUSTOM_KERNEL}")
|
||||
message(STATUS "\tMSLITE_ENABLE_MINDRT = \t${MSLITE_ENABLE_MINDRT}")
|
||||
message(STATUS "\tMSLITE_MINDDATA_IMPLEMENT = \t${MSLITE_MINDDATA_IMPLEMENT}")
|
||||
message(STATUS "\tMSLITE_ENABLE_DELEGATE = \t${MSLITE_ENABLE_DELEGATE}")
|
||||
message(STATUS "\tMSLITE_ENABLE_ACL = \t${MSLITE_ENABLE_ACL}")
|
||||
message(STATUS "\tMSLITE_ENABLE_FP16 = \t${MSLITE_ENABLE_FP16}")
|
||||
message(STATUS "\tMSLITE_ENABLE_INT8 = \t${MSLITE_ENABLE_INT8}")
|
||||
message(STATUS "\tMSLITE_ENABLE_MODEL_ENCRYPTION = \t${MSLITE_ENABLE_MODEL_ENCRYPTION}")
|
||||
message(STATUS "\tMSLITE_ENABLE_SPARSE_COMPUTE = \t${MSLITE_ENABLE_SPARSE_COMPUTE}")
|
||||
message(STATUS "\tMSLITE_ENABLE_RUNTIME_CONVERT = \t${MSLITE_ENABLE_RUNTIME_CONVERT}")
|
||||
message(STATUS "\tMSLITE_ENABLE_RUNTIME_GLOG = \t${MSLITE_ENABLE_RUNTIME_GLOG}")
|
||||
message(STATUS "\tMSLITE_ENABLE_COVERAGE = \t${MSLITE_ENABLE_COVERAGE}")
|
||||
message(STATUS "\tMSLITE_ENABLE_SERVER_INFERENCE = \t${MSLITE_ENABLE_SERVER_INFERENCE}")
|
||||
message(STATUS "\tMSLITE_ENABLE_DYNAMIC_THREAD_DISTRIBUTE = \t${MSLITE_ENABLE_DYNAMIC_THREAD_DISTRIBUTE}")
|
||||
message(STATUS "\tMSLITE_ENABLE_BFC_MEMORY = \t${MSLITE_ENABLE_BFC_MEMORY}")
|
||||
message(STATUS "\tMSLITE_ENABLE_PARALLEL_INFERENCE = \t${MSLITE_ENABLE_PARALLEL_INFERENCE}")
|
||||
message(STATUS "\tMSLITE_ENABLE_SHARING_MODEL_WEIGHT = \t${MSLITE_ENABLE_SHARING_MODEL_WEIGHT}")
|
||||
message(STATUS "\tMSLITE_ENABLE_EXPERIMENTAL_KERNEL = \t${MSLITE_ENABLE_EXPERIMENTAL_KERNEL}")
|
||||
message(STATUS "\tMSLITE_ENABLE_GRAPH_KERNEL = \t${MSLITE_ENABLE_GRAPH_KERNEL}")
|
||||
message(STATUS "\tMSLITE_ENABLE_KERNEL_EXECUTOR = \t${MSLITE_ENABLE_KERNEL_EXECUTOR}")
|
||||
message(STATUS "\tMSLITE_ENABLE_CLOUD_FUSION_INFERENCE = \t${MSLITE_ENABLE_CLOUD_FUSION_INFERENCE}")
|
|
@ -8,6 +8,9 @@ set(MSLITE_EXTEND_RUNTIME_SRC ${MSLITE_EXTEND_RUNTIME_SRC}
|
|||
|
||||
if(MSLITE_ENABLE_CLOUD_FUSION_INFERENCE)
|
||||
set(ENABLE_CPU on)
|
||||
# set(ENABLE_IBVERBS "OFF")
|
||||
# set(ENABLE_DEBUGGER on)
|
||||
# add_compile_definitions(ENABLE_DEBUGGER)
|
||||
add_compile_definitions(USE_GLOG)
|
||||
string(REPLACE "-fno-rtti" "" CMAKE_C_FLAGS ${CMAKE_C_FLAGS})
|
||||
string(REPLACE "-fno-rtti" "" CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS})
|
||||
|
@ -29,8 +32,8 @@ if(MSLITE_ENABLE_CLOUD_FUSION_INFERENCE)
|
|||
set(MSLITE_EXTEND_RUNTIME_SRC ${MSLITE_EXTEND_RUNTIME_SRC}
|
||||
${MINDIR_MODEL_SRC}
|
||||
${MINDIR_KERNEL_SRC}
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/mindir_loader/mindir_model/less_test_kernel_mod.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/mindir_loader/mindir_model/kernel_mod_mock.cc)
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/mindir_loader/mindir_model/less_test_kernel_mod.cc)
|
||||
# ${CMAKE_CURRENT_SOURCE_DIR}/mindir_loader/mindir_model/kernel_mod_mock.cc)
|
||||
|
||||
set(FBS_FILES
|
||||
${CCSRC_DIR}/../schema/cipher.fbs
|
||||
|
@ -55,15 +58,68 @@ if(MSLITE_ENABLE_CLOUD_FUSION_INFERENCE)
|
|||
|
||||
add_library(mindspore-lite-proto OBJECT ${MSLITE_PROTO_SRC})
|
||||
|
||||
# add_subdirectory(${CCSRC_DIR} mindspore_ccsrc)
|
||||
|
||||
set(ANF_ALG_SRC ${ANF_ALG_SRC}
|
||||
${CCSRC_DIR}/utils/anfalgo.cc
|
||||
${CCSRC_DIR}/utils/parallel_context.cc
|
||||
${CCSRC_DIR}/utils/convert_utils.cc)
|
||||
add_library(mindspore-infer-anfalgo OBJECT ${ANF_ALG_SRC})
|
||||
|
||||
set(KERNEL_GRAPH_SRC ${KERNEL_GRAPH_SRC}
|
||||
${CCSRC_DIR}/backend/common/session/kernel_graph.cc
|
||||
${CCSRC_DIR}/backend/common/session/anf_runtime_algorithm.cc
|
||||
${CCSRC_DIR}/backend/common/session/session_basic.cc
|
||||
${CCSRC_DIR}/backend/common/session/session_factory.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/executor.cc
|
||||
${CCSRC_DIR}/backend/common/session/executor_manager.cc
|
||||
${CCSRC_DIR}/backend/common/somas/somas.cc
|
||||
${CCSRC_DIR}/backend/common/somas/somas_tensor.cc
|
||||
${CCSRC_DIR}/backend/common/somas/somas_solver_pre.cc
|
||||
${CCSRC_DIR}/backend/common/somas/somas_solver_core.cc
|
||||
${CCSRC_DIR}/backend/common/somas/somas_solver_alg.cc
|
||||
# ${CCSRC_DIR}/backend/common/optimizer/helper.cc
|
||||
# ${CCSRC_DIR}/backend/common/optimizer/const_input_to_attr.cc
|
||||
# ${CCSRC_DIR}/backend/common/optimizer/pattern_engine.cc
|
||||
# ${CCSRC_DIR}/backend/common/optimizer/visit.cc
|
||||
# ${CCSRC_DIR}/backend/common/optimizer/common_backend_optimization.cc
|
||||
${CCSRC_DIR}/runtime/device/ms_device_shape_transfer.cc
|
||||
${CCSRC_DIR}/runtime/device/kernel_info.cc
|
||||
${CCSRC_DIR}/runtime/device/convert_tensor_utils.cc
|
||||
${CCSRC_DIR}/runtime/device/kernel_runtime_manager.cc
|
||||
${CCSRC_DIR}/runtime/device/bucket.cc
|
||||
${CCSRC_DIR}/runtime/device/kernel_runtime.cc
|
||||
${CCSRC_DIR}/runtime/device/memory_scheduler.cc
|
||||
${CCSRC_DIR}/runtime/device/memory_offload_strategy.cc
|
||||
${CCSRC_DIR}/runtime/device/memory_manager.cc
|
||||
${CCSRC_DIR}/runtime/pynative/op_executor.cc
|
||||
${CCSRC_DIR}/runtime/pynative/op_runtime_info.cc
|
||||
${CCSRC_DIR}/runtime/hardware/device_type.cc
|
||||
${CCSRC_DIR}/kernel/kernel_build_info.cc
|
||||
${CCSRC_DIR}/kernel/common_utils.cc
|
||||
${CCSRC_DIR}/kernel/kernel.cc
|
||||
${CCSRC_DIR}/kernel/kash/kernel_pack.cc
|
||||
${CCSRC_DIR}/kernel/oplib/oplib.cc
|
||||
${CCSRC_DIR}/common/debug/common.cc
|
||||
${CCSRC_DIR}/common/debug/env_config_parser.cc
|
||||
${CCSRC_DIR}/common/thread_pool.cc
|
||||
${CCSRC_DIR}/utils/scoped_long_running.cc
|
||||
${CCSRC_DIR}/utils/cse.cc
|
||||
${CCSRC_DIR}/utils/comm_manager.cc)
|
||||
add_library(mindspore-kernel-graph OBJECT ${KERNEL_GRAPH_SRC})
|
||||
add_dependencies(mindspore-kernel-graph mindspore-lite-proto)
|
||||
|
||||
add_library(mindspore-extendrt SHARED ${MSLITE_EXTEND_RUNTIME_SRC})
|
||||
add_dependencies(mindspore-extendrt fbs_inner_src)
|
||||
add_dependencies(mindspore-extendrt generated_fbs_files)
|
||||
add_dependencies(mindspore-extendrt mindspore-lite-proto)
|
||||
add_dependencies(mindspore-extendrt mindspore-infer-anfalgo)
|
||||
add_dependencies(mindspore-extendrt mindspore-kernel-graph)
|
||||
add_subdirectory(cxx_api)
|
||||
|
||||
add_subdirectory(${CCSRC_DIR}/transform/graph_ir graph_ir)
|
||||
add_subdirectory(${CCSRC_DIR}/backend/common/session common_session)
|
||||
add_subdirectory(cxx_api)
|
||||
|
||||
add_subdirectory(${CCSRC_DIR}/backend/common/pass common_pass)
|
||||
add_subdirectory(${CCSRC_DIR}/utils mindspore_ccsrc_utils)
|
||||
add_subdirectory(${CCSRC_DIR}/runtime/device mindspore_ccsrc_runtime_device)
|
||||
add_subdirectory(${CCSRC_DIR}/runtime/pynative mindspore_ccsrc_runtime_pynative)
|
||||
|
@ -77,7 +133,7 @@ if(MSLITE_ENABLE_CLOUD_FUSION_INFERENCE)
|
|||
add_subdirectory(${CCSRC_DIR}/pybind_api mindspore_ccsrc_pybind_api)
|
||||
add_subdirectory(${CCSRC_DIR}/ps mindspore_ccsrc_ps)
|
||||
|
||||
target_link_libraries(mindspore-extendrt mindspore_shared_lib_obj)
|
||||
target_link_libraries(mindspore-extendrt mindspore_infer_shared_lib_obj)
|
||||
# target_link_libraries(mindspore-extendrt _mindspore_backend_common_session_obj _mindspore_transform_graph_ir_obj)
|
||||
# target_link_libraries(mindspore-extendrt _mindspore_backend_common_session_obj
|
||||
# _mindspore_backend_common_optimizer_obj _mindspore_runtime_device_obj
|
||||
|
@ -87,6 +143,10 @@ if(MSLITE_ENABLE_CLOUD_FUSION_INFERENCE)
|
|||
# _mindspore_common_obj _mindspore_common_mem_reuse_obj
|
||||
# _mindspore_plugin_device_cpu_kernel_obj
|
||||
# _mindspore_ps_obj ps_cache)
|
||||
|
||||
target_link_libraries(mindspore-extendrt mindspore-infer-anfalgo
|
||||
mindspore-kernel-graph _mindspore_backend_common_optimizer_obj
|
||||
_mindspore_backend_common_pass_obj)
|
||||
target_link_libraries(mindspore-extendrt mindspore_core mindspore::protobuf mindspore::pybind11_module)
|
||||
# target_link_libraries(mindspore-extendrt )
|
||||
|
||||
|
|
|
@ -85,7 +85,7 @@ if(BUILD_LITE)
|
|||
${ACL_REMOVE_SRC})
|
||||
endif()
|
||||
|
||||
add_library(mindspore_shared_lib_obj OBJECT ${MSLIB_SRC})
|
||||
add_library(mindspore_infer_shared_lib_obj OBJECT ${MSLIB_SRC})
|
||||
# add_library(mindspore_shared_lib SHARED $<TARGET_OBJECTS:mindspore_shared_lib_obj>)
|
||||
# if(BUILD_LITE)
|
||||
# target_link_libraries(mindspore_shared_lib PRIVATE $<TARGET_OBJECTS:_mindspore_transform_graph_ir_obj>)
|
||||
|
|
|
@ -0,0 +1,490 @@
|
|||
/**
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/common/session/executor.h"
|
||||
#include "backend/common/session/executor_manager.h"
|
||||
#include <algorithm>
|
||||
#include <exception>
|
||||
#include <set>
|
||||
#include "runtime/device/kernel_runtime_manager.h"
|
||||
#include "include/common/utils/comm_manager.h"
|
||||
#include "include/common/utils/scoped_long_running.h"
|
||||
#if ((defined ENABLE_CPU) && (!defined _WIN32) && !defined(__APPLE__))
|
||||
#include "ps/ps_cache/ps_cache_manager.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
namespace session {
|
||||
namespace {
|
||||
void GetNeedNotifyTensors(const VectorRef *outputs, std::set<TensorPtr> *result) {
|
||||
MS_EXCEPTION_IF_NULL(outputs);
|
||||
MS_EXCEPTION_IF_NULL(result);
|
||||
for (auto &item : *outputs) {
|
||||
if (utils::isa<VectorRefPtr>(item)) {
|
||||
auto vector_ref = utils::cast<VectorRef>(item);
|
||||
GetNeedNotifyTensors(&vector_ref, result);
|
||||
} else if (utils::isa<tensor::TensorPtr>(item)) {
|
||||
auto tensor = utils::cast<tensor::TensorPtr>(item);
|
||||
result->emplace(tensor);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool TensorInVector(const VectorRef *outputs) {
|
||||
MS_EXCEPTION_IF_NULL(outputs);
|
||||
for (auto &item : *outputs) {
|
||||
if (utils::isa<VectorRefPtr>(item)) {
|
||||
auto vector_ref = utils::cast<VectorRef>(item);
|
||||
if (TensorInVector(&vector_ref)) {
|
||||
return true;
|
||||
}
|
||||
} else if (utils::isa<tensor::TensorPtr>(item)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool IsTaskReady(const std::shared_ptr<RunGraphTask> &task) {
|
||||
MS_EXCEPTION_IF_NULL(task);
|
||||
for (auto &input : task->input_need_wait_tensors_) {
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
if (input->NeedWait()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
auto session = task->session_;
|
||||
MS_EXCEPTION_IF_NULL(session);
|
||||
auto graph = session->GetGraph(task->graph_id_);
|
||||
if (graph != nullptr) {
|
||||
return graph->IsPreGraphFinished();
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void WaitLockedInputs(const std::shared_ptr<RunGraphTask> &task) {
|
||||
bool need_lock = false;
|
||||
for (auto &tensor : task->input_tensors_) {
|
||||
if (tensor->NeedWait()) {
|
||||
if (tensor->IsGraphOutput()) {
|
||||
task->input_need_wait_tensors_.emplace_back(tensor);
|
||||
} else {
|
||||
need_lock = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (need_lock) {
|
||||
mindspore::ScopedLongRunning long_running;
|
||||
for (auto &input_tensor : task->input_tensors_) {
|
||||
if (input_tensor->NeedWait() && !input_tensor->IsGraphOutput()) {
|
||||
MsException::Instance().CheckException();
|
||||
input_tensor->Wait();
|
||||
}
|
||||
}
|
||||
MsException::Instance().CheckException();
|
||||
}
|
||||
// need lock input parameters for optimizer
|
||||
for (auto &need_lock_tensor : task->input_need_lock_tensors_) {
|
||||
need_lock_tensor->SetNeedWait(true);
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void CompileNodesTask::Run() {
|
||||
MS_EXCEPTION_IF_NULL(session_);
|
||||
MS_EXCEPTION_IF_NULL(segment_);
|
||||
graph_id_ = session_->CompileGraphImpl(segment_->nodes_, output_nodes_);
|
||||
}
|
||||
|
||||
void CompileGraphTask::Run() {
|
||||
MS_EXCEPTION_IF_NULL(session_);
|
||||
graph_id_ = session_->CompileGraphImpl(NOT_NULL(func_graph_));
|
||||
}
|
||||
|
||||
void BuildGraphTask::Run() {
|
||||
MS_EXCEPTION_IF_NULL(session_);
|
||||
session_->BuildGraphImpl(graph_id_);
|
||||
}
|
||||
|
||||
void RunGraphTask::Run() {
|
||||
MS_EXCEPTION_IF_NULL(session_);
|
||||
MS_LOG(INFO) << "Start run graph " << graph_id_;
|
||||
auto graph = session_->GetGraph(graph_id_);
|
||||
if (graph == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid graph id " << graph_id_;
|
||||
return;
|
||||
}
|
||||
graph->ResetGraphRunningStatus();
|
||||
if (device::KernelRuntime::UseMemScheduler()) {
|
||||
graph->SetOutputNodeToTensor(node_to_tensor_);
|
||||
}
|
||||
try {
|
||||
session_->LoadInputs(graph_id_, input_tensors_);
|
||||
session_->RunGraphImpl(graph_id_, input_tensors_, &outputs_);
|
||||
std::map<DeviceAddressPtr, DeviceAddressPtr> new_to_old_device_address;
|
||||
session_->UpdateOutputTensors(&outputs_, tensor_to_node_, &new_to_old_device_address);
|
||||
} catch (const std::exception &e) {
|
||||
session_->ReportErrorMessage();
|
||||
ExecutorManager::Instance().OnEvent(ExecutorEvent::kException);
|
||||
MsException::Instance().SetException();
|
||||
}
|
||||
MS_LOG(INFO) << "End run graph " << graph_id_;
|
||||
graph->OnRunGraphFinished();
|
||||
std::set<TensorPtr> need_notify_tensors(input_need_lock_tensors_.begin(), input_need_lock_tensors_.end());
|
||||
GetNeedNotifyTensors(&outputs_, &need_notify_tensors);
|
||||
for (auto &tensor : need_notify_tensors) {
|
||||
if (tensor != nullptr) {
|
||||
tensor->SetNeedWait(false);
|
||||
}
|
||||
}
|
||||
ExecutorManager::Instance().OnEvent(ExecutorEvent::kRunGraphFinished);
|
||||
}
|
||||
|
||||
void RunOpTask::Run() {
|
||||
MS_EXCEPTION_IF_NULL(session_);
|
||||
session_->RunOpImpl(graph_info_, op_run_info_, input_tensors_, &outputs_, tensors_mask_);
|
||||
}
|
||||
|
||||
void RunOpsInGraphTask::Run() {
|
||||
MS_EXCEPTION_IF_NULL(session_);
|
||||
session_->RunOpsInGraphImpl(graph_id_, input_tensors_, &outputs_);
|
||||
}
|
||||
|
||||
void CreateCommGroupTask::Run() { result_ = CommManager::GetInstance().CreateGroupSync(group_name_, ranks_); }
|
||||
|
||||
void DestroyCommGroupTask::Run() { result_ = CommManager::GetInstance().DestroyGroup(group_name_); }
|
||||
|
||||
Executor::Executor(const std::string &device_name, uint32_t device_id) {
|
||||
device_name_ = device_name;
|
||||
device_id_ = device_id;
|
||||
worker_ = std::make_shared<std::thread>(&Executor::WorkerLoop, this);
|
||||
}
|
||||
|
||||
Executor::~Executor() {
|
||||
try {
|
||||
WorkerJoin();
|
||||
} catch (const std::exception &e) {
|
||||
MS_LOG(ERROR) << "Executor call destructor failed: " << e.what();
|
||||
} catch (...) {
|
||||
MS_LOG(ERROR) << "Executor call destructor failed.";
|
||||
}
|
||||
}
|
||||
|
||||
void Executor::WorkerJoin() {
|
||||
// Avoid worker thread join itself which will cause deadlock
|
||||
if (worker_->joinable() && worker_->get_id() != std::this_thread::get_id()) {
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(task_mutex_);
|
||||
auto task = std::make_shared<ExitTask>();
|
||||
ready_tasks_.push(task);
|
||||
task_cond_var_.notify_all();
|
||||
}
|
||||
worker_->join();
|
||||
}
|
||||
}
|
||||
|
||||
void Executor::WorkerLoop() {
|
||||
while (true) {
|
||||
std::shared_ptr<Task> task;
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(task_mutex_);
|
||||
task_cond_var_.wait(lock, [this] { return !ready_tasks_.empty(); });
|
||||
task = ready_tasks_.front();
|
||||
ready_tasks_.pop();
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(task);
|
||||
enum TaskType task_type = task->type_;
|
||||
bool task_sync_flag = task->sync_run_;
|
||||
if (task_type == kExit) {
|
||||
OnWorkerExit();
|
||||
return;
|
||||
}
|
||||
try {
|
||||
if (task->session_ != nullptr) {
|
||||
task->session_->SetThreadContext();
|
||||
}
|
||||
task->Run();
|
||||
if (task->session_ != nullptr) {
|
||||
task->session_->ReportWarningMessage();
|
||||
}
|
||||
} catch (const std::exception &e) {
|
||||
if (task->session_ != nullptr) {
|
||||
task->session_->ReportErrorMessage();
|
||||
}
|
||||
ExecutorManager::Instance().OnEvent(ExecutorEvent::kException);
|
||||
MsException::Instance().SetException();
|
||||
}
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(done_task_mutex_);
|
||||
done_tasks_.emplace_back(std::move(task));
|
||||
}
|
||||
if (task_type != kRunGraph || task_sync_flag) {
|
||||
std::lock_guard<std::mutex> lock(task_mutex_);
|
||||
sync_run_task_finished_ = true;
|
||||
sync_cond_var_.notify_all();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<RunGraphTask>> Executor::GetReadyTasksFromPendingList() {
|
||||
std::vector<std::shared_ptr<RunGraphTask>> ready_tasks;
|
||||
std::lock_guard<std::mutex> lock(pending_task_mutex_);
|
||||
for (auto iter = pending_tasks_.begin(); iter != pending_tasks_.end();) {
|
||||
auto task = *iter;
|
||||
if (IsTaskReady(task)) {
|
||||
(void)ready_tasks.emplace_back(task);
|
||||
iter = pending_tasks_.erase(iter);
|
||||
} else {
|
||||
++iter;
|
||||
}
|
||||
}
|
||||
return ready_tasks;
|
||||
}
|
||||
|
||||
void Executor::OnEvent(const ExecutorEvent &event) {
|
||||
if (event == ExecutorEvent::kRunGraphFinished) {
|
||||
OnRunGraphFinished();
|
||||
} else if (event == ExecutorEvent::kClear) {
|
||||
OnClear();
|
||||
} else if (event == ExecutorEvent::kException) {
|
||||
OnException();
|
||||
}
|
||||
}
|
||||
|
||||
void Executor::OnClear() {
|
||||
{
|
||||
mindspore::ScopedLongRunning long_running;
|
||||
WorkerJoin();
|
||||
}
|
||||
ClearDoneTasks();
|
||||
}
|
||||
|
||||
void Executor::OnException() {
|
||||
std::vector<std::shared_ptr<Task>> done_tasks;
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(task_mutex_);
|
||||
while (!ready_tasks_.empty()) {
|
||||
(void)done_tasks.emplace_back(ready_tasks_.front());
|
||||
ready_tasks_.pop();
|
||||
}
|
||||
}
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(pending_task_mutex_);
|
||||
(void)std::copy(pending_tasks_.begin(), pending_tasks_.end(), std::back_inserter(done_tasks));
|
||||
pending_tasks_.clear();
|
||||
}
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(done_task_mutex_);
|
||||
(void)done_tasks_.insert(done_tasks_.end(), done_tasks.begin(), done_tasks.end());
|
||||
}
|
||||
}
|
||||
|
||||
void Executor::OnRunGraphFinished() {
|
||||
auto ready_tasks = GetReadyTasksFromPendingList();
|
||||
std::lock_guard<std::mutex> lock(task_mutex_);
|
||||
for (auto &task : ready_tasks) {
|
||||
ready_tasks_.push(task);
|
||||
}
|
||||
if (!ready_tasks.empty()) {
|
||||
task_cond_var_.notify_all();
|
||||
}
|
||||
reenter_cond_var_.notify_all();
|
||||
}
|
||||
|
||||
void Executor::ClearDoneTasks() {
|
||||
std::lock_guard<std::mutex> lock(done_task_mutex_);
|
||||
done_tasks_.clear();
|
||||
}
|
||||
|
||||
void Executor::RunTask(const std::shared_ptr<Task> &task, bool sync, bool long_run) {
|
||||
if (sync) {
|
||||
ClearDoneTasks();
|
||||
}
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(task_mutex_);
|
||||
sync_run_task_finished_ = false;
|
||||
ready_tasks_.push(task);
|
||||
}
|
||||
task_cond_var_.notify_all();
|
||||
if (sync && !sync_run_task_finished_) {
|
||||
std::unique_lock<std::mutex> lock(task_mutex_);
|
||||
if (sync && long_run) {
|
||||
mindspore::ScopedLongRunning long_running;
|
||||
sync_cond_var_.wait(lock, [this] { return sync_run_task_finished_; });
|
||||
} else {
|
||||
sync_cond_var_.wait(lock, [this] { return sync_run_task_finished_; });
|
||||
}
|
||||
}
|
||||
ClearDoneTasks();
|
||||
MsException::Instance().CheckException();
|
||||
}
|
||||
|
||||
GraphId Executor::CompileGraph(const SessionPtr &session, const GraphSegmentPtr &segment,
|
||||
const AnfNodePtrList &outputs) {
|
||||
auto task = std::make_shared<CompileNodesTask>();
|
||||
task->session_ = session;
|
||||
task->segment_ = segment;
|
||||
task->output_nodes_ = outputs;
|
||||
RunTask(task, true);
|
||||
return task->graph_id_;
|
||||
}
|
||||
|
||||
GraphId Executor::CompileGraph(const SessionPtr &session, NotNull<FuncGraphPtr> func_graph) {
|
||||
auto task = std::make_shared<CompileGraphTask>();
|
||||
task->session_ = session;
|
||||
task->func_graph_ = func_graph.get();
|
||||
RunTask(task, true);
|
||||
return task->graph_id_;
|
||||
}
|
||||
|
||||
void Executor::BuildGraph(const SessionPtr &session, GraphId graphId) {
|
||||
auto task = std::make_shared<BuildGraphTask>();
|
||||
task->session_ = session;
|
||||
task->graph_id_ = graphId;
|
||||
RunTask(task, true);
|
||||
}
|
||||
|
||||
void Executor::RunGraph(const SessionPtr &session, const GraphId &graph_id,
|
||||
const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) {
|
||||
MS_EXCEPTION_IF_NULL(session);
|
||||
MS_EXCEPTION_IF_NULL(outputs);
|
||||
auto task = std::make_shared<RunGraphTask>();
|
||||
task->session_ = session;
|
||||
task->graph_id_ = graph_id;
|
||||
task->input_tensors_ = inputs;
|
||||
session->CreateOutputTensors(graph_id, inputs, outputs, &task->tensor_to_node_, &task->node_to_tensor_);
|
||||
task->outputs_ = *outputs;
|
||||
task->sync_run_ = true;
|
||||
RunTask(task, true, true);
|
||||
}
|
||||
|
||||
void Executor::RunGraphAsync(const SessionPtr &session, const GraphId &graph_id,
|
||||
const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) {
|
||||
MS_EXCEPTION_IF_NULL(session);
|
||||
MS_EXCEPTION_IF_NULL(outputs);
|
||||
auto task = std::make_shared<RunGraphTask>();
|
||||
task->session_ = session;
|
||||
task->graph_id_ = graph_id;
|
||||
task->input_tensors_ = inputs;
|
||||
task->input_need_lock_tensors_ = session->GetInputNeedLockTensors(graph_id, inputs);
|
||||
auto graph = session->GetGraph(task->graph_id_);
|
||||
if (graph != nullptr && !graph->IsPostGraphFinished()) {
|
||||
mindspore::ScopedLongRunning long_running;
|
||||
std::unique_lock<std::mutex> lock(reenter_mutex_);
|
||||
reenter_cond_var_.wait(lock, [&graph] { return graph->IsPostGraphFinished(); });
|
||||
MsException::Instance().CheckException();
|
||||
}
|
||||
session->CreateOutputTensors(graph_id, inputs, outputs, &task->tensor_to_node_, &task->node_to_tensor_);
|
||||
// maintain a copy of output vector
|
||||
task->outputs_ = *outputs;
|
||||
|
||||
// Run graph synchronously when the graph require gil.
|
||||
if (graph != nullptr && graph->is_need_gil()) {
|
||||
std::unique_lock<std::mutex> lock(reenter_mutex_);
|
||||
reenter_cond_var_.wait(lock, [&graph] { return graph->IsPreGraphFinished(); });
|
||||
MsException::Instance().CheckException();
|
||||
task->sync_run_ = true;
|
||||
RunTask(task, true, true);
|
||||
return;
|
||||
}
|
||||
|
||||
// sync run graph without output tensor(int dataset graph)
|
||||
if ((!TensorInVector(outputs) && !graph->HasPostGraph())) {
|
||||
task->sync_run_ = true;
|
||||
RunTask(task, true, true);
|
||||
return;
|
||||
}
|
||||
WaitLockedInputs(task);
|
||||
for (auto &tensor_node : task->tensor_to_node_) {
|
||||
tensor_node.first->SetNeedWait(true);
|
||||
}
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(pending_task_mutex_);
|
||||
if (!IsTaskReady(task)) {
|
||||
ClearDoneTasks();
|
||||
pending_tasks_.push_back(task);
|
||||
return;
|
||||
}
|
||||
}
|
||||
RunTask(task, false);
|
||||
}
|
||||
|
||||
void Executor::RunOp(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info,
|
||||
std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs,
|
||||
const std::vector<int64_t> &tensors_mask) {
|
||||
MS_EXCEPTION_IF_NULL(session);
|
||||
MS_EXCEPTION_IF_NULL(input_tensors);
|
||||
MS_EXCEPTION_IF_NULL(outputs);
|
||||
MS_EXCEPTION_IF_NULL(op_run_info);
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
auto target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
|
||||
if (target == kGPUDevice) {
|
||||
for (auto &tensor : *input_tensors) {
|
||||
if (tensor->NeedWait()) {
|
||||
tensor->Wait();
|
||||
}
|
||||
}
|
||||
session->RunOpImpl(graph_info, op_run_info, input_tensors, outputs, tensors_mask);
|
||||
} else {
|
||||
auto task = std::make_shared<RunOpTask>();
|
||||
task->session_ = session;
|
||||
task->op_run_info_ = op_run_info;
|
||||
task->graph_info_ = graph_info;
|
||||
task->input_tensors_ = input_tensors;
|
||||
task->tensors_mask_ = tensors_mask;
|
||||
for (auto &tensor : *input_tensors) {
|
||||
if (tensor->NeedWait()) {
|
||||
tensor->Wait();
|
||||
}
|
||||
}
|
||||
RunTask(task, true, true);
|
||||
*outputs = task->outputs_;
|
||||
}
|
||||
}
|
||||
|
||||
void Executor::RunOpsInGraph(const SessionPtr &session, const GraphId &graph_id,
|
||||
const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) {
|
||||
MS_EXCEPTION_IF_NULL(session);
|
||||
MS_EXCEPTION_IF_NULL(outputs);
|
||||
auto task = std::make_shared<RunOpsInGraphTask>();
|
||||
task->session_ = session;
|
||||
task->graph_id_ = graph_id;
|
||||
task->input_tensors_ = inputs;
|
||||
RunTask(task, true, true);
|
||||
*outputs = task->outputs_;
|
||||
}
|
||||
|
||||
bool Executor::CreateCommGroup(const std::string &group_name, const std::vector<uint32_t> &ranks) {
|
||||
auto task = std::make_shared<CreateCommGroupTask>();
|
||||
task->group_name_ = group_name;
|
||||
task->ranks_ = ranks;
|
||||
RunTask(task, true);
|
||||
return task->result_;
|
||||
}
|
||||
|
||||
bool Executor::DestroyCommGroup(const std::string &group_name) {
|
||||
auto task = std::make_shared<DestroyCommGroupTask>();
|
||||
task->group_name_ = group_name;
|
||||
RunTask(task, true);
|
||||
return task->result_;
|
||||
}
|
||||
|
||||
void Executor::OnWorkerExit() {
|
||||
if (device_name_ == kAscendDevice) {
|
||||
device::KernelRuntimeManager::Instance().ReleaseKernelRuntime(kAscendDevice, device_id_);
|
||||
}
|
||||
}
|
||||
} // namespace session
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,849 @@
|
|||
/**
|
||||
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "src/extendrt/utils/kernel_graph_utils.h"
|
||||
#include "ir/graph_utils.h"
|
||||
|
||||
namespace mindspore::infer {
|
||||
KernelGraphPtr KernelGraphUtils::ConstructKernelGraph(const FuncGraphPtr &func_graph,
|
||||
std::vector<KernelGraphPtr> *all_out_graph,
|
||||
DeviceType device_target) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(all_out_graph);
|
||||
auto node_list = TopoSort(func_graph->get_return());
|
||||
auto graph = NewKernelGraph();
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
front_backend_graph_map_[func_graph.get()] = graph;
|
||||
MS_LOG(INFO) << "Create graph: " << graph->graph_id();
|
||||
graph->set_device_target(device_target);
|
||||
for (const auto &node : node_list) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_LOG(DEBUG) << "Start create new cnode, node = " << node->DebugString();
|
||||
// Create parameter
|
||||
if (node->isa<Parameter>()) {
|
||||
auto graph_inputs = graph->MutableInputs();
|
||||
MS_EXCEPTION_IF_NULL(graph_inputs);
|
||||
auto new_parameter = CreateNewParameter(node, graph.get());
|
||||
graph_inputs->push_back(new_parameter);
|
||||
graph->FrontBackendMapAdd(node, new_parameter);
|
||||
continue;
|
||||
}
|
||||
// Create value node
|
||||
if (node->isa<ValueNode>()) {
|
||||
// Create common value node
|
||||
if (!IsValueNode<FuncGraph>(node)) {
|
||||
(void)CreateNewValueNode(node, graph.get());
|
||||
continue;
|
||||
}
|
||||
// Create child kernel graph according ValueNode<FuncGraph>
|
||||
FuncGraphPtr child_graph = common::AnfAlgo::GetValueNodeFuncGraph(node);
|
||||
if (front_backend_graph_map_.find(child_graph.get()) == front_backend_graph_map_.end()) {
|
||||
(void)ConstructKernelGraph(child_graph, all_out_graph, device_target);
|
||||
}
|
||||
(void)CreateValueNodeKernelGraph(node, graph.get());
|
||||
continue;
|
||||
}
|
||||
// Create cnode
|
||||
if (!CreateCNodeOfKernelGraph(node, graph.get())) {
|
||||
#ifdef ENABLE_DUMP_IR
|
||||
DumpIR("construct_kernel_graph_fail.ir", func_graph);
|
||||
#endif
|
||||
MS_LOG(EXCEPTION) << "Construct func graph " << func_graph->ToString() << " failed."
|
||||
<< trace::DumpSourceLines(node);
|
||||
}
|
||||
}
|
||||
|
||||
AddParameterToGraphInputs(func_graph->parameters(), graph.get());
|
||||
FuncGraphManagerPtr manager = MakeManager({graph});
|
||||
graph->SetInputNodes();
|
||||
SetInputNodeUsage(graph, manager);
|
||||
graph->SetExecOrderByDefault();
|
||||
|
||||
#ifndef ENABLE_SECURITY
|
||||
if (KernelGraphUtils::ExistSummaryNode(graph.get())) {
|
||||
graph->set_summary_node_exist(true);
|
||||
}
|
||||
#endif
|
||||
|
||||
all_out_graph->push_back(graph);
|
||||
return graph;
|
||||
}
|
||||
|
||||
KernelGraphPtr KernelGraphUtils::NewKernelGraph() {
|
||||
auto graph = std::make_shared<KernelGraph>();
|
||||
graph->set_graph_id(graph_sum_);
|
||||
graphs_[graph_sum_++] = graph;
|
||||
return graph;
|
||||
}
|
||||
|
||||
ParameterPtr KernelGraphUtils::CreateNewParameter(const AnfNodePtr &anf, KernelGraph *graph) {
|
||||
MS_EXCEPTION_IF_NULL(anf);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
if (!anf->isa<Parameter>()) {
|
||||
MS_LOG(EXCEPTION) << "Anf[" << anf->DebugString() << "] is not a parameter";
|
||||
}
|
||||
|
||||
auto param_value = GetParamDefaultValue(anf);
|
||||
ParameterPtr new_parameter = nullptr;
|
||||
// if parameter's python parameter has been exist a backend parameter, reuse the exist parameter
|
||||
if (param_value != nullptr) {
|
||||
new_parameter = param_value->parameter();
|
||||
if (new_parameter == nullptr) {
|
||||
TraceGuard trace_guard(std::make_shared<TraceCopy>(anf->debug_info()));
|
||||
new_parameter = graph->NewParameter(anf->cast<ParameterPtr>());
|
||||
param_value->set_parameter(new_parameter);
|
||||
}
|
||||
} else {
|
||||
TraceGuard trace_guard(std::make_shared<TraceCopy>(anf->debug_info()));
|
||||
new_parameter = graph->NewParameter(anf->cast<ParameterPtr>());
|
||||
}
|
||||
|
||||
new_parameter->IncreaseUsedGraphCount();
|
||||
|
||||
return new_parameter;
|
||||
}
|
||||
|
||||
ValueNodePtr KernelGraphUtils::CreateNewValueNode(const AnfNodePtr &anf, KernelGraph *graph) {
|
||||
MS_EXCEPTION_IF_NULL(anf);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto value_node = anf->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
auto value = value_node->value();
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
if (value->isa<None>()) {
|
||||
return nullptr;
|
||||
}
|
||||
auto new_value_node = graph->NewValueNode(value_node);
|
||||
graph->FrontBackendMapAdd(anf, new_value_node);
|
||||
graph->AddValueNodeToGraph(new_value_node);
|
||||
return new_value_node;
|
||||
}
|
||||
|
||||
ParameterPtr KernelGraphUtils::CreateNewParameterFromParameter(const AnfNodePtr &anf, KernelGraph *graph) {
|
||||
MS_EXCEPTION_IF_NULL(anf);
|
||||
if (!anf->isa<Parameter>()) {
|
||||
MS_LOG(EXCEPTION) << "Anf[" << anf->DebugString() << "] is not a parameter";
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto param_value = GetParamDefaultValue(anf);
|
||||
auto valid_inputs = graph->MutableValidInputs();
|
||||
MS_EXCEPTION_IF_NULL(valid_inputs);
|
||||
auto graph_inputs = graph->MutableInputs();
|
||||
MS_EXCEPTION_IF_NULL(graph_inputs);
|
||||
ParameterPtr new_parameter = nullptr;
|
||||
auto func_graph = anf->func_graph();
|
||||
if (func_graph->manager() != nullptr && func_graph->exist_multi_target() &&
|
||||
graph->device_target() == device::DeviceType::kCPU) {
|
||||
auto iter = default_param_map_.find(anf);
|
||||
if (iter != default_param_map_.end()) {
|
||||
new_parameter = iter->second;
|
||||
}
|
||||
if (new_parameter != nullptr) {
|
||||
return new_parameter;
|
||||
}
|
||||
TraceGuard trace_guard(std::make_shared<TraceCopy>(anf->debug_info()));
|
||||
new_parameter = graph->NewParameter(anf->cast<ParameterPtr>());
|
||||
graph_inputs->push_back(new_parameter);
|
||||
valid_inputs->push_back(true);
|
||||
default_param_map_[anf] = new_parameter;
|
||||
return new_parameter;
|
||||
}
|
||||
// if parameter's python parameter has been exist a backend parameter, reuse the exist parameter
|
||||
if (param_value != nullptr) {
|
||||
new_parameter = param_value->parameter();
|
||||
}
|
||||
if (new_parameter == nullptr) {
|
||||
TraceGuard trace_guard(std::make_shared<TraceCopy>(anf->debug_info()));
|
||||
new_parameter = graph->NewParameter(anf->cast<ParameterPtr>());
|
||||
|
||||
auto input_node_iter = partial_parameters_map_.find(anf);
|
||||
if (input_node_iter != partial_parameters_map_.end()) {
|
||||
InitInternalOutputParameter(input_node_iter->second, new_parameter);
|
||||
}
|
||||
|
||||
if (param_value != nullptr) {
|
||||
param_value->set_parameter(new_parameter);
|
||||
}
|
||||
}
|
||||
new_parameter->IncreaseUsedGraphCount();
|
||||
graph_inputs->push_back(new_parameter);
|
||||
valid_inputs->push_back(true);
|
||||
return new_parameter;
|
||||
}
|
||||
|
||||
ParamInfoPtr KernelGraphUtils::GetParamDefaultValue(const AnfNodePtr &node) {
|
||||
if (node == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
auto parameter = node->cast<ParameterPtr>();
|
||||
if (parameter == nullptr || !parameter->has_default()) {
|
||||
return nullptr;
|
||||
}
|
||||
return parameter->param_info();
|
||||
}
|
||||
|
||||
void KernelGraphUtils::InitInternalOutputParameter(const AnfNodePtr &out_node, const AnfNodePtr ¶meter) {
|
||||
auto graph_id = GetGraphIdByNode(out_node);
|
||||
if (graph_id == kInvalidGraphId) {
|
||||
return;
|
||||
}
|
||||
auto node_graph = GetGraph(graph_id);
|
||||
if (node_graph == nullptr) {
|
||||
return;
|
||||
}
|
||||
MS_LOG(INFO) << "Init parameter with pre graph output node: " << out_node->DebugString();
|
||||
auto ref_node = node_graph->GetInternalOutputByFrontNode(out_node);
|
||||
if (ref_node == nullptr) {
|
||||
MS_LOG(INFO) << "No corresponding internal output for output node";
|
||||
return;
|
||||
}
|
||||
size_t output_idx = 0;
|
||||
if (common::AnfAlgo::CheckPrimitiveType(out_node, prim::kPrimTupleGetItem)) {
|
||||
output_idx = common::AnfAlgo::GetTupleGetItemOutIndex(out_node->cast<CNodePtr>());
|
||||
}
|
||||
auto real_kernel = common::AnfAlgo::VisitKernel(ref_node, output_idx);
|
||||
auto ref_real_node = real_kernel.first;
|
||||
auto ref_real_node_index = real_kernel.second;
|
||||
if (ref_real_node->isa<CNode>() && node_graph->IsUniqueTargetInternalOutput(ref_real_node, ref_real_node_index)) {
|
||||
auto kernel_info = ref_real_node->kernel_info();
|
||||
if (kernel_info == nullptr || !kernel_info->has_build_info()) {
|
||||
MS_LOG(INFO) << "No kernel info";
|
||||
return;
|
||||
}
|
||||
if (!common::AnfAlgo::IsNopNode(ref_real_node) && !AnfAlgo::OutputAddrExist(ref_real_node, ref_real_node_index)) {
|
||||
MS_LOG(INFO) << "No kernel address";
|
||||
return;
|
||||
}
|
||||
auto address = AnfAlgo::GetMutableOutputAddr(ref_real_node, ref_real_node_index);
|
||||
auto format = AnfAlgo::GetOutputFormat(ref_real_node, ref_real_node_index);
|
||||
auto type = AnfAlgo::GetOutputDeviceDataType(ref_real_node, ref_real_node_index);
|
||||
auto d_kernel_info = std::make_shared<device::KernelInfo>();
|
||||
MS_EXCEPTION_IF_NULL(d_kernel_info);
|
||||
parameter->set_kernel_info(d_kernel_info);
|
||||
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
|
||||
builder.SetOutputsDeviceType({type});
|
||||
builder.SetOutputsFormat({format});
|
||||
d_kernel_info->set_select_kernel_build_info(builder.Build());
|
||||
AnfAlgo::SetOutputAddr(address, 0, parameter.get());
|
||||
auto abstract = std::make_shared<abstract::AbstractTensor>(TypeIdToType(type),
|
||||
parameter->Shape()->cast<abstract::BaseShapePtr>());
|
||||
parameter->set_abstract(abstract);
|
||||
}
|
||||
}
|
||||
|
||||
AnfNodePtr KernelGraphUtils::CreateNewParameterFromCNode(const AnfNodePtr &anf, KernelGraph *graph) {
|
||||
MS_EXCEPTION_IF_NULL(anf);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_LOG(INFO) << "Create a new parameter from cnode[" << anf->DebugString() << "]";
|
||||
if (IsPrimitiveCNode(anf, prim::kPrimLoad)) {
|
||||
auto input = common::AnfAlgo::GetInputNode(anf->cast<CNodePtr>(), 0);
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
if (input->isa<Parameter>()) {
|
||||
auto new_param = CreateNewParameterFromParameter(input, graph);
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
if (context_ptr->get_param<bool>(MS_CTX_ENABLE_MINDRT) == true) {
|
||||
graph->CacheInternalParameterToFrontNode(new_param, {anf, 0});
|
||||
}
|
||||
return new_param;
|
||||
}
|
||||
}
|
||||
return CreateParameterFromTuple(anf, graph);
|
||||
}
|
||||
|
||||
ValueNodePtr KernelGraphUtils::CreateValueNodeKernelGraph(const AnfNodePtr &anf, KernelGraph *graph) {
|
||||
MS_EXCEPTION_IF_NULL(anf);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto value_node = anf->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
auto sub_func_graph = common::AnfAlgo::GetValueNodeFuncGraph(anf);
|
||||
MS_EXCEPTION_IF_NULL(sub_func_graph);
|
||||
if (front_backend_graph_map_.find(sub_func_graph.get()) == front_backend_graph_map_.end()) {
|
||||
MS_LOG(EXCEPTION) << "FuncGraph: " << sub_func_graph->ToString() << " has not been transformed to KernelGraph.";
|
||||
}
|
||||
auto sub_kernel_graph = front_backend_graph_map_[sub_func_graph.get()];
|
||||
|
||||
ValueNodePtr new_value_node = std::make_shared<ValueNode>(sub_kernel_graph);
|
||||
new_value_node->set_abstract(value_node->abstract());
|
||||
// create new kernel_info of new value_node
|
||||
auto kernel_info = std::make_shared<device::KernelInfo>();
|
||||
new_value_node->set_kernel_info(kernel_info);
|
||||
// create kernel_build_info for new value node
|
||||
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
|
||||
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get());
|
||||
AnfAlgo::SetGraphId(graph->graph_id(), new_value_node.get());
|
||||
|
||||
graph->FrontBackendMapAdd(anf, new_value_node);
|
||||
|
||||
return new_value_node;
|
||||
}
|
||||
|
||||
bool KernelGraphUtils::CreateCNodeOfKernelGraph(const AnfNodePtr &node, KernelGraph *graph) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
// create a new cnode object
|
||||
auto new_cnode = CreateNewCNode(cnode, graph);
|
||||
if (new_cnode == nullptr) {
|
||||
return false;
|
||||
}
|
||||
new_cnode->set_abstract(cnode->abstract());
|
||||
std::string fullname;
|
||||
if (cnode->input(kAnfPrimitiveIndex)->isa<CNode>()) {
|
||||
fullname = cnode->input(kAnfPrimitiveIndex)->fullname_with_scope();
|
||||
} else if (IsPrimitiveCNode(cnode, prim::kPrimLoad)) {
|
||||
fullname = cnode->input(kFirstDataInputIndex)->fullname_with_scope();
|
||||
} else {
|
||||
fullname = cnode->fullname_with_scope();
|
||||
}
|
||||
new_cnode->set_fullname_with_scope(fullname);
|
||||
new_cnode->set_scope(cnode->scope());
|
||||
graph->FrontBackendMapAdd(node, new_cnode);
|
||||
SetReturnNode(new_cnode, graph);
|
||||
return true;
|
||||
}
|
||||
|
||||
void KernelGraphUtils::AddParameterToGraphInputs(const std::vector<AnfNodePtr> ¶meters, KernelGraph *graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto graph_inputs = graph->MutableInputs();
|
||||
MS_EXCEPTION_IF_NULL(graph_inputs);
|
||||
graph_inputs->clear();
|
||||
for (auto ¶meter : parameters) {
|
||||
MS_EXCEPTION_IF_NULL(parameter);
|
||||
auto backend_parameter = graph->GetBackendAnfByFrontAnf(parameter);
|
||||
if (backend_parameter == nullptr) {
|
||||
// for example "def f(x,y,z) {return x + y}", parameter z in unused
|
||||
auto new_parameter = CreateNewParameter(parameter, graph);
|
||||
graph_inputs->push_back(new_parameter);
|
||||
graph->FrontBackendMapAdd(parameter, new_parameter);
|
||||
MS_LOG(INFO) << "Can't find parameter:" << parameter->DebugString();
|
||||
continue;
|
||||
}
|
||||
graph_inputs->push_back(backend_parameter);
|
||||
}
|
||||
}
|
||||
|
||||
void KernelGraphUtils::SetInputNodeUsage(const KernelGraphPtr &graph, const FuncGraphManagerPtr &manager) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
auto input_nodes = graph->input_nodes();
|
||||
for (auto &input_node : input_nodes) {
|
||||
if (input_node->isa<Parameter>()) {
|
||||
auto node_ptr = input_node->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(node_ptr);
|
||||
if (!IsUsedByRealKernel(manager, input_node, graph->graph_id())) {
|
||||
node_ptr->SetNotUsedByRealKernelInGraph(graph->graph_id());
|
||||
}
|
||||
auto shape = node_ptr->Shape();
|
||||
if (IsShapeDynamic(shape->cast<abstract::ShapePtr>())) {
|
||||
node_ptr->set_has_dynamic_shape(true);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#ifndef ENABLE_SECURITY
|
||||
bool KernelGraphUtils::ExistSummaryNode(const KernelGraph *graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto ret = graph->get_return();
|
||||
MS_EXCEPTION_IF_NULL(ret);
|
||||
auto all_nodes = DeepLinkedGraphSearch(ret);
|
||||
for (auto &n : all_nodes) {
|
||||
if (IsPrimitiveCNode(n, prim::kPrimScalarSummary) || IsPrimitiveCNode(n, prim::kPrimTensorSummary) ||
|
||||
IsPrimitiveCNode(n, prim::kPrimImageSummary) || IsPrimitiveCNode(n, prim::kPrimHistogramSummary)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
|
||||
GraphId KernelGraphUtils::GetGraphIdByNode(const AnfNodePtr &front_anf) const {
|
||||
for (const auto &graph_item : graphs_) {
|
||||
auto graph = graph_item.second;
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
// if front_anf is a parameter,the backend parameter may have two
|
||||
if (graph->GetBackendAnfByFrontAnf(front_anf) != nullptr) {
|
||||
return graph_item.first;
|
||||
}
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(front_anf);
|
||||
MS_LOG(DEBUG) << "Front_anf " << front_anf->DebugString() << " is not exist in any graph";
|
||||
return kInvalidGraphId;
|
||||
}
|
||||
|
||||
KernelGraphPtr KernelGraphUtils::GetGraph(mindspore::GraphId graph_id) const {
|
||||
auto it = graphs_.find(graph_id);
|
||||
if (it == graphs_.end()) {
|
||||
MS_LOG(INFO) << "Can't find graph " << graph_id;
|
||||
return nullptr;
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
AnfNodePtr KernelGraphUtils::CreateParameterFromTuple(const AnfNodePtr &node, KernelGraph *graph) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto new_parameter = graph->TransTupleToMakeTuple(graph->NewParameter(node->abstract()));
|
||||
auto parameters = common::AnfAlgo::GetAllOutput(new_parameter);
|
||||
std::vector<AnfNodePtr> pre_graph_out = {node};
|
||||
// If a cnode is a call, it's input0 is a cnode too, so it doesn't have primitive
|
||||
if (!pre_graph_out.empty() && !AnfUtils::IsRealKernel(node)) {
|
||||
pre_graph_out = common::AnfAlgo::GetAllOutput(node, {prim::kPrimTupleGetItem, prim::kPrimUpdateState});
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < parameters.size(); ++i) {
|
||||
const auto ¶meter = parameters[i];
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
if (context_ptr->get_param<bool>(MS_CTX_ENABLE_MINDRT) == true) {
|
||||
// In control flow, if the input of the cnode is a call node, it will be processed as a make_tuple input,
|
||||
// which needs to be linked when processing the internal node.
|
||||
graph->CacheInternalParameterToFrontNode(parameter, {node, i});
|
||||
}
|
||||
auto valid_inputs = graph->MutableValidInputs();
|
||||
MS_EXCEPTION_IF_NULL(valid_inputs);
|
||||
auto graph_inputs = graph->MutableInputs();
|
||||
MS_EXCEPTION_IF_NULL(graph_inputs);
|
||||
valid_inputs->push_back(true);
|
||||
graph_inputs->push_back(parameter);
|
||||
}
|
||||
size_t param_index = 0;
|
||||
for (const auto &out_node : pre_graph_out) {
|
||||
size_t output_size = common::AnfAlgo::GetOutputTensorNum(out_node);
|
||||
for (size_t i = 0; i < output_size; i++) {
|
||||
if (param_index >= parameters.size()) {
|
||||
MS_LOG(EXCEPTION) << "Parameters size:" << parameters.size() << "out of range.Node:" << node->DebugString()
|
||||
<< ",out_node:" << out_node->DebugString();
|
||||
}
|
||||
InitInternalOutputParameter(out_node, parameters[param_index++]);
|
||||
}
|
||||
}
|
||||
return new_parameter;
|
||||
}
|
||||
|
||||
CNodePtr KernelGraphUtils::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
std::vector<AnfNodePtr> cnode_inputs;
|
||||
auto attr_input = cnode->input(kAnfPrimitiveIndex);
|
||||
MS_EXCEPTION_IF_NULL(attr_input);
|
||||
if (IsValueNode<FuncGraph>(attr_input)) {
|
||||
// cnode is a graph or a call
|
||||
cnode_inputs = CreateValueNode(cnode, graph);
|
||||
} else if (attr_input->isa<CNode>()) {
|
||||
// cnode ia a call (partial/switch/switch_layer)
|
||||
// 1. take the args of call to the partial node, as the real_args to call switch's or switch_layer's child graph
|
||||
// 2. the call in frontend is map to the partial/switch/switch_layer in backend and haven't been created
|
||||
cnode_inputs = CreateSwitchOrPartialNode(cnode, graph);
|
||||
if (cnode_inputs.empty()) {
|
||||
MS_LOG_ERROR << "Create switch or partial failed, cnode:" << cnode->DebugString();
|
||||
return nullptr;
|
||||
}
|
||||
} else {
|
||||
// get primitive of old node
|
||||
auto prim = common::AnfAlgo::GetCNodePrimitive(cnode);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
// push attr to inputs[0] of new cnode
|
||||
cnode_inputs = {graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(*prim)))};
|
||||
}
|
||||
// handle inputs of cnode except primitive
|
||||
CreateCNodeInputs(cnode, graph, &cnode_inputs);
|
||||
TraceGuard trace_guard(std::make_shared<TraceCopy>(cnode->debug_info()));
|
||||
auto new_cnode = graph->NewCNodeWithInfos(cnode_inputs, cnode);
|
||||
// if the cnode is call switch, remove call
|
||||
if (new_cnode->inputs().size() > 1) {
|
||||
auto first_input = new_cnode->input(kFirstDataInputIndex);
|
||||
MS_EXCEPTION_IF_NULL(first_input);
|
||||
if (common::AnfAlgo::CheckPrimitiveType(new_cnode, prim::kPrimCall) &&
|
||||
common::AnfAlgo::CheckPrimitiveType(first_input, prim::kPrimSwitch)) {
|
||||
new_cnode = first_input->cast<CNodePtr>();
|
||||
}
|
||||
if (common::AnfAlgo::CheckPrimitiveType(new_cnode, prim::kPrimCall) &&
|
||||
common::AnfAlgo::CheckPrimitiveType(first_input, prim::kPrimSwitchLayer)) {
|
||||
auto abstract = cnode->abstract();
|
||||
new_cnode = first_input->cast<CNodePtr>();
|
||||
new_cnode->set_abstract(abstract);
|
||||
}
|
||||
}
|
||||
return new_cnode;
|
||||
}
|
||||
|
||||
void KernelGraphUtils::SetReturnNode(const AnfNodePtr &node, KernelGraph *graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
||||
if (common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) {
|
||||
constexpr auto kReturnInputIdx = 1;
|
||||
auto return_node = node->cast<CNodePtr>();
|
||||
graph->set_return(return_node);
|
||||
auto graph_output = return_node->input(kReturnInputIdx);
|
||||
MS_EXCEPTION_IF_NULL(graph_output);
|
||||
|
||||
// If return's input is value node, then the graph has no kernel, and the pass 'trans tuple to make_tuple' cannot
|
||||
// match this pattern because that pass begin with output node but return node. So we add transform value tuple
|
||||
// to make_tuple here.
|
||||
if (common::AnfAlgo::IsTupleOutput(graph_output) && graph_output->isa<ValueNode>()) {
|
||||
return_node->set_input(kReturnInputIdx, graph->TransTupleToMakeTuple(graph_output));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool KernelGraphUtils::IsUsedByRealKernel(const FuncGraphManagerPtr &manager, const AnfNodePtr &node,
|
||||
const uint32_t graph_id) {
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto node_users = manager->node_users()[node];
|
||||
// filter nodes not in current graph
|
||||
for (auto iter = node_users.begin(); iter != node_users.end();) {
|
||||
auto func_graph = iter->first->func_graph();
|
||||
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
|
||||
if (kernel_graph == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "func graph cast kernel graph failed, related node is: " << iter->first->DebugString();
|
||||
}
|
||||
if (kernel_graph->graph_id() != graph_id) {
|
||||
iter = node_users.erase(iter);
|
||||
} else {
|
||||
iter++;
|
||||
}
|
||||
}
|
||||
|
||||
size_t idx = 0;
|
||||
if (std::any_of(node_users.begin(), node_users.end(), [&](const std::pair<AnfNodePtr, int64_t> &kernel) {
|
||||
return RecursiveCheck(manager, kernel, &idx);
|
||||
})) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool KernelGraphUtils::IsShapeDynamic(const abstract::ShapePtr &shape) {
|
||||
if (shape == nullptr) {
|
||||
return false;
|
||||
}
|
||||
return std::any_of(shape->shape().begin(), shape->shape().end(), [](int64_t s) { return s < 0; });
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> KernelGraphUtils::CreateValueNode(const CNodePtr &cnode, KernelGraph *graph) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
std::vector<AnfNodePtr> cnode_inputs;
|
||||
auto attr_input = cnode->input(kAnfPrimitiveIndex);
|
||||
MS_EXCEPTION_IF_NULL(attr_input);
|
||||
if (common::AnfAlgo::IsGraphKernel(cnode)) {
|
||||
auto fg = common::AnfAlgo::GetCNodeFuncGraphPtr(cnode);
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
auto new_fg = BasicClone(fg);
|
||||
cnode_inputs.push_back(std::make_shared<ValueNode>(new_fg));
|
||||
} else {
|
||||
// create primitive of cnode:call
|
||||
cnode_inputs = {graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))};
|
||||
// create a ValueNode<KernelGraph> as input of cnode:call
|
||||
if (graph->GetBackendAnfByFrontAnf(attr_input) != nullptr) {
|
||||
cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(attr_input));
|
||||
} else {
|
||||
auto new_value_node = CreateValueNodeKernelGraph(attr_input, graph);
|
||||
if (new_value_node != nullptr) {
|
||||
cnode_inputs.emplace_back(new_value_node);
|
||||
}
|
||||
}
|
||||
}
|
||||
return cnode_inputs;
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> KernelGraphUtils::CreateSwitchOrPartialNode(const CNodePtr &cnode, KernelGraph *graph) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
// create primitive of cnode:call(partial or switch or switch_layer)
|
||||
std::vector<AnfNodePtr> cnode_inputs = {
|
||||
graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))};
|
||||
auto attr_input = cnode->input(kAnfPrimitiveIndex);
|
||||
MS_EXCEPTION_IF_NULL(attr_input);
|
||||
auto cnode_input = graph->GetBackendAnfByFrontAnf(attr_input);
|
||||
if (cnode_input == nullptr) {
|
||||
MS_LOG(ERROR) << "CNode input[0] is CNode:" << attr_input->DebugString() << ", but input[0] has not been created.";
|
||||
return {};
|
||||
}
|
||||
// if the node is partial, insert the inputs of partial to the call
|
||||
if (common::AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimPartial)) {
|
||||
auto partial_node = attr_input->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(partial_node);
|
||||
auto partial_inputs = partial_node->inputs();
|
||||
(void)std::transform(partial_inputs.begin() + kFirstDataInputIndex, partial_inputs.end(),
|
||||
std::back_inserter(cnode_inputs), [&graph](const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(graph->GetBackendAnfByFrontAnf(node));
|
||||
return graph->GetBackendAnfByFrontAnf(node);
|
||||
});
|
||||
return cnode_inputs;
|
||||
} else if (common::AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimSwitch)) {
|
||||
return CreateCallSwitchInputs(cnode, graph);
|
||||
} else if (common::AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimSwitchLayer)) {
|
||||
return CreateCallSwitchLayerInputs(cnode, graph);
|
||||
}
|
||||
MS_LOG(ERROR) << "CNode:" << cnode->DebugString() << " input[0]" << cnode_input->DebugString()
|
||||
<< "must be partial or switch or switch_layer.";
|
||||
return {};
|
||||
}
|
||||
|
||||
void KernelGraphUtils::CreateCNodeInputs(const CNodePtr &cnode, KernelGraph *graph,
|
||||
std::vector<AnfNodePtr> *cnode_inputs) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
if (common::AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) {
|
||||
(void)cnode_inputs->emplace_back(graph->GetBackendAnfByFrontAnf(cnode->input(kFirstDataInputIndex)));
|
||||
for (size_t index = kSwitchTrueBranchIndex; index < cnode->inputs().size(); index++) {
|
||||
auto node_input = cnode->input(index);
|
||||
auto switch_input = CreateSwitchInput(cnode, node_input, graph);
|
||||
(void)cnode_inputs->emplace_back(switch_input);
|
||||
}
|
||||
} else {
|
||||
for (size_t input_idx = kFirstDataInputIndex; input_idx < cnode->inputs().size(); input_idx++) {
|
||||
auto anf = cnode->input(input_idx);
|
||||
MS_EXCEPTION_IF_NULL(anf);
|
||||
// anf has been created before
|
||||
if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) {
|
||||
(void)cnode_inputs->emplace_back(graph->GetBackendAnfByFrontAnf(anf));
|
||||
continue;
|
||||
} else if (IsValueNode<None>(anf)) {
|
||||
continue;
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "Unexpected input[" << anf->DebugString() << "]";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool KernelGraphUtils::RecursiveCheck(const FuncGraphManagerPtr &manager, const std::pair<AnfNodePtr, int64_t> &kernel,
|
||||
size_t *idx) {
|
||||
auto node = kernel.first;
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (kernel.second > 1 && (common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend) ||
|
||||
common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimLoad))) {
|
||||
return false;
|
||||
}
|
||||
if (AnfUtils::IsRealKernel(node) && !common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial)) {
|
||||
return true;
|
||||
}
|
||||
(*idx) += 1;
|
||||
// max recursion depth
|
||||
if (*idx <= max_depth) {
|
||||
auto users = manager->node_users()[node];
|
||||
if (std::any_of(users.begin(), users.end(), [&](const std::pair<AnfNodePtr, int64_t> &kernel) {
|
||||
return RecursiveCheck(manager, kernel, idx);
|
||||
})) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> KernelGraphUtils::CreateCallSwitchInputs(const CNodePtr &cnode, KernelGraph *graph) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
std::vector<AnfNodePtr> cnode_inputs = {
|
||||
graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))};
|
||||
auto attr_input = cnode->input(kAnfPrimitiveIndex);
|
||||
MS_EXCEPTION_IF_NULL(attr_input);
|
||||
auto cnode_input = graph->GetBackendAnfByFrontAnf(attr_input);
|
||||
auto switch_cnode = cnode_input->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(switch_cnode);
|
||||
if (cnode->inputs().size() <= 1) {
|
||||
cnode_inputs = switch_cnode->inputs();
|
||||
return cnode_inputs;
|
||||
}
|
||||
std::vector<AnfNodePtr> switch_inputs = {switch_cnode->input(kAnfPrimitiveIndex),
|
||||
switch_cnode->input(kFirstDataInputIndex)};
|
||||
for (size_t index = kSwitchTrueBranchIndex; index < switch_cnode->inputs().size(); index++) {
|
||||
auto node = switch_cnode->input(index);
|
||||
// there is real input in call, should put it to true and false branch in switch
|
||||
if (common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial)) {
|
||||
auto partial_node = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(partial_node);
|
||||
std::vector<AnfNodePtr> partial_inputs = partial_node->inputs();
|
||||
// Put all call args at the end of partial inputs.
|
||||
for (size_t i = kFirstDataInputIndex; i < cnode->size(); ++i) {
|
||||
(void)partial_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(cnode->input(i)));
|
||||
}
|
||||
auto new_partial = graph->NewCNode(partial_inputs);
|
||||
(void)switch_inputs.emplace_back(new_partial);
|
||||
}
|
||||
}
|
||||
if (switch_inputs.size() < kSwitchInputSize) {
|
||||
MS_LOG(EXCEPTION) << "Switch inputs size: " << switch_inputs.size() << "less than " << kSwitchInputSize;
|
||||
}
|
||||
auto switch_node = graph->NewCNode(switch_inputs);
|
||||
(void)cnode_inputs.emplace_back(switch_node);
|
||||
return cnode_inputs;
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> KernelGraphUtils::CreateCallSwitchLayerInputs(const CNodePtr &cnode, KernelGraph *graph) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
std::vector<AnfNodePtr> cnode_inputs = {
|
||||
graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))};
|
||||
auto attr_input = cnode->input(kAnfPrimitiveIndex);
|
||||
MS_EXCEPTION_IF_NULL(attr_input);
|
||||
auto cnode_input = graph->GetBackendAnfByFrontAnf(attr_input);
|
||||
auto switch_layer_cnode = cnode_input->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(switch_layer_cnode);
|
||||
std::vector<AnfNodePtr> switch_layer_inputs = {switch_layer_cnode->input(kAnfPrimitiveIndex),
|
||||
switch_layer_cnode->input(kFirstDataInputIndex)};
|
||||
auto make_tuple_node = switch_layer_cnode->input(kSwitchLayerBranchesIndex);
|
||||
MS_EXCEPTION_IF_NULL(make_tuple_node);
|
||||
auto node = make_tuple_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto make_tuple_inputs = node->inputs();
|
||||
// there are real inputs in call, should put it to make_tuple in switch_layer
|
||||
std::vector<AnfNodePtr> real_inputs;
|
||||
for (size_t idx = kFirstDataInputIndex; idx < cnode->inputs().size(); ++idx) {
|
||||
real_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(cnode->input(idx)));
|
||||
}
|
||||
std::vector<AnfNodePtr> new_make_tuple_inputs = {
|
||||
graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name())))};
|
||||
for (size_t idx = kFirstDataInputIndex; idx < make_tuple_inputs.size(); idx++) {
|
||||
auto partial_idx = make_tuple_inputs[idx];
|
||||
MS_EXCEPTION_IF_NULL(cnode->abstract());
|
||||
std::vector<AnfNodePtr> new_partial_inputs;
|
||||
KernelGraphPtr partial_kernel_graph;
|
||||
// switch_layer node input is partial cnode
|
||||
if (common::AnfAlgo::CheckPrimitiveType(partial_idx, prim::kPrimPartial)) {
|
||||
auto partial_node = partial_idx->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(partial_node);
|
||||
auto partial_input = partial_node->input(kFirstDataInputIndex);
|
||||
partial_kernel_graph = GetValueNode<KernelGraphPtr>(partial_input);
|
||||
new_partial_inputs = partial_node->inputs();
|
||||
} else if (IsValueNode<KernelGraph>(partial_idx)) { // switch_layer node input is kernel graph value node
|
||||
new_partial_inputs.emplace_back(NewValueNode(std::make_shared<Primitive>(prim::kPrimPartial->name())));
|
||||
new_partial_inputs.emplace_back(partial_idx);
|
||||
partial_kernel_graph = GetValueNode<KernelGraphPtr>(partial_idx);
|
||||
}
|
||||
// when branch in swich_layer return function
|
||||
MS_EXCEPTION_IF_NULL(partial_kernel_graph);
|
||||
auto ret = partial_kernel_graph->get_return();
|
||||
MS_EXCEPTION_IF_NULL(ret);
|
||||
auto return_input = ret->input(kFirstDataInputIndex);
|
||||
if (common::AnfAlgo::CheckPrimitiveType(return_input, prim::kPrimPartial) || return_input->isa<ValueNode>()) {
|
||||
ProcessNodeRetFunc(cnode, partial_kernel_graph.get(), real_inputs);
|
||||
}
|
||||
// partial node add input args
|
||||
new_partial_inputs.insert(new_partial_inputs.end(), real_inputs.begin(), real_inputs.end());
|
||||
// create new partial node
|
||||
auto new_partial = graph->NewCNode(new_partial_inputs);
|
||||
new_make_tuple_inputs.emplace_back(new_partial);
|
||||
}
|
||||
auto new_make_tuple = graph->NewCNode(new_make_tuple_inputs);
|
||||
auto abstract = make_tuple_node->abstract();
|
||||
if (abstract == nullptr) {
|
||||
abstract = std::make_shared<abstract::AbstractTuple>(AbstractBasePtrList());
|
||||
}
|
||||
new_make_tuple->set_abstract(abstract);
|
||||
switch_layer_inputs.emplace_back(new_make_tuple);
|
||||
auto new_switch_layer = graph->NewCNode(switch_layer_inputs);
|
||||
cnode_inputs.emplace_back(new_switch_layer);
|
||||
return cnode_inputs;
|
||||
}
|
||||
|
||||
CNodePtr KernelGraphUtils::CreateSwitchInput(const CNodePtr &cnode, const AnfNodePtr &node_input, KernelGraph *graph) {
|
||||
MS_EXCEPTION_IF_NULL(node_input);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
// switch input generalizes partial
|
||||
std::vector<AnfNodePtr> partial_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimPartial->name()))};
|
||||
if (common::AnfAlgo::CheckPrimitiveType(node_input, prim::kPrimPartial)) {
|
||||
auto backend_node = graph->GetBackendAnfByFrontAnf(node_input);
|
||||
return backend_node->cast<CNodePtr>();
|
||||
} else if (node_input->isa<ValueNode>() && IsValueNode<FuncGraph>(node_input)) {
|
||||
partial_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(node_input));
|
||||
} else {
|
||||
KernelGraphPtr kernel_graph = NewKernelGraph();
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
auto parameter = CreateNewParameterFromCNode(cnode, kernel_graph.get());
|
||||
MS_EXCEPTION_IF_NULL(parameter);
|
||||
parameter->set_abstract(cnode->abstract());
|
||||
auto primitive = NewValueNode(std::make_shared<Primitive>(prim::kPrimReturn->name()));
|
||||
auto return_node = kernel_graph->NewCNode({primitive, parameter});
|
||||
return_node->set_abstract(cnode->abstract());
|
||||
kernel_graph->set_return(return_node);
|
||||
partial_inputs.emplace_back(std::make_shared<ValueNode>(kernel_graph));
|
||||
partial_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(node_input));
|
||||
}
|
||||
auto partial_node = graph->NewCNode(partial_inputs);
|
||||
return partial_node;
|
||||
}
|
||||
|
||||
void KernelGraphUtils::ProcessNodeRetFunc(const CNodePtr &cnode, KernelGraph *graph,
|
||||
const std::vector<AnfNodePtr> &real_inputs) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
// func1 =switch(branch1, branch2)
|
||||
// func2 = func1(param1)
|
||||
// out = func2(param2)
|
||||
// process the last cnode(func2), not func1 which abstract is AbstractFunction
|
||||
if (cnode->abstract()->isa<abstract::AbstractFunction>()) {
|
||||
return;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto ret = graph->get_return();
|
||||
MS_EXCEPTION_IF_NULL(ret);
|
||||
auto return_input = ret->input(kFirstDataInputIndex);
|
||||
// return node is a function
|
||||
std::vector<AnfNodePtr> call_inputs = {
|
||||
graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))};
|
||||
if (common::AnfAlgo::CheckPrimitiveType(return_input, prim::kPrimPartial)) {
|
||||
auto return_input_cnode = return_input->cast<CNodePtr>();
|
||||
auto partial_inputs = return_input_cnode->inputs();
|
||||
call_inputs.insert(call_inputs.end(), partial_inputs.begin() + kFirstDataInputIndex, partial_inputs.end());
|
||||
} else if (IsValueNode<KernelGraph>(return_input)) { // return node is kernel graph
|
||||
call_inputs.emplace_back(return_input);
|
||||
} else { // return node is value node
|
||||
KernelGraphPtr kernel_graph = NewKernelGraph();
|
||||
auto valid_inputs = kernel_graph->MutableValidInputs();
|
||||
MS_EXCEPTION_IF_NULL(valid_inputs);
|
||||
auto graph_inputs = kernel_graph->MutableInputs();
|
||||
MS_EXCEPTION_IF_NULL(graph_inputs);
|
||||
std::vector<AnfNodePtr> cnode_inputs = {return_input};
|
||||
for (auto &real_input : real_inputs) {
|
||||
auto new_parameter = kernel_graph->NewParameter(real_input->abstract());
|
||||
valid_inputs->push_back(true);
|
||||
graph_inputs->push_back(new_parameter);
|
||||
cnode_inputs.push_back(new_parameter);
|
||||
}
|
||||
auto new_cnode = kernel_graph->NewCNode(cnode_inputs);
|
||||
new_cnode->set_abstract(cnode->abstract());
|
||||
std::vector<AnfNodePtr> return_inputs = {
|
||||
kernel_graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimReturn->name()))), new_cnode};
|
||||
auto return_node = kernel_graph->NewCNode(return_inputs);
|
||||
return_node->set_abstract(cnode->abstract());
|
||||
kernel_graph->set_return(return_node);
|
||||
call_inputs.push_back(std::make_shared<ValueNode>(kernel_graph));
|
||||
}
|
||||
|
||||
// new call node inputs
|
||||
for (auto &input_node : real_inputs) {
|
||||
auto parameter_for_input = CreateNewParameterFromCNode(input_node, graph);
|
||||
call_inputs.emplace_back(parameter_for_input);
|
||||
}
|
||||
|
||||
auto call_node = graph->NewCNode(call_inputs);
|
||||
call_node->set_abstract(cnode->abstract());
|
||||
// update return input
|
||||
ret->set_input(kFirstDataInputIndex, call_node);
|
||||
}
|
||||
} // namespace mindspore::infer
|
|
@ -0,0 +1,77 @@
|
|||
/**
|
||||
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_EXTENDRT_UTILS_KERNEL_GRAPH_UTILS_H_
|
||||
#define MINDSPORE_LITE_SRC_EXTENDRT_UTILS_KERNEL_GRAPH_UTILS_H_
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "backend/common/session/kernel_graph.h"
|
||||
|
||||
namespace mindspore::infer {
|
||||
using GraphId = uint32_t;
|
||||
class KernelGraphUtils {
|
||||
public:
|
||||
KernelGraphUtils() = default;
|
||||
virtual ~KernelGraphUtils() = default;
|
||||
|
||||
static KernelGraphUtils &Instance() {
|
||||
static KernelGraphUtils instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
KernelGraphPtr ConstructKernelGraph(const FuncGraphPtr &func_graph, std::vector<KernelGraphPtr> *all_out_graph,
|
||||
DeviceType device_target);
|
||||
KernelGraphPtr NewKernelGraph();
|
||||
ParameterPtr CreateNewParameter(const AnfNodePtr &anf, KernelGraph *graph);
|
||||
ValueNodePtr CreateNewValueNode(const AnfNodePtr &anf, KernelGraph *graph);
|
||||
ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, KernelGraph *graph);
|
||||
ParamInfoPtr GetParamDefaultValue(const AnfNodePtr &node);
|
||||
void InitInternalOutputParameter(const AnfNodePtr &out_node, const AnfNodePtr ¶meter);
|
||||
AnfNodePtr CreateNewParameterFromCNode(const AnfNodePtr &anf, KernelGraph *graph);
|
||||
ValueNodePtr CreateValueNodeKernelGraph(const AnfNodePtr &anf, KernelGraph *graph);
|
||||
bool CreateCNodeOfKernelGraph(const AnfNodePtr &node, KernelGraph *graph);
|
||||
void AddParameterToGraphInputs(const std::vector<AnfNodePtr> ¶meters, KernelGraph *graph);
|
||||
void SetInputNodeUsage(const KernelGraphPtr &graph, const FuncGraphManagerPtr &manager);
|
||||
GraphId GetGraphIdByNode(const AnfNodePtr &front_anf) const;
|
||||
KernelGraphPtr GetGraph(mindspore::GraphId graph_id) const;
|
||||
AnfNodePtr CreateParameterFromTuple(const AnfNodePtr &node, KernelGraph *graph);
|
||||
CNodePtr CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph);
|
||||
void SetReturnNode(const AnfNodePtr &node, KernelGraph *graph);
|
||||
bool IsUsedByRealKernel(const FuncGraphManagerPtr &manager, const AnfNodePtr &node, const uint32_t graph_id);
|
||||
bool IsShapeDynamic(const abstract::ShapePtr &shape);
|
||||
std::vector<AnfNodePtr> CreateValueNode(const CNodePtr &cnode, KernelGraph *graph);
|
||||
std::vector<AnfNodePtr> CreateSwitchOrPartialNode(const CNodePtr &cnode, KernelGraph *graph);
|
||||
void CreateCNodeInputs(const CNodePtr &cnode, KernelGraph *graph, std::vector<AnfNodePtr> *cnode_inputs);
|
||||
bool RecursiveCheck(const FuncGraphManagerPtr &manager, const std::pair<AnfNodePtr, int64_t> &kernel, size_t *idx);
|
||||
std::vector<AnfNodePtr> CreateCallSwitchInputs(const CNodePtr &cnode, KernelGraph *graph);
|
||||
std::vector<AnfNodePtr> CreateCallSwitchLayerInputs(const CNodePtr &cnode, KernelGraph *graph);
|
||||
CNodePtr CreateSwitchInput(const CNodePtr &cnode, const AnfNodePtr &node_input, KernelGraph *graph);
|
||||
void ProcessNodeRetFunc(const CNodePtr &cnode, KernelGraph *graph, const std::vector<AnfNodePtr> &real_inputs);
|
||||
|
||||
#ifndef ENABLE_SECURITY
|
||||
static bool KernelGraphUtils::ExistSummaryNode(const KernelGraph *graph);
|
||||
#endif
|
||||
|
||||
private:
|
||||
mindspore::HashMap<FuncGraph *, KernelGraphPtr> front_backend_graph_map_;
|
||||
static GraphId graph_sum_;
|
||||
};
|
||||
} // namespace mindspore::infer
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_EXTENDRT_UTILS_KERNEL_GRAPH_UTILS_H_
|
|
@ -171,8 +171,8 @@ set(MODEL_LOADER_FRAMEWORK_SRC
|
|||
if(MSLITE_ENABLE_CLOUD_FUSION_INFERENCE)
|
||||
add_compile_definitions(ENABLE_CLOUD_FUSION_INFERENCE)
|
||||
|
||||
string(REPLACE "-Werror" "" CMAKE_C_FLAGS ${CMAKE_C_FLAGS})
|
||||
string(REPLACE "-Werror" "" CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS})
|
||||
# string(REPLACE "-Werror" "" CMAKE_C_FLAGS ${CMAKE_C_FLAGS})
|
||||
# string(REPLACE "-Werror" "" CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS})
|
||||
|
||||
set(MINDIR_MODEL_SRC
|
||||
${MINDIR_MODEL_SRC}
|
||||
|
|
|
@ -15,9 +15,11 @@ add_compile_definitions(MINDIR_EXPORT_TENSOR_LAYOUT_CLIP)
|
|||
add_compile_definitions(COMMON_DLL)
|
||||
|
||||
if(NOT ENABLE_CLOUD_AND_LITE)
|
||||
set(MINDIR_EXPORT_DIR ${CCSRC_DIR}/transform/express_ir)
|
||||
add_subdirectory(${MINDIR_EXPORT_DIR} mindir_exporter)
|
||||
add_dependencies(_mindspore_transform_express_ir_obj mindir_proto_mid)
|
||||
# if(NOT MSLITE_ENABLE_CLOUD_FUSION_INFERENCE)
|
||||
set(MINDIR_EXPORT_DIR ${CCSRC_DIR}/transform/express_ir)
|
||||
add_subdirectory(${MINDIR_EXPORT_DIR} mindir_exporter)
|
||||
add_dependencies(_mindspore_transform_express_ir_obj mindir_proto_mid)
|
||||
# endif()
|
||||
endif()
|
||||
|
||||
add_library(mindir_serializer_mid OBJECT
|
||||
|
|
Loading…
Reference in New Issue