[MS][LITE] add KernelExecutor to run single operator.
This commit is contained in:
parent
c525beaaff
commit
5c6a70d8ad
|
@ -642,6 +642,20 @@ if(PLATFORM_ARM64)
|
|||
endif()
|
||||
endif()
|
||||
endif()
|
||||
if(MSLITE_ENABLE_KERNEL_EXECUTOR)
|
||||
install(DIRECTORY ${TOP_DIR}/mindspore/core/ops/ DESTINATION ${RUNTIME_INC_DIR}/core/ops
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h")
|
||||
install(DIRECTORY ${TOP_DIR}/mindspore/core/mindapi/ DESTINATION ${RUNTIME_INC_DIR}/core/mindapi
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h")
|
||||
install(FILES ${TOP_DIR}/mindspore/lite/src/cxx_api/kernel_executor/kernel_executor.h DESTINATION
|
||||
${RUNTIME_INC_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(TARGETS kernel_executor DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(TARGETS mindspore_core DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
if(MSLITE_ENABLE_CONVERTER)
|
||||
install(FILES ${glog_LIBPATH}/libglog.so.0.4.0 DESTINATION ${GLOG_DIR} RENAME libglog.so.0
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
endif()
|
||||
endif()
|
||||
elseif(PLATFORM_ARM32)
|
||||
if(SUPPORT_NPU)
|
||||
install(FILES ${DDK_LIB_PATH}/libhiai.so DESTINATION ${RUNTIME_DIR}/third_party/hiai_ddk/lib
|
||||
|
@ -998,6 +1012,20 @@ else()
|
|||
DESTINATION ${CROPPER_ROOT_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
endif()
|
||||
endif()
|
||||
if(MSLITE_ENABLE_KERNEL_EXECUTOR)
|
||||
install(DIRECTORY ${TOP_DIR}/mindspore/core/ops/ DESTINATION ${RUNTIME_INC_DIR}/core/ops
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h")
|
||||
install(DIRECTORY ${TOP_DIR}/mindspore/core/mindapi/ DESTINATION ${RUNTIME_INC_DIR}/core/mindapi
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h")
|
||||
install(FILES ${TOP_DIR}/mindspore/lite/src/cxx_api/kernel_executor/kernel_executor.h DESTINATION
|
||||
${RUNTIME_INC_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(TARGETS kernel_executor DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(TARGETS mindspore_core DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
if(MSLITE_ENABLE_CONVERTER)
|
||||
install(FILES ${glog_LIBPATH}/libglog.so.0.4.0 DESTINATION ${GLOG_DIR} RENAME libglog.so.0
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(CMAKE_SYSTEM_NAME MATCHES "Windows")
|
||||
|
|
|
@ -35,9 +35,9 @@ struct alignas(sizeof(T) * 2) ComplexStorage {
|
|||
ComplexStorage &operator=(ComplexStorage<T> &&other) noexcept = default;
|
||||
|
||||
inline constexpr ComplexStorage(const T &real, const T &imag = T()) : real_(real), imag_(imag) {}
|
||||
|
||||
#ifndef ENABLE_ARM
|
||||
inline explicit constexpr ComplexStorage(const float16 &real) : real_(static_cast<T>(real)), imag_(T()) {}
|
||||
|
||||
#endif
|
||||
template <typename U = T>
|
||||
explicit ComplexStorage(const std::enable_if_t<std::is_same<U, float>::value, ComplexStorage<double>> &other)
|
||||
: real_(other.real_), imag_(other.imag_) {}
|
||||
|
|
|
@ -25,10 +25,18 @@
|
|||
#include <thread>
|
||||
#include <vector>
|
||||
#include "utils/convert_utils_base.h"
|
||||
#ifdef ENABLE_ARM
|
||||
#if defined(__ANDROID__) || defined(ANDROID)
|
||||
#include <android/log.h>
|
||||
#endif
|
||||
#endif
|
||||
|
||||
// namespace to support utils module definition
|
||||
namespace mindspore {
|
||||
constexpr int kNameMaxLength = 18;
|
||||
#if defined(__ANDROID__) || defined(ANDROID)
|
||||
constexpr const char *ANDROID_LOG_TAG = "MS_LITE";
|
||||
#endif
|
||||
std::map<void **, std::thread *> acl_handle_map;
|
||||
// set default log level to WARNING for all sub modules
|
||||
int g_ms_submodule_log_levels[NUM_SUBMODUES] = {WARNING};
|
||||
|
@ -100,6 +108,38 @@ static int GetThresholdLevel(const std::string &threshold) {
|
|||
}
|
||||
}
|
||||
#undef google
|
||||
#elif defined(BUILD_CORE_RUNTIME)
|
||||
const char *EnumStrForMsLogLevel(MsLogLevel level) {
|
||||
if (level == MsLogLevel::DEBUG) {
|
||||
return "DEBUG";
|
||||
} else if (level == MsLogLevel::INFO) {
|
||||
return "INFO";
|
||||
} else if (level == MsLogLevel::WARNING) {
|
||||
return "WARNING";
|
||||
} else if (level == MsLogLevel::ERROR) {
|
||||
return "ERROR";
|
||||
} else {
|
||||
return "NO_LEVEL";
|
||||
}
|
||||
}
|
||||
#ifdef ENABLE_ARM
|
||||
#if defined(__ANDROID__) || defined(ANDROID)
|
||||
static int GetAndroidLogLevel(MsLogLevel level) {
|
||||
switch (level) {
|
||||
case MsLogLevel::DEBUG:
|
||||
return ANDROID_LOG_DEBUG;
|
||||
case MsLogLevel::INFO:
|
||||
return ANDROID_LOG_INFO;
|
||||
case MsLogLevel::WARNING:
|
||||
return ANDROID_LOG_WARN;
|
||||
case MsLogLevel::ERROR:
|
||||
default:
|
||||
return ANDROID_LOG_ERROR;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#else
|
||||
|
||||
#undef Dlog
|
||||
|
@ -153,6 +193,14 @@ void LogWriter::OutputLog(const std::ostringstream &msg) const {
|
|||
<< std::this_thread::get_id() << std::dec << "," << GetProcName() << "):" << GetTimeString() << " "
|
||||
<< "[" << location_.file_ << ":" << location_.line_ << "] " << location_.func_ << "] " << msg.str() << std::endl;
|
||||
#undef google
|
||||
#elif defined(BUILD_CORE_RUNTIME)
|
||||
#if defined(ENABLE_ARM) && (defined(__ANDROID__) || defined(ANDROID))
|
||||
__android_log_print(GetAndroidLogLevel(log_level_), ANDROID_LOG_TAG, "[%s:%d] %s] %s", location_.file_,
|
||||
location_.line_, location_.func_, msg.str().c_str());
|
||||
#else
|
||||
printf("%s [%s:%d] %s] %s\n", EnumStrForMsLogLevel(log_level_), location_.file_, location_.line_, location_.func_,
|
||||
msg.str().c_str());
|
||||
#endif
|
||||
#else
|
||||
auto str_msg = msg.str();
|
||||
auto slog_module_id = (submodule_ == SM_MD ? MD : ME);
|
||||
|
@ -166,7 +214,7 @@ void LogWriter::operator<(const LogStream &stream) const noexcept {
|
|||
msg << stream.sstream_->rdbuf();
|
||||
OutputLog(msg);
|
||||
}
|
||||
#ifndef BUILD_LITE_INFERENCE
|
||||
#if !defined(BUILD_LITE_INFERENCE) || defined(BUILD_CORE_RUNTIME)
|
||||
void LogWriter::operator^(const LogStream &stream) const {
|
||||
std::ostringstream msg;
|
||||
msg << stream.sstream_->rdbuf();
|
||||
|
|
|
@ -34,6 +34,7 @@
|
|||
#define google mindspore_private
|
||||
#include "glog/logging.h"
|
||||
#undef google
|
||||
#elif defined(BUILD_CORE_RUNTIME)
|
||||
#else
|
||||
#include "toolchain/slog.h"
|
||||
#endif
|
||||
|
@ -238,7 +239,7 @@ class MS_CORE_API LogWriter {
|
|||
/// \param[in] stream The input log stream.
|
||||
void operator<(const LogStream &stream) const noexcept;
|
||||
|
||||
#ifndef BUILD_LITE_INFERENCE
|
||||
#if !defined(BUILD_LITE_INFERENCE) || defined(BUILD_CORE_RUNTIME)
|
||||
/// \brief Output log message from the input log stream and then throw exception.
|
||||
///
|
||||
/// \param[in] stream The input log stream.
|
||||
|
@ -266,7 +267,7 @@ class MS_CORE_API LogWriter {
|
|||
: mindspore::LogWriter(mindspore::LocationInfo(FILE_NAME, __LINE__, __FUNCTION__), level, SUBMODULE_ID, \
|
||||
excp_type) < mindspore::LogStream()
|
||||
|
||||
#ifndef BUILD_LITE_INFERENCE
|
||||
#if !defined(BUILD_LITE_INFERENCE) || defined(BUILD_CORE_RUNTIME)
|
||||
#define MSLOG_THROW(excp_type) \
|
||||
mindspore::LogWriter(mindspore::LocationInfo(FILE_NAME, __LINE__, __FUNCTION__), mindspore::EXCEPTION, SUBMODULE_ID, \
|
||||
excp_type) ^ \
|
||||
|
|
|
@ -56,6 +56,7 @@ option(MSLITE_ENABLE_SHARING_MODEL_WEIGHT "enable sharing model weight" off)
|
|||
option(MSLITE_ENABLE_EXPERIMENTAL_KERNEL "enable experimental kernel" on)
|
||||
option(MSLITE_ENABLE_GRAPH_KERNEL "enable graph kernel" off)
|
||||
option(MSLITE_ENABLE_CONVERT_PYTORCH_MODEL "enable to convert pytorch model" off)
|
||||
option(MSLITE_ENABLE_KERNEL_EXECUTOR "enable kernel executor" off)
|
||||
|
||||
#Option that can be configured through manually
|
||||
option(ENABLE_VERBOSE "" off)
|
||||
|
@ -175,6 +176,9 @@ endif()
|
|||
if(DEFINED ENV{MSLITE_ENABLE_SERVING})
|
||||
set(MSLITE_ENABLE_SERVING $ENV{MSLITE_ENABLE_SERVING})
|
||||
endif()
|
||||
if(DEFINED ENV{MSLITE_ENABLE_KERNEL_EXECUTOR})
|
||||
set(MSLITE_ENABLE_KERNEL_EXECUTOR $ENV{MSLITE_ENABLE_KERNEL_EXECUTOR})
|
||||
endif()
|
||||
|
||||
if(DEFINED ENV{MSLITE_ENABLE_CONVERT_PYTORCH_MODEL} AND DEFINED ENV{LIB_TORCH_PATH})
|
||||
set(ENABLE_CONVERT_PYTORCH_MODEL $ENV{MSLITE_ENABLE_CONVERT_PYTORCH_MODEL})
|
||||
|
@ -404,6 +408,7 @@ message(STATUS "\tMSLITE_ENABLE_PARALLEL_INFERENCE = \t${MSLITE_ENABLE
|
|||
message(STATUS "\tMSLITE_ENABLE_SHARING_MODEL_WEIGHT = \t${MSLITE_ENABLE_SHARING_MODEL_WEIGHT}")
|
||||
message(STATUS "\tMSLITE_ENABLE_EXPERIMENTAL_KERNEL = \t${MSLITE_ENABLE_EXPERIMENTAL_KERNEL}")
|
||||
message(STATUS "\tMSLITE_ENABLE_GRAPH_KERNEL = \t${MSLITE_ENABLE_GRAPH_KERNEL}")
|
||||
message(STATUS "\tMSLITE_ENABLE_KERNEL_EXECUTOR = \t${MSLITE_ENABLE_KERNEL_EXECUTOR}")
|
||||
|
||||
if((MSLITE_ENABLE_CONVERTER OR MSLITE_ENABLE_TESTCASES) AND (
|
||||
NOT MSLITE_ENABLE_MINDRT
|
||||
|
@ -531,7 +536,7 @@ if(MSLITE_GPU_BACKEND STREQUAL opencl)
|
|||
endif()
|
||||
|
||||
if(MSLITE_ENABLE_CONVERTER OR MSLITE_MINDDATA_IMPLEMENT STREQUAL "full" OR MSLITE_MINDDATA_IMPLEMENT STREQUAL "wrapper"
|
||||
OR MSLITE_ENABLE_TOOLS)
|
||||
OR MSLITE_ENABLE_TOOLS OR MSLITE_ENABLE_KERNEL_EXECUTOR)
|
||||
if(NOT ENABLE_CLOUD_AND_LITE)
|
||||
include(${TOP_DIR}/cmake/external_libs/json.cmake)
|
||||
endif()
|
||||
|
@ -635,6 +640,13 @@ function(find_required_package pkg_name)
|
|||
endif()
|
||||
endfunction()
|
||||
|
||||
if(MSLITE_ENABLE_CONVERTER OR MSLITE_ENABLE_KERNEL_EXECUTOR)
|
||||
find_required_package(Patch)
|
||||
if(NOT ENABLE_CLOUD_AND_LITE)
|
||||
include(${TOP_DIR}/cmake/external_libs/protobuf.cmake)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(MSLITE_ENABLE_CONVERTER)
|
||||
if(ENABLE_FAST_HASH_TABLE)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DENABLE_FAST_HASH_TABLE=1")
|
||||
|
@ -645,7 +657,6 @@ if(MSLITE_ENABLE_CONVERTER)
|
|||
if(NOT ENABLE_CLOUD_AND_LITE)
|
||||
include(${TOP_DIR}/cmake/external_libs/opencv.cmake)
|
||||
include(${TOP_DIR}/cmake/external_libs/eigen.cmake)
|
||||
include(${TOP_DIR}/cmake/external_libs/protobuf.cmake)
|
||||
endif()
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/tools/converter)
|
||||
endif()
|
||||
|
|
|
@ -187,7 +187,7 @@ if(MSLITE_ENABLE_RUNTIME_GLOG)
|
|||
add_definitions(-DUSE_GLOG)
|
||||
string(REPLACE "-fno-rtti" "" CMAKE_C_FLAGS ${CMAKE_C_FLAGS})
|
||||
string(REPLACE "-fno-rtti" "" CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS})
|
||||
if(NOT MSLITE_ENABLE_RUNTIME_CONVERT)
|
||||
if(NOT MSLITE_ENABLE_RUNTIME_CONVERT AND NOT MSLITE_ENABLE_KERNEL_EXECUTOR)
|
||||
set(LITE_SRC ${LITE_SRC}
|
||||
${CORE_DIR}/utils/log_adapter.cc)
|
||||
endif()
|
||||
|
@ -199,6 +199,7 @@ if(MSLITE_ENABLE_RUNTIME_CONVERT)
|
|||
file(GLOB RUNTIME_CONVERT_SRC
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/ops/ops_def.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/ops/ops_utils.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/ops/anf_utils.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/runtime/runtime_convert.cc)
|
||||
|
||||
set(LITE_SRC ${LITE_SRC} ${RUNTIME_CONVERT_SRC})
|
||||
|
@ -466,6 +467,10 @@ if(SUPPORT_TRAIN)
|
|||
endif()
|
||||
endif()
|
||||
|
||||
if(MSLITE_ENABLE_KERNEL_EXECUTOR)
|
||||
add_subdirectory(cxx_api/kernel_executor)
|
||||
endif()
|
||||
|
||||
########################## build optimize and float16 library #################################
|
||||
if(PLATFORM_ARM)
|
||||
if(PLATFORM_ARM64 AND NOT TARGET_HIMIX AND NOT MACHINE_LINUX_ARM64)
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
#ifndef MINDSPORE_LITE_SRC_COMMON_LOG_ADAPTER_H_
|
||||
#define MINDSPORE_LITE_SRC_COMMON_LOG_ADAPTER_H_
|
||||
|
||||
#ifdef USE_GLOG
|
||||
#if defined(USE_GLOG) || defined(BUILD_CORE_RUNTIME)
|
||||
#include "utils/log_adapter.h"
|
||||
#else
|
||||
#include "src/common/log.h"
|
||||
|
|
|
@ -0,0 +1,74 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/common/primitive_t_utils.h"
|
||||
#include "src/ops/ops_utils.h"
|
||||
#include "ops/primitive_c.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
constexpr size_t INITIAL_SIZE = 1024;
|
||||
const schema::Primitive *ConvertToPrimitive(schema::PrimitiveT *primitive_t, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
if (primitive_t == nullptr || fbb == nullptr) {
|
||||
MS_LOG(ERROR) << "primitiveT or fbb is nullptr.";
|
||||
return nullptr;
|
||||
}
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, primitive_t);
|
||||
fbb->Finish(prim_offset);
|
||||
auto prim_buf = fbb->GetBufferPointer();
|
||||
return flatbuffers::GetRoot<schema::Primitive>(prim_buf);
|
||||
}
|
||||
|
||||
OpParameter *GetOpParameter(schema::PrimitiveT *primitive_t) {
|
||||
flatbuffers::FlatBufferBuilder fbb(INITIAL_SIZE);
|
||||
auto primitive = ConvertToPrimitive(primitive_t, &fbb);
|
||||
fbb.Clear();
|
||||
auto prim_type = GetPrimitiveType(primitive, SCHEMA_VERSION::SCHEMA_CUR);
|
||||
auto parame_gen = PopulateRegistry::GetInstance()->GetParameterCreator(prim_type, SCHEMA_VERSION::SCHEMA_CUR);
|
||||
if (parame_gen == nullptr) {
|
||||
MS_LOG(ERROR) << "parameter generator is nullptr.";
|
||||
return nullptr;
|
||||
}
|
||||
auto parameter = parame_gen(primitive);
|
||||
if (parameter == nullptr) {
|
||||
MS_LOG(ERROR) << "PopulateParameter return nullptr, type: "
|
||||
<< GetPrimitiveTypeName(primitive, SCHEMA_VERSION::SCHEMA_CUR);
|
||||
}
|
||||
return parameter;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::PrimitiveT> GetPrimitiveT(const std::shared_ptr<mindspore::ops::BaseOperator> &op) {
|
||||
if (op == nullptr) {
|
||||
MS_LOG(DEBUG) << "base operator is nullptr";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (op->name().empty()) {
|
||||
MS_LOG(ERROR) << "the name of operator is null";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
MS_LOG(DEBUG) << "export operator: " << op->name();
|
||||
auto creator = MSOpsRegistry::GetInstance()->GetPrimitiveCreator(op->name());
|
||||
if (creator != nullptr) {
|
||||
return creator(op->GetPrim());
|
||||
} else {
|
||||
MS_LOG(WARNING) << "can not find SingleOpRegistry for operator: " << op->name();
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,32 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_COMMON_PRIMITIVE_T_UTILS_H_
|
||||
#define MINDSPORE_LITE_SRC_COMMON_PRIMITIVE_T_UTILS_H_
|
||||
|
||||
#include <memory>
|
||||
#include "schema/inner/model_generated.h"
|
||||
#include "src/ops/populate/populate_register.h"
|
||||
#include "ops/base_operator.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
const schema::Primitive *ConvertToPrimitive(schema::PrimitiveT *primitive_t, flatbuffers::FlatBufferBuilder *fbb);
|
||||
OpParameter *GetOpParameter(schema::PrimitiveT *primitive_t);
|
||||
std::unique_ptr<schema::PrimitiveT> GetPrimitiveT(const std::shared_ptr<mindspore::ops::BaseOperator> &op);
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_SRC_COMMON_PRIMITIVE_T_UTILS_H_
|
|
@ -0,0 +1,33 @@
|
|||
add_compile_definitions(BUILD_CORE_RUNTIME)
|
||||
add_definitions(-DPRIMITIVE_WRITEABLE)
|
||||
if(MSLITE_ENABLE_RUNTIME_GLOG)
|
||||
set(USE_GLOG on)
|
||||
add_definitions(-DUSE_GLOG)
|
||||
endif()
|
||||
string(REPLACE "-fno-rtti" "" CMAKE_C_FLAGS ${CMAKE_C_FLAGS})
|
||||
string(REPLACE "-fno-rtti" "" CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS})
|
||||
string(REPLACE "-fno-exceptions" "" CMAKE_C_FLAGS ${CMAKE_C_FLAGS})
|
||||
string(REPLACE "-fno-exceptions" "" CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS})
|
||||
if(NOT MSLITE_ENABLE_CONVERTER)
|
||||
add_subdirectory(${CORE_DIR} mindspore_core)
|
||||
endif()
|
||||
|
||||
add_library(kernel_executor SHARED
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/kernel_executor.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/kernel_executor_impl.cc
|
||||
${TOP_DIR}/mindspore/lite/src/ops/ops_utils.cc
|
||||
${TOP_DIR}/mindspore/lite/src/common/primitive_t_utils.cc
|
||||
${TOP_DIR}/mindspore/lite/src/ops/ops_def.cc)
|
||||
|
||||
add_dependencies(kernel_executor fbs_inner_src fbs_src mindspore_core)
|
||||
|
||||
target_link_libraries(kernel_executor
|
||||
mindspore-lite
|
||||
mindspore_core
|
||||
mindspore::json
|
||||
mindspore::protobuf
|
||||
mindspore::flatbuffers)
|
||||
|
||||
if(USE_GLOG)
|
||||
target_link_libraries(kernel_executor mindspore::glog)
|
||||
endif()
|
|
@ -0,0 +1,61 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/cxx_api/kernel_executor/kernel_executor.h"
|
||||
#include "src/cxx_api/kernel_executor/kernel_executor_impl.h"
|
||||
|
||||
namespace mindspore {
|
||||
Status KernelExecutor::Build(const std::shared_ptr<ops::BaseOperator> &op, const std::vector<MSTensor> &inputs,
|
||||
const std::vector<MSTensor> &outputs, const std::shared_ptr<Context> &ms_context) {
|
||||
if (impl_ == nullptr) {
|
||||
impl_ = std::make_shared<KernelExecutorImpl>();
|
||||
if (impl_ == nullptr) {
|
||||
MS_LOG(ERROR) << "implement is null.";
|
||||
return kLiteNullptr;
|
||||
}
|
||||
}
|
||||
|
||||
Status ret = impl_->Build(op, inputs, outputs, ms_context);
|
||||
if (ret != kSuccess) {
|
||||
return ret;
|
||||
}
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
Status KernelExecutor::ReSize(const std::vector<MSTensor> &inputs, const std::vector<MSTensor> &outputs) {
|
||||
if (impl_ == nullptr) {
|
||||
MS_LOG(ERROR) << "implement is null.";
|
||||
return kLiteNullptr;
|
||||
}
|
||||
return impl_->ReSize(inputs, outputs);
|
||||
}
|
||||
|
||||
Status KernelExecutor::Infer(std::vector<MSTensor> *outputs) {
|
||||
if (impl_ == nullptr) {
|
||||
MS_LOG(ERROR) << "implement is null.";
|
||||
return kLiteNullptr;
|
||||
}
|
||||
return impl_->Infer(outputs);
|
||||
}
|
||||
|
||||
Status KernelExecutor::Execute(const std::vector<MSTensor> &inputs, const std::vector<MSTensor> &outputs) {
|
||||
if (impl_ == nullptr) {
|
||||
MS_LOG(ERROR) << "implement is null.";
|
||||
return kLiteNullptr;
|
||||
}
|
||||
return impl_->Execute(inputs, outputs);
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,73 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_CXX_API_KERNEL_EXECUTOR_KERNEL_EXECUTOR_H
|
||||
#define MINDSPORE_LITE_SRC_CXX_API_KERNEL_EXECUTOR_KERNEL_EXECUTOR_H
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "include/api/types.h"
|
||||
#include "include/api/status.h"
|
||||
#include "include/api/context.h"
|
||||
#include "ops/base_operator.h"
|
||||
|
||||
namespace mindspore {
|
||||
class KernelExecutorImpl;
|
||||
|
||||
class MS_API KernelExecutor {
|
||||
public:
|
||||
KernelExecutor() = default;
|
||||
~KernelExecutor() = default;
|
||||
|
||||
/// \brief Build a single operator so that it can run on a device.
|
||||
///
|
||||
/// \param[in] op Define an operator pointer.
|
||||
/// \param[in] ms_context Define the context used to store options during execution.
|
||||
/// \param[in] inputs A vector where single operator inputs are arranged in sequence.
|
||||
/// \param[in] outputs A vector where single operator outputs are arranged in sequence.
|
||||
///
|
||||
/// \return Status.
|
||||
Status Build(const std::shared_ptr<ops::BaseOperator> &op, const std::vector<MSTensor> &inputs,
|
||||
const std::vector<MSTensor> &outputs, const std::shared_ptr<Context> &ms_context);
|
||||
|
||||
/// \brief ReSize KernelExecutor.
|
||||
///
|
||||
/// \param[in] inputs A vector where single operator inputs are arranged in sequence.
|
||||
/// \param[in] outputs A vector where single operator outputs are arranged in sequence.
|
||||
///
|
||||
/// \return Status.
|
||||
Status ReSize(const std::vector<MSTensor> &inputs, const std::vector<MSTensor> &outputs);
|
||||
|
||||
/// \brief set outputs infer shape info.
|
||||
///
|
||||
/// \param[in] outputs A vector where single operator outputs are arranged in sequence.
|
||||
///
|
||||
/// \return Status.
|
||||
Status Infer(std::vector<MSTensor> *outputs);
|
||||
|
||||
/// \brief ReSize KernelExecutor.
|
||||
///
|
||||
/// \param[in] inputs A vector where single operator inputs are arranged in sequence.
|
||||
/// \param[in] outputs A vector where single operator outputs are arranged in sequence.
|
||||
///
|
||||
/// \return Status.
|
||||
Status Execute(const std::vector<MSTensor> &inputs, const std::vector<MSTensor> &outputs);
|
||||
|
||||
private:
|
||||
std::shared_ptr<KernelExecutorImpl> impl_ = nullptr;
|
||||
};
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_SRC_CXX_API_KERNEL_EXECUTOR_KERNEL_EXECUTOR_H
|
|
@ -0,0 +1,253 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <algorithm>
|
||||
#include "src/ops/ops_utils.h"
|
||||
#include "src/cxx_api/converters.h"
|
||||
#include "src/common/prim_util.h"
|
||||
#include "src/ops/populate/populate_register.h"
|
||||
#include "src/common/primitive_t_utils.h"
|
||||
#include "schema/inner/model_generated.h"
|
||||
#include "src/runtime/infer_manager.h"
|
||||
#include "src/kernel_registry.h"
|
||||
#include "src/cxx_api/kernel_executor/kernel_executor_impl.h"
|
||||
|
||||
namespace mindspore {
|
||||
constexpr size_t INITIAL_SIZE = 1024;
|
||||
|
||||
KernelExecutorImpl::~KernelExecutorImpl() {
|
||||
if (context_ != nullptr) {
|
||||
delete context_;
|
||||
context_ = nullptr;
|
||||
}
|
||||
|
||||
if (kernel_ != nullptr) {
|
||||
delete kernel_;
|
||||
kernel_ = nullptr;
|
||||
}
|
||||
FreeInOutTensor();
|
||||
}
|
||||
|
||||
Status KernelExecutorImpl::Build(const std::shared_ptr<ops::BaseOperator> &op, const std::vector<MSTensor> &inputs,
|
||||
const std::vector<MSTensor> &outputs, const std::shared_ptr<Context> &ms_context) {
|
||||
data_type_ = static_cast<enum TypeId>(inputs[FIRST_INPUT].DataType());
|
||||
std::unique_ptr<mindspore::schema::PrimitiveT> prim_t = lite::GetPrimitiveT(op);
|
||||
flatbuffers::FlatBufferBuilder fbb(INITIAL_SIZE);
|
||||
primitive_ = lite::ConvertToPrimitive(prim_t.get(), &fbb);
|
||||
fbb.Clear();
|
||||
if (primitive_ == nullptr) {
|
||||
MS_LOG(ERROR) << "convert to primitive nullptr.";
|
||||
return kLiteNullptr;
|
||||
}
|
||||
prim_type_ = lite::GetPrimitiveType(primitive_, schema_version_);
|
||||
|
||||
context_ = ContextUtils::Convert(ms_context.get());
|
||||
if (context_ == nullptr) {
|
||||
MS_LOG(ERROR) << "failed to convert Context to LiteContext.";
|
||||
return kLiteNullptr;
|
||||
}
|
||||
int ret = context_->Init();
|
||||
if (ret != RET_OK) {
|
||||
return static_cast<StatusCode>(ret);
|
||||
}
|
||||
|
||||
Status status = InitInOutTensor(inputs, outputs);
|
||||
if (status != kSuccess) {
|
||||
MS_LOG(ERROR) << "InitInOutTensor error.";
|
||||
return status;
|
||||
}
|
||||
|
||||
if (prim_type_ == schema::PrimitiveType_Custom) {
|
||||
status = GetCustomKernel(ms_context);
|
||||
} else {
|
||||
status = GetCpuKernel(ms_context);
|
||||
}
|
||||
|
||||
if (status != kSuccess) {
|
||||
MS_LOG(ERROR) << "get kernel error.";
|
||||
return status;
|
||||
}
|
||||
ret = kernel_->Prepare();
|
||||
return static_cast<StatusCode>(ret);
|
||||
}
|
||||
|
||||
Status KernelExecutorImpl::ReSize(const std::vector<MSTensor> &inputs, const std::vector<MSTensor> &outputs) {
|
||||
Status status = InitInOutTensor(inputs, outputs);
|
||||
if (status != kSuccess) {
|
||||
MS_LOG(ERROR) << "InitInOutTensor error.";
|
||||
return status;
|
||||
}
|
||||
kernel_->set_in_tensors(inputs_);
|
||||
kernel_->set_out_tensors(outputs_);
|
||||
int ret;
|
||||
if (kernel_->type() == schema::PrimitiveType_Custom) {
|
||||
ret = KernelInferShape(inputs_, outputs_, primitive_, context_->GetProviders(), schema_version_);
|
||||
} else {
|
||||
ret = KernelInferShape(inputs_, outputs_, parameter_);
|
||||
}
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "do infer shape error.";
|
||||
return static_cast<StatusCode>(ret);
|
||||
}
|
||||
ret = kernel_->ReSize();
|
||||
return static_cast<StatusCode>(ret);
|
||||
}
|
||||
Status KernelExecutorImpl::Infer(std::vector<MSTensor> *outputs) {
|
||||
for (size_t i = 0; i < outputs->size(); ++i) {
|
||||
auto user_output = outputs->at(i);
|
||||
auto output = outputs_[i];
|
||||
user_output.SetFormat(output->format());
|
||||
auto output_shape = output->shape();
|
||||
std::vector<int64_t> shape;
|
||||
std::transform(output_shape.begin(), output_shape.end(), std::back_inserter(shape),
|
||||
[](auto s) { return static_cast<int64_t>(s); });
|
||||
user_output.SetShape(shape);
|
||||
}
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
Status KernelExecutorImpl::Execute(const std::vector<MSTensor> &inputs, const std::vector<MSTensor> &outputs) {
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
auto user_input = inputs[i];
|
||||
auto input = inputs_[i];
|
||||
input->set_data(user_input.MutableData());
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < outputs.size(); ++i) {
|
||||
auto user_output = outputs[i];
|
||||
auto output = outputs_[i];
|
||||
output->set_data(user_output.MutableData());
|
||||
}
|
||||
int ret = kernel_->Execute();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "execute error.";
|
||||
return static_cast<StatusCode>(ret);
|
||||
}
|
||||
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
Status KernelExecutorImpl::GetOpParameter() {
|
||||
auto parame_gen = lite::PopulateRegistry::GetInstance()->GetParameterCreator(prim_type_, schema_version_);
|
||||
if (parame_gen == nullptr) {
|
||||
MS_LOG(ERROR) << "parameter generator is nullptr.";
|
||||
return kLiteNullptr;
|
||||
}
|
||||
parameter_ = parame_gen(primitive_);
|
||||
if (parameter_ == nullptr) {
|
||||
MS_LOG(ERROR) << "PopulateParameter return nullptr, type: "
|
||||
<< lite::GetPrimitiveTypeName(primitive_, schema_version_);
|
||||
return kLiteNullptr;
|
||||
}
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
Status KernelExecutorImpl::GetCustomKernel(const std::shared_ptr<Context> &ms_context) {
|
||||
int get_kernel = lite::RET_ERROR;
|
||||
|
||||
// find kernel match arch, data_type, kernel_arch and provider
|
||||
for (auto &&device : context_->device_list_) {
|
||||
if (!device.provider_.empty() && !device.provider_device_.empty()) {
|
||||
kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, data_type_, prim_type_, device.provider_device_,
|
||||
device.provider_};
|
||||
get_kernel = lite::KernelRegistry::GetInstance()->GetKernel(inputs_, outputs_, context_, ms_context.get(), desc,
|
||||
nullptr, &kernel_, primitive_);
|
||||
}
|
||||
}
|
||||
|
||||
// find kernel only match arch and data_type
|
||||
if (get_kernel != RET_OK) {
|
||||
kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, data_type_, prim_type_, "", ""};
|
||||
get_kernel = lite::KernelRegistry::GetInstance()->GetKernel(inputs_, outputs_, context_, ms_context.get(), desc,
|
||||
nullptr, &kernel_, primitive_);
|
||||
}
|
||||
|
||||
// if found kernel, do infershape
|
||||
if (get_kernel == RET_OK) {
|
||||
int ret = KernelInferShape(inputs_, outputs_, primitive_, context_->GetProviders(), schema_version_);
|
||||
return static_cast<StatusCode>(ret);
|
||||
}
|
||||
|
||||
return static_cast<StatusCode>(get_kernel);
|
||||
}
|
||||
|
||||
Status KernelExecutorImpl::GetCpuKernel(const std::shared_ptr<Context> &ms_context) {
|
||||
Status status = GetOpParameter();
|
||||
if (status != kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, data_type_, prim_type_};
|
||||
int get_kernel = lite::KernelRegistry::GetInstance()->GetKernel(inputs_, outputs_, context_, ms_context.get(), desc,
|
||||
parameter_, &kernel_);
|
||||
if (get_kernel == RET_OK) {
|
||||
int ret = KernelInferShape(inputs_, outputs_, parameter_);
|
||||
return static_cast<StatusCode>(ret);
|
||||
}
|
||||
|
||||
return static_cast<StatusCode>(get_kernel);
|
||||
}
|
||||
|
||||
void KernelExecutorImpl::FreeInOutTensor() {
|
||||
for (auto &input : inputs_) {
|
||||
if (input != nullptr) {
|
||||
delete input;
|
||||
input = nullptr;
|
||||
}
|
||||
}
|
||||
inputs_.clear();
|
||||
for (auto &output : outputs_) {
|
||||
if (output != nullptr) {
|
||||
delete output;
|
||||
output = nullptr;
|
||||
}
|
||||
}
|
||||
outputs_.clear();
|
||||
}
|
||||
|
||||
Status KernelExecutorImpl::InitInOutTensor(const std::vector<MSTensor> &inputs, const std::vector<MSTensor> &outputs) {
|
||||
FreeInOutTensor();
|
||||
for (auto input : inputs) {
|
||||
auto input_shape = input.Shape();
|
||||
std::vector<int> shape;
|
||||
std::transform(input_shape.begin(), input_shape.end(), std::back_inserter(shape),
|
||||
[](auto s) { return static_cast<int>(s); });
|
||||
lite::Tensor *input_tensor = new (std::nothrow)
|
||||
lite::Tensor(static_cast<enum TypeId>(input.DataType()), shape, input.format(), lite::Category::GRAPH_INPUT);
|
||||
if (input_tensor == nullptr) {
|
||||
delete input_tensor;
|
||||
return kLiteNullptr;
|
||||
}
|
||||
input_tensor->set_data(input.MutableData());
|
||||
inputs_.emplace_back(input_tensor);
|
||||
}
|
||||
|
||||
for (auto output : outputs) {
|
||||
auto output_shape = output.Shape();
|
||||
std::vector<int> shape;
|
||||
std::transform(output_shape.begin(), output_shape.end(), std::back_inserter(shape),
|
||||
[](auto s) { return static_cast<int>(s); });
|
||||
lite::Tensor *output_tensor =
|
||||
new (std::nothrow) lite::Tensor(static_cast<enum TypeId>(output.DataType()), shape, output.format());
|
||||
if (output_tensor == nullptr) {
|
||||
delete output_tensor;
|
||||
return kLiteNullptr;
|
||||
}
|
||||
outputs_.emplace_back(output_tensor);
|
||||
}
|
||||
return kSuccess;
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,56 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_CXX_API_KERNEL_EXECUTOR_KERNEL_EXECUTOR_IMPL_H
|
||||
#define MINDSPORE_LITE_SRC_CXX_API_KERNEL_EXECUTOR_KERNEL_EXECUTOR_IMPL_H
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "src/cxx_api/kernel_executor/kernel_executor.h"
|
||||
#include "src/kernel_exec.h"
|
||||
#include "common/version_manager.h"
|
||||
|
||||
namespace mindspore {
|
||||
class KernelExecutorImpl {
|
||||
public:
|
||||
KernelExecutorImpl() = default;
|
||||
~KernelExecutorImpl();
|
||||
Status Build(const std::shared_ptr<ops::BaseOperator> &op, const std::vector<MSTensor> &inputs,
|
||||
const std::vector<MSTensor> &outputs, const std::shared_ptr<Context> &ms_context);
|
||||
Status ReSize(const std::vector<MSTensor> &inputs, const std::vector<MSTensor> &outputs);
|
||||
Status Infer(std::vector<MSTensor> *outputs);
|
||||
Status Execute(const std::vector<MSTensor> &inputs, const std::vector<MSTensor> &outputs);
|
||||
|
||||
protected:
|
||||
Status GetCustomKernel(const std::shared_ptr<Context> &ms_context);
|
||||
Status GetCpuKernel(const std::shared_ptr<Context> &ms_context);
|
||||
Status GetOpParameter();
|
||||
Status InitInOutTensor(const std::vector<MSTensor> &inputs, const std::vector<MSTensor> &outputs);
|
||||
void FreeInOutTensor();
|
||||
|
||||
private:
|
||||
const schema::Primitive *primitive_ = nullptr;
|
||||
int prim_type_;
|
||||
OpParameter *parameter_ = nullptr;
|
||||
lite::InnerContext *context_ = nullptr;
|
||||
TypeId data_type_;
|
||||
kernel::KernelExec *kernel_ = nullptr;
|
||||
std::vector<lite::Tensor *> inputs_;
|
||||
std::vector<lite::Tensor *> outputs_;
|
||||
int schema_version_ = lite::SCHEMA_VERSION::SCHEMA_CUR;
|
||||
};
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_SRC_CXX_API_KERNEL_EXECUTOR_KERNEL_EXECUTOR_IMPL_H
|
|
@ -0,0 +1,44 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/ops/anf_utils.h"
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
std::unique_ptr<schema::PrimitiveT> GetPrimitiveT(const AnfNodePtr &node) {
|
||||
auto prim = GetValueNode<std::shared_ptr<Primitive>>(node);
|
||||
if (prim == nullptr) {
|
||||
MS_LOG(DEBUG) << "primitive is nullptr";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (prim->name().empty()) {
|
||||
MS_LOG(ERROR) << "the name of primitive is null";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
MS_LOG(DEBUG) << "export prim: " << prim->name();
|
||||
auto creator = MSOpsRegistry::GetInstance()->GetPrimitiveCreator(prim->name());
|
||||
if (creator != nullptr) {
|
||||
return creator(prim);
|
||||
} else {
|
||||
MS_LOG(WARNING) << "can not find MSOpsRegistry for op: " << prim->name();
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif
|
|
@ -0,0 +1,29 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_SRC_OPS_ANF_UTILS_H_
|
||||
#define MINDSPORE_LITE_SRC_OPS_ANF_UTILS_H_
|
||||
|
||||
#include <memory>
|
||||
#include "src/ops/ops_utils.h"
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
std::unique_ptr<schema::PrimitiveT> GetPrimitiveT(const mindspore::AnfNodePtr &node);
|
||||
}
|
||||
} // namespace mindspore
|
||||
#endif
|
||||
#endif // MINDSPORE_LITE_SRC_OPS_ANF_UTILS_H_
|
File diff suppressed because it is too large
Load Diff
|
@ -20,14 +20,14 @@
|
|||
#include <map>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
#include "src/ops/ops_func_declare.h"
|
||||
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
typedef std::unique_ptr<schema::PrimitiveT> (*PrimitiveTCreator)(const AnfNodePtr &node);
|
||||
typedef std::unique_ptr<schema::PrimitiveT> (*PrimitiveTCreator)(const PrimitivePtr &primitive);
|
||||
|
||||
class MSOpsRegistry {
|
||||
public:
|
||||
|
@ -35,12 +35,19 @@ class MSOpsRegistry {
|
|||
static MSOpsRegistry registry;
|
||||
return ®istry;
|
||||
}
|
||||
void InsertPrimitiveTMap(const std::string &name, PrimitiveTCreator creator) { primitive_creators[name] = creator; }
|
||||
void InsertPrimitiveTMap(const std::string &name, PrimitiveTCreator creator) {
|
||||
std::string lower_name = name;
|
||||
std::transform(name.begin(), name.end(), lower_name.begin(), ::tolower);
|
||||
primitive_creators[lower_name] = creator;
|
||||
}
|
||||
PrimitiveTCreator GetPrimitiveCreator(const std::string &name) {
|
||||
if (primitive_creators.find(name) != primitive_creators.end()) {
|
||||
return primitive_creators[name];
|
||||
std::string lower_name = name;
|
||||
std::transform(name.begin(), name.end(), lower_name.begin(), ::tolower);
|
||||
lower_name.erase(std::remove(lower_name.begin(), lower_name.end(), '_'), lower_name.end());
|
||||
if (primitive_creators.find(lower_name) != primitive_creators.end()) {
|
||||
return primitive_creators[lower_name];
|
||||
} else {
|
||||
MS_LOG(WARNING) << "Unsupported primitive type in Create: " << name;
|
||||
MS_LOG(ERROR) << "Unsupported primitive type in Create: " << name;
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
@ -57,7 +64,8 @@ class RegistryMSOps {
|
|||
~RegistryMSOps() = default;
|
||||
};
|
||||
|
||||
std::unique_ptr<schema::PrimitiveT> GetPrimitiveT(const mindspore::AnfNodePtr &node);
|
||||
#define REG_MINDSPORE_OPERATOR(OP) \
|
||||
static RegistryMSOps g_##OP##PrimitiveCreatorRegistry(#OP, PrimitiveCreator<mindspore::ops::OP>);
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif
|
||||
|
|
|
@ -41,7 +41,7 @@
|
|||
#include "src/common/utils.h"
|
||||
#include "tools/common/graph_util.h"
|
||||
#include "tools/common/meta_graph_utils.h"
|
||||
#include "src/ops/ops_utils.h"
|
||||
#include "src/ops/anf_utils.h"
|
||||
#include "src/weight_decoder.h"
|
||||
#include "tools/common/node_util.h"
|
||||
#include "src/common/log_util.h"
|
||||
|
|
|
@ -29,8 +29,9 @@
|
|||
#include "tools/optimizer/common/format_utils.h"
|
||||
#include "nnacl/op_base.h"
|
||||
#include "tools/common/node_util.h"
|
||||
#include "src/ops/ops_utils.h"
|
||||
#include "src/ops/anf_utils.h"
|
||||
#include "src/ops/populate/populate_register.h"
|
||||
#include "src/common/primitive_t_utils.h"
|
||||
#include "mindapi/base/format.h"
|
||||
#include "ops/op_utils.h"
|
||||
|
||||
|
|
|
@ -46,17 +46,6 @@ std::vector<CNodePtr> GetInputCNode(const CNodePtr &cnode) {
|
|||
return inputs;
|
||||
}
|
||||
|
||||
const schema::Primitive *ConvertToPrimitive(schema::PrimitiveT *primitive_t, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
if (primitive_t == nullptr || fbb == nullptr) {
|
||||
MS_LOG(ERROR) << "primitiveT or fbb is nullptr.";
|
||||
return nullptr;
|
||||
}
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, primitive_t);
|
||||
fbb->Finish(prim_offset);
|
||||
auto prim_buf = fbb->GetBufferPointer();
|
||||
return flatbuffers::GetRoot<schema::Primitive>(prim_buf);
|
||||
}
|
||||
|
||||
STATUS NodeUtils::ConvertDims(mindspore::schema::Format src_format, const std::vector<int32_t> &src_dims,
|
||||
mindspore::schema::Format dst_format, std::vector<int32_t> *dst_dims) {
|
||||
MS_ASSERT(dst_dims != nullptr);
|
||||
|
|
|
@ -81,8 +81,6 @@ std::vector<schema::PrimitiveType> Getfp32FullOpList();
|
|||
|
||||
std::vector<schema::PrimitiveType> GetUint8NhwcOpList();
|
||||
|
||||
const schema::Primitive *ConvertToPrimitive(schema::PrimitiveT *primitive_t, flatbuffers::FlatBufferBuilder *fbb);
|
||||
|
||||
size_t GetTensorInputIndexInCNode(const uint32_t &tensor_index, const schema::CNodeT &cnode);
|
||||
|
||||
class NodeUtils {
|
||||
|
|
|
@ -67,6 +67,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
|||
${SRC_DIR}/common/dynamic_library_loader.cc
|
||||
${SRC_DIR}/train/train_populate_parameter.cc
|
||||
${SRC_DIR}/common/config_file.cc
|
||||
${SRC_DIR}/common/primitive_t_utils.cc
|
||||
../optimizer/*.cc
|
||||
)
|
||||
|
||||
|
@ -117,6 +118,7 @@ set(LITE_SRC ${API_SRC}
|
|||
${RUNTIME_PASS_SRCS}
|
||||
${SRC_DIR}/ops/ops_def.cc
|
||||
${SRC_DIR}/ops/ops_utils.cc
|
||||
${SRC_DIR}/ops/anf_utils.cc
|
||||
${SRC_DIR}/common/utils.cc
|
||||
${SRC_DIR}/common/file_utils.cc
|
||||
${SRC_DIR}/common/context_util.cc
|
||||
|
|
|
@ -41,6 +41,7 @@
|
|||
#include "tools/common/tensor_util.h"
|
||||
#include "include/api/model.h"
|
||||
#include "tools/mindir_serializer/mindir_serializer.h"
|
||||
#include "src/common/primitive_t_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
|
|
@ -26,6 +26,7 @@
|
|||
#include "src/common/prim_util.h"
|
||||
#include "src/ops/populate/populate_register.h"
|
||||
#include "src/runtime/infer_manager.h"
|
||||
#include "src/common/primitive_t_utils.h"
|
||||
#include "tools/common/node_util.h"
|
||||
#include "tools/converter/converter_flags.h"
|
||||
#include "src/common/string_utils.h"
|
||||
|
|
|
@ -31,7 +31,7 @@
|
|||
#include "src/kernel_registry.h"
|
||||
#include "src/inner_context.h"
|
||||
#include "src/tensor.h"
|
||||
#include "src/ops/ops_utils.h"
|
||||
#include "src/ops/anf_utils.h"
|
||||
#include "src/runtime/infer_manager.h"
|
||||
#include "tools/optimizer/graph/lite_tensor_extractor.h"
|
||||
|
||||
|
|
|
@ -25,7 +25,7 @@
|
|||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "securec/include/securec.h"
|
||||
#include "nnacl/op_base.h"
|
||||
#include "src/ops/ops_utils.h"
|
||||
#include "src/ops/anf_utils.h"
|
||||
#include "ops/op_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
|
|
@ -19,11 +19,12 @@
|
|||
#include <memory>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include "src/common/primitive_t_utils.h"
|
||||
#include "tools/common/node_util.h"
|
||||
#include "tools/common/tensor_util.h"
|
||||
#include "src/common/utils.h"
|
||||
#include "src/ops/populate/populate_register.h"
|
||||
#include "src/ops/ops_utils.h"
|
||||
#include "src/ops/anf_utils.h"
|
||||
#include "src/runtime/infer_manager.h"
|
||||
#include "src/tensorlist.h"
|
||||
#include "src/registry/kernel_interface_registry.h"
|
||||
|
|
Loading…
Reference in New Issue