From 5c6a70d8adecaa1730b43befb74101508985f363 Mon Sep 17 00:00:00 2001 From: wangpingan2 Date: Mon, 28 Mar 2022 15:35:58 +0800 Subject: [PATCH] [MS][LITE] add KernelExecutor to run single operator. --- cmake/package_lite.cmake | 28 + mindspore/core/base/complex_storage.h | 4 +- mindspore/core/utils/log_adapter.cc | 50 +- mindspore/core/utils/log_adapter.h | 5 +- mindspore/lite/CMakeLists.txt | 15 +- mindspore/lite/src/CMakeLists.txt | 7 +- mindspore/lite/src/common/log_adapter.h | 2 +- .../lite/src/common/primitive_t_utils.cc | 74 + mindspore/lite/src/common/primitive_t_utils.h | 32 + .../cxx_api/kernel_executor/CMakeLists.txt | 33 + .../kernel_executor/kernel_executor.cc | 61 + .../cxx_api/kernel_executor/kernel_executor.h | 73 + .../kernel_executor/kernel_executor_impl.cc | 253 ++++ .../kernel_executor/kernel_executor_impl.h | 56 + mindspore/lite/src/ops/anf_utils.cc | 44 + mindspore/lite/src/ops/anf_utils.h | 29 + mindspore/lite/src/ops/ops_utils.cc | 1307 +++-------------- mindspore/lite/src/ops/ops_utils.h | 24 +- .../lite/tools/anf_exporter/anf_exporter.cc | 2 +- .../lite/tools/anf_exporter/fetch_content.cc | 3 +- mindspore/lite/tools/common/node_util.cc | 11 - mindspore/lite/tools/common/node_util.h | 2 - mindspore/lite/tools/converter/CMakeLists.txt | 2 + mindspore/lite/tools/converter/converter.cc | 1 + .../legacy_optimizer/graph/infershape_pass.cc | 1 + .../tools/optimizer/const_fold/fold_utils.cc | 2 +- .../tools/optimizer/fusion/norm_fusion.cc | 2 +- .../tools/optimizer/graph/node_infershape.cc | 3 +- 28 files changed, 1003 insertions(+), 1123 deletions(-) create mode 100644 mindspore/lite/src/common/primitive_t_utils.cc create mode 100644 mindspore/lite/src/common/primitive_t_utils.h create mode 100644 mindspore/lite/src/cxx_api/kernel_executor/CMakeLists.txt create mode 100644 mindspore/lite/src/cxx_api/kernel_executor/kernel_executor.cc create mode 100644 mindspore/lite/src/cxx_api/kernel_executor/kernel_executor.h create mode 100644 mindspore/lite/src/cxx_api/kernel_executor/kernel_executor_impl.cc create mode 100644 mindspore/lite/src/cxx_api/kernel_executor/kernel_executor_impl.h create mode 100644 mindspore/lite/src/ops/anf_utils.cc create mode 100644 mindspore/lite/src/ops/anf_utils.h diff --git a/cmake/package_lite.cmake b/cmake/package_lite.cmake index 73a7782be2b..b581fba3bb6 100644 --- a/cmake/package_lite.cmake +++ b/cmake/package_lite.cmake @@ -642,6 +642,20 @@ if(PLATFORM_ARM64) endif() endif() endif() + if(MSLITE_ENABLE_KERNEL_EXECUTOR) + install(DIRECTORY ${TOP_DIR}/mindspore/core/ops/ DESTINATION ${RUNTIME_INC_DIR}/core/ops + COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h") + install(DIRECTORY ${TOP_DIR}/mindspore/core/mindapi/ DESTINATION ${RUNTIME_INC_DIR}/core/mindapi + COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h") + install(FILES ${TOP_DIR}/mindspore/lite/src/cxx_api/kernel_executor/kernel_executor.h DESTINATION + ${RUNTIME_INC_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) + install(TARGETS kernel_executor DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) + install(TARGETS mindspore_core DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) + if(MSLITE_ENABLE_CONVERTER) + install(FILES ${glog_LIBPATH}/libglog.so.0.4.0 DESTINATION ${GLOG_DIR} RENAME libglog.so.0 + COMPONENT ${RUNTIME_COMPONENT_NAME}) + endif() + endif() elseif(PLATFORM_ARM32) if(SUPPORT_NPU) install(FILES ${DDK_LIB_PATH}/libhiai.so DESTINATION ${RUNTIME_DIR}/third_party/hiai_ddk/lib @@ -998,6 +1012,20 @@ else() DESTINATION ${CROPPER_ROOT_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) endif() endif() + if(MSLITE_ENABLE_KERNEL_EXECUTOR) + install(DIRECTORY ${TOP_DIR}/mindspore/core/ops/ DESTINATION ${RUNTIME_INC_DIR}/core/ops + COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h") + install(DIRECTORY ${TOP_DIR}/mindspore/core/mindapi/ DESTINATION ${RUNTIME_INC_DIR}/core/mindapi + COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h") + install(FILES ${TOP_DIR}/mindspore/lite/src/cxx_api/kernel_executor/kernel_executor.h DESTINATION + ${RUNTIME_INC_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) + install(TARGETS kernel_executor DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) + install(TARGETS mindspore_core DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) + if(MSLITE_ENABLE_CONVERTER) + install(FILES ${glog_LIBPATH}/libglog.so.0.4.0 DESTINATION ${GLOG_DIR} RENAME libglog.so.0 + COMPONENT ${RUNTIME_COMPONENT_NAME}) + endif() + endif() endif() if(CMAKE_SYSTEM_NAME MATCHES "Windows") diff --git a/mindspore/core/base/complex_storage.h b/mindspore/core/base/complex_storage.h index 189fa0f1be2..cd79044c45e 100644 --- a/mindspore/core/base/complex_storage.h +++ b/mindspore/core/base/complex_storage.h @@ -35,9 +35,9 @@ struct alignas(sizeof(T) * 2) ComplexStorage { ComplexStorage &operator=(ComplexStorage &&other) noexcept = default; inline constexpr ComplexStorage(const T &real, const T &imag = T()) : real_(real), imag_(imag) {} - +#ifndef ENABLE_ARM inline explicit constexpr ComplexStorage(const float16 &real) : real_(static_cast(real)), imag_(T()) {} - +#endif template explicit ComplexStorage(const std::enable_if_t::value, ComplexStorage> &other) : real_(other.real_), imag_(other.imag_) {} diff --git a/mindspore/core/utils/log_adapter.cc b/mindspore/core/utils/log_adapter.cc index 507dfab813d..a9eba06c40c 100644 --- a/mindspore/core/utils/log_adapter.cc +++ b/mindspore/core/utils/log_adapter.cc @@ -25,10 +25,18 @@ #include #include #include "utils/convert_utils_base.h" +#ifdef ENABLE_ARM +#if defined(__ANDROID__) || defined(ANDROID) +#include +#endif +#endif // namespace to support utils module definition namespace mindspore { constexpr int kNameMaxLength = 18; +#if defined(__ANDROID__) || defined(ANDROID) +constexpr const char *ANDROID_LOG_TAG = "MS_LITE"; +#endif std::map acl_handle_map; // set default log level to WARNING for all sub modules int g_ms_submodule_log_levels[NUM_SUBMODUES] = {WARNING}; @@ -100,6 +108,38 @@ static int GetThresholdLevel(const std::string &threshold) { } } #undef google +#elif defined(BUILD_CORE_RUNTIME) +const char *EnumStrForMsLogLevel(MsLogLevel level) { + if (level == MsLogLevel::DEBUG) { + return "DEBUG"; + } else if (level == MsLogLevel::INFO) { + return "INFO"; + } else if (level == MsLogLevel::WARNING) { + return "WARNING"; + } else if (level == MsLogLevel::ERROR) { + return "ERROR"; + } else { + return "NO_LEVEL"; + } +} +#ifdef ENABLE_ARM +#if defined(__ANDROID__) || defined(ANDROID) +static int GetAndroidLogLevel(MsLogLevel level) { + switch (level) { + case MsLogLevel::DEBUG: + return ANDROID_LOG_DEBUG; + case MsLogLevel::INFO: + return ANDROID_LOG_INFO; + case MsLogLevel::WARNING: + return ANDROID_LOG_WARN; + case MsLogLevel::ERROR: + default: + return ANDROID_LOG_ERROR; + } +} +#endif +#endif + #else #undef Dlog @@ -153,6 +193,14 @@ void LogWriter::OutputLog(const std::ostringstream &msg) const { << std::this_thread::get_id() << std::dec << "," << GetProcName() << "):" << GetTimeString() << " " << "[" << location_.file_ << ":" << location_.line_ << "] " << location_.func_ << "] " << msg.str() << std::endl; #undef google +#elif defined(BUILD_CORE_RUNTIME) +#if defined(ENABLE_ARM) && (defined(__ANDROID__) || defined(ANDROID)) + __android_log_print(GetAndroidLogLevel(log_level_), ANDROID_LOG_TAG, "[%s:%d] %s] %s", location_.file_, + location_.line_, location_.func_, msg.str().c_str()); +#else + printf("%s [%s:%d] %s] %s\n", EnumStrForMsLogLevel(log_level_), location_.file_, location_.line_, location_.func_, + msg.str().c_str()); +#endif #else auto str_msg = msg.str(); auto slog_module_id = (submodule_ == SM_MD ? MD : ME); @@ -166,7 +214,7 @@ void LogWriter::operator<(const LogStream &stream) const noexcept { msg << stream.sstream_->rdbuf(); OutputLog(msg); } -#ifndef BUILD_LITE_INFERENCE +#if !defined(BUILD_LITE_INFERENCE) || defined(BUILD_CORE_RUNTIME) void LogWriter::operator^(const LogStream &stream) const { std::ostringstream msg; msg << stream.sstream_->rdbuf(); diff --git a/mindspore/core/utils/log_adapter.h b/mindspore/core/utils/log_adapter.h index 6df6e3819f7..985eabce0cd 100644 --- a/mindspore/core/utils/log_adapter.h +++ b/mindspore/core/utils/log_adapter.h @@ -34,6 +34,7 @@ #define google mindspore_private #include "glog/logging.h" #undef google +#elif defined(BUILD_CORE_RUNTIME) #else #include "toolchain/slog.h" #endif @@ -238,7 +239,7 @@ class MS_CORE_API LogWriter { /// \param[in] stream The input log stream. void operator<(const LogStream &stream) const noexcept; -#ifndef BUILD_LITE_INFERENCE +#if !defined(BUILD_LITE_INFERENCE) || defined(BUILD_CORE_RUNTIME) /// \brief Output log message from the input log stream and then throw exception. /// /// \param[in] stream The input log stream. @@ -266,7 +267,7 @@ class MS_CORE_API LogWriter { : mindspore::LogWriter(mindspore::LocationInfo(FILE_NAME, __LINE__, __FUNCTION__), level, SUBMODULE_ID, \ excp_type) < mindspore::LogStream() -#ifndef BUILD_LITE_INFERENCE +#if !defined(BUILD_LITE_INFERENCE) || defined(BUILD_CORE_RUNTIME) #define MSLOG_THROW(excp_type) \ mindspore::LogWriter(mindspore::LocationInfo(FILE_NAME, __LINE__, __FUNCTION__), mindspore::EXCEPTION, SUBMODULE_ID, \ excp_type) ^ \ diff --git a/mindspore/lite/CMakeLists.txt b/mindspore/lite/CMakeLists.txt index c52e1346c8f..71642e16756 100644 --- a/mindspore/lite/CMakeLists.txt +++ b/mindspore/lite/CMakeLists.txt @@ -56,6 +56,7 @@ option(MSLITE_ENABLE_SHARING_MODEL_WEIGHT "enable sharing model weight" off) option(MSLITE_ENABLE_EXPERIMENTAL_KERNEL "enable experimental kernel" on) option(MSLITE_ENABLE_GRAPH_KERNEL "enable graph kernel" off) option(MSLITE_ENABLE_CONVERT_PYTORCH_MODEL "enable to convert pytorch model" off) +option(MSLITE_ENABLE_KERNEL_EXECUTOR "enable kernel executor" off) #Option that can be configured through manually option(ENABLE_VERBOSE "" off) @@ -175,6 +176,9 @@ endif() if(DEFINED ENV{MSLITE_ENABLE_SERVING}) set(MSLITE_ENABLE_SERVING $ENV{MSLITE_ENABLE_SERVING}) endif() +if(DEFINED ENV{MSLITE_ENABLE_KERNEL_EXECUTOR}) + set(MSLITE_ENABLE_KERNEL_EXECUTOR $ENV{MSLITE_ENABLE_KERNEL_EXECUTOR}) +endif() if(DEFINED ENV{MSLITE_ENABLE_CONVERT_PYTORCH_MODEL} AND DEFINED ENV{LIB_TORCH_PATH}) set(ENABLE_CONVERT_PYTORCH_MODEL $ENV{MSLITE_ENABLE_CONVERT_PYTORCH_MODEL}) @@ -404,6 +408,7 @@ message(STATUS "\tMSLITE_ENABLE_PARALLEL_INFERENCE = \t${MSLITE_ENABLE message(STATUS "\tMSLITE_ENABLE_SHARING_MODEL_WEIGHT = \t${MSLITE_ENABLE_SHARING_MODEL_WEIGHT}") message(STATUS "\tMSLITE_ENABLE_EXPERIMENTAL_KERNEL = \t${MSLITE_ENABLE_EXPERIMENTAL_KERNEL}") message(STATUS "\tMSLITE_ENABLE_GRAPH_KERNEL = \t${MSLITE_ENABLE_GRAPH_KERNEL}") +message(STATUS "\tMSLITE_ENABLE_KERNEL_EXECUTOR = \t${MSLITE_ENABLE_KERNEL_EXECUTOR}") if((MSLITE_ENABLE_CONVERTER OR MSLITE_ENABLE_TESTCASES) AND ( NOT MSLITE_ENABLE_MINDRT @@ -531,7 +536,7 @@ if(MSLITE_GPU_BACKEND STREQUAL opencl) endif() if(MSLITE_ENABLE_CONVERTER OR MSLITE_MINDDATA_IMPLEMENT STREQUAL "full" OR MSLITE_MINDDATA_IMPLEMENT STREQUAL "wrapper" - OR MSLITE_ENABLE_TOOLS) + OR MSLITE_ENABLE_TOOLS OR MSLITE_ENABLE_KERNEL_EXECUTOR) if(NOT ENABLE_CLOUD_AND_LITE) include(${TOP_DIR}/cmake/external_libs/json.cmake) endif() @@ -635,6 +640,13 @@ function(find_required_package pkg_name) endif() endfunction() +if(MSLITE_ENABLE_CONVERTER OR MSLITE_ENABLE_KERNEL_EXECUTOR) + find_required_package(Patch) + if(NOT ENABLE_CLOUD_AND_LITE) + include(${TOP_DIR}/cmake/external_libs/protobuf.cmake) + endif() +endif() + if(MSLITE_ENABLE_CONVERTER) if(ENABLE_FAST_HASH_TABLE) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DENABLE_FAST_HASH_TABLE=1") @@ -645,7 +657,6 @@ if(MSLITE_ENABLE_CONVERTER) if(NOT ENABLE_CLOUD_AND_LITE) include(${TOP_DIR}/cmake/external_libs/opencv.cmake) include(${TOP_DIR}/cmake/external_libs/eigen.cmake) - include(${TOP_DIR}/cmake/external_libs/protobuf.cmake) endif() add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/tools/converter) endif() diff --git a/mindspore/lite/src/CMakeLists.txt b/mindspore/lite/src/CMakeLists.txt index 87e2eff5688..d81ecf40a71 100644 --- a/mindspore/lite/src/CMakeLists.txt +++ b/mindspore/lite/src/CMakeLists.txt @@ -187,7 +187,7 @@ if(MSLITE_ENABLE_RUNTIME_GLOG) add_definitions(-DUSE_GLOG) string(REPLACE "-fno-rtti" "" CMAKE_C_FLAGS ${CMAKE_C_FLAGS}) string(REPLACE "-fno-rtti" "" CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS}) - if(NOT MSLITE_ENABLE_RUNTIME_CONVERT) + if(NOT MSLITE_ENABLE_RUNTIME_CONVERT AND NOT MSLITE_ENABLE_KERNEL_EXECUTOR) set(LITE_SRC ${LITE_SRC} ${CORE_DIR}/utils/log_adapter.cc) endif() @@ -199,6 +199,7 @@ if(MSLITE_ENABLE_RUNTIME_CONVERT) file(GLOB RUNTIME_CONVERT_SRC ${CMAKE_CURRENT_SOURCE_DIR}/ops/ops_def.cc ${CMAKE_CURRENT_SOURCE_DIR}/ops/ops_utils.cc + ${CMAKE_CURRENT_SOURCE_DIR}/ops/anf_utils.cc ${CMAKE_CURRENT_SOURCE_DIR}/runtime/runtime_convert.cc) set(LITE_SRC ${LITE_SRC} ${RUNTIME_CONVERT_SRC}) @@ -466,6 +467,10 @@ if(SUPPORT_TRAIN) endif() endif() +if(MSLITE_ENABLE_KERNEL_EXECUTOR) + add_subdirectory(cxx_api/kernel_executor) +endif() + ########################## build optimize and float16 library ################################# if(PLATFORM_ARM) if(PLATFORM_ARM64 AND NOT TARGET_HIMIX AND NOT MACHINE_LINUX_ARM64) diff --git a/mindspore/lite/src/common/log_adapter.h b/mindspore/lite/src/common/log_adapter.h index 1da3de29e80..4178bb94b8d 100644 --- a/mindspore/lite/src/common/log_adapter.h +++ b/mindspore/lite/src/common/log_adapter.h @@ -16,7 +16,7 @@ #ifndef MINDSPORE_LITE_SRC_COMMON_LOG_ADAPTER_H_ #define MINDSPORE_LITE_SRC_COMMON_LOG_ADAPTER_H_ -#ifdef USE_GLOG +#if defined(USE_GLOG) || defined(BUILD_CORE_RUNTIME) #include "utils/log_adapter.h" #else #include "src/common/log.h" diff --git a/mindspore/lite/src/common/primitive_t_utils.cc b/mindspore/lite/src/common/primitive_t_utils.cc new file mode 100644 index 00000000000..ddae86e392a --- /dev/null +++ b/mindspore/lite/src/common/primitive_t_utils.cc @@ -0,0 +1,74 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/common/primitive_t_utils.h" +#include "src/ops/ops_utils.h" +#include "ops/primitive_c.h" + +namespace mindspore { +namespace lite { +constexpr size_t INITIAL_SIZE = 1024; +const schema::Primitive *ConvertToPrimitive(schema::PrimitiveT *primitive_t, flatbuffers::FlatBufferBuilder *fbb) { + if (primitive_t == nullptr || fbb == nullptr) { + MS_LOG(ERROR) << "primitiveT or fbb is nullptr."; + return nullptr; + } + auto prim_offset = schema::CreatePrimitive(*fbb, primitive_t); + fbb->Finish(prim_offset); + auto prim_buf = fbb->GetBufferPointer(); + return flatbuffers::GetRoot(prim_buf); +} + +OpParameter *GetOpParameter(schema::PrimitiveT *primitive_t) { + flatbuffers::FlatBufferBuilder fbb(INITIAL_SIZE); + auto primitive = ConvertToPrimitive(primitive_t, &fbb); + fbb.Clear(); + auto prim_type = GetPrimitiveType(primitive, SCHEMA_VERSION::SCHEMA_CUR); + auto parame_gen = PopulateRegistry::GetInstance()->GetParameterCreator(prim_type, SCHEMA_VERSION::SCHEMA_CUR); + if (parame_gen == nullptr) { + MS_LOG(ERROR) << "parameter generator is nullptr."; + return nullptr; + } + auto parameter = parame_gen(primitive); + if (parameter == nullptr) { + MS_LOG(ERROR) << "PopulateParameter return nullptr, type: " + << GetPrimitiveTypeName(primitive, SCHEMA_VERSION::SCHEMA_CUR); + } + return parameter; +} + +std::unique_ptr GetPrimitiveT(const std::shared_ptr &op) { + if (op == nullptr) { + MS_LOG(DEBUG) << "base operator is nullptr"; + return nullptr; + } + + if (op->name().empty()) { + MS_LOG(ERROR) << "the name of operator is null"; + return nullptr; + } + + MS_LOG(DEBUG) << "export operator: " << op->name(); + auto creator = MSOpsRegistry::GetInstance()->GetPrimitiveCreator(op->name()); + if (creator != nullptr) { + return creator(op->GetPrim()); + } else { + MS_LOG(WARNING) << "can not find SingleOpRegistry for operator: " << op->name(); + return nullptr; + } +} +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/common/primitive_t_utils.h b/mindspore/lite/src/common/primitive_t_utils.h new file mode 100644 index 00000000000..d5a88ba1407 --- /dev/null +++ b/mindspore/lite/src/common/primitive_t_utils.h @@ -0,0 +1,32 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_COMMON_PRIMITIVE_T_UTILS_H_ +#define MINDSPORE_LITE_SRC_COMMON_PRIMITIVE_T_UTILS_H_ + +#include +#include "schema/inner/model_generated.h" +#include "src/ops/populate/populate_register.h" +#include "ops/base_operator.h" + +namespace mindspore { +namespace lite { +const schema::Primitive *ConvertToPrimitive(schema::PrimitiveT *primitive_t, flatbuffers::FlatBufferBuilder *fbb); +OpParameter *GetOpParameter(schema::PrimitiveT *primitive_t); +std::unique_ptr GetPrimitiveT(const std::shared_ptr &op); +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_SRC_COMMON_PRIMITIVE_T_UTILS_H_ diff --git a/mindspore/lite/src/cxx_api/kernel_executor/CMakeLists.txt b/mindspore/lite/src/cxx_api/kernel_executor/CMakeLists.txt new file mode 100644 index 00000000000..516f411894a --- /dev/null +++ b/mindspore/lite/src/cxx_api/kernel_executor/CMakeLists.txt @@ -0,0 +1,33 @@ +add_compile_definitions(BUILD_CORE_RUNTIME) +add_definitions(-DPRIMITIVE_WRITEABLE) +if(MSLITE_ENABLE_RUNTIME_GLOG) + set(USE_GLOG on) + add_definitions(-DUSE_GLOG) +endif() +string(REPLACE "-fno-rtti" "" CMAKE_C_FLAGS ${CMAKE_C_FLAGS}) +string(REPLACE "-fno-rtti" "" CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS}) +string(REPLACE "-fno-exceptions" "" CMAKE_C_FLAGS ${CMAKE_C_FLAGS}) +string(REPLACE "-fno-exceptions" "" CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS}) +if(NOT MSLITE_ENABLE_CONVERTER) + add_subdirectory(${CORE_DIR} mindspore_core) +endif() + +add_library(kernel_executor SHARED + ${CMAKE_CURRENT_SOURCE_DIR}/kernel_executor.cc + ${CMAKE_CURRENT_SOURCE_DIR}/kernel_executor_impl.cc + ${TOP_DIR}/mindspore/lite/src/ops/ops_utils.cc + ${TOP_DIR}/mindspore/lite/src/common/primitive_t_utils.cc + ${TOP_DIR}/mindspore/lite/src/ops/ops_def.cc) + +add_dependencies(kernel_executor fbs_inner_src fbs_src mindspore_core) + +target_link_libraries(kernel_executor + mindspore-lite + mindspore_core + mindspore::json + mindspore::protobuf + mindspore::flatbuffers) + +if(USE_GLOG) + target_link_libraries(kernel_executor mindspore::glog) +endif() diff --git a/mindspore/lite/src/cxx_api/kernel_executor/kernel_executor.cc b/mindspore/lite/src/cxx_api/kernel_executor/kernel_executor.cc new file mode 100644 index 00000000000..13ff57d4109 --- /dev/null +++ b/mindspore/lite/src/cxx_api/kernel_executor/kernel_executor.cc @@ -0,0 +1,61 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/cxx_api/kernel_executor/kernel_executor.h" +#include "src/cxx_api/kernel_executor/kernel_executor_impl.h" + +namespace mindspore { +Status KernelExecutor::Build(const std::shared_ptr &op, const std::vector &inputs, + const std::vector &outputs, const std::shared_ptr &ms_context) { + if (impl_ == nullptr) { + impl_ = std::make_shared(); + if (impl_ == nullptr) { + MS_LOG(ERROR) << "implement is null."; + return kLiteNullptr; + } + } + + Status ret = impl_->Build(op, inputs, outputs, ms_context); + if (ret != kSuccess) { + return ret; + } + return kSuccess; +} + +Status KernelExecutor::ReSize(const std::vector &inputs, const std::vector &outputs) { + if (impl_ == nullptr) { + MS_LOG(ERROR) << "implement is null."; + return kLiteNullptr; + } + return impl_->ReSize(inputs, outputs); +} + +Status KernelExecutor::Infer(std::vector *outputs) { + if (impl_ == nullptr) { + MS_LOG(ERROR) << "implement is null."; + return kLiteNullptr; + } + return impl_->Infer(outputs); +} + +Status KernelExecutor::Execute(const std::vector &inputs, const std::vector &outputs) { + if (impl_ == nullptr) { + MS_LOG(ERROR) << "implement is null."; + return kLiteNullptr; + } + return impl_->Execute(inputs, outputs); +} +} // namespace mindspore diff --git a/mindspore/lite/src/cxx_api/kernel_executor/kernel_executor.h b/mindspore/lite/src/cxx_api/kernel_executor/kernel_executor.h new file mode 100644 index 00000000000..7cf8e3377fa --- /dev/null +++ b/mindspore/lite/src/cxx_api/kernel_executor/kernel_executor.h @@ -0,0 +1,73 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_CXX_API_KERNEL_EXECUTOR_KERNEL_EXECUTOR_H +#define MINDSPORE_LITE_SRC_CXX_API_KERNEL_EXECUTOR_KERNEL_EXECUTOR_H + +#include +#include +#include "include/api/types.h" +#include "include/api/status.h" +#include "include/api/context.h" +#include "ops/base_operator.h" + +namespace mindspore { +class KernelExecutorImpl; + +class MS_API KernelExecutor { + public: + KernelExecutor() = default; + ~KernelExecutor() = default; + + /// \brief Build a single operator so that it can run on a device. + /// + /// \param[in] op Define an operator pointer. + /// \param[in] ms_context Define the context used to store options during execution. + /// \param[in] inputs A vector where single operator inputs are arranged in sequence. + /// \param[in] outputs A vector where single operator outputs are arranged in sequence. + /// + /// \return Status. + Status Build(const std::shared_ptr &op, const std::vector &inputs, + const std::vector &outputs, const std::shared_ptr &ms_context); + + /// \brief ReSize KernelExecutor. + /// + /// \param[in] inputs A vector where single operator inputs are arranged in sequence. + /// \param[in] outputs A vector where single operator outputs are arranged in sequence. + /// + /// \return Status. + Status ReSize(const std::vector &inputs, const std::vector &outputs); + + /// \brief set outputs infer shape info. + /// + /// \param[in] outputs A vector where single operator outputs are arranged in sequence. + /// + /// \return Status. + Status Infer(std::vector *outputs); + + /// \brief ReSize KernelExecutor. + /// + /// \param[in] inputs A vector where single operator inputs are arranged in sequence. + /// \param[in] outputs A vector where single operator outputs are arranged in sequence. + /// + /// \return Status. + Status Execute(const std::vector &inputs, const std::vector &outputs); + + private: + std::shared_ptr impl_ = nullptr; +}; +} // namespace mindspore +#endif // MINDSPORE_LITE_SRC_CXX_API_KERNEL_EXECUTOR_KERNEL_EXECUTOR_H diff --git a/mindspore/lite/src/cxx_api/kernel_executor/kernel_executor_impl.cc b/mindspore/lite/src/cxx_api/kernel_executor/kernel_executor_impl.cc new file mode 100644 index 00000000000..4077a7a563b --- /dev/null +++ b/mindspore/lite/src/cxx_api/kernel_executor/kernel_executor_impl.cc @@ -0,0 +1,253 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "src/ops/ops_utils.h" +#include "src/cxx_api/converters.h" +#include "src/common/prim_util.h" +#include "src/ops/populate/populate_register.h" +#include "src/common/primitive_t_utils.h" +#include "schema/inner/model_generated.h" +#include "src/runtime/infer_manager.h" +#include "src/kernel_registry.h" +#include "src/cxx_api/kernel_executor/kernel_executor_impl.h" + +namespace mindspore { +constexpr size_t INITIAL_SIZE = 1024; + +KernelExecutorImpl::~KernelExecutorImpl() { + if (context_ != nullptr) { + delete context_; + context_ = nullptr; + } + + if (kernel_ != nullptr) { + delete kernel_; + kernel_ = nullptr; + } + FreeInOutTensor(); +} + +Status KernelExecutorImpl::Build(const std::shared_ptr &op, const std::vector &inputs, + const std::vector &outputs, const std::shared_ptr &ms_context) { + data_type_ = static_cast(inputs[FIRST_INPUT].DataType()); + std::unique_ptr prim_t = lite::GetPrimitiveT(op); + flatbuffers::FlatBufferBuilder fbb(INITIAL_SIZE); + primitive_ = lite::ConvertToPrimitive(prim_t.get(), &fbb); + fbb.Clear(); + if (primitive_ == nullptr) { + MS_LOG(ERROR) << "convert to primitive nullptr."; + return kLiteNullptr; + } + prim_type_ = lite::GetPrimitiveType(primitive_, schema_version_); + + context_ = ContextUtils::Convert(ms_context.get()); + if (context_ == nullptr) { + MS_LOG(ERROR) << "failed to convert Context to LiteContext."; + return kLiteNullptr; + } + int ret = context_->Init(); + if (ret != RET_OK) { + return static_cast(ret); + } + + Status status = InitInOutTensor(inputs, outputs); + if (status != kSuccess) { + MS_LOG(ERROR) << "InitInOutTensor error."; + return status; + } + + if (prim_type_ == schema::PrimitiveType_Custom) { + status = GetCustomKernel(ms_context); + } else { + status = GetCpuKernel(ms_context); + } + + if (status != kSuccess) { + MS_LOG(ERROR) << "get kernel error."; + return status; + } + ret = kernel_->Prepare(); + return static_cast(ret); +} + +Status KernelExecutorImpl::ReSize(const std::vector &inputs, const std::vector &outputs) { + Status status = InitInOutTensor(inputs, outputs); + if (status != kSuccess) { + MS_LOG(ERROR) << "InitInOutTensor error."; + return status; + } + kernel_->set_in_tensors(inputs_); + kernel_->set_out_tensors(outputs_); + int ret; + if (kernel_->type() == schema::PrimitiveType_Custom) { + ret = KernelInferShape(inputs_, outputs_, primitive_, context_->GetProviders(), schema_version_); + } else { + ret = KernelInferShape(inputs_, outputs_, parameter_); + } + if (ret != RET_OK) { + MS_LOG(ERROR) << "do infer shape error."; + return static_cast(ret); + } + ret = kernel_->ReSize(); + return static_cast(ret); +} +Status KernelExecutorImpl::Infer(std::vector *outputs) { + for (size_t i = 0; i < outputs->size(); ++i) { + auto user_output = outputs->at(i); + auto output = outputs_[i]; + user_output.SetFormat(output->format()); + auto output_shape = output->shape(); + std::vector shape; + std::transform(output_shape.begin(), output_shape.end(), std::back_inserter(shape), + [](auto s) { return static_cast(s); }); + user_output.SetShape(shape); + } + return kSuccess; +} + +Status KernelExecutorImpl::Execute(const std::vector &inputs, const std::vector &outputs) { + for (size_t i = 0; i < inputs.size(); ++i) { + auto user_input = inputs[i]; + auto input = inputs_[i]; + input->set_data(user_input.MutableData()); + } + + for (size_t i = 0; i < outputs.size(); ++i) { + auto user_output = outputs[i]; + auto output = outputs_[i]; + output->set_data(user_output.MutableData()); + } + int ret = kernel_->Execute(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "execute error."; + return static_cast(ret); + } + + return kSuccess; +} + +Status KernelExecutorImpl::GetOpParameter() { + auto parame_gen = lite::PopulateRegistry::GetInstance()->GetParameterCreator(prim_type_, schema_version_); + if (parame_gen == nullptr) { + MS_LOG(ERROR) << "parameter generator is nullptr."; + return kLiteNullptr; + } + parameter_ = parame_gen(primitive_); + if (parameter_ == nullptr) { + MS_LOG(ERROR) << "PopulateParameter return nullptr, type: " + << lite::GetPrimitiveTypeName(primitive_, schema_version_); + return kLiteNullptr; + } + return kSuccess; +} + +Status KernelExecutorImpl::GetCustomKernel(const std::shared_ptr &ms_context) { + int get_kernel = lite::RET_ERROR; + + // find kernel match arch, data_type, kernel_arch and provider + for (auto &&device : context_->device_list_) { + if (!device.provider_.empty() && !device.provider_device_.empty()) { + kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, data_type_, prim_type_, device.provider_device_, + device.provider_}; + get_kernel = lite::KernelRegistry::GetInstance()->GetKernel(inputs_, outputs_, context_, ms_context.get(), desc, + nullptr, &kernel_, primitive_); + } + } + + // find kernel only match arch and data_type + if (get_kernel != RET_OK) { + kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, data_type_, prim_type_, "", ""}; + get_kernel = lite::KernelRegistry::GetInstance()->GetKernel(inputs_, outputs_, context_, ms_context.get(), desc, + nullptr, &kernel_, primitive_); + } + + // if found kernel, do infershape + if (get_kernel == RET_OK) { + int ret = KernelInferShape(inputs_, outputs_, primitive_, context_->GetProviders(), schema_version_); + return static_cast(ret); + } + + return static_cast(get_kernel); +} + +Status KernelExecutorImpl::GetCpuKernel(const std::shared_ptr &ms_context) { + Status status = GetOpParameter(); + if (status != kSuccess) { + return status; + } + + kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, data_type_, prim_type_}; + int get_kernel = lite::KernelRegistry::GetInstance()->GetKernel(inputs_, outputs_, context_, ms_context.get(), desc, + parameter_, &kernel_); + if (get_kernel == RET_OK) { + int ret = KernelInferShape(inputs_, outputs_, parameter_); + return static_cast(ret); + } + + return static_cast(get_kernel); +} + +void KernelExecutorImpl::FreeInOutTensor() { + for (auto &input : inputs_) { + if (input != nullptr) { + delete input; + input = nullptr; + } + } + inputs_.clear(); + for (auto &output : outputs_) { + if (output != nullptr) { + delete output; + output = nullptr; + } + } + outputs_.clear(); +} + +Status KernelExecutorImpl::InitInOutTensor(const std::vector &inputs, const std::vector &outputs) { + FreeInOutTensor(); + for (auto input : inputs) { + auto input_shape = input.Shape(); + std::vector shape; + std::transform(input_shape.begin(), input_shape.end(), std::back_inserter(shape), + [](auto s) { return static_cast(s); }); + lite::Tensor *input_tensor = new (std::nothrow) + lite::Tensor(static_cast(input.DataType()), shape, input.format(), lite::Category::GRAPH_INPUT); + if (input_tensor == nullptr) { + delete input_tensor; + return kLiteNullptr; + } + input_tensor->set_data(input.MutableData()); + inputs_.emplace_back(input_tensor); + } + + for (auto output : outputs) { + auto output_shape = output.Shape(); + std::vector shape; + std::transform(output_shape.begin(), output_shape.end(), std::back_inserter(shape), + [](auto s) { return static_cast(s); }); + lite::Tensor *output_tensor = + new (std::nothrow) lite::Tensor(static_cast(output.DataType()), shape, output.format()); + if (output_tensor == nullptr) { + delete output_tensor; + return kLiteNullptr; + } + outputs_.emplace_back(output_tensor); + } + return kSuccess; +} +} // namespace mindspore diff --git a/mindspore/lite/src/cxx_api/kernel_executor/kernel_executor_impl.h b/mindspore/lite/src/cxx_api/kernel_executor/kernel_executor_impl.h new file mode 100644 index 00000000000..28cd1cb038d --- /dev/null +++ b/mindspore/lite/src/cxx_api/kernel_executor/kernel_executor_impl.h @@ -0,0 +1,56 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_CXX_API_KERNEL_EXECUTOR_KERNEL_EXECUTOR_IMPL_H +#define MINDSPORE_LITE_SRC_CXX_API_KERNEL_EXECUTOR_KERNEL_EXECUTOR_IMPL_H + +#include +#include +#include "src/cxx_api/kernel_executor/kernel_executor.h" +#include "src/kernel_exec.h" +#include "common/version_manager.h" + +namespace mindspore { +class KernelExecutorImpl { + public: + KernelExecutorImpl() = default; + ~KernelExecutorImpl(); + Status Build(const std::shared_ptr &op, const std::vector &inputs, + const std::vector &outputs, const std::shared_ptr &ms_context); + Status ReSize(const std::vector &inputs, const std::vector &outputs); + Status Infer(std::vector *outputs); + Status Execute(const std::vector &inputs, const std::vector &outputs); + + protected: + Status GetCustomKernel(const std::shared_ptr &ms_context); + Status GetCpuKernel(const std::shared_ptr &ms_context); + Status GetOpParameter(); + Status InitInOutTensor(const std::vector &inputs, const std::vector &outputs); + void FreeInOutTensor(); + + private: + const schema::Primitive *primitive_ = nullptr; + int prim_type_; + OpParameter *parameter_ = nullptr; + lite::InnerContext *context_ = nullptr; + TypeId data_type_; + kernel::KernelExec *kernel_ = nullptr; + std::vector inputs_; + std::vector outputs_; + int schema_version_ = lite::SCHEMA_VERSION::SCHEMA_CUR; +}; +} // namespace mindspore +#endif // MINDSPORE_LITE_SRC_CXX_API_KERNEL_EXECUTOR_KERNEL_EXECUTOR_IMPL_H diff --git a/mindspore/lite/src/ops/anf_utils.cc b/mindspore/lite/src/ops/anf_utils.cc new file mode 100644 index 00000000000..b1e3f3c5357 --- /dev/null +++ b/mindspore/lite/src/ops/anf_utils.cc @@ -0,0 +1,44 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/anf_utils.h" +#ifdef PRIMITIVE_WRITEABLE +namespace mindspore { +namespace lite { +std::unique_ptr GetPrimitiveT(const AnfNodePtr &node) { + auto prim = GetValueNode>(node); + if (prim == nullptr) { + MS_LOG(DEBUG) << "primitive is nullptr"; + return nullptr; + } + + if (prim->name().empty()) { + MS_LOG(ERROR) << "the name of primitive is null"; + return nullptr; + } + + MS_LOG(DEBUG) << "export prim: " << prim->name(); + auto creator = MSOpsRegistry::GetInstance()->GetPrimitiveCreator(prim->name()); + if (creator != nullptr) { + return creator(prim); + } else { + MS_LOG(WARNING) << "can not find MSOpsRegistry for op: " << prim->name(); + return nullptr; + } +} +} // namespace lite +} // namespace mindspore +#endif diff --git a/mindspore/lite/src/ops/anf_utils.h b/mindspore/lite/src/ops/anf_utils.h new file mode 100644 index 00000000000..2382b151d72 --- /dev/null +++ b/mindspore/lite/src/ops/anf_utils.h @@ -0,0 +1,29 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_OPS_ANF_UTILS_H_ +#define MINDSPORE_LITE_SRC_OPS_ANF_UTILS_H_ + +#include +#include "src/ops/ops_utils.h" +#ifdef PRIMITIVE_WRITEABLE +#include "abstract/ops/primitive_infer_map.h" +namespace mindspore { +namespace lite { +std::unique_ptr GetPrimitiveT(const mindspore::AnfNodePtr &node); +} +} // namespace mindspore +#endif +#endif // MINDSPORE_LITE_SRC_OPS_ANF_UTILS_H_ diff --git a/mindspore/lite/src/ops/ops_utils.cc b/mindspore/lite/src/ops/ops_utils.cc index 16f673b9dfe..14c28df3685 100644 --- a/mindspore/lite/src/ops/ops_utils.cc +++ b/mindspore/lite/src/ops/ops_utils.cc @@ -18,862 +18,22 @@ #include #include "src/ops/ops_utils.h" #include "mindapi/base/shared_ptr.h" - #ifdef PRIMITIVE_WRITEABLE -#include "mindspore/core/ir/anf.h" +#include "ops/primitive_c.h" namespace mindspore { namespace lite { -std::unique_ptr GetPrimitiveT(const AnfNodePtr &node) { - auto prim = GetValueNode>(node); - if (prim == nullptr) { - MS_LOG(DEBUG) << "primitive is nullptr"; - return nullptr; - } - - if (prim->name().empty()) { - MS_LOG(ERROR) << "the name of primitive is null"; - return nullptr; - } - - MS_LOG(DEBUG) << "export prim: " << prim->name(); - auto creator = MSOpsRegistry::GetInstance()->GetPrimitiveCreator(prim->name()); - if (creator != nullptr) { - return creator(node); - } else { - MS_LOG(WARNING) << "can not find MSOpsRegistry for op: " << prim->name(); - return nullptr; - } -} - -template -api::SharedPtr GetOperator(const AnfNodePtr &node) { - auto prim = GetValueNode(node); - if (prim == nullptr) { - return nullptr; - } - return api::MakeShared(prim); -} - -std::unique_ptr AbsPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr AbsGradPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr ActivationPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr ActivationGradPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr AdamPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr AdderFusionPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr AddFusionPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr AddGradPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr AddNPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr AllPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr ApplyMomentumPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr ArgMaxFusionPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr ArgMinFusionPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr AssertPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr AssignPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr AssignAddPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr AudioSpectrogramPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr AvgPoolFusionPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr AvgPoolGradPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr BatchNormPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr BatchToSpacePrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr BatchToSpaceNDPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr BiasAddPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr BiasAddGradPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr BNGradPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr BroadcastToPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr CastPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr CeilPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr ClipPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr ConcatPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr ConstantOfShapePrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr Conv2DBackpropFilterFusionPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr Conv2DBackpropInputFusionPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr Conv2DFusionPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr Conv2dTransposeFusionPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr CosPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr CropPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr CropAndResizePrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr CustomExtractFeaturesPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr CustomNormalizePrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr CustomPredictPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr DependPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr DepthToSpacePrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr DetectionPostProcessPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr DivFusionPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr DivGradPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr DropoutPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr DropoutGradPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr GRUPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr EltwisePrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr EluPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr EmbeddingLookupFusionPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr EqualPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr ExpandDimsPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr ExpFusionPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr FftImagPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr FftRealPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr FillPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr FlattenPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr FlattenGradPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr FloorPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr FloorDivPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr FloorModPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr FullConnectionPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr FusedBatchNormPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr GatherPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr GatherNdPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr GreaterPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr GreaterEqualPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr HashtableLookupPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr InstanceNormPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr InvertPermutationPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr LayerNormFusionPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr LayerNormGradPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr LeakyReluPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr LessPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr LessEqualPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr LogPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr LogGradPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr LogicalAndPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr LogicalNotPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr LogicalOrPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr LrnPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr LpNormalizationPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr LshProjectionPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr LSTMPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr LSTMGradPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr LSTMGradDataPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr LSTMGradWeightPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr L2NormalizeFusionPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr MatMulFusionPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr MaximumPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr MaximumGradPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr MaxPoolFusionPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr MaxPoolGradPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr SwitchLayerPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr MfccPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr MinimumPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr MinimumGradPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr ModPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr MulFusionPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr MulGradPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr NegPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr NegGradPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr NotEqualPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr NonMaxSuppressionPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr OneHotPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr OnesLikePrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr PadFusionPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr PartialFusionPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr PowerGradPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr PowFusionPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr PReLUFusionPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr QuantDTypeCastPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr RaggedRangePrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr RangePrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr RandomStandardNormalPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr RankPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr RealDivPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr ReciprocalPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr ReduceFusionPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr ReshapePrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr ResizePrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr ResizeGradPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr ReverseV2PrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr ReverseSequencePrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr RfftPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr ROIPoolingPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr RoundPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr RsqrtPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr RsqrtGradPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr ScaleFusionPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr ScatterNdPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr SelectPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr SGDPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr ShapePrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr SigmoidCrossEntropyWithLogitsPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr SigmoidCrossEntropyWithLogitsGradPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr SinPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr SizePrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr SkipGramPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr SliceFusionPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr SmoothL1LossPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr SmoothL1LossGradPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr SoftmaxPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr SoftmaxCrossEntropyWithLogitsPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr SpaceToBatchPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr SpaceToBatchNDPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr SpaceToDepthPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr SparseSoftmaxCrossEntropyWithLogitsPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr SparseToDensePrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr SplitPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr SqrtPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr SqrtGradPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr SquarePrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr SquaredDifferencePrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr SqueezePrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr StackPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr StridedSlicePrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr StridedSliceGradPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr SubFusionPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr SubGradPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr SwitchPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr TensorListFromTensorPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr TensorListGetItemPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr TensorListReservePrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr TensorListSetItemPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr TensorListStackPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr TileFusionPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr TopKFusionPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr TransposePrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr UniquePrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr UnstackPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr UnsortedSegmentSumPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr UnsqueezePrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr WherePrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr ZerosLikePrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} -std::unique_ptr ErfPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} - -std::unique_ptr SplicePrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} - -std::unique_ptr LogSoftmaxPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} - -std::unique_ptr CallPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} - -std::unique_ptr CumSumPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} - -std::unique_ptr SplitWithOverlapPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} - -std::unique_ptr GluPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} - -std::unique_ptr TensorArrayPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} - -std::unique_ptr TensorArrayReadPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} - -std::unique_ptr TensorArrayWritePrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} - -std::unique_ptr AffinePrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} - -std::unique_ptr AttentionPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} - -std::unique_ptr ScatterNdUpdatePrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} - -std::unique_ptr AllGatherPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} - -std::unique_ptr ReduceScatterPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} - -std::unique_ptr DynamicQuantPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} - -std::unique_ptr RandomNormalPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} - -std::unique_ptr NLLLossPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} - -std::unique_ptr NLLLossGradPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} - -std::unique_ptr FormatTransposePrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); - return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; -} - -std::unique_ptr CustomPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); +namespace ops { +std::unique_ptr MSOp2SchemaOp(const mindspore::ops::Custom *op) { auto schema_op = std::make_unique(); if (schema_op == nullptr) { return nullptr; } - if (ms_primc->GetAttr("type") != nullptr) { - schema_op->type = ms_primc->get_type(); + if (op->GetAttr("type") != nullptr) { + schema_op->type = op->get_type(); } - if (ms_primc->GetAttr("attr") != nullptr) { - auto attr_map = ms_primc->get_attr(); + if (op->GetAttr("attr") != nullptr) { + auto attr_map = op->get_attr(); for (const auto &attr_item : attr_map) { auto attr = std::make_unique(); if (attr == nullptr) { @@ -884,7 +44,6 @@ std::unique_ptr CustomPrimitiveCreator(const AnfNodePtr &nod schema_op->attr.emplace_back(std::move(attr)); } } - auto prim = std::make_unique(); if (prim == nullptr) { return nullptr; @@ -893,250 +52,222 @@ std::unique_ptr CustomPrimitiveCreator(const AnfNodePtr &nod prim->value.type = schema::PrimitiveType_Custom; return prim; } +} // namespace ops -std::unique_ptr UniformRealPrimitiveCreator(const AnfNodePtr &node) { - auto ms_primc = GetOperator(node); +template +std::unique_ptr PrimitiveCreator(const PrimitivePtr &primitive) { + auto ms_primc = api::MakeShared(primitive); return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; } -RegistryMSOps g_absPrimitiveCreatorRegistry("Abs", AbsPrimitiveCreator); -RegistryMSOps g_absGradPrimitiveCreatorRegistry("AbsGrad", AbsGradPrimitiveCreator); -RegistryMSOps g_activationPrimitiveCreatorRegistry("Activation", ActivationPrimitiveCreator); -RegistryMSOps g_activationGradPrimitiveCreatorRegistry("ActivationGrad", ActivationGradPrimitiveCreator); -RegistryMSOps g_reluGradPrimitiveCreatorRegistry("ReluGrad", ActivationGradPrimitiveCreator); // ? -RegistryMSOps g_addPrimitiveCreatorRegistry("Add", AddFusionPrimitiveCreator); -RegistryMSOps g_addFusionPrimitiveCreatorRegistry("AddFusion", AddFusionPrimitiveCreator); -RegistryMSOps g_addGradPrimitiveCreatorRegistry("AddGrad", AddGradPrimitiveCreator); -RegistryMSOps g_adamPrimitiveCreatorRegistry("Adam", AdamPrimitiveCreator); -RegistryMSOps g_adderPrimitiveCreatorRegistry("Adder", AdderFusionPrimitiveCreator); -RegistryMSOps g_adderFusionPrimitiveCreatorRegistry("AdderFusion", AdderFusionPrimitiveCreator); -RegistryMSOps g_addNPrimitiveCreatorRegistry("AddN", AddNPrimitiveCreator); -RegistryMSOps g_allPrimitiveCreatorRegistry("All", AllPrimitiveCreator); -RegistryMSOps g_applyMomentumPrimitiveCreatorRegistry("ApplyMomentum", ApplyMomentumPrimitiveCreator); -RegistryMSOps g_argMaxPrimitiveCreatorRegistry("ArgMax", ArgMaxFusionPrimitiveCreator); -RegistryMSOps g_argMaxFusionPrimitiveCreatorRegistry("ArgMaxFusion", ArgMaxFusionPrimitiveCreator); -RegistryMSOps g_argMinPrimitiveCreatorRegistry("ArgMin", ArgMinFusionPrimitiveCreator); -RegistryMSOps g_argMinFusionPrimitiveCreatorRegistry("ArgMinFusion", ArgMinFusionPrimitiveCreator); -RegistryMSOps g_assertPrimitiveCreatorRegistry("Assert", AssertPrimitiveCreator); -RegistryMSOps g_assignPrimitiveCreatorRegistry("Assign", AssignPrimitiveCreator); -RegistryMSOps g_assignAddPrimitiveCreatorRegistry("AssignAdd", AssignAddPrimitiveCreator); -RegistryMSOps g_audioSpectrogramPrimitiveCreatorRegistry("AudioSpectrogram", AudioSpectrogramPrimitiveCreator); -RegistryMSOps g_avgPoolPrimitiveCreatorRegistry("AvgPool", AvgPoolFusionPrimitiveCreator); -RegistryMSOps g_avgPoolFusionPrimitiveCreatorRegistry("AvgPoolFusion", AvgPoolFusionPrimitiveCreator); -RegistryMSOps g_avgPoolGradPrimitiveCreatorRegistry("AvgPoolGrad", AvgPoolGradPrimitiveCreator); -RegistryMSOps g_avgPoolGradGpuPrimitiveCreatorRegistry("AvgPoolGradGpu", AvgPoolGradPrimitiveCreator); -RegistryMSOps g_avgPoolGradCpuPrimitiveCreatorRegistry("AvgPoolGradCpu", AvgPoolGradPrimitiveCreator); -RegistryMSOps g_batchNormPrimitiveCreatorRegistry("BatchNorm", BatchNormPrimitiveCreator); -RegistryMSOps g_batchToSpacePrimitiveCreatorRegistry("BatchToSpace", BatchToSpacePrimitiveCreator); -RegistryMSOps g_batchToSpaceNDPrimitiveCreatorRegistry("BatchToSpaceND", BatchToSpaceNDPrimitiveCreator); -RegistryMSOps g_biasAddPrimitiveCreatorRegistry("BiasAdd", BiasAddPrimitiveCreator); -RegistryMSOps g_biasGradAddPrimitiveCreatorRegistry("BiasGrad", BiasAddGradPrimitiveCreator); -RegistryMSOps g_biasAddGradAddPrimitiveCreatorRegistry("BiasAddGrad", BiasAddGradPrimitiveCreator); -RegistryMSOps g_bNGradPrimitiveCreatorRegistry("BatchNormGrad", BNGradPrimitiveCreator); -RegistryMSOps g_broadcastToPrimitiveCreatorRegistry("BroadcastTo", BroadcastToPrimitiveCreator); -RegistryMSOps g_castPrimitiveCreatorRegistry("Cast", CastPrimitiveCreator); -RegistryMSOps g_ceilPrimitiveCreatorRegistry("Ceil", CeilPrimitiveCreator); -RegistryMSOps g_clipPrimitiveCreatorRegistry("Clip", ClipPrimitiveCreator); -RegistryMSOps g_concatPrimitiveCreatorRegistry("Concat", ConcatPrimitiveCreator); -RegistryMSOps g_conv2DBackpropFilterFusionPrimitiveCreatorRegistry("Conv2DBackpropFilterFusion", - Conv2DBackpropFilterFusionPrimitiveCreator); -RegistryMSOps g_conv2DBackpropInputFusionPrimitiveCreatorRegistry("Conv2DBackpropInputFusion", - Conv2DBackpropInputFusionPrimitiveCreator); -RegistryMSOps g_conv2DPrimitiveCreatorRegistry("Conv2D", Conv2DFusionPrimitiveCreator); -RegistryMSOps g_conv2DFusionPrimitiveCreatorRegistry("Conv2DFusion", Conv2DFusionPrimitiveCreator); -RegistryMSOps g_conv2dTransposePrimitiveCreatorRegistry("Conv2dTranspose", Conv2dTransposeFusionPrimitiveCreator); -RegistryMSOps g_conv2dTransposeFusionPrimitiveCreatorRegistry("Conv2dTransposeFusion", - Conv2dTransposeFusionPrimitiveCreator); -RegistryMSOps g_constantOfShapePrimitiveCreatorRegistry("ConstantOfShape", ConstantOfShapePrimitiveCreator); -RegistryMSOps g_cosPrimitiveCreatorRegistry("Cos", CosPrimitiveCreator); -RegistryMSOps g_cropPrimitiveCreatorRegistry("Crop", CropPrimitiveCreator); -RegistryMSOps g_cropAndResizePrimitiveCreatorRegistry("CropAndResize", CropAndResizePrimitiveCreator); -RegistryMSOps g_customExtractFeaturesPrimitiveCreatorRegistry("CustomExtractFeatures", - CustomExtractFeaturesPrimitiveCreator); -RegistryMSOps g_customNormalizePrimitiveCreatorRegistry("CustomNormalize", CustomNormalizePrimitiveCreator); -RegistryMSOps g_customPredictPrimitiveCreatorRegistry("CustomPredict", CustomPredictPrimitiveCreator); -RegistryMSOps g_dependPrimitiveCreatorRegistry("Depend", DependPrimitiveCreator); -RegistryMSOps g_depthToSpacePrimitiveCreatorRegistry("DepthToSpace", DepthToSpacePrimitiveCreator); -RegistryMSOps g_detectionPostProcessPrimitiveCreatorRegistry("DetectionPostProcess", - DetectionPostProcessPrimitiveCreator); -RegistryMSOps g_divPrimitiveCreatorRegistry("Div", DivFusionPrimitiveCreator); -RegistryMSOps g_divFusionPrimitiveCreatorRegistry("DivFusion", DivFusionPrimitiveCreator); -RegistryMSOps g_divGradPrimitiveCreatorRegistry("DivGrad", DivGradPrimitiveCreator); -RegistryMSOps g_dropoutPrimitiveCreatorRegistry("Dropout", DropoutPrimitiveCreator); -RegistryMSOps g_dropoutGradPrimitiveCreatorRegistry("DropoutGrad", DropoutGradPrimitiveCreator); -RegistryMSOps g_eltwisePrimitiveCreatorRegistry("Eltwise", EltwisePrimitiveCreator); -RegistryMSOps g_eluPrimitiveCreatorRegistry("Elu", EluPrimitiveCreator); -RegistryMSOps g_eluGradPrimitiveCreatorRegistry("EluGrad", ActivationGradPrimitiveCreator); -RegistryMSOps g_equalPrimitiveCreatorRegistry("Equal", EqualPrimitiveCreator); -RegistryMSOps g_embeddingLookupFusionPrimitiveCreatorRegistry("EmbeddingLookupFusion", - EmbeddingLookupFusionPrimitiveCreator); -RegistryMSOps g_expandDimsPrimitiveCreatorRegistry("ExpandDims", ExpandDimsPrimitiveCreator); -RegistryMSOps g_expPrimitiveCreatorRegistry("Exp", ExpFusionPrimitiveCreator); -RegistryMSOps g_expFusionPrimitiveCreatorRegistry("ExpFusion", ExpFusionPrimitiveCreator); -RegistryMSOps g_fftImagPrimitiveCreatorRegistry("FftImag", FftImagPrimitiveCreator); -RegistryMSOps g_fftRealPrimitiveCreatorRegistry("FftReal", FftRealPrimitiveCreator); -RegistryMSOps g_fillPrimitiveCreatorRegistry("Fill", FillPrimitiveCreator); -RegistryMSOps g_flattenPrimitiveCreatorRegistry("Flatten", FlattenPrimitiveCreator); -RegistryMSOps g_flattenGradPrimitiveCreatorRegistry("FlattenGrad", FlattenGradPrimitiveCreator); -RegistryMSOps g_floorPrimitiveCreatorRegistry("Floor", FloorPrimitiveCreator); -RegistryMSOps g_floorDivPrimitiveCreatorRegistry("FloorDiv", FloorDivPrimitiveCreator); -RegistryMSOps g_floorModPrimitiveCreatorRegistry("FloorMod", FloorModPrimitiveCreator); -RegistryMSOps g_fullConnectionPrimitiveCreatorRegistry("FullConnection", FullConnectionPrimitiveCreator); -RegistryMSOps g_fusedBatchNormPrimitiveCreatorRegistry("FusedBatchNorm", FusedBatchNormPrimitiveCreator); -RegistryMSOps g_gatherPrimitiveCreatorRegistry("Gather", GatherPrimitiveCreator); -RegistryMSOps g_gatherNdPrimitiveCreatorRegistry("GatherNd", GatherNdPrimitiveCreator); -RegistryMSOps g_greaterPrimitiveCreatorRegistry("Greater", GreaterPrimitiveCreator); -RegistryMSOps g_greaterEqualPrimitiveCreatorRegistry("GreaterEqual", GreaterEqualPrimitiveCreator); -RegistryMSOps g_gRUPrimitiveCreatorRegistry("GRU", GRUPrimitiveCreator); -RegistryMSOps g_hashtableLookupPrimitiveCreatorRegistry("HashtableLookup", HashtableLookupPrimitiveCreator); -RegistryMSOps g_instanceNormPrimitiveCreatorRegistry("InstanceNorm", InstanceNormPrimitiveCreator); -RegistryMSOps g_invertPermutationPrimitiveCreatorRegistry("InvertPermutation", InvertPermutationPrimitiveCreator); -RegistryMSOps g_layerNormPrimitiveCreatorRegistry("LayerNorm", LayerNormFusionPrimitiveCreator); -RegistryMSOps g_layerNormFusionPrimitiveCreatorRegistry("LayerNormFusion", LayerNormFusionPrimitiveCreator); -RegistryMSOps g_layerNormGradPrimitiveCreatorRegistry("LayerNormGrad", LayerNormGradPrimitiveCreator); -RegistryMSOps g_leakyReluPrimitiveCreatorRegistry("LeakyRelu", LeakyReluPrimitiveCreator); -RegistryMSOps g_lessPrimitiveCreatorRegistry("Less", LessPrimitiveCreator); -RegistryMSOps g_lessEqualPrimitiveCreatorRegistry("LessEqual", LessEqualPrimitiveCreator); -RegistryMSOps g_logPrimitiveCreatorRegistry("Log", LogPrimitiveCreator); -RegistryMSOps g_logGradPrimitiveCreatorRegistry("LogGrad", LogGradPrimitiveCreator); -RegistryMSOps g_logicalAndPrimitiveCreatorRegistry("LogicalAnd", LogicalAndPrimitiveCreator); -RegistryMSOps g_logicalNotPrimitiveCreatorRegistry("LogicalNot", LogicalNotPrimitiveCreator); -RegistryMSOps g_logicalOrPrimitiveCreatorRegistry("LogicalOr", LogicalOrPrimitiveCreator); -RegistryMSOps g_lpNormalizationPrimitiveCreatorRegistry("LpNormalization", LpNormalizationPrimitiveCreator); -RegistryMSOps g_lrnPrimitiveCreatorRegistry("LRN", LrnPrimitiveCreator); -RegistryMSOps g_lshProjectionPrimitiveCreatorRegistry("LshProjection", LshProjectionPrimitiveCreator); -RegistryMSOps g_lSTMPrimitiveCreatorRegistry("LSTM", LSTMPrimitiveCreator); -RegistryMSOps g_lSTMGradPrimitiveCreatorRegistry("LSTMGrad", LSTMGradPrimitiveCreator); -RegistryMSOps g_lSTMGradDataPrimitiveCreatorRegistry("LSTMGradData", LSTMGradDataPrimitiveCreator); -RegistryMSOps g_lSTMGradWeightPrimitiveCreatorRegistry("LSTMGradWeight", LSTMGradWeightPrimitiveCreator); -RegistryMSOps g_l2NormalizeFusionPrimitiveCreatorRegistry("L2NormalizeFusion", L2NormalizeFusionPrimitiveCreator); -RegistryMSOps g_matMulFusionPrimitiveCreatorRegistry("MatMulFusion", MatMulFusionPrimitiveCreator); -RegistryMSOps g_matMulPrimitiveCreatorRegistry("MatMul", MatMulFusionPrimitiveCreator); -RegistryMSOps g_maximumPrimitiveCreatorRegistry("Maximum", MaximumPrimitiveCreator); -RegistryMSOps g_maximumGradPrimitiveCreatorRegistry("MaximumGrad", MaximumGradPrimitiveCreator); -RegistryMSOps g_maxPoolPrimitiveCreatorRegistry("MaxPool", MaxPoolFusionPrimitiveCreator); -RegistryMSOps g_maxPoolFusionPrimitiveCreatorRegistry("MaxPoolFusion", MaxPoolFusionPrimitiveCreator); -RegistryMSOps g_maxPoolGradPrimitiveCreatorRegistry("MaxPoolGrad", MaxPoolGradPrimitiveCreator); -RegistryMSOps g_mergePrimitiveCreatorRegistry("switch_layer", SwitchLayerPrimitiveCreator); -RegistryMSOps g_mfccPrimitiveCreatorRegistry("Mfcc", MfccPrimitiveCreator); -RegistryMSOps g_minimumPrimitiveCreatorRegistry("Minimum", MinimumPrimitiveCreator); -RegistryMSOps g_minimumGradPrimitiveCreatorRegistry("MinimumGrad", MinimumGradPrimitiveCreator); -RegistryMSOps g_modPrimitiveCreatorRegistry("Mod", ModPrimitiveCreator); -RegistryMSOps g_mulPrimitiveCreatorRegistry("Mul", MulFusionPrimitiveCreator); -RegistryMSOps g_mulMulFusionPrimitiveCreatorRegistry("MulFusion", MulFusionPrimitiveCreator); -RegistryMSOps g_mulGradPrimitiveCreatorRegistry("MulGrad", MulGradPrimitiveCreator); -RegistryMSOps g_negPrimitiveCreatorRegistry("Neg", NegPrimitiveCreator); -RegistryMSOps g_negGradPrimitiveCreatorRegistry("NegGrad", NegGradPrimitiveCreator); -RegistryMSOps g_nonMaxSuppressionPrimitiveCreatorRegistry("NonMaxSuppression", NonMaxSuppressionPrimitiveCreator); -RegistryMSOps g_notEqualPrimitiveCreatorRegistry("NotEqual", NotEqualPrimitiveCreator); -RegistryMSOps g_oneHotPrimitiveCreatorRegistry("OneHot", OneHotPrimitiveCreator); -RegistryMSOps g_onesLikePrimitiveCreatorRegistry("OnesLike", OnesLikePrimitiveCreator); -RegistryMSOps g_padPrimitiveCreatorRegistry("Pad", PadFusionPrimitiveCreator); -RegistryMSOps g_padFusionPrimitiveCreatorRegistry("PadFusion", PadFusionPrimitiveCreator); -RegistryMSOps g_partialFusionPrimitiveCreatorRegistry("PartialFusion", PartialFusionPrimitiveCreator); -RegistryMSOps g_powerGradPrimitiveCreatorRegistry("PowerGrad", PowerGradPrimitiveCreator); -RegistryMSOps g_powFusionPrimitiveCreatorRegistry("PowFusion", PowFusionPrimitiveCreator); -RegistryMSOps g_pReLUFusionPrimitiveCreatorRegistry("PReLUFusion", PReLUFusionPrimitiveCreator); -RegistryMSOps g_RandomStandardNormalPrimitiveCreatorRegistry("RandomStandardNormal", - RandomStandardNormalPrimitiveCreator); -RegistryMSOps g_StandardNormalPrimitiveCreatorRegistry("StandardNormal", RandomStandardNormalPrimitiveCreator); -RegistryMSOps g_raggedRangePrimitiveCreatorRegistry("RaggedRange", RaggedRangePrimitiveCreator); -RegistryMSOps g_rangePrimitiveCreatorRegistry("Range", RangePrimitiveCreator); -RegistryMSOps g_rankPrimitiveCreatorRegistry("Rank", RankPrimitiveCreator); -RegistryMSOps g_reciprocalPrimitiveCreatorRegistry("Reciprocal", ReciprocalPrimitiveCreator); -RegistryMSOps g_realDivPrimitiveCreatorRegistry("RealDiv", RealDivPrimitiveCreator); -RegistryMSOps g_reducePrimitiveCreatorRegistry("Reduce", ReduceFusionPrimitiveCreator); -RegistryMSOps g_reduceFusionPrimitiveCreatorRegistry("ReduceFusion", ReduceFusionPrimitiveCreator); -RegistryMSOps g_reshapePrimitiveCreatorRegistry("Reshape", ReshapePrimitiveCreator); -RegistryMSOps g_resizePrimitiveCreatorRegistry("Resize", ResizePrimitiveCreator); -RegistryMSOps g_resizeGradPrimitiveCreatorRegistry("ResizeGrad", ResizeGradPrimitiveCreator); -RegistryMSOps g_resizeBilinearGradPrimitiveCreatorRegistry("ResizeBilinearGrad", ResizeGradPrimitiveCreator); -RegistryMSOps g_resizeNearestNeighborGradPrimitiveCreatorRegistry("ResizeNearestNeighborGrad", - ResizeGradPrimitiveCreator); -RegistryMSOps g_reverseV2PrimitiveCreatorRegistry("ReverseV2", ReverseV2PrimitiveCreator); -RegistryMSOps g_reverseSequencePrimitiveCreatorRegistry("ReverseSequence", ReverseSequencePrimitiveCreator); -RegistryMSOps g_rfftPrimitiveCreatorRegistry("Rfft", RfftPrimitiveCreator); -RegistryMSOps g_rOIPoolingPrimitiveCreatorRegistry("ROIPooling", ROIPoolingPrimitiveCreator); -RegistryMSOps g_roundPrimitiveCreatorRegistry("Round", RoundPrimitiveCreator); -RegistryMSOps g_rsqrtPrimitiveCreatorRegistry("Rsqrt", RsqrtPrimitiveCreator); -RegistryMSOps g_rsqrtGradPrimitiveCreatorRegistry("RsqrtGrad", RsqrtGradPrimitiveCreator); -RegistryMSOps g_quantDTypeCastPrimitiveCreatorRegistry("QuantDTypeCast", QuantDTypeCastPrimitiveCreator); -RegistryMSOps g_scalePrimitiveCreatorRegistry("Scale", ScaleFusionPrimitiveCreator); -RegistryMSOps g_scaleFusionPrimitiveCreatorRegistry("ScaleFusion", ScaleFusionPrimitiveCreator); -RegistryMSOps g_scatterNdPrimitiveCreatorRegistry("ScatterNd", ScatterNdPrimitiveCreator); -RegistryMSOps g_selectPrimitiveCreatorRegistry("Select", SelectPrimitiveCreator); -RegistryMSOps g_SGDPrimitiveCreatorRegistry("SGD", SGDPrimitiveCreator); -RegistryMSOps g_shapePrimitiveCreatorRegistry("Shape", ShapePrimitiveCreator); -RegistryMSOps g_sigmoidCrossEntropyWithLogitsPrimitiveCreatorRegistry("SigmoidCrossEntropyWithLogits", - SigmoidCrossEntropyWithLogitsPrimitiveCreator); -RegistryMSOps g_sigmoidCrossEntropyWithLogitsGradPrimitiveCreatorRegistry( - "SigmoidCrossEntropyWithLogitsGrad", SigmoidCrossEntropyWithLogitsGradPrimitiveCreator); -RegistryMSOps g_sinPrimitiveCreatorRegistry("Sin", SinPrimitiveCreator); -RegistryMSOps g_sizePrimitiveCreatorRegistry("Size", SizePrimitiveCreator); -RegistryMSOps g_skipGramPrimitiveCreatorRegistry("SkipGram", SkipGramPrimitiveCreator); -RegistryMSOps g_sliceFusionPrimitiveCreatorRegistry("SliceFusion", SliceFusionPrimitiveCreator); -RegistryMSOps g_smoothL1LossPrimitiveCreatorRegistry("SmoothL1Loss", SmoothL1LossPrimitiveCreator); -RegistryMSOps g_smoothL1LossGradPrimitiveCreatorRegistry("SmoothL1LossGrad", SmoothL1LossGradPrimitiveCreator); -RegistryMSOps g_softmaxPrimitiveCreatorRegistry("Softmax", SoftmaxPrimitiveCreator); -RegistryMSOps g_softmaxCrossEntropyWithLogitsPrimitiveCreatorRegistry("SoftmaxCrossEntropyWithLogits", - SoftmaxCrossEntropyWithLogitsPrimitiveCreator); -RegistryMSOps g_spaceToBatchPrimitiveCreatorRegistry("SpaceToBatch", SpaceToBatchPrimitiveCreator); -RegistryMSOps g_spaceToBatchNDPrimitiveCreatorRegistry("SpaceToBatchND", SpaceToBatchNDPrimitiveCreator); -RegistryMSOps g_spaceToDepthPrimitiveCreatorRegistry("SpaceToDepth", SpaceToDepthPrimitiveCreator); -RegistryMSOps g_sparseSoftmaxCrossEntropyWithLogitsPrimitiveCreatorRegistry( - "SparseSoftmaxCrossEntropyWithLogits", SparseSoftmaxCrossEntropyWithLogitsPrimitiveCreator); -RegistryMSOps g_sparseToDensePrimitiveCreatorRegistry("SparseToDense", SparseToDensePrimitiveCreator); -RegistryMSOps g_splitPrimitiveCreatorRegistry("Split", SplitPrimitiveCreator); -RegistryMSOps g_sqrtPrimitiveCreatorRegistry("Sqrt", SqrtPrimitiveCreator); -RegistryMSOps g_sqrtGradPrimitiveCreatorRegistry("SqrtGrad", SqrtGradPrimitiveCreator); -RegistryMSOps g_squeezePrimitiveCreatorRegistry("Squeeze", SqueezePrimitiveCreator); -RegistryMSOps g_squarePrimitiveCreatorRegistry("Square", SquarePrimitiveCreator); -RegistryMSOps g_squaredDifferencePrimitiveCreatorRegistry("SquaredDifference", SquaredDifferencePrimitiveCreator); -RegistryMSOps g_stackPrimitiveCreatorRegistry("Stack", StackPrimitiveCreator); -RegistryMSOps g_stridedSlicePrimitiveCreatorRegistry("StridedSlice", StridedSlicePrimitiveCreator); -RegistryMSOps g_stridedSliceGradPrimitiveCreatorRegistry("StridedSliceGrad", StridedSliceGradPrimitiveCreator); -RegistryMSOps g_subPrimitiveCreatorRegistry("Sub", SubFusionPrimitiveCreator); -RegistryMSOps g_subFusionPrimitiveCreatorRegistry("SubFusion", SubFusionPrimitiveCreator); -RegistryMSOps g_subGradPrimitiveCreatorRegistry("SubGrad", SubGradPrimitiveCreator); -RegistryMSOps g_switchPrimitiveCreatorRegistry("Switch", SwitchPrimitiveCreator); -RegistryMSOps g_tensorListFromTensorPrimitiveCreatorRegistry("TensorListFromTensor", - TensorListFromTensorPrimitiveCreator); -RegistryMSOps g_tensorListGetItemPrimitiveCreatorRegistry("TensorListGetItem", TensorListGetItemPrimitiveCreator); -RegistryMSOps g_tensorListReservePrimitiveCreatorRegistry("TensorListReserve", TensorListReservePrimitiveCreator); -RegistryMSOps g_tensorListSetItemPrimitiveCreatorRegistry("TensorListSetItem", TensorListSetItemPrimitiveCreator); -RegistryMSOps g_tensorListStackPrimitiveCreatorRegistry("TensorListStack", TensorListStackPrimitiveCreator); -RegistryMSOps g_tileFusionPrimitiveCreatorRegistry("TileFusion", TileFusionPrimitiveCreator); -RegistryMSOps g_topKPrimitiveCreatorRegistry("TopK", TopKFusionPrimitiveCreator); -RegistryMSOps g_topKFusionPrimitiveCreatorRegistry("TopKFusion", TopKFusionPrimitiveCreator); -RegistryMSOps g_transposePrimitiveCreatorxRegistry("Transpose", TransposePrimitiveCreator); -RegistryMSOps g_uniquePrimitiveCreatorRegistry("Unique", UniquePrimitiveCreator); -RegistryMSOps g_unstackPrimitiveCreatorRegistry("Unstack", UnstackPrimitiveCreator); -RegistryMSOps g_unsortedSegmentSumPrimitiveCreatorRegistry("UnsortedSegmentSum", UnsortedSegmentSumPrimitiveCreator); -RegistryMSOps g_unsqueezePrimitiveCreatorRegistry("Unsqueeze", UnsqueezePrimitiveCreator); -RegistryMSOps g_wherePrimitiveCreatorRegistry("Where", WherePrimitiveCreator); -RegistryMSOps g_zerosLikePrimitiveCreatorRegistry("ZerosLike", ZerosLikePrimitiveCreator); -RegistryMSOps g_erfPrimitiveCreatorRegistry("Erf", ErfPrimitiveCreator); -RegistryMSOps g_SplicePrimitiveCreatorRegistry("Splice", SplicePrimitiveCreator); -RegistryMSOps g_LogSoftmaxPrimitiveCreatorRegistry("LogSoftmax", LogSoftmaxPrimitiveCreator); -RegistryMSOps g_CallPrimitiveCreatorRegistry("call", CallPrimitiveCreator); -RegistryMSOps g_CumSumPrimitiveCreatorRegistry("CumSum", CumSumPrimitiveCreator); -RegistryMSOps g_SplitWithOverlapCreatorRegistry("SplitWithOverlap", SplitWithOverlapPrimitiveCreator); -RegistryMSOps g_GluCreatorRegistry("GLU", GluPrimitiveCreator); -RegistryMSOps g_TensorArrayCreatorRegistry("TensorArray", TensorArrayPrimitiveCreator); -RegistryMSOps g_TensorArrayReadCreatorRegistry("TensorArrayRead", TensorArrayReadPrimitiveCreator); -RegistryMSOps g_TensorArrayWriteCreatorRegistry("TensorArrayWrite", TensorArrayWritePrimitiveCreator); -RegistryMSOps g_AffineCreatorRegistry("Affine", AffinePrimitiveCreator); -RegistryMSOps g_AttentionCreatorRegistry("Attention", AttentionPrimitiveCreator); -RegistryMSOps g_ScatterNdUpdateCreatorRegistry("ScatterNdUpdate", ScatterNdUpdatePrimitiveCreator); -RegistryMSOps g_AllGatherCreatorRegistry("AllGather", AllGatherPrimitiveCreator); -RegistryMSOps g_ReduceScatterCreatorRegistry("ReduceScatter", ReduceScatterPrimitiveCreator); -RegistryMSOps g_DynamicQuantCreatorRegistry("DynamicQuant", DynamicQuantPrimitiveCreator); -RegistryMSOps g_RandomNormalCreatorRegistry("RandomNormal", RandomNormalPrimitiveCreator); -RegistryMSOps g_NLLLossCreatorRegistry("NLLLoss", NLLLossPrimitiveCreator); -RegistryMSOps g_NLLLossGradCreatorRegistry("NLLLossGrad", NLLLossGradPrimitiveCreator); -RegistryMSOps g_CustomPrimitiveCreatorRegistry("Custom", CustomPrimitiveCreator); -RegistryMSOps g_UniformRealPrimitiveCreatorRegistry("UniformReal", UniformRealPrimitiveCreator); -RegistryMSOps g_FormatTransposePrimitiveCreatorRegistry("FormatTranspose", FormatTransposePrimitiveCreator); +REG_MINDSPORE_OPERATOR(Abs) +REG_MINDSPORE_OPERATOR(Activation) +REG_MINDSPORE_OPERATOR(ActivationGrad) +REG_MINDSPORE_OPERATOR(Adam) +REG_MINDSPORE_OPERATOR(AddFusion) +REG_MINDSPORE_OPERATOR(AdderFusion) +REG_MINDSPORE_OPERATOR(AddGrad) +REG_MINDSPORE_OPERATOR(AddN) +REG_MINDSPORE_OPERATOR(All) +REG_MINDSPORE_OPERATOR(ApplyMomentum) +REG_MINDSPORE_OPERATOR(ArgMaxFusion) +REG_MINDSPORE_OPERATOR(ArgMinFusion) +REG_MINDSPORE_OPERATOR(Assert) +REG_MINDSPORE_OPERATOR(Assign) +REG_MINDSPORE_OPERATOR(AssignAdd) +REG_MINDSPORE_OPERATOR(AudioSpectrogram) +REG_MINDSPORE_OPERATOR(AvgPoolFusion) +REG_MINDSPORE_OPERATOR(AvgPoolGrad) +REG_MINDSPORE_OPERATOR(BatchNorm) +REG_MINDSPORE_OPERATOR(BatchNormGrad) +REG_MINDSPORE_OPERATOR(BatchToSpace) +REG_MINDSPORE_OPERATOR(BatchToSpaceND) +REG_MINDSPORE_OPERATOR(BiasAdd) +REG_MINDSPORE_OPERATOR(BinaryCrossEntropy) +REG_MINDSPORE_OPERATOR(BinaryCrossEntropyGrad) +REG_MINDSPORE_OPERATOR(BiasAddGrad) +REG_MINDSPORE_OPERATOR(BroadcastTo) +REG_MINDSPORE_OPERATOR(Cast) +REG_MINDSPORE_OPERATOR(Ceil) +REG_MINDSPORE_OPERATOR(Clip) +REG_MINDSPORE_OPERATOR(Concat) +REG_MINDSPORE_OPERATOR(Attention) +REG_MINDSPORE_OPERATOR(Conv2DBackpropFilterFusion) +REG_MINDSPORE_OPERATOR(Conv2DBackpropInputFusion) +REG_MINDSPORE_OPERATOR(Conv2DFusion) +REG_MINDSPORE_OPERATOR(Conv2dTransposeFusion) +REG_MINDSPORE_OPERATOR(Cos) +REG_MINDSPORE_OPERATOR(ConstantOfShape) +REG_MINDSPORE_OPERATOR(Crop) +REG_MINDSPORE_OPERATOR(CustomExtractFeatures) +REG_MINDSPORE_OPERATOR(CustomNormalize) +REG_MINDSPORE_OPERATOR(CustomPredict) +REG_MINDSPORE_OPERATOR(DeConv2DGradFilter) +REG_MINDSPORE_OPERATOR(Depend) +REG_MINDSPORE_OPERATOR(DepthToSpace) +REG_MINDSPORE_OPERATOR(DetectionPostProcess) +REG_MINDSPORE_OPERATOR(DivFusion) +REG_MINDSPORE_OPERATOR(DivGrad) +REG_MINDSPORE_OPERATOR(Dropout) +REG_MINDSPORE_OPERATOR(DropoutGrad) +REG_MINDSPORE_OPERATOR(Elu) +REG_MINDSPORE_OPERATOR(Eltwise) +REG_MINDSPORE_OPERATOR(Equal) +REG_MINDSPORE_OPERATOR(EmbeddingLookupFusion) +REG_MINDSPORE_OPERATOR(ExpFusion) +REG_MINDSPORE_OPERATOR(ExpandDims) +REG_MINDSPORE_OPERATOR(FakeQuantWithMinMaxVars) +REG_MINDSPORE_OPERATOR(FakeQuantWithMinMaxVarsPerChannel) +REG_MINDSPORE_OPERATOR(FftReal) +REG_MINDSPORE_OPERATOR(FftImag) +REG_MINDSPORE_OPERATOR(Flatten) +REG_MINDSPORE_OPERATOR(FlattenGrad) +REG_MINDSPORE_OPERATOR(Floor) +REG_MINDSPORE_OPERATOR(FloorDiv) +REG_MINDSPORE_OPERATOR(FloorMod) +REG_MINDSPORE_OPERATOR(Fill) +REG_MINDSPORE_OPERATOR(FullConnection) +REG_MINDSPORE_OPERATOR(FusedBatchNorm) +REG_MINDSPORE_OPERATOR(Gather) +REG_MINDSPORE_OPERATOR(GatherNd) +REG_MINDSPORE_OPERATOR(Greater) +REG_MINDSPORE_OPERATOR(GreaterEqual) +REG_MINDSPORE_OPERATOR(HashtableLookup) +REG_MINDSPORE_OPERATOR(InstanceNorm) +REG_MINDSPORE_OPERATOR(LayerNormFusion) +REG_MINDSPORE_OPERATOR(LeakyRelu) +REG_MINDSPORE_OPERATOR(Less) +REG_MINDSPORE_OPERATOR(LessEqual) +REG_MINDSPORE_OPERATOR(Log) +REG_MINDSPORE_OPERATOR(LogGrad) +REG_MINDSPORE_OPERATOR(LogicalAnd) +REG_MINDSPORE_OPERATOR(LogicalNot) +REG_MINDSPORE_OPERATOR(LogicalOr) +REG_MINDSPORE_OPERATOR(LpNormalization) +REG_MINDSPORE_OPERATOR(LRN) +REG_MINDSPORE_OPERATOR(LshProjection) +REG_MINDSPORE_OPERATOR(LSTM) +REG_MINDSPORE_OPERATOR(L2NormalizeFusion) +REG_MINDSPORE_OPERATOR(MatMulFusion) +REG_MINDSPORE_OPERATOR(Maximum) +REG_MINDSPORE_OPERATOR(MaximumGrad) +REG_MINDSPORE_OPERATOR(MaxPoolFusion) +REG_MINDSPORE_OPERATOR(MaxPoolGrad) +REG_MINDSPORE_OPERATOR(SwitchLayer) +REG_MINDSPORE_OPERATOR(Mfcc) +REG_MINDSPORE_OPERATOR(Minimum) +REG_MINDSPORE_OPERATOR(MinimumGrad) +REG_MINDSPORE_OPERATOR(Mod) +REG_MINDSPORE_OPERATOR(MulFusion) +REG_MINDSPORE_OPERATOR(MulGrad) +REG_MINDSPORE_OPERATOR(Neg) +REG_MINDSPORE_OPERATOR(NegGrad) +REG_MINDSPORE_OPERATOR(NotEqual) +REG_MINDSPORE_OPERATOR(NonMaxSuppression) +REG_MINDSPORE_OPERATOR(OneHot) +REG_MINDSPORE_OPERATOR(OnesLike) +REG_MINDSPORE_OPERATOR(PadFusion) +REG_MINDSPORE_OPERATOR(PartialFusion) +REG_MINDSPORE_OPERATOR(PowerGrad) +REG_MINDSPORE_OPERATOR(PowFusion) +REG_MINDSPORE_OPERATOR(PriorBox) +REG_MINDSPORE_OPERATOR(PReLUFusion) +REG_MINDSPORE_OPERATOR(QuantDTypeCast) +REG_MINDSPORE_OPERATOR(Rank) +REG_MINDSPORE_OPERATOR(Range) +REG_MINDSPORE_OPERATOR(Reciprocal) +REG_MINDSPORE_OPERATOR(RealDiv) +REG_MINDSPORE_OPERATOR(ReduceFusion) +REG_MINDSPORE_OPERATOR(Reshape) +REG_MINDSPORE_OPERATOR(Resize) +REG_MINDSPORE_OPERATOR(ReverseSequence) +REG_MINDSPORE_OPERATOR(ReverseV2) +REG_MINDSPORE_OPERATOR(Rfft) +REG_MINDSPORE_OPERATOR(ROIPooling) +REG_MINDSPORE_OPERATOR(Round) +REG_MINDSPORE_OPERATOR(Rsqrt) +REG_MINDSPORE_OPERATOR(ScaleFusion) +REG_MINDSPORE_OPERATOR(ScatterNd) +REG_MINDSPORE_OPERATOR(SGD) +REG_MINDSPORE_OPERATOR(Shape) +REG_MINDSPORE_OPERATOR(SigmoidCrossEntropyWithLogits) +REG_MINDSPORE_OPERATOR(SigmoidCrossEntropyWithLogitsGrad) +REG_MINDSPORE_OPERATOR(Sin) +REG_MINDSPORE_OPERATOR(SkipGram) +REG_MINDSPORE_OPERATOR(SliceFusion) +REG_MINDSPORE_OPERATOR(SmoothL1Loss) +REG_MINDSPORE_OPERATOR(SmoothL1LossGrad) +REG_MINDSPORE_OPERATOR(Softmax) +REG_MINDSPORE_OPERATOR(SoftmaxCrossEntropyWithLogits) +REG_MINDSPORE_OPERATOR(SpaceToBatch) +REG_MINDSPORE_OPERATOR(SpaceToBatchND) +REG_MINDSPORE_OPERATOR(SpaceToDepth) +REG_MINDSPORE_OPERATOR(SparseSoftmaxCrossEntropyWithLogits) +REG_MINDSPORE_OPERATOR(SparseToDense) +REG_MINDSPORE_OPERATOR(Split) +REG_MINDSPORE_OPERATOR(Sqrt) +REG_MINDSPORE_OPERATOR(Squeeze) +REG_MINDSPORE_OPERATOR(Square) +REG_MINDSPORE_OPERATOR(SquaredDifference) +REG_MINDSPORE_OPERATOR(Stack) +REG_MINDSPORE_OPERATOR(StridedSlice) +REG_MINDSPORE_OPERATOR(SubFusion) +REG_MINDSPORE_OPERATOR(SubGrad) +REG_MINDSPORE_OPERATOR(Switch) +REG_MINDSPORE_OPERATOR(TensorListFromTensor) +REG_MINDSPORE_OPERATOR(TensorListGetItem) +REG_MINDSPORE_OPERATOR(TensorListReserve) +REG_MINDSPORE_OPERATOR(TensorListSetItem) +REG_MINDSPORE_OPERATOR(TensorListStack) +REG_MINDSPORE_OPERATOR(TileFusion) +REG_MINDSPORE_OPERATOR(TopKFusion) +REG_MINDSPORE_OPERATOR(Transpose) +REG_MINDSPORE_OPERATOR(Unique) +REG_MINDSPORE_OPERATOR(UnsortedSegmentSum) +REG_MINDSPORE_OPERATOR(Unsqueeze) +REG_MINDSPORE_OPERATOR(Unstack) +REG_MINDSPORE_OPERATOR(LSTMGrad) +REG_MINDSPORE_OPERATOR(Where) +REG_MINDSPORE_OPERATOR(ZerosLike) +REG_MINDSPORE_OPERATOR(Select) +REG_MINDSPORE_OPERATOR(ScatterNdUpdate) +REG_MINDSPORE_OPERATOR(GRU) +REG_MINDSPORE_OPERATOR(NonZero) +REG_MINDSPORE_OPERATOR(InvertPermutation) +REG_MINDSPORE_OPERATOR(Size) +REG_MINDSPORE_OPERATOR(RandomStandardNormal) +REG_MINDSPORE_OPERATOR(CropAndResize) +REG_MINDSPORE_OPERATOR(Erf) +REG_MINDSPORE_OPERATOR(StridedSliceGrad) +REG_MINDSPORE_OPERATOR(IsFinite) +REG_MINDSPORE_OPERATOR(LinSpace) +REG_MINDSPORE_OPERATOR(UniformReal) +REG_MINDSPORE_OPERATOR(AbsGrad) +REG_MINDSPORE_OPERATOR(RsqrtGrad) +REG_MINDSPORE_OPERATOR(SqrtGrad) +REG_MINDSPORE_OPERATOR(LayerNormGrad) +REG_MINDSPORE_OPERATOR(ResizeGrad) +REG_MINDSPORE_OPERATOR(Splice) +REG_MINDSPORE_OPERATOR(LogSoftmax) +REG_MINDSPORE_OPERATOR(Call) +REG_MINDSPORE_OPERATOR(Custom) +REG_MINDSPORE_OPERATOR(CumSum) +REG_MINDSPORE_OPERATOR(SplitWithOverlap) +REG_MINDSPORE_OPERATOR(RaggedRange) +REG_MINDSPORE_OPERATOR(GLU) +REG_MINDSPORE_OPERATOR(TensorArray) +REG_MINDSPORE_OPERATOR(TensorArrayRead) +REG_MINDSPORE_OPERATOR(TensorArrayWrite) +REG_MINDSPORE_OPERATOR(Affine) +REG_MINDSPORE_OPERATOR(AllGather) +REG_MINDSPORE_OPERATOR(ReduceScatter) +REG_MINDSPORE_OPERATOR(DynamicQuant) +REG_MINDSPORE_OPERATOR(LSTMGradData) +REG_MINDSPORE_OPERATOR(LSTMGradWeight) +REG_MINDSPORE_OPERATOR(RandomNormal) +REG_MINDSPORE_OPERATOR(NLLLoss) +REG_MINDSPORE_OPERATOR(NLLLossGrad) +REG_MINDSPORE_OPERATOR(FormatTranspose) } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/ops_utils.h b/mindspore/lite/src/ops/ops_utils.h index 69a96e05600..7f5d48bba4a 100644 --- a/mindspore/lite/src/ops/ops_utils.h +++ b/mindspore/lite/src/ops/ops_utils.h @@ -20,14 +20,14 @@ #include #include #include +#include #include "src/ops/ops_func_declare.h" - #ifdef PRIMITIVE_WRITEABLE -#include "abstract/ops/primitive_infer_map.h" +#include "src/common/log_adapter.h" namespace mindspore { namespace lite { -typedef std::unique_ptr (*PrimitiveTCreator)(const AnfNodePtr &node); +typedef std::unique_ptr (*PrimitiveTCreator)(const PrimitivePtr &primitive); class MSOpsRegistry { public: @@ -35,12 +35,19 @@ class MSOpsRegistry { static MSOpsRegistry registry; return ®istry; } - void InsertPrimitiveTMap(const std::string &name, PrimitiveTCreator creator) { primitive_creators[name] = creator; } + void InsertPrimitiveTMap(const std::string &name, PrimitiveTCreator creator) { + std::string lower_name = name; + std::transform(name.begin(), name.end(), lower_name.begin(), ::tolower); + primitive_creators[lower_name] = creator; + } PrimitiveTCreator GetPrimitiveCreator(const std::string &name) { - if (primitive_creators.find(name) != primitive_creators.end()) { - return primitive_creators[name]; + std::string lower_name = name; + std::transform(name.begin(), name.end(), lower_name.begin(), ::tolower); + lower_name.erase(std::remove(lower_name.begin(), lower_name.end(), '_'), lower_name.end()); + if (primitive_creators.find(lower_name) != primitive_creators.end()) { + return primitive_creators[lower_name]; } else { - MS_LOG(WARNING) << "Unsupported primitive type in Create: " << name; + MS_LOG(ERROR) << "Unsupported primitive type in Create: " << name; return nullptr; } } @@ -57,7 +64,8 @@ class RegistryMSOps { ~RegistryMSOps() = default; }; -std::unique_ptr GetPrimitiveT(const mindspore::AnfNodePtr &node); +#define REG_MINDSPORE_OPERATOR(OP) \ + static RegistryMSOps g_##OP##PrimitiveCreatorRegistry(#OP, PrimitiveCreator); } // namespace lite } // namespace mindspore #endif diff --git a/mindspore/lite/tools/anf_exporter/anf_exporter.cc b/mindspore/lite/tools/anf_exporter/anf_exporter.cc index c2cea8e6a7b..560c8abe3e2 100644 --- a/mindspore/lite/tools/anf_exporter/anf_exporter.cc +++ b/mindspore/lite/tools/anf_exporter/anf_exporter.cc @@ -41,7 +41,7 @@ #include "src/common/utils.h" #include "tools/common/graph_util.h" #include "tools/common/meta_graph_utils.h" -#include "src/ops/ops_utils.h" +#include "src/ops/anf_utils.h" #include "src/weight_decoder.h" #include "tools/common/node_util.h" #include "src/common/log_util.h" diff --git a/mindspore/lite/tools/anf_exporter/fetch_content.cc b/mindspore/lite/tools/anf_exporter/fetch_content.cc index a49bb144410..3d927a42f7b 100644 --- a/mindspore/lite/tools/anf_exporter/fetch_content.cc +++ b/mindspore/lite/tools/anf_exporter/fetch_content.cc @@ -29,8 +29,9 @@ #include "tools/optimizer/common/format_utils.h" #include "nnacl/op_base.h" #include "tools/common/node_util.h" -#include "src/ops/ops_utils.h" +#include "src/ops/anf_utils.h" #include "src/ops/populate/populate_register.h" +#include "src/common/primitive_t_utils.h" #include "mindapi/base/format.h" #include "ops/op_utils.h" diff --git a/mindspore/lite/tools/common/node_util.cc b/mindspore/lite/tools/common/node_util.cc index d1bee1b0c28..187b0c0ab12 100644 --- a/mindspore/lite/tools/common/node_util.cc +++ b/mindspore/lite/tools/common/node_util.cc @@ -46,17 +46,6 @@ std::vector GetInputCNode(const CNodePtr &cnode) { return inputs; } -const schema::Primitive *ConvertToPrimitive(schema::PrimitiveT *primitive_t, flatbuffers::FlatBufferBuilder *fbb) { - if (primitive_t == nullptr || fbb == nullptr) { - MS_LOG(ERROR) << "primitiveT or fbb is nullptr."; - return nullptr; - } - auto prim_offset = schema::CreatePrimitive(*fbb, primitive_t); - fbb->Finish(prim_offset); - auto prim_buf = fbb->GetBufferPointer(); - return flatbuffers::GetRoot(prim_buf); -} - STATUS NodeUtils::ConvertDims(mindspore::schema::Format src_format, const std::vector &src_dims, mindspore::schema::Format dst_format, std::vector *dst_dims) { MS_ASSERT(dst_dims != nullptr); diff --git a/mindspore/lite/tools/common/node_util.h b/mindspore/lite/tools/common/node_util.h index 2bb826372ff..5c7d5583903 100644 --- a/mindspore/lite/tools/common/node_util.h +++ b/mindspore/lite/tools/common/node_util.h @@ -81,8 +81,6 @@ std::vector Getfp32FullOpList(); std::vector GetUint8NhwcOpList(); -const schema::Primitive *ConvertToPrimitive(schema::PrimitiveT *primitive_t, flatbuffers::FlatBufferBuilder *fbb); - size_t GetTensorInputIndexInCNode(const uint32_t &tensor_index, const schema::CNodeT &cnode); class NodeUtils { diff --git a/mindspore/lite/tools/converter/CMakeLists.txt b/mindspore/lite/tools/converter/CMakeLists.txt index d18483f2b41..457229bca74 100644 --- a/mindspore/lite/tools/converter/CMakeLists.txt +++ b/mindspore/lite/tools/converter/CMakeLists.txt @@ -67,6 +67,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} ${SRC_DIR}/common/dynamic_library_loader.cc ${SRC_DIR}/train/train_populate_parameter.cc ${SRC_DIR}/common/config_file.cc + ${SRC_DIR}/common/primitive_t_utils.cc ../optimizer/*.cc ) @@ -117,6 +118,7 @@ set(LITE_SRC ${API_SRC} ${RUNTIME_PASS_SRCS} ${SRC_DIR}/ops/ops_def.cc ${SRC_DIR}/ops/ops_utils.cc + ${SRC_DIR}/ops/anf_utils.cc ${SRC_DIR}/common/utils.cc ${SRC_DIR}/common/file_utils.cc ${SRC_DIR}/common/context_util.cc diff --git a/mindspore/lite/tools/converter/converter.cc b/mindspore/lite/tools/converter/converter.cc index 8c651ec4da3..d9b9d8e8c39 100644 --- a/mindspore/lite/tools/converter/converter.cc +++ b/mindspore/lite/tools/converter/converter.cc @@ -41,6 +41,7 @@ #include "tools/common/tensor_util.h" #include "include/api/model.h" #include "tools/mindir_serializer/mindir_serializer.h" +#include "src/common/primitive_t_utils.h" namespace mindspore { namespace lite { diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/infershape_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/infershape_pass.cc index 3431b187556..f5997acade2 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/infershape_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/infershape_pass.cc @@ -26,6 +26,7 @@ #include "src/common/prim_util.h" #include "src/ops/populate/populate_register.h" #include "src/runtime/infer_manager.h" +#include "src/common/primitive_t_utils.h" #include "tools/common/node_util.h" #include "tools/converter/converter_flags.h" #include "src/common/string_utils.h" diff --git a/mindspore/lite/tools/optimizer/const_fold/fold_utils.cc b/mindspore/lite/tools/optimizer/const_fold/fold_utils.cc index 8fab46300a8..8b11fabbc73 100644 --- a/mindspore/lite/tools/optimizer/const_fold/fold_utils.cc +++ b/mindspore/lite/tools/optimizer/const_fold/fold_utils.cc @@ -31,7 +31,7 @@ #include "src/kernel_registry.h" #include "src/inner_context.h" #include "src/tensor.h" -#include "src/ops/ops_utils.h" +#include "src/ops/anf_utils.h" #include "src/runtime/infer_manager.h" #include "tools/optimizer/graph/lite_tensor_extractor.h" diff --git a/mindspore/lite/tools/optimizer/fusion/norm_fusion.cc b/mindspore/lite/tools/optimizer/fusion/norm_fusion.cc index 06362b9b42d..d7288eb3229 100644 --- a/mindspore/lite/tools/optimizer/fusion/norm_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/norm_fusion.cc @@ -25,7 +25,7 @@ #include "tools/optimizer/common/gllo_utils.h" #include "securec/include/securec.h" #include "nnacl/op_base.h" -#include "src/ops/ops_utils.h" +#include "src/ops/anf_utils.h" #include "ops/op_utils.h" namespace mindspore { diff --git a/mindspore/lite/tools/optimizer/graph/node_infershape.cc b/mindspore/lite/tools/optimizer/graph/node_infershape.cc index ebbc3602f42..eb6de64e0e1 100644 --- a/mindspore/lite/tools/optimizer/graph/node_infershape.cc +++ b/mindspore/lite/tools/optimizer/graph/node_infershape.cc @@ -19,11 +19,12 @@ #include #include #include +#include "src/common/primitive_t_utils.h" #include "tools/common/node_util.h" #include "tools/common/tensor_util.h" #include "src/common/utils.h" #include "src/ops/populate/populate_register.h" -#include "src/ops/ops_utils.h" +#include "src/ops/anf_utils.h" #include "src/runtime/infer_manager.h" #include "src/tensorlist.h" #include "src/registry/kernel_interface_registry.h"