From 5440f9fbc1c80fbf85df9714511989735dd7d14f Mon Sep 17 00:00:00 2001 From: zhengyuanhua Date: Sat, 18 Sep 2021 10:46:46 +0800 Subject: [PATCH] external interface modify --- include/api/callback/ckpt_saver.h | 10 ++++- include/api/model.h | 33 +++++++++++--- include/api/serialization.h | 17 ++++++-- include/api/types.h | 7 ++- mindspore/ccsrc/cxx_api/model/model.cc | 6 +-- mindspore/ccsrc/cxx_api/serialization.cc | 4 +- mindspore/ccsrc/cxx_api/types.cc | 2 +- .../include/registry/node_parser_registry.h | 21 +++++++-- .../include/registry/opencl_runtime_wrapper.h | 25 +++++++++-- .../lite/include/registry/pass_registry.h | 28 ++++++++++-- .../lite/include/registry/register_kernel.h | 43 ++++++++++++++++--- .../registry/register_kernel_interface.h | 28 ++++++++++-- .../lite/src/cxx_api/callback/ckpt_saver.cc | 5 ++- mindspore/lite/src/cxx_api/model/model.cc | 14 +++--- mindspore/lite/src/cxx_api/serialization.cc | 9 ++-- mindspore/lite/src/cxx_api/types.cc | 4 +- .../lite/src/registry/register_kernel.cc | 22 ++++++---- .../src/registry/register_kernel_interface.cc | 12 +++--- .../gpu/opencl/opencl_runtime_wrapper.cc | 17 ++++---- .../registry/registry_gpu_custom_op_test.cc | 2 +- .../registry/node_parser_registry.cc | 10 +++-- .../tools/converter/registry/pass_registry.cc | 15 ++++--- 22 files changed, 251 insertions(+), 83 deletions(-) diff --git a/include/api/callback/ckpt_saver.h b/include/api/callback/ckpt_saver.h index 2c67d3a44e6..e673c624224 100644 --- a/include/api/callback/ckpt_saver.h +++ b/include/api/callback/ckpt_saver.h @@ -21,13 +21,21 @@ #include #include #include "include/api/callback/callback.h" +#include "include/api/dual_abi_helper.h" namespace mindspore { class CkptSaver: public TrainCallBack { public: - explicit CkptSaver(int save_every_n, const std::string &filename_prefix); + inline CkptSaver(int save_every_n, const std::string &filename_prefix); virtual ~CkptSaver(); + + private: + CkptSaver(int save_every_n, const std::vector &filename_prefix); }; + +CkptSaver::CkptSaver(int save_every_n, const std::string &filename_prefix) + : CkptSaver(save_every_n, StringToChar(filename_prefix)) {} + } // namespace mindspore #endif // MINDSPORE_INCLUDE_API_CALLBACK_CKPT_SAVER_H diff --git a/include/api/model.h b/include/api/model.h index ac1f9e2428c..cecda6499f7 100644 --- a/include/api/model.h +++ b/include/api/model.h @@ -104,7 +104,7 @@ class MS_API Model { /// \param[in] config_path config file path. /// /// \return Status. - Status LoadConfig(const std::string &config_path); + inline Status LoadConfig(const std::string &config_path); /// \brief Obtains all input tensors of the model. /// @@ -189,9 +189,9 @@ class MS_API Model { /// \param[in] dec_mode Define the decryption mode. Options: AES-GCM, AES-CBC. /// /// \return Status. - Status Build(const void *model_data, size_t data_size, ModelType model_type, - const std::shared_ptr &model_context = nullptr, const Key &dec_key = {}, - const std::string &dec_mode = kDecModeAesGcm); + inline Status Build(const void *model_data, size_t data_size, ModelType model_type, + const std::shared_ptr &model_context = nullptr, const Key &dec_key = {}, + const std::string &dec_mode = kDecModeAesGcm); /// \brief Load and build a model from model buffer so that it can run on a device. Only valid for Lite. /// @@ -203,9 +203,9 @@ class MS_API Model { /// \param[in] dec_mode Define the decryption mode. Options: AES-GCM, AES-CBC. /// /// \return Status. - Status Build(const std::string &model_path, ModelType model_type, - const std::shared_ptr &model_context = nullptr, const Key &dec_key = {}, - const std::string &dec_mode = kDecModeAesGcm); + inline Status Build(const std::string &model_path, ModelType model_type, + const std::shared_ptr &model_context = nullptr, const Key &dec_key = {}, + const std::string &dec_mode = kDecModeAesGcm); private: friend class Serialization; @@ -214,6 +214,11 @@ class MS_API Model { std::vector> GetOutputTensorNamesChar(); MSTensor GetOutputByTensorName(const std::vector &tensor_name); std::vector GetOutputsByNodeName(const std::vector &node_name); + Status LoadConfig(const std::vector &config_path); + Status Build(const void *model_data, size_t data_size, ModelType model_type, + const std::shared_ptr &model_context, const Key &dec_key, const std::vector &dec_mode); + Status Build(const std::vector &model_path, ModelType model_type, const std::shared_ptr &model_context, + const Key &dec_key, const std::vector &dec_mode); std::shared_ptr impl_; }; @@ -231,5 +236,19 @@ MSTensor Model::GetOutputByTensorName(const std::string &tensor_name) { std::vector Model::GetOutputsByNodeName(const std::string &node_name) { return GetOutputsByNodeName(StringToChar(node_name)); } + +Status Model::LoadConfig(const std::string &config_path) { + return LoadConfig(StringToChar(config_path)); +} + +Status Model::Build(const void *model_data, size_t data_size, ModelType model_type, + const std::shared_ptr &model_context, const Key &dec_key, const std::string &dec_mode) { + return Build(model_data, data_size, model_type, model_context, dec_key, StringToChar(dec_mode)); +} + +Status Model::Build(const std::string &model_path, ModelType model_type, const std::shared_ptr &model_context, + const Key &dec_key, const std::string &dec_mode) { + return Build(StringToChar(model_path), model_type, model_context, dec_key, StringToChar(dec_mode)); +} } // namespace mindspore #endif // MINDSPORE_INCLUDE_API_MODEL_H diff --git a/include/api/serialization.h b/include/api/serialization.h index 613355df152..31476d82f6e 100644 --- a/include/api/serialization.h +++ b/include/api/serialization.h @@ -68,9 +68,9 @@ class MS_API Serialization { const Key &dec_key = {}, const std::string &dec_mode = kDecModeAesGcm); static Status SetParameters(const std::map ¶meters, Model *model); static Status ExportModel(const Model &model, ModelType model_type, Buffer *model_data); - static Status ExportModel(const Model &model, ModelType model_type, const std::string &model_file, - QuantizationType quantization_type = kNoQuant, bool export_inference_only = true, - std::vector output_tensor_name = {}); + inline static Status ExportModel(const Model &model, ModelType model_type, const std::string &model_file, + QuantizationType quantization_type = kNoQuant, bool export_inference_only = true, + std::vector output_tensor_name = {}); private: static Status Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph, const Key &dec_key, @@ -80,6 +80,9 @@ class MS_API Serialization { const std::vector &dec_mode); static Status Load(const std::vector> &files, ModelType model_type, std::vector *graphs, const Key &dec_key, const std::vector &dec_mode); + static Status ExportModel(const Model &model, ModelType model_type, const std::vector &model_file, + QuantizationType quantization_type, bool export_inference_only, + const std::vector> &output_tensor_name); }; Status Serialization::Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph, @@ -96,5 +99,13 @@ Status Serialization::Load(const std::vector &files, ModelType mode const Key &dec_key, const std::string &dec_mode) { return Load(VectorStringToChar(files), model_type, graphs, dec_key, StringToChar(dec_mode)); } + +Status Serialization::ExportModel(const Model &model, ModelType model_type, const std::string &model_file, + QuantizationType quantization_type, bool export_inference_only, + std::vector output_tensor_name) { + return ExportModel(model, model_type, StringToChar(model_file), quantization_type, export_inference_only, + VectorStringToChar(output_tensor_name)); +} + } // namespace mindspore #endif // MINDSPORE_INCLUDE_API_SERIALIZATION_H diff --git a/include/api/types.h b/include/api/types.h index 26e2e8f49dd..3acd8e38c98 100644 --- a/include/api/types.h +++ b/include/api/types.h @@ -221,7 +221,7 @@ class MS_API MSTensor { /// \brief Set the name for the MSTensor. Only valid for Lite. /// /// \param[in] The name of the MSTensor. - void SetTensorName(const std::string &name); + inline void SetTensorName(const std::string &name); /// \brief Set the Allocator for the MSTensor. Only valid for Lite. /// @@ -275,6 +275,7 @@ class MS_API MSTensor { MSTensor(const std::vector &name, enum DataType type, const std::vector &shape, const void *data, size_t data_len); std::vector CharName() const; + void SetTensorName(const std::vector &name); friend class ModelImpl; std::shared_ptr impl_; @@ -333,6 +334,10 @@ MSTensor::MSTensor(const std::string &name, enum DataType type, const std::vecto std::string MSTensor::Name() const { return CharToString(CharName()); } +void MSTensor::SetTensorName(const std::string &name) { + return SetTensorName(StringToChar(name)); +} + using Key = struct Key { const size_t max_key_len = 32; size_t len; diff --git a/mindspore/ccsrc/cxx_api/model/model.cc b/mindspore/ccsrc/cxx_api/model/model.cc index 84905a7c63f..14d386cef7b 100644 --- a/mindspore/ccsrc/cxx_api/model/model.cc +++ b/mindspore/ccsrc/cxx_api/model/model.cc @@ -66,13 +66,13 @@ Status Model::Build(GraphCell graph_cell, const std::shared_ptr &model_ } Status Model::Build(const void *, size_t, ModelType, const std::shared_ptr &, const Key &, - const std::string &) { + const std::vector &) { MS_LOG(ERROR) << "Unsupported Feature."; return kMCFailed; } -Status Model::Build(const std::string &, ModelType, const std::shared_ptr &, const Key &, - const std::string &) { +Status Model::Build(const std::vector &, ModelType, const std::shared_ptr &, const Key &, + const std::vector &) { MS_LOG(ERROR) << "Unsupported Feature."; return kMCFailed; } diff --git a/mindspore/ccsrc/cxx_api/serialization.cc b/mindspore/ccsrc/cxx_api/serialization.cc index 31fd54df0f3..58553b1f05f 100644 --- a/mindspore/ccsrc/cxx_api/serialization.cc +++ b/mindspore/ccsrc/cxx_api/serialization.cc @@ -340,8 +340,8 @@ Status Serialization::ExportModel(const Model &, ModelType, Buffer *) { return kMEFailed; } -Status Serialization::ExportModel(const Model &, ModelType, const std::string &, QuantizationType, bool, - std::vector output_tensor_name) { +Status Serialization::ExportModel(const Model &, ModelType, const std::vector &, QuantizationType, bool, + const std::vector> &output_tensor_name) { MS_LOG(ERROR) << "Unsupported feature."; return kMEFailed; } diff --git a/mindspore/ccsrc/cxx_api/types.cc b/mindspore/ccsrc/cxx_api/types.cc index cc3a81f6675..d3dadceedb8 100644 --- a/mindspore/ccsrc/cxx_api/types.cc +++ b/mindspore/ccsrc/cxx_api/types.cc @@ -429,7 +429,7 @@ void MSTensor::SetShape(const std::vector &) { MS_LOG_EXCEPTION << "Inv void MSTensor::SetDataType(enum DataType) { MS_LOG_EXCEPTION << "Invalid implement."; } -void MSTensor::SetTensorName(const std::string &) { MS_LOG_EXCEPTION << "Invalid implement."; } +void MSTensor::SetTensorName(const std::vector &) { MS_LOG_EXCEPTION << "Invalid implement."; } void MSTensor::SetAllocator(std::shared_ptr) { MS_LOG_EXCEPTION << "Invalid implement."; } diff --git a/mindspore/lite/include/registry/node_parser_registry.h b/mindspore/lite/include/registry/node_parser_registry.h index e838fd4a100..a23fae1ea38 100644 --- a/mindspore/lite/include/registry/node_parser_registry.h +++ b/mindspore/lite/include/registry/node_parser_registry.h @@ -18,7 +18,9 @@ #define MINDSPORE_LITE_INCLUDE_REGISTRY_NODE_PARSER_REGISTRY_H_ #include +#include #include "include/registry/node_parser.h" +#include "include/api/dual_abi_helper.h" namespace mindspore { namespace registry { @@ -30,8 +32,8 @@ class MS_API NodeParserRegistry { /// \param[in] fmk_type Define the framework. /// \param[in] node_type Define the type of the node to be resolved. /// \param[in] node_parser Define the NodeParser instance to parse the node. - NodeParserRegistry(converter::FmkType fmk_type, const std::string &node_type, - const converter::NodeParserPtr &node_parser); + inline NodeParserRegistry(converter::FmkType fmk_type, const std::string &node_type, + const converter::NodeParserPtr &node_parser); /// \brief Destructor ~NodeParserRegistry() = default; @@ -42,9 +44,22 @@ class MS_API NodeParserRegistry { /// \param[in] node_type Define the type of the node to be resolved. /// /// \return NodeParser instance. - static converter::NodeParserPtr GetNodeParser(converter::FmkType fmk_type, const std::string &node_type); + inline static converter::NodeParserPtr GetNodeParser(converter::FmkType fmk_type, const std::string &node_type); + + private: + NodeParserRegistry(converter::FmkType fmk_type, const std::vector &node_type, + const converter::NodeParserPtr &node_parser); + static converter::NodeParserPtr GetNodeParser(converter::FmkType fmk_type, const std::vector &node_type); }; +NodeParserRegistry::NodeParserRegistry(converter::FmkType fmk_type, const std::string &node_type, + const converter::NodeParserPtr &node_parser) + : NodeParserRegistry(fmk_type, StringToChar(node_type), node_parser) {} + +converter::NodeParserPtr NodeParserRegistry::GetNodeParser(converter::FmkType fmk_type, const std::string &node_type) { + return GetNodeParser(fmk_type, StringToChar(node_type)); +} + /// \brief Defined registering macro to register NodeParser instance. /// /// \param[in] fmk_type Define the framework. diff --git a/mindspore/lite/include/registry/opencl_runtime_wrapper.h b/mindspore/lite/include/registry/opencl_runtime_wrapper.h index fdb00060e37..608e4274d0f 100644 --- a/mindspore/lite/include/registry/opencl_runtime_wrapper.h +++ b/mindspore/lite/include/registry/opencl_runtime_wrapper.h @@ -16,6 +16,7 @@ #ifndef MINDSPORE_LITE_INCLUDE_REGISTRY_OPENCL_RUNTIME_WRAPPER_H #define MINDSPORE_LITE_INCLUDE_REGISTRY_OPENCL_RUNTIME_WRAPPER_H + #include #include #include @@ -26,6 +27,7 @@ #include "CL/cl2.hpp" #include "include/api/allocator.h" #include "include/api/status.h" +#include "include/api/dual_abi_helper.h" namespace mindspore::registry::opencl { class OpenCLRuntimeWrapper { @@ -39,7 +41,7 @@ class OpenCLRuntimeWrapper { /// \param[in] source Define OpenCl source. /// /// \return Status as a status identification of loading code. - Status LoadSource(const std::string &program_name, const std::string &source); + inline Status LoadSource(const std::string &program_name, const std::string &source); /// \brief Building OpenCL code. /// @@ -49,8 +51,8 @@ class OpenCLRuntimeWrapper { /// \param[in] build_options_ext Define OpenCl kernel build options. /// /// \return Status as a status identification of build Kernel - Status BuildKernel(cl::Kernel *kernel, const std::string &program_name, const std::string &kernel_name, - const std::vector &build_options_ext = {}); + inline Status BuildKernel(cl::Kernel *kernel, const std::string &program_name, const std::string &kernel_name, + const std::vector &build_options_ext = {}); /// \brief Set kernel argument /// @@ -114,6 +116,23 @@ class OpenCLRuntimeWrapper { uint64_t GetMaxImage2DHeight(); uint64_t GetImagePitchAlignment(); + + private: + Status LoadSource(const std::vector &program_name, const std::vector &source); + + Status BuildKernel(cl::Kernel *kernel, const std::vector &program_name, const std::vector &kernel_name, + const std::vector> &build_options_ext); }; + +Status OpenCLRuntimeWrapper::LoadSource(const std::string &program_name, const std::string &source) { + return LoadSource(StringToChar(program_name), StringToChar(source)); +} + +Status OpenCLRuntimeWrapper::BuildKernel(cl::Kernel *kernel, const std::string &program_name, + const std::string &kernel_name, + const std::vector &build_options_ext) { + return BuildKernel(kernel, StringToChar(program_name), StringToChar(kernel_name), + VectorStringToChar(build_options_ext)); +} } // namespace mindspore::registry::opencl #endif // MINDSPORE_LITE_INCLUDE_REGISTRY_OPENCL_RUNTIME_WRAPPER_H diff --git a/mindspore/lite/include/registry/pass_registry.h b/mindspore/lite/include/registry/pass_registry.h index 9664361bbd7..df693fd3225 100644 --- a/mindspore/lite/include/registry/pass_registry.h +++ b/mindspore/lite/include/registry/pass_registry.h @@ -21,6 +21,7 @@ #include #include #include "include/lite_utils.h" +#include "include/api/dual_abi_helper.h" namespace mindspore { namespace registry { @@ -36,13 +37,13 @@ class MS_API PassRegistry { /// /// \param[in] pass_name Define the name of the pass, a string which should guarantee uniqueness. /// \param[in] pass Define pass instance. - PassRegistry(const std::string &pass_name, const PassBasePtr &pass); + inline PassRegistry(const std::string &pass_name, const PassBasePtr &pass); /// \brief Constructor of PassRegistry to assign which passes are required for external extension. /// /// \param[in] position Define the place where assigned passes will run. /// \param[in] names Define the names of the passes. - PassRegistry(PassPosition position, const std::vector &names); + inline PassRegistry(PassPosition position, const std::vector &names); /// \brief Destructor of PassRegistrar. ~PassRegistry() = default; @@ -52,16 +53,35 @@ class MS_API PassRegistry { /// \param[in] position Define the place where assigned passes will run. /// /// \return Passes' Name Vector. - static std::vector GetOuterScheduleTask(PassPosition position); + inline static std::vector GetOuterScheduleTask(PassPosition position); /// \brief Static method to obtain pass instance according to passes' name. /// /// \param[in] pass_names Define the name of pass. /// /// \return Pass Instance Vector. - static PassBasePtr GetPassFromStoreRoom(const std::string &pass_name); + inline static PassBasePtr GetPassFromStoreRoom(const std::string &pass_name); + + private: + PassRegistry(const std::vector &pass_name, const PassBasePtr &pass); + PassRegistry(PassPosition position, const std::vector> &names); + static std::vector> GetOuterScheduleTaskInner(PassPosition position); + static PassBasePtr GetPassFromStoreRoom(const std::vector &pass_name_char); }; +PassRegistry::PassRegistry(const std::string &pass_name, const PassBasePtr &pass) + : PassRegistry(StringToChar(pass_name), pass) {} + +PassRegistry::PassRegistry(PassPosition position, const std::vector &names) + : PassRegistry(position, VectorStringToChar(names)) {} + +std::vector PassRegistry::GetOuterScheduleTask(PassPosition position) { + return VectorCharToString(GetOuterScheduleTaskInner(position)); +} + +PassBasePtr PassRegistry::GetPassFromStoreRoom(const std::string &pass_name) { + return GetPassFromStoreRoom(StringToChar(pass_name)); +} /// \brief Defined registering macro to register Pass, which called by user directly. /// /// \param[in] name Define the name of the pass, a string which should guarantee uniqueness. diff --git a/mindspore/lite/include/registry/register_kernel.h b/mindspore/lite/include/registry/register_kernel.h index 753d0381590..ef2fec58def 100644 --- a/mindspore/lite/include/registry/register_kernel.h +++ b/mindspore/lite/include/registry/register_kernel.h @@ -38,6 +38,14 @@ struct KernelDesc { std::string provider; /**< user identification argument */ }; +/// \brief KernelDesc defined kernel's basic attribute. +struct KernelDescHelper { + DataType data_type; /**< kernel data type argument */ + int type; /**< op type argument */ + std::vector arch; /**< deviceType argument */ + std::vector provider; /**< user identification argument */ +}; + /// \brief CreateKernel Defined a functor to create a kernel. /// /// \param[in] inputs Define input tensors of kernel. @@ -62,8 +70,8 @@ class MS_API RegisterKernel { /// \param[in] creator Define a function pointer to create a kernel. /// /// \return Status as a status identification of registering. - static Status RegKernel(const std::string &arch, const std::string &provider, DataType data_type, int type, - CreateKernel creator); + inline static Status RegKernel(const std::string &arch, const std::string &provider, DataType data_type, int type, + CreateKernel creator); /// \brief Static method to register kernel which is corresponding to custom op. /// @@ -74,8 +82,8 @@ class MS_API RegisterKernel { /// \param[in] creator Define a function pointer to create a kernel. /// /// \return Status as a status identification of registering. - static Status RegCustomKernel(const std::string &arch, const std::string &provider, DataType data_type, - const std::string &type, CreateKernel creator); + inline static Status RegCustomKernel(const std::string &arch, const std::string &provider, DataType data_type, + const std::string &type, CreateKernel creator); /// \brief Static methon to get a kernel's create function. /// @@ -83,7 +91,14 @@ class MS_API RegisterKernel { /// \param[in] primitive Define the primitive of kernel generated by flatbuffers. /// /// \return Function pointer to create a kernel. - static CreateKernel GetCreator(const schema::Primitive *primitive, KernelDesc *desc); + inline static CreateKernel GetCreator(const schema::Primitive *primitive, KernelDesc *desc); + + private: + static Status RegKernel(const std::vector &arch, const std::vector &provider, DataType data_type, + int type, CreateKernel creator); + static Status RegCustomKernel(const std::vector &arch, const std::vector &provider, DataType data_type, + const std::vector &type, CreateKernel creator); + static CreateKernel GetCreator(const schema::Primitive *primitive, KernelDescHelper *desc); }; /// \brief KernelReg Defined registration class of kernel. @@ -117,6 +132,24 @@ class MS_API KernelReg { } }; +Status RegisterKernel::RegKernel(const std::string &arch, const std::string &provider, DataType data_type, int type, + CreateKernel creator) { + return RegKernel(StringToChar(arch), StringToChar(provider), data_type, type, creator); +} + +Status RegisterKernel::RegCustomKernel(const std::string &arch, const std::string &provider, DataType data_type, + const std::string &type, CreateKernel creator) { + return RegCustomKernel(StringToChar(arch), StringToChar(provider), data_type, StringToChar(type), creator); +} + +CreateKernel RegisterKernel::GetCreator(const schema::Primitive *primitive, KernelDesc *desc) { + if (desc == nullptr) { + return nullptr; + } + KernelDescHelper kernel_desc = {desc->data_type, desc->type, StringToChar(desc->arch), StringToChar(desc->provider)}; + return GetCreator(primitive, &kernel_desc); +} + /// \brief Defined registering macro to register ordinary op kernel, which called by user directly. /// /// \param[in] arch Define deviceType, such as CPU. diff --git a/mindspore/lite/include/registry/register_kernel_interface.h b/mindspore/lite/include/registry/register_kernel_interface.h index 4830a9c21e8..d8ddd260225 100644 --- a/mindspore/lite/include/registry/register_kernel_interface.h +++ b/mindspore/lite/include/registry/register_kernel_interface.h @@ -39,7 +39,8 @@ class MS_API RegisterKernelInterface { /// \param[in] creator Define the KernelInterface create function. /// /// \return Status as a status identification of registering. - static Status CustomReg(const std::string &provider, const std::string &op_type, KernelInterfaceCreator creator); + inline static Status CustomReg(const std::string &provider, const std::string &op_type, + KernelInterfaceCreator creator); /// \brief Static method to register op whose primitive type is ordinary. /// @@ -48,7 +49,7 @@ class MS_API RegisterKernelInterface { /// \param[in] creator Define the KernelInterface create function. /// /// \return Status as a status identification of registering. - static Status Reg(const std::string &provider, int op_type, KernelInterfaceCreator creator); + inline static Status Reg(const std::string &provider, int op_type, KernelInterfaceCreator creator); /// \brief Static method to get registration of a certain op. /// @@ -56,7 +57,14 @@ class MS_API RegisterKernelInterface { /// \param[in] primitive Define the attributes of a certain op. /// /// \return Boolean value to represent registration of a certain op is existing or not. - static std::shared_ptr GetKernelInterface(const std::string &provider, + inline static std::shared_ptr GetKernelInterface(const std::string &provider, + const schema::Primitive *primitive); + + private: + static Status CustomReg(const std::vector &provider, const std::vector &op_type, + KernelInterfaceCreator creator); + static Status Reg(const std::vector &provider, int op_type, KernelInterfaceCreator creator); + static std::shared_ptr GetKernelInterface(const std::vector &provider, const schema::Primitive *primitive); }; @@ -82,6 +90,20 @@ class MS_API KernelInterfaceReg { } }; +Status RegisterKernelInterface::CustomReg(const std::string &provider, const std::string &op_type, + KernelInterfaceCreator creator) { + return CustomReg(StringToChar(provider), StringToChar(op_type), creator); +} + +Status RegisterKernelInterface::Reg(const std::string &provider, int op_type, KernelInterfaceCreator creator) { + return Reg(StringToChar(provider), op_type, creator); +} + +std::shared_ptr RegisterKernelInterface::GetKernelInterface( + const std::string &provider, const schema::Primitive *primitive) { + return GetKernelInterface(StringToChar(provider), primitive); +} + /// \brief Defined registering macro to register ordinary op, which called by user directly. /// /// \param[in] provider Define the identification of user. diff --git a/mindspore/lite/src/cxx_api/callback/ckpt_saver.cc b/mindspore/lite/src/cxx_api/callback/ckpt_saver.cc index 8b9743a25a4..e5c5885788a 100644 --- a/mindspore/lite/src/cxx_api/callback/ckpt_saver.cc +++ b/mindspore/lite/src/cxx_api/callback/ckpt_saver.cc @@ -23,8 +23,9 @@ #include "src/common/log_adapter.h" namespace mindspore { -CkptSaver::CkptSaver(int save_every_n, const std::string &filename_prefix) { - callback_impl_ = new (std::nothrow) CallbackImpl(new (std::nothrow) lite::CkptSaver(save_every_n, filename_prefix)); +CkptSaver::CkptSaver(int save_every_n, const std::vector &filename_prefix) { + callback_impl_ = + new (std::nothrow) CallbackImpl(new (std::nothrow) lite::CkptSaver(save_every_n, CharToString(filename_prefix))); if (callback_impl_ == nullptr) { MS_LOG(ERROR) << "Callback implement new failed"; } diff --git a/mindspore/lite/src/cxx_api/model/model.cc b/mindspore/lite/src/cxx_api/model/model.cc index 6dbd9e99860..c8bcdb7bbf0 100644 --- a/mindspore/lite/src/cxx_api/model/model.cc +++ b/mindspore/lite/src/cxx_api/model/model.cc @@ -29,7 +29,8 @@ namespace mindspore { std::mutex g_impl_init_lock; Status Model::Build(const void *model_data, size_t data_size, ModelType model_type, - const std::shared_ptr &model_context, const Key &dec_key, const std::string &dec_mode) { + const std::shared_ptr &model_context, const Key &dec_key, + const std::vector &dec_mode) { if (impl_ == nullptr) { std::unique_lock impl_lock(g_impl_init_lock); impl_ = std::shared_ptr(new (std::nothrow) ModelImpl()); @@ -46,8 +47,9 @@ Status Model::Build(const void *model_data, size_t data_size, ModelType model_ty return kSuccess; } -Status Model::Build(const std::string &model_path, ModelType model_type, const std::shared_ptr &model_context, - const Key &dec_key, const std::string &dec_mode) { +Status Model::Build(const std::vector &model_path, ModelType model_type, + const std::shared_ptr &model_context, const Key &dec_key, + const std::vector &dec_mode) { if (impl_ == nullptr) { std::unique_lock impl_lock(g_impl_init_lock); impl_ = std::shared_ptr(new (std::nothrow) ModelImpl()); @@ -57,7 +59,7 @@ Status Model::Build(const std::string &model_path, ModelType model_type, const s } } - Status ret = impl_->Build(model_path, model_type, model_context); + Status ret = impl_->Build(CharToString(model_path), model_type, model_context); if (ret != kSuccess) { return ret; } @@ -186,7 +188,7 @@ std::vector Model::GetOutputsByNodeName(const std::vector &node_ return impl_->GetOutputsByNodeName(CharToString(node_name)); } -Status Model::LoadConfig(const std::string &config_path) { +Status Model::LoadConfig(const std::vector &config_path) { std::unique_lock impl_lock(g_impl_init_lock); if (impl_ != nullptr) { MS_LOG(ERROR) << "impl_ illegal in LoadConfig."; @@ -199,7 +201,7 @@ Status Model::LoadConfig(const std::string &config_path) { return Status(kLiteFileError, "Fail to load config file."); } - auto ret = impl_->LoadConfig(config_path); + auto ret = impl_->LoadConfig(CharToString(config_path)); if (ret != kSuccess) { MS_LOG(ERROR) << "impl_ LoadConfig failed,"; return Status(kLiteFileError, "Invalid config file."); diff --git a/mindspore/lite/src/cxx_api/serialization.cc b/mindspore/lite/src/cxx_api/serialization.cc index cc90a464842..9f58356a60f 100644 --- a/mindspore/lite/src/cxx_api/serialization.cc +++ b/mindspore/lite/src/cxx_api/serialization.cc @@ -122,9 +122,9 @@ Status Serialization::ExportModel(const Model &model, ModelType model_type, Buff return kMEFailed; } -Status Serialization::ExportModel(const Model &model, ModelType model_type, const std::string &model_file, +Status Serialization::ExportModel(const Model &model, ModelType model_type, const std::vector &model_file, QuantizationType quantization_type, bool export_inference_only, - std::vector output_tensor_name) { + const std::vector> &output_tensor_name) { if (model.impl_ == nullptr) { MS_LOG(ERROR) << "Model implement is null."; return kLiteUninitializedObj; @@ -141,8 +141,9 @@ Status Serialization::ExportModel(const Model &model, ModelType model_type, cons MS_LOG(ERROR) << "Model session is nullptr."; return kLiteError; } - auto ret = model.impl_->session_->Export(model_file, export_inference_only ? lite::MT_INFERENCE : lite::MT_TRAIN, - A2L_ConvertQT(quantization_type), lite::FT_FLATBUFFERS, output_tensor_name); + auto ret = model.impl_->session_->Export( + CharToString(model_file), export_inference_only ? lite::MT_INFERENCE : lite::MT_TRAIN, + A2L_ConvertQT(quantization_type), lite::FT_FLATBUFFERS, VectorCharToString(output_tensor_name)); return (ret == mindspore::lite::RET_OK) ? kSuccess : kLiteError; } diff --git a/mindspore/lite/src/cxx_api/types.cc b/mindspore/lite/src/cxx_api/types.cc index 4344700de81..e6d678b7959 100644 --- a/mindspore/lite/src/cxx_api/types.cc +++ b/mindspore/lite/src/cxx_api/types.cc @@ -302,12 +302,12 @@ void MSTensor::SetDataType(enum DataType data_type) { impl_->SetDataType(data_type); } -void MSTensor::SetTensorName(const std::string &name) { +void MSTensor::SetTensorName(const std::vector &name) { if (impl_ == nullptr) { MS_LOG(ERROR) << "Invalid tensor implement."; return; } - impl_->SetName(name); + impl_->SetName(CharToString(name)); } void MSTensor::SetAllocator(std::shared_ptr allocator) { diff --git a/mindspore/lite/src/registry/register_kernel.cc b/mindspore/lite/src/registry/register_kernel.cc index a7d0528a69b..c7f8f28b4d7 100644 --- a/mindspore/lite/src/registry/register_kernel.cc +++ b/mindspore/lite/src/registry/register_kernel.cc @@ -22,29 +22,35 @@ namespace mindspore { namespace registry { -Status RegisterKernel::RegCustomKernel(const std::string &arch, const std::string &provider, DataType data_type, - const std::string &type, CreateKernel creator) { +Status RegisterKernel::RegCustomKernel(const std::vector &arch, const std::vector &provider, + DataType data_type, const std::vector &type, CreateKernel creator) { #ifndef CUSTOM_KERNEL_REGISTRY_CLIP - return RegistryKernelImpl::GetInstance()->RegCustomKernel(arch, provider, data_type, type, creator); + return RegistryKernelImpl::GetInstance()->RegCustomKernel(CharToString(arch), CharToString(provider), data_type, + CharToString(type), creator); #else MS_LOG(ERROR) << unsupport_custom_kernel_register_log; return kLiteNotSupport; #endif } -Status RegisterKernel::RegKernel(const std::string &arch, const std::string &provider, DataType data_type, int op_type, - CreateKernel creator) { +Status RegisterKernel::RegKernel(const std::vector &arch, const std::vector &provider, DataType data_type, + int op_type, CreateKernel creator) { #ifndef CUSTOM_KERNEL_REGISTRY_CLIP - return RegistryKernelImpl::GetInstance()->RegKernel(arch, provider, data_type, op_type, creator); + return RegistryKernelImpl::GetInstance()->RegKernel(CharToString(arch), CharToString(provider), data_type, op_type, + creator); #else MS_LOG(ERROR) << unsupport_custom_kernel_register_log; return kLiteNotSupport; #endif } -CreateKernel RegisterKernel::GetCreator(const schema::Primitive *primitive, KernelDesc *desc) { +CreateKernel RegisterKernel::GetCreator(const schema::Primitive *primitive, KernelDescHelper *desc) { #ifndef CUSTOM_KERNEL_REGISTRY_CLIP - return RegistryKernelImpl::GetInstance()->GetProviderCreator(primitive, desc); + if (desc == nullptr) { + return nullptr; + } + KernelDesc kernel_desc = {desc->data_type, desc->type, CharToString(desc->arch), CharToString(desc->provider)}; + return RegistryKernelImpl::GetInstance()->GetProviderCreator(primitive, &kernel_desc); #else MS_LOG(ERROR) << unsupport_custom_kernel_register_log; return nullptr; diff --git a/mindspore/lite/src/registry/register_kernel_interface.cc b/mindspore/lite/src/registry/register_kernel_interface.cc index ef04adc6480..d24fb5b391a 100644 --- a/mindspore/lite/src/registry/register_kernel_interface.cc +++ b/mindspore/lite/src/registry/register_kernel_interface.cc @@ -22,19 +22,19 @@ namespace mindspore { namespace registry { -Status RegisterKernelInterface::Reg(const std::string &provider, int op_type, KernelInterfaceCreator creator) { +Status RegisterKernelInterface::Reg(const std::vector &provider, int op_type, KernelInterfaceCreator creator) { #ifndef CUSTOM_KERNEL_REGISTRY_CLIP - return KernelInterfaceRegistry::Instance()->Reg(provider, op_type, creator); + return KernelInterfaceRegistry::Instance()->Reg(CharToString(provider), op_type, creator); #else MS_LOG(ERROR) << unsupport_custom_kernel_register_log; return kLiteNotSupport; #endif } -Status RegisterKernelInterface::CustomReg(const std::string &provider, const std::string &op_type, +Status RegisterKernelInterface::CustomReg(const std::vector &provider, const std::vector &op_type, KernelInterfaceCreator creator) { #ifndef CUSTOM_KERNEL_REGISTRY_CLIP - return KernelInterfaceRegistry::Instance()->CustomReg(provider, op_type, creator); + return KernelInterfaceRegistry::Instance()->CustomReg(CharToString(provider), CharToString(op_type), creator); #else MS_LOG(ERROR) << unsupport_custom_kernel_register_log; return kLiteNotSupport; @@ -42,9 +42,9 @@ Status RegisterKernelInterface::CustomReg(const std::string &provider, const std } std::shared_ptr RegisterKernelInterface::GetKernelInterface( - const std::string &provider, const schema::Primitive *primitive) { + const std::vector &provider, const schema::Primitive *primitive) { #ifndef CUSTOM_KERNEL_REGISTRY_CLIP - return KernelInterfaceRegistry::Instance()->GetKernelInterface(provider, primitive); + return KernelInterfaceRegistry::Instance()->GetKernelInterface(CharToString(provider), primitive); #else MS_LOG(ERROR) << unsupport_custom_kernel_register_log; return nullptr; diff --git a/mindspore/lite/src/runtime/gpu/opencl/opencl_runtime_wrapper.cc b/mindspore/lite/src/runtime/gpu/opencl/opencl_runtime_wrapper.cc index e9e370826bb..428fb870305 100644 --- a/mindspore/lite/src/runtime/gpu/opencl/opencl_runtime_wrapper.cc +++ b/mindspore/lite/src/runtime/gpu/opencl/opencl_runtime_wrapper.cc @@ -31,24 +31,25 @@ using mindspore::kernel::CLErrorCode; namespace mindspore::registry::opencl { -Status OpenCLRuntimeWrapper::LoadSource(const std::string &program_name, const std::string &source) { +Status OpenCLRuntimeWrapper::LoadSource(const std::vector &program_name, const std::vector &source) { lite::opencl::OpenCLRuntimeInnerWrapper ocl_runtime_wrap; lite::opencl::OpenCLRuntime *ocl_runtime = ocl_runtime_wrap.GetInstance(); - const std::string program_name_ext = "provider_" + program_name; - if (ocl_runtime->LoadSource(program_name_ext, source)) { + const std::string program_name_ext = "provider_" + CharToString(program_name); + if (ocl_runtime->LoadSource(program_name_ext, CharToString(source))) { return kSuccess; } else { return kLiteError; } } -Status OpenCLRuntimeWrapper::BuildKernel(cl::Kernel *kernel, const std::string &program_name, - const std::string &kernel_name, - const std::vector &build_options_ext) { +Status OpenCLRuntimeWrapper::BuildKernel(cl::Kernel *kernel, const std::vector &program_name, + const std::vector &kernel_name, + const std::vector> &build_options_ext) { lite::opencl::OpenCLRuntimeInnerWrapper ocl_runtime_wrap; lite::opencl::OpenCLRuntime *ocl_runtime = ocl_runtime_wrap.GetInstance(); - const std::string program_name_ext = "provider_" + program_name; - if (ocl_runtime->BuildKernel(*kernel, program_name_ext, kernel_name, build_options_ext, false) == RET_OK) { + const std::string program_name_ext = "provider_" + CharToString(program_name); + if (ocl_runtime->BuildKernel(*kernel, program_name_ext, CharToString(kernel_name), + VectorCharToString(build_options_ext), false) == RET_OK) { return kSuccess; } else { return kLiteError; diff --git a/mindspore/lite/test/ut/src/registry/registry_gpu_custom_op_test.cc b/mindspore/lite/test/ut/src/registry/registry_gpu_custom_op_test.cc index 7205fab789b..b248b7dd24b 100644 --- a/mindspore/lite/test/ut/src/registry/registry_gpu_custom_op_test.cc +++ b/mindspore/lite/test/ut/src/registry/registry_gpu_custom_op_test.cc @@ -285,7 +285,7 @@ class CustomAddKernel : public kernel::Kernel { return lite::RET_OK; } auto status = - registry::RegisterKernelInterface::GetKernelInterface({}, primitive_)->Infer(&inputs_, &outputs_, primitive_); + registry::RegisterKernelInterface::GetKernelInterface("", primitive_)->Infer(&inputs_, &outputs_, primitive_); if (status != kSuccess) { std::cerr << "infer failed." << std::endl; return lite::RET_ERROR; diff --git a/mindspore/lite/tools/converter/registry/node_parser_registry.cc b/mindspore/lite/tools/converter/registry/node_parser_registry.cc index b7d1db2f946..fb7b1077866 100644 --- a/mindspore/lite/tools/converter/registry/node_parser_registry.cc +++ b/mindspore/lite/tools/converter/registry/node_parser_registry.cc @@ -25,13 +25,15 @@ namespace { std::map> node_parser_room; std::mutex node_mutex; } // namespace -NodeParserRegistry::NodeParserRegistry(converter::FmkType fmk_type, const std::string &node_type, +NodeParserRegistry::NodeParserRegistry(converter::FmkType fmk_type, const std::vector &node_type, const converter::NodeParserPtr &node_parser) { std::unique_lock lock(node_mutex); - node_parser_room[fmk_type][node_type] = node_parser; + std::string node_type_str = CharToString(node_type); + node_parser_room[fmk_type][node_type_str] = node_parser; } -converter::NodeParserPtr NodeParserRegistry::GetNodeParser(converter::FmkType fmk_type, const std::string &node_type) { +converter::NodeParserPtr NodeParserRegistry::GetNodeParser(converter::FmkType fmk_type, + const std::vector &node_type) { auto iter_level1 = node_parser_room.find(fmk_type); if (iter_level1 == node_parser_room.end()) { return nullptr; @@ -39,7 +41,7 @@ converter::NodeParserPtr NodeParserRegistry::GetNodeParser(converter::FmkType fm if (node_type.empty()) { return nullptr; } - auto iter_level2 = iter_level1->second.find(node_type); + auto iter_level2 = iter_level1->second.find(CharToString(node_type)); if (iter_level2 == iter_level1->second.end()) { return nullptr; } diff --git a/mindspore/lite/tools/converter/registry/pass_registry.cc b/mindspore/lite/tools/converter/registry/pass_registry.cc index 6b2e90c79a5..1c0018b57f5 100644 --- a/mindspore/lite/tools/converter/registry/pass_registry.cc +++ b/mindspore/lite/tools/converter/registry/pass_registry.cc @@ -37,18 +37,21 @@ void RegPass(const std::string &pass_name, const PassBasePtr &pass) { } } // namespace -PassRegistry::PassRegistry(const std::string &pass_name, const PassBasePtr &pass) { RegPass(pass_name, pass); } +PassRegistry::PassRegistry(const std::vector &pass_name, const PassBasePtr &pass) { + RegPass(CharToString(pass_name), pass); +} -PassRegistry::PassRegistry(PassPosition position, const std::vector &names) { +PassRegistry::PassRegistry(PassPosition position, const std::vector> &names) { std::unique_lock lock(pass_mutex); - external_assigned_passes[position] = names; + external_assigned_passes[position] = VectorCharToString(names); } -std::vector PassRegistry::GetOuterScheduleTask(PassPosition position) { - return external_assigned_passes[position]; +std::vector> PassRegistry::GetOuterScheduleTaskInner(PassPosition position) { + return VectorStringToChar(external_assigned_passes[position]); } -PassBasePtr PassRegistry::GetPassFromStoreRoom(const std::string &pass_name) { +PassBasePtr PassRegistry::GetPassFromStoreRoom(const std::vector &pass_name_char) { + std::string pass_name = CharToString(pass_name_char); return outer_pass_storage.find(pass_name) == outer_pass_storage.end() ? nullptr : outer_pass_storage[pass_name]; } } // namespace registry