[MS][LITE] add KernelExecutor to run single operator.

This commit is contained in:
wangpingan2 2022-03-28 15:35:58 +08:00
parent c525beaaff
commit 5c6a70d8ad
28 changed files with 1003 additions and 1123 deletions

View File

@ -642,6 +642,20 @@ if(PLATFORM_ARM64)
endif()
endif()
endif()
if(MSLITE_ENABLE_KERNEL_EXECUTOR)
install(DIRECTORY ${TOP_DIR}/mindspore/core/ops/ DESTINATION ${RUNTIME_INC_DIR}/core/ops
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h")
install(DIRECTORY ${TOP_DIR}/mindspore/core/mindapi/ DESTINATION ${RUNTIME_INC_DIR}/core/mindapi
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h")
install(FILES ${TOP_DIR}/mindspore/lite/src/cxx_api/kernel_executor/kernel_executor.h DESTINATION
${RUNTIME_INC_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
install(TARGETS kernel_executor DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
install(TARGETS mindspore_core DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
if(MSLITE_ENABLE_CONVERTER)
install(FILES ${glog_LIBPATH}/libglog.so.0.4.0 DESTINATION ${GLOG_DIR} RENAME libglog.so.0
COMPONENT ${RUNTIME_COMPONENT_NAME})
endif()
endif()
elseif(PLATFORM_ARM32)
if(SUPPORT_NPU)
install(FILES ${DDK_LIB_PATH}/libhiai.so DESTINATION ${RUNTIME_DIR}/third_party/hiai_ddk/lib
@ -998,6 +1012,20 @@ else()
DESTINATION ${CROPPER_ROOT_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
endif()
endif()
if(MSLITE_ENABLE_KERNEL_EXECUTOR)
install(DIRECTORY ${TOP_DIR}/mindspore/core/ops/ DESTINATION ${RUNTIME_INC_DIR}/core/ops
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h")
install(DIRECTORY ${TOP_DIR}/mindspore/core/mindapi/ DESTINATION ${RUNTIME_INC_DIR}/core/mindapi
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h")
install(FILES ${TOP_DIR}/mindspore/lite/src/cxx_api/kernel_executor/kernel_executor.h DESTINATION
${RUNTIME_INC_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
install(TARGETS kernel_executor DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
install(TARGETS mindspore_core DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
if(MSLITE_ENABLE_CONVERTER)
install(FILES ${glog_LIBPATH}/libglog.so.0.4.0 DESTINATION ${GLOG_DIR} RENAME libglog.so.0
COMPONENT ${RUNTIME_COMPONENT_NAME})
endif()
endif()
endif()
if(CMAKE_SYSTEM_NAME MATCHES "Windows")

View File

@ -35,9 +35,9 @@ struct alignas(sizeof(T) * 2) ComplexStorage {
ComplexStorage &operator=(ComplexStorage<T> &&other) noexcept = default;
inline constexpr ComplexStorage(const T &real, const T &imag = T()) : real_(real), imag_(imag) {}
#ifndef ENABLE_ARM
inline explicit constexpr ComplexStorage(const float16 &real) : real_(static_cast<T>(real)), imag_(T()) {}
#endif
template <typename U = T>
explicit ComplexStorage(const std::enable_if_t<std::is_same<U, float>::value, ComplexStorage<double>> &other)
: real_(other.real_), imag_(other.imag_) {}

View File

@ -25,10 +25,18 @@
#include <thread>
#include <vector>
#include "utils/convert_utils_base.h"
#ifdef ENABLE_ARM
#if defined(__ANDROID__) || defined(ANDROID)
#include <android/log.h>
#endif
#endif
// namespace to support utils module definition
namespace mindspore {
constexpr int kNameMaxLength = 18;
#if defined(__ANDROID__) || defined(ANDROID)
constexpr const char *ANDROID_LOG_TAG = "MS_LITE";
#endif
std::map<void **, std::thread *> acl_handle_map;
// set default log level to WARNING for all sub modules
int g_ms_submodule_log_levels[NUM_SUBMODUES] = {WARNING};
@ -100,6 +108,38 @@ static int GetThresholdLevel(const std::string &threshold) {
}
}
#undef google
#elif defined(BUILD_CORE_RUNTIME)
const char *EnumStrForMsLogLevel(MsLogLevel level) {
if (level == MsLogLevel::DEBUG) {
return "DEBUG";
} else if (level == MsLogLevel::INFO) {
return "INFO";
} else if (level == MsLogLevel::WARNING) {
return "WARNING";
} else if (level == MsLogLevel::ERROR) {
return "ERROR";
} else {
return "NO_LEVEL";
}
}
#ifdef ENABLE_ARM
#if defined(__ANDROID__) || defined(ANDROID)
static int GetAndroidLogLevel(MsLogLevel level) {
switch (level) {
case MsLogLevel::DEBUG:
return ANDROID_LOG_DEBUG;
case MsLogLevel::INFO:
return ANDROID_LOG_INFO;
case MsLogLevel::WARNING:
return ANDROID_LOG_WARN;
case MsLogLevel::ERROR:
default:
return ANDROID_LOG_ERROR;
}
}
#endif
#endif
#else
#undef Dlog
@ -153,6 +193,14 @@ void LogWriter::OutputLog(const std::ostringstream &msg) const {
<< std::this_thread::get_id() << std::dec << "," << GetProcName() << "):" << GetTimeString() << " "
<< "[" << location_.file_ << ":" << location_.line_ << "] " << location_.func_ << "] " << msg.str() << std::endl;
#undef google
#elif defined(BUILD_CORE_RUNTIME)
#if defined(ENABLE_ARM) && (defined(__ANDROID__) || defined(ANDROID))
__android_log_print(GetAndroidLogLevel(log_level_), ANDROID_LOG_TAG, "[%s:%d] %s] %s", location_.file_,
location_.line_, location_.func_, msg.str().c_str());
#else
printf("%s [%s:%d] %s] %s\n", EnumStrForMsLogLevel(log_level_), location_.file_, location_.line_, location_.func_,
msg.str().c_str());
#endif
#else
auto str_msg = msg.str();
auto slog_module_id = (submodule_ == SM_MD ? MD : ME);
@ -166,7 +214,7 @@ void LogWriter::operator<(const LogStream &stream) const noexcept {
msg << stream.sstream_->rdbuf();
OutputLog(msg);
}
#ifndef BUILD_LITE_INFERENCE
#if !defined(BUILD_LITE_INFERENCE) || defined(BUILD_CORE_RUNTIME)
void LogWriter::operator^(const LogStream &stream) const {
std::ostringstream msg;
msg << stream.sstream_->rdbuf();

View File

@ -34,6 +34,7 @@
#define google mindspore_private
#include "glog/logging.h"
#undef google
#elif defined(BUILD_CORE_RUNTIME)
#else
#include "toolchain/slog.h"
#endif
@ -238,7 +239,7 @@ class MS_CORE_API LogWriter {
/// \param[in] stream The input log stream.
void operator<(const LogStream &stream) const noexcept;
#ifndef BUILD_LITE_INFERENCE
#if !defined(BUILD_LITE_INFERENCE) || defined(BUILD_CORE_RUNTIME)
/// \brief Output log message from the input log stream and then throw exception.
///
/// \param[in] stream The input log stream.
@ -266,7 +267,7 @@ class MS_CORE_API LogWriter {
: mindspore::LogWriter(mindspore::LocationInfo(FILE_NAME, __LINE__, __FUNCTION__), level, SUBMODULE_ID, \
excp_type) < mindspore::LogStream()
#ifndef BUILD_LITE_INFERENCE
#if !defined(BUILD_LITE_INFERENCE) || defined(BUILD_CORE_RUNTIME)
#define MSLOG_THROW(excp_type) \
mindspore::LogWriter(mindspore::LocationInfo(FILE_NAME, __LINE__, __FUNCTION__), mindspore::EXCEPTION, SUBMODULE_ID, \
excp_type) ^ \

View File

@ -56,6 +56,7 @@ option(MSLITE_ENABLE_SHARING_MODEL_WEIGHT "enable sharing model weight" off)
option(MSLITE_ENABLE_EXPERIMENTAL_KERNEL "enable experimental kernel" on)
option(MSLITE_ENABLE_GRAPH_KERNEL "enable graph kernel" off)
option(MSLITE_ENABLE_CONVERT_PYTORCH_MODEL "enable to convert pytorch model" off)
option(MSLITE_ENABLE_KERNEL_EXECUTOR "enable kernel executor" off)
#Option that can be configured through manually
option(ENABLE_VERBOSE "" off)
@ -175,6 +176,9 @@ endif()
if(DEFINED ENV{MSLITE_ENABLE_SERVING})
set(MSLITE_ENABLE_SERVING $ENV{MSLITE_ENABLE_SERVING})
endif()
if(DEFINED ENV{MSLITE_ENABLE_KERNEL_EXECUTOR})
set(MSLITE_ENABLE_KERNEL_EXECUTOR $ENV{MSLITE_ENABLE_KERNEL_EXECUTOR})
endif()
if(DEFINED ENV{MSLITE_ENABLE_CONVERT_PYTORCH_MODEL} AND DEFINED ENV{LIB_TORCH_PATH})
set(ENABLE_CONVERT_PYTORCH_MODEL $ENV{MSLITE_ENABLE_CONVERT_PYTORCH_MODEL})
@ -404,6 +408,7 @@ message(STATUS "\tMSLITE_ENABLE_PARALLEL_INFERENCE = \t${MSLITE_ENABLE
message(STATUS "\tMSLITE_ENABLE_SHARING_MODEL_WEIGHT = \t${MSLITE_ENABLE_SHARING_MODEL_WEIGHT}")
message(STATUS "\tMSLITE_ENABLE_EXPERIMENTAL_KERNEL = \t${MSLITE_ENABLE_EXPERIMENTAL_KERNEL}")
message(STATUS "\tMSLITE_ENABLE_GRAPH_KERNEL = \t${MSLITE_ENABLE_GRAPH_KERNEL}")
message(STATUS "\tMSLITE_ENABLE_KERNEL_EXECUTOR = \t${MSLITE_ENABLE_KERNEL_EXECUTOR}")
if((MSLITE_ENABLE_CONVERTER OR MSLITE_ENABLE_TESTCASES) AND (
NOT MSLITE_ENABLE_MINDRT
@ -531,7 +536,7 @@ if(MSLITE_GPU_BACKEND STREQUAL opencl)
endif()
if(MSLITE_ENABLE_CONVERTER OR MSLITE_MINDDATA_IMPLEMENT STREQUAL "full" OR MSLITE_MINDDATA_IMPLEMENT STREQUAL "wrapper"
OR MSLITE_ENABLE_TOOLS)
OR MSLITE_ENABLE_TOOLS OR MSLITE_ENABLE_KERNEL_EXECUTOR)
if(NOT ENABLE_CLOUD_AND_LITE)
include(${TOP_DIR}/cmake/external_libs/json.cmake)
endif()
@ -635,6 +640,13 @@ function(find_required_package pkg_name)
endif()
endfunction()
if(MSLITE_ENABLE_CONVERTER OR MSLITE_ENABLE_KERNEL_EXECUTOR)
find_required_package(Patch)
if(NOT ENABLE_CLOUD_AND_LITE)
include(${TOP_DIR}/cmake/external_libs/protobuf.cmake)
endif()
endif()
if(MSLITE_ENABLE_CONVERTER)
if(ENABLE_FAST_HASH_TABLE)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DENABLE_FAST_HASH_TABLE=1")
@ -645,7 +657,6 @@ if(MSLITE_ENABLE_CONVERTER)
if(NOT ENABLE_CLOUD_AND_LITE)
include(${TOP_DIR}/cmake/external_libs/opencv.cmake)
include(${TOP_DIR}/cmake/external_libs/eigen.cmake)
include(${TOP_DIR}/cmake/external_libs/protobuf.cmake)
endif()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/tools/converter)
endif()

View File

@ -187,7 +187,7 @@ if(MSLITE_ENABLE_RUNTIME_GLOG)
add_definitions(-DUSE_GLOG)
string(REPLACE "-fno-rtti" "" CMAKE_C_FLAGS ${CMAKE_C_FLAGS})
string(REPLACE "-fno-rtti" "" CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS})
if(NOT MSLITE_ENABLE_RUNTIME_CONVERT)
if(NOT MSLITE_ENABLE_RUNTIME_CONVERT AND NOT MSLITE_ENABLE_KERNEL_EXECUTOR)
set(LITE_SRC ${LITE_SRC}
${CORE_DIR}/utils/log_adapter.cc)
endif()
@ -199,6 +199,7 @@ if(MSLITE_ENABLE_RUNTIME_CONVERT)
file(GLOB RUNTIME_CONVERT_SRC
${CMAKE_CURRENT_SOURCE_DIR}/ops/ops_def.cc
${CMAKE_CURRENT_SOURCE_DIR}/ops/ops_utils.cc
${CMAKE_CURRENT_SOURCE_DIR}/ops/anf_utils.cc
${CMAKE_CURRENT_SOURCE_DIR}/runtime/runtime_convert.cc)
set(LITE_SRC ${LITE_SRC} ${RUNTIME_CONVERT_SRC})
@ -466,6 +467,10 @@ if(SUPPORT_TRAIN)
endif()
endif()
if(MSLITE_ENABLE_KERNEL_EXECUTOR)
add_subdirectory(cxx_api/kernel_executor)
endif()
########################## build optimize and float16 library #################################
if(PLATFORM_ARM)
if(PLATFORM_ARM64 AND NOT TARGET_HIMIX AND NOT MACHINE_LINUX_ARM64)

View File

@ -16,7 +16,7 @@
#ifndef MINDSPORE_LITE_SRC_COMMON_LOG_ADAPTER_H_
#define MINDSPORE_LITE_SRC_COMMON_LOG_ADAPTER_H_
#ifdef USE_GLOG
#if defined(USE_GLOG) || defined(BUILD_CORE_RUNTIME)
#include "utils/log_adapter.h"
#else
#include "src/common/log.h"

View File

@ -0,0 +1,74 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/common/primitive_t_utils.h"
#include "src/ops/ops_utils.h"
#include "ops/primitive_c.h"
namespace mindspore {
namespace lite {
constexpr size_t INITIAL_SIZE = 1024;
const schema::Primitive *ConvertToPrimitive(schema::PrimitiveT *primitive_t, flatbuffers::FlatBufferBuilder *fbb) {
if (primitive_t == nullptr || fbb == nullptr) {
MS_LOG(ERROR) << "primitiveT or fbb is nullptr.";
return nullptr;
}
auto prim_offset = schema::CreatePrimitive(*fbb, primitive_t);
fbb->Finish(prim_offset);
auto prim_buf = fbb->GetBufferPointer();
return flatbuffers::GetRoot<schema::Primitive>(prim_buf);
}
OpParameter *GetOpParameter(schema::PrimitiveT *primitive_t) {
flatbuffers::FlatBufferBuilder fbb(INITIAL_SIZE);
auto primitive = ConvertToPrimitive(primitive_t, &fbb);
fbb.Clear();
auto prim_type = GetPrimitiveType(primitive, SCHEMA_VERSION::SCHEMA_CUR);
auto parame_gen = PopulateRegistry::GetInstance()->GetParameterCreator(prim_type, SCHEMA_VERSION::SCHEMA_CUR);
if (parame_gen == nullptr) {
MS_LOG(ERROR) << "parameter generator is nullptr.";
return nullptr;
}
auto parameter = parame_gen(primitive);
if (parameter == nullptr) {
MS_LOG(ERROR) << "PopulateParameter return nullptr, type: "
<< GetPrimitiveTypeName(primitive, SCHEMA_VERSION::SCHEMA_CUR);
}
return parameter;
}
std::unique_ptr<schema::PrimitiveT> GetPrimitiveT(const std::shared_ptr<mindspore::ops::BaseOperator> &op) {
if (op == nullptr) {
MS_LOG(DEBUG) << "base operator is nullptr";
return nullptr;
}
if (op->name().empty()) {
MS_LOG(ERROR) << "the name of operator is null";
return nullptr;
}
MS_LOG(DEBUG) << "export operator: " << op->name();
auto creator = MSOpsRegistry::GetInstance()->GetPrimitiveCreator(op->name());
if (creator != nullptr) {
return creator(op->GetPrim());
} else {
MS_LOG(WARNING) << "can not find SingleOpRegistry for operator: " << op->name();
return nullptr;
}
}
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,32 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_COMMON_PRIMITIVE_T_UTILS_H_
#define MINDSPORE_LITE_SRC_COMMON_PRIMITIVE_T_UTILS_H_
#include <memory>
#include "schema/inner/model_generated.h"
#include "src/ops/populate/populate_register.h"
#include "ops/base_operator.h"
namespace mindspore {
namespace lite {
const schema::Primitive *ConvertToPrimitive(schema::PrimitiveT *primitive_t, flatbuffers::FlatBufferBuilder *fbb);
OpParameter *GetOpParameter(schema::PrimitiveT *primitive_t);
std::unique_ptr<schema::PrimitiveT> GetPrimitiveT(const std::shared_ptr<mindspore::ops::BaseOperator> &op);
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_COMMON_PRIMITIVE_T_UTILS_H_

View File

@ -0,0 +1,33 @@
add_compile_definitions(BUILD_CORE_RUNTIME)
add_definitions(-DPRIMITIVE_WRITEABLE)
if(MSLITE_ENABLE_RUNTIME_GLOG)
set(USE_GLOG on)
add_definitions(-DUSE_GLOG)
endif()
string(REPLACE "-fno-rtti" "" CMAKE_C_FLAGS ${CMAKE_C_FLAGS})
string(REPLACE "-fno-rtti" "" CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS})
string(REPLACE "-fno-exceptions" "" CMAKE_C_FLAGS ${CMAKE_C_FLAGS})
string(REPLACE "-fno-exceptions" "" CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS})
if(NOT MSLITE_ENABLE_CONVERTER)
add_subdirectory(${CORE_DIR} mindspore_core)
endif()
add_library(kernel_executor SHARED
${CMAKE_CURRENT_SOURCE_DIR}/kernel_executor.cc
${CMAKE_CURRENT_SOURCE_DIR}/kernel_executor_impl.cc
${TOP_DIR}/mindspore/lite/src/ops/ops_utils.cc
${TOP_DIR}/mindspore/lite/src/common/primitive_t_utils.cc
${TOP_DIR}/mindspore/lite/src/ops/ops_def.cc)
add_dependencies(kernel_executor fbs_inner_src fbs_src mindspore_core)
target_link_libraries(kernel_executor
mindspore-lite
mindspore_core
mindspore::json
mindspore::protobuf
mindspore::flatbuffers)
if(USE_GLOG)
target_link_libraries(kernel_executor mindspore::glog)
endif()

View File

@ -0,0 +1,61 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/cxx_api/kernel_executor/kernel_executor.h"
#include "src/cxx_api/kernel_executor/kernel_executor_impl.h"
namespace mindspore {
Status KernelExecutor::Build(const std::shared_ptr<ops::BaseOperator> &op, const std::vector<MSTensor> &inputs,
const std::vector<MSTensor> &outputs, const std::shared_ptr<Context> &ms_context) {
if (impl_ == nullptr) {
impl_ = std::make_shared<KernelExecutorImpl>();
if (impl_ == nullptr) {
MS_LOG(ERROR) << "implement is null.";
return kLiteNullptr;
}
}
Status ret = impl_->Build(op, inputs, outputs, ms_context);
if (ret != kSuccess) {
return ret;
}
return kSuccess;
}
Status KernelExecutor::ReSize(const std::vector<MSTensor> &inputs, const std::vector<MSTensor> &outputs) {
if (impl_ == nullptr) {
MS_LOG(ERROR) << "implement is null.";
return kLiteNullptr;
}
return impl_->ReSize(inputs, outputs);
}
Status KernelExecutor::Infer(std::vector<MSTensor> *outputs) {
if (impl_ == nullptr) {
MS_LOG(ERROR) << "implement is null.";
return kLiteNullptr;
}
return impl_->Infer(outputs);
}
Status KernelExecutor::Execute(const std::vector<MSTensor> &inputs, const std::vector<MSTensor> &outputs) {
if (impl_ == nullptr) {
MS_LOG(ERROR) << "implement is null.";
return kLiteNullptr;
}
return impl_->Execute(inputs, outputs);
}
} // namespace mindspore

View File

@ -0,0 +1,73 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_CXX_API_KERNEL_EXECUTOR_KERNEL_EXECUTOR_H
#define MINDSPORE_LITE_SRC_CXX_API_KERNEL_EXECUTOR_KERNEL_EXECUTOR_H
#include <vector>
#include <memory>
#include "include/api/types.h"
#include "include/api/status.h"
#include "include/api/context.h"
#include "ops/base_operator.h"
namespace mindspore {
class KernelExecutorImpl;
class MS_API KernelExecutor {
public:
KernelExecutor() = default;
~KernelExecutor() = default;
/// \brief Build a single operator so that it can run on a device.
///
/// \param[in] op Define an operator pointer.
/// \param[in] ms_context Define the context used to store options during execution.
/// \param[in] inputs A vector where single operator inputs are arranged in sequence.
/// \param[in] outputs A vector where single operator outputs are arranged in sequence.
///
/// \return Status.
Status Build(const std::shared_ptr<ops::BaseOperator> &op, const std::vector<MSTensor> &inputs,
const std::vector<MSTensor> &outputs, const std::shared_ptr<Context> &ms_context);
/// \brief ReSize KernelExecutor.
///
/// \param[in] inputs A vector where single operator inputs are arranged in sequence.
/// \param[in] outputs A vector where single operator outputs are arranged in sequence.
///
/// \return Status.
Status ReSize(const std::vector<MSTensor> &inputs, const std::vector<MSTensor> &outputs);
/// \brief set outputs infer shape info.
///
/// \param[in] outputs A vector where single operator outputs are arranged in sequence.
///
/// \return Status.
Status Infer(std::vector<MSTensor> *outputs);
/// \brief ReSize KernelExecutor.
///
/// \param[in] inputs A vector where single operator inputs are arranged in sequence.
/// \param[in] outputs A vector where single operator outputs are arranged in sequence.
///
/// \return Status.
Status Execute(const std::vector<MSTensor> &inputs, const std::vector<MSTensor> &outputs);
private:
std::shared_ptr<KernelExecutorImpl> impl_ = nullptr;
};
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_CXX_API_KERNEL_EXECUTOR_KERNEL_EXECUTOR_H

View File

@ -0,0 +1,253 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <algorithm>
#include "src/ops/ops_utils.h"
#include "src/cxx_api/converters.h"
#include "src/common/prim_util.h"
#include "src/ops/populate/populate_register.h"
#include "src/common/primitive_t_utils.h"
#include "schema/inner/model_generated.h"
#include "src/runtime/infer_manager.h"
#include "src/kernel_registry.h"
#include "src/cxx_api/kernel_executor/kernel_executor_impl.h"
namespace mindspore {
constexpr size_t INITIAL_SIZE = 1024;
KernelExecutorImpl::~KernelExecutorImpl() {
if (context_ != nullptr) {
delete context_;
context_ = nullptr;
}
if (kernel_ != nullptr) {
delete kernel_;
kernel_ = nullptr;
}
FreeInOutTensor();
}
Status KernelExecutorImpl::Build(const std::shared_ptr<ops::BaseOperator> &op, const std::vector<MSTensor> &inputs,
const std::vector<MSTensor> &outputs, const std::shared_ptr<Context> &ms_context) {
data_type_ = static_cast<enum TypeId>(inputs[FIRST_INPUT].DataType());
std::unique_ptr<mindspore::schema::PrimitiveT> prim_t = lite::GetPrimitiveT(op);
flatbuffers::FlatBufferBuilder fbb(INITIAL_SIZE);
primitive_ = lite::ConvertToPrimitive(prim_t.get(), &fbb);
fbb.Clear();
if (primitive_ == nullptr) {
MS_LOG(ERROR) << "convert to primitive nullptr.";
return kLiteNullptr;
}
prim_type_ = lite::GetPrimitiveType(primitive_, schema_version_);
context_ = ContextUtils::Convert(ms_context.get());
if (context_ == nullptr) {
MS_LOG(ERROR) << "failed to convert Context to LiteContext.";
return kLiteNullptr;
}
int ret = context_->Init();
if (ret != RET_OK) {
return static_cast<StatusCode>(ret);
}
Status status = InitInOutTensor(inputs, outputs);
if (status != kSuccess) {
MS_LOG(ERROR) << "InitInOutTensor error.";
return status;
}
if (prim_type_ == schema::PrimitiveType_Custom) {
status = GetCustomKernel(ms_context);
} else {
status = GetCpuKernel(ms_context);
}
if (status != kSuccess) {
MS_LOG(ERROR) << "get kernel error.";
return status;
}
ret = kernel_->Prepare();
return static_cast<StatusCode>(ret);
}
Status KernelExecutorImpl::ReSize(const std::vector<MSTensor> &inputs, const std::vector<MSTensor> &outputs) {
Status status = InitInOutTensor(inputs, outputs);
if (status != kSuccess) {
MS_LOG(ERROR) << "InitInOutTensor error.";
return status;
}
kernel_->set_in_tensors(inputs_);
kernel_->set_out_tensors(outputs_);
int ret;
if (kernel_->type() == schema::PrimitiveType_Custom) {
ret = KernelInferShape(inputs_, outputs_, primitive_, context_->GetProviders(), schema_version_);
} else {
ret = KernelInferShape(inputs_, outputs_, parameter_);
}
if (ret != RET_OK) {
MS_LOG(ERROR) << "do infer shape error.";
return static_cast<StatusCode>(ret);
}
ret = kernel_->ReSize();
return static_cast<StatusCode>(ret);
}
Status KernelExecutorImpl::Infer(std::vector<MSTensor> *outputs) {
for (size_t i = 0; i < outputs->size(); ++i) {
auto user_output = outputs->at(i);
auto output = outputs_[i];
user_output.SetFormat(output->format());
auto output_shape = output->shape();
std::vector<int64_t> shape;
std::transform(output_shape.begin(), output_shape.end(), std::back_inserter(shape),
[](auto s) { return static_cast<int64_t>(s); });
user_output.SetShape(shape);
}
return kSuccess;
}
Status KernelExecutorImpl::Execute(const std::vector<MSTensor> &inputs, const std::vector<MSTensor> &outputs) {
for (size_t i = 0; i < inputs.size(); ++i) {
auto user_input = inputs[i];
auto input = inputs_[i];
input->set_data(user_input.MutableData());
}
for (size_t i = 0; i < outputs.size(); ++i) {
auto user_output = outputs[i];
auto output = outputs_[i];
output->set_data(user_output.MutableData());
}
int ret = kernel_->Execute();
if (ret != RET_OK) {
MS_LOG(ERROR) << "execute error.";
return static_cast<StatusCode>(ret);
}
return kSuccess;
}
Status KernelExecutorImpl::GetOpParameter() {
auto parame_gen = lite::PopulateRegistry::GetInstance()->GetParameterCreator(prim_type_, schema_version_);
if (parame_gen == nullptr) {
MS_LOG(ERROR) << "parameter generator is nullptr.";
return kLiteNullptr;
}
parameter_ = parame_gen(primitive_);
if (parameter_ == nullptr) {
MS_LOG(ERROR) << "PopulateParameter return nullptr, type: "
<< lite::GetPrimitiveTypeName(primitive_, schema_version_);
return kLiteNullptr;
}
return kSuccess;
}
Status KernelExecutorImpl::GetCustomKernel(const std::shared_ptr<Context> &ms_context) {
int get_kernel = lite::RET_ERROR;
// find kernel match arch, data_type, kernel_arch and provider
for (auto &&device : context_->device_list_) {
if (!device.provider_.empty() && !device.provider_device_.empty()) {
kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, data_type_, prim_type_, device.provider_device_,
device.provider_};
get_kernel = lite::KernelRegistry::GetInstance()->GetKernel(inputs_, outputs_, context_, ms_context.get(), desc,
nullptr, &kernel_, primitive_);
}
}
// find kernel only match arch and data_type
if (get_kernel != RET_OK) {
kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, data_type_, prim_type_, "", ""};
get_kernel = lite::KernelRegistry::GetInstance()->GetKernel(inputs_, outputs_, context_, ms_context.get(), desc,
nullptr, &kernel_, primitive_);
}
// if found kernel, do infershape
if (get_kernel == RET_OK) {
int ret = KernelInferShape(inputs_, outputs_, primitive_, context_->GetProviders(), schema_version_);
return static_cast<StatusCode>(ret);
}
return static_cast<StatusCode>(get_kernel);
}
Status KernelExecutorImpl::GetCpuKernel(const std::shared_ptr<Context> &ms_context) {
Status status = GetOpParameter();
if (status != kSuccess) {
return status;
}
kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, data_type_, prim_type_};
int get_kernel = lite::KernelRegistry::GetInstance()->GetKernel(inputs_, outputs_, context_, ms_context.get(), desc,
parameter_, &kernel_);
if (get_kernel == RET_OK) {
int ret = KernelInferShape(inputs_, outputs_, parameter_);
return static_cast<StatusCode>(ret);
}
return static_cast<StatusCode>(get_kernel);
}
void KernelExecutorImpl::FreeInOutTensor() {
for (auto &input : inputs_) {
if (input != nullptr) {
delete input;
input = nullptr;
}
}
inputs_.clear();
for (auto &output : outputs_) {
if (output != nullptr) {
delete output;
output = nullptr;
}
}
outputs_.clear();
}
Status KernelExecutorImpl::InitInOutTensor(const std::vector<MSTensor> &inputs, const std::vector<MSTensor> &outputs) {
FreeInOutTensor();
for (auto input : inputs) {
auto input_shape = input.Shape();
std::vector<int> shape;
std::transform(input_shape.begin(), input_shape.end(), std::back_inserter(shape),
[](auto s) { return static_cast<int>(s); });
lite::Tensor *input_tensor = new (std::nothrow)
lite::Tensor(static_cast<enum TypeId>(input.DataType()), shape, input.format(), lite::Category::GRAPH_INPUT);
if (input_tensor == nullptr) {
delete input_tensor;
return kLiteNullptr;
}
input_tensor->set_data(input.MutableData());
inputs_.emplace_back(input_tensor);
}
for (auto output : outputs) {
auto output_shape = output.Shape();
std::vector<int> shape;
std::transform(output_shape.begin(), output_shape.end(), std::back_inserter(shape),
[](auto s) { return static_cast<int>(s); });
lite::Tensor *output_tensor =
new (std::nothrow) lite::Tensor(static_cast<enum TypeId>(output.DataType()), shape, output.format());
if (output_tensor == nullptr) {
delete output_tensor;
return kLiteNullptr;
}
outputs_.emplace_back(output_tensor);
}
return kSuccess;
}
} // namespace mindspore

View File

@ -0,0 +1,56 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_CXX_API_KERNEL_EXECUTOR_KERNEL_EXECUTOR_IMPL_H
#define MINDSPORE_LITE_SRC_CXX_API_KERNEL_EXECUTOR_KERNEL_EXECUTOR_IMPL_H
#include <vector>
#include <memory>
#include "src/cxx_api/kernel_executor/kernel_executor.h"
#include "src/kernel_exec.h"
#include "common/version_manager.h"
namespace mindspore {
class KernelExecutorImpl {
public:
KernelExecutorImpl() = default;
~KernelExecutorImpl();
Status Build(const std::shared_ptr<ops::BaseOperator> &op, const std::vector<MSTensor> &inputs,
const std::vector<MSTensor> &outputs, const std::shared_ptr<Context> &ms_context);
Status ReSize(const std::vector<MSTensor> &inputs, const std::vector<MSTensor> &outputs);
Status Infer(std::vector<MSTensor> *outputs);
Status Execute(const std::vector<MSTensor> &inputs, const std::vector<MSTensor> &outputs);
protected:
Status GetCustomKernel(const std::shared_ptr<Context> &ms_context);
Status GetCpuKernel(const std::shared_ptr<Context> &ms_context);
Status GetOpParameter();
Status InitInOutTensor(const std::vector<MSTensor> &inputs, const std::vector<MSTensor> &outputs);
void FreeInOutTensor();
private:
const schema::Primitive *primitive_ = nullptr;
int prim_type_;
OpParameter *parameter_ = nullptr;
lite::InnerContext *context_ = nullptr;
TypeId data_type_;
kernel::KernelExec *kernel_ = nullptr;
std::vector<lite::Tensor *> inputs_;
std::vector<lite::Tensor *> outputs_;
int schema_version_ = lite::SCHEMA_VERSION::SCHEMA_CUR;
};
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_CXX_API_KERNEL_EXECUTOR_KERNEL_EXECUTOR_IMPL_H

View File

@ -0,0 +1,44 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/anf_utils.h"
#ifdef PRIMITIVE_WRITEABLE
namespace mindspore {
namespace lite {
std::unique_ptr<schema::PrimitiveT> GetPrimitiveT(const AnfNodePtr &node) {
auto prim = GetValueNode<std::shared_ptr<Primitive>>(node);
if (prim == nullptr) {
MS_LOG(DEBUG) << "primitive is nullptr";
return nullptr;
}
if (prim->name().empty()) {
MS_LOG(ERROR) << "the name of primitive is null";
return nullptr;
}
MS_LOG(DEBUG) << "export prim: " << prim->name();
auto creator = MSOpsRegistry::GetInstance()->GetPrimitiveCreator(prim->name());
if (creator != nullptr) {
return creator(prim);
} else {
MS_LOG(WARNING) << "can not find MSOpsRegistry for op: " << prim->name();
return nullptr;
}
}
} // namespace lite
} // namespace mindspore
#endif

View File

@ -0,0 +1,29 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_OPS_ANF_UTILS_H_
#define MINDSPORE_LITE_SRC_OPS_ANF_UTILS_H_
#include <memory>
#include "src/ops/ops_utils.h"
#ifdef PRIMITIVE_WRITEABLE
#include "abstract/ops/primitive_infer_map.h"
namespace mindspore {
namespace lite {
std::unique_ptr<schema::PrimitiveT> GetPrimitiveT(const mindspore::AnfNodePtr &node);
}
} // namespace mindspore
#endif
#endif // MINDSPORE_LITE_SRC_OPS_ANF_UTILS_H_

File diff suppressed because it is too large Load Diff

View File

@ -20,14 +20,14 @@
#include <map>
#include <string>
#include <memory>
#include <algorithm>
#include "src/ops/ops_func_declare.h"
#ifdef PRIMITIVE_WRITEABLE
#include "abstract/ops/primitive_infer_map.h"
#include "src/common/log_adapter.h"
namespace mindspore {
namespace lite {
typedef std::unique_ptr<schema::PrimitiveT> (*PrimitiveTCreator)(const AnfNodePtr &node);
typedef std::unique_ptr<schema::PrimitiveT> (*PrimitiveTCreator)(const PrimitivePtr &primitive);
class MSOpsRegistry {
public:
@ -35,12 +35,19 @@ class MSOpsRegistry {
static MSOpsRegistry registry;
return &registry;
}
void InsertPrimitiveTMap(const std::string &name, PrimitiveTCreator creator) { primitive_creators[name] = creator; }
void InsertPrimitiveTMap(const std::string &name, PrimitiveTCreator creator) {
std::string lower_name = name;
std::transform(name.begin(), name.end(), lower_name.begin(), ::tolower);
primitive_creators[lower_name] = creator;
}
PrimitiveTCreator GetPrimitiveCreator(const std::string &name) {
if (primitive_creators.find(name) != primitive_creators.end()) {
return primitive_creators[name];
std::string lower_name = name;
std::transform(name.begin(), name.end(), lower_name.begin(), ::tolower);
lower_name.erase(std::remove(lower_name.begin(), lower_name.end(), '_'), lower_name.end());
if (primitive_creators.find(lower_name) != primitive_creators.end()) {
return primitive_creators[lower_name];
} else {
MS_LOG(WARNING) << "Unsupported primitive type in Create: " << name;
MS_LOG(ERROR) << "Unsupported primitive type in Create: " << name;
return nullptr;
}
}
@ -57,7 +64,8 @@ class RegistryMSOps {
~RegistryMSOps() = default;
};
std::unique_ptr<schema::PrimitiveT> GetPrimitiveT(const mindspore::AnfNodePtr &node);
#define REG_MINDSPORE_OPERATOR(OP) \
static RegistryMSOps g_##OP##PrimitiveCreatorRegistry(#OP, PrimitiveCreator<mindspore::ops::OP>);
} // namespace lite
} // namespace mindspore
#endif

View File

@ -41,7 +41,7 @@
#include "src/common/utils.h"
#include "tools/common/graph_util.h"
#include "tools/common/meta_graph_utils.h"
#include "src/ops/ops_utils.h"
#include "src/ops/anf_utils.h"
#include "src/weight_decoder.h"
#include "tools/common/node_util.h"
#include "src/common/log_util.h"

View File

@ -29,8 +29,9 @@
#include "tools/optimizer/common/format_utils.h"
#include "nnacl/op_base.h"
#include "tools/common/node_util.h"
#include "src/ops/ops_utils.h"
#include "src/ops/anf_utils.h"
#include "src/ops/populate/populate_register.h"
#include "src/common/primitive_t_utils.h"
#include "mindapi/base/format.h"
#include "ops/op_utils.h"

View File

@ -46,17 +46,6 @@ std::vector<CNodePtr> GetInputCNode(const CNodePtr &cnode) {
return inputs;
}
const schema::Primitive *ConvertToPrimitive(schema::PrimitiveT *primitive_t, flatbuffers::FlatBufferBuilder *fbb) {
if (primitive_t == nullptr || fbb == nullptr) {
MS_LOG(ERROR) << "primitiveT or fbb is nullptr.";
return nullptr;
}
auto prim_offset = schema::CreatePrimitive(*fbb, primitive_t);
fbb->Finish(prim_offset);
auto prim_buf = fbb->GetBufferPointer();
return flatbuffers::GetRoot<schema::Primitive>(prim_buf);
}
STATUS NodeUtils::ConvertDims(mindspore::schema::Format src_format, const std::vector<int32_t> &src_dims,
mindspore::schema::Format dst_format, std::vector<int32_t> *dst_dims) {
MS_ASSERT(dst_dims != nullptr);

View File

@ -81,8 +81,6 @@ std::vector<schema::PrimitiveType> Getfp32FullOpList();
std::vector<schema::PrimitiveType> GetUint8NhwcOpList();
const schema::Primitive *ConvertToPrimitive(schema::PrimitiveT *primitive_t, flatbuffers::FlatBufferBuilder *fbb);
size_t GetTensorInputIndexInCNode(const uint32_t &tensor_index, const schema::CNodeT &cnode);
class NodeUtils {

View File

@ -67,6 +67,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
${SRC_DIR}/common/dynamic_library_loader.cc
${SRC_DIR}/train/train_populate_parameter.cc
${SRC_DIR}/common/config_file.cc
${SRC_DIR}/common/primitive_t_utils.cc
../optimizer/*.cc
)
@ -117,6 +118,7 @@ set(LITE_SRC ${API_SRC}
${RUNTIME_PASS_SRCS}
${SRC_DIR}/ops/ops_def.cc
${SRC_DIR}/ops/ops_utils.cc
${SRC_DIR}/ops/anf_utils.cc
${SRC_DIR}/common/utils.cc
${SRC_DIR}/common/file_utils.cc
${SRC_DIR}/common/context_util.cc

View File

@ -41,6 +41,7 @@
#include "tools/common/tensor_util.h"
#include "include/api/model.h"
#include "tools/mindir_serializer/mindir_serializer.h"
#include "src/common/primitive_t_utils.h"
namespace mindspore {
namespace lite {

View File

@ -26,6 +26,7 @@
#include "src/common/prim_util.h"
#include "src/ops/populate/populate_register.h"
#include "src/runtime/infer_manager.h"
#include "src/common/primitive_t_utils.h"
#include "tools/common/node_util.h"
#include "tools/converter/converter_flags.h"
#include "src/common/string_utils.h"

View File

@ -31,7 +31,7 @@
#include "src/kernel_registry.h"
#include "src/inner_context.h"
#include "src/tensor.h"
#include "src/ops/ops_utils.h"
#include "src/ops/anf_utils.h"
#include "src/runtime/infer_manager.h"
#include "tools/optimizer/graph/lite_tensor_extractor.h"

View File

@ -25,7 +25,7 @@
#include "tools/optimizer/common/gllo_utils.h"
#include "securec/include/securec.h"
#include "nnacl/op_base.h"
#include "src/ops/ops_utils.h"
#include "src/ops/anf_utils.h"
#include "ops/op_utils.h"
namespace mindspore {

View File

@ -19,11 +19,12 @@
#include <memory>
#include <vector>
#include <algorithm>
#include "src/common/primitive_t_utils.h"
#include "tools/common/node_util.h"
#include "tools/common/tensor_util.h"
#include "src/common/utils.h"
#include "src/ops/populate/populate_register.h"
#include "src/ops/ops_utils.h"
#include "src/ops/anf_utils.h"
#include "src/runtime/infer_manager.h"
#include "src/tensorlist.h"
#include "src/registry/kernel_interface_registry.h"