external interface modify

This commit is contained in:
zhengyuanhua 2021-09-18 10:46:46 +08:00
parent 41788a88d2
commit 5440f9fbc1
22 changed files with 251 additions and 83 deletions

View File

@ -21,13 +21,21 @@
#include <vector> #include <vector>
#include <memory> #include <memory>
#include "include/api/callback/callback.h" #include "include/api/callback/callback.h"
#include "include/api/dual_abi_helper.h"
namespace mindspore { namespace mindspore {
class CkptSaver: public TrainCallBack { class CkptSaver: public TrainCallBack {
public: public:
explicit CkptSaver(int save_every_n, const std::string &filename_prefix); inline CkptSaver(int save_every_n, const std::string &filename_prefix);
virtual ~CkptSaver(); virtual ~CkptSaver();
private:
CkptSaver(int save_every_n, const std::vector<char> &filename_prefix);
}; };
CkptSaver::CkptSaver(int save_every_n, const std::string &filename_prefix)
: CkptSaver(save_every_n, StringToChar(filename_prefix)) {}
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_CALLBACK_CKPT_SAVER_H #endif // MINDSPORE_INCLUDE_API_CALLBACK_CKPT_SAVER_H

View File

@ -104,7 +104,7 @@ class MS_API Model {
/// \param[in] config_path config file path. /// \param[in] config_path config file path.
/// ///
/// \return Status. /// \return Status.
Status LoadConfig(const std::string &config_path); inline Status LoadConfig(const std::string &config_path);
/// \brief Obtains all input tensors of the model. /// \brief Obtains all input tensors of the model.
/// ///
@ -189,9 +189,9 @@ class MS_API Model {
/// \param[in] dec_mode Define the decryption mode. Options: AES-GCM, AES-CBC. /// \param[in] dec_mode Define the decryption mode. Options: AES-GCM, AES-CBC.
/// ///
/// \return Status. /// \return Status.
Status Build(const void *model_data, size_t data_size, ModelType model_type, inline Status Build(const void *model_data, size_t data_size, ModelType model_type,
const std::shared_ptr<Context> &model_context = nullptr, const Key &dec_key = {}, const std::shared_ptr<Context> &model_context = nullptr, const Key &dec_key = {},
const std::string &dec_mode = kDecModeAesGcm); const std::string &dec_mode = kDecModeAesGcm);
/// \brief Load and build a model from model buffer so that it can run on a device. Only valid for Lite. /// \brief Load and build a model from model buffer so that it can run on a device. Only valid for Lite.
/// ///
@ -203,9 +203,9 @@ class MS_API Model {
/// \param[in] dec_mode Define the decryption mode. Options: AES-GCM, AES-CBC. /// \param[in] dec_mode Define the decryption mode. Options: AES-GCM, AES-CBC.
/// ///
/// \return Status. /// \return Status.
Status Build(const std::string &model_path, ModelType model_type, inline Status Build(const std::string &model_path, ModelType model_type,
const std::shared_ptr<Context> &model_context = nullptr, const Key &dec_key = {}, const std::shared_ptr<Context> &model_context = nullptr, const Key &dec_key = {},
const std::string &dec_mode = kDecModeAesGcm); const std::string &dec_mode = kDecModeAesGcm);
private: private:
friend class Serialization; friend class Serialization;
@ -214,6 +214,11 @@ class MS_API Model {
std::vector<std::vector<char>> GetOutputTensorNamesChar(); std::vector<std::vector<char>> GetOutputTensorNamesChar();
MSTensor GetOutputByTensorName(const std::vector<char> &tensor_name); MSTensor GetOutputByTensorName(const std::vector<char> &tensor_name);
std::vector<MSTensor> GetOutputsByNodeName(const std::vector<char> &node_name); std::vector<MSTensor> GetOutputsByNodeName(const std::vector<char> &node_name);
Status LoadConfig(const std::vector<char> &config_path);
Status Build(const void *model_data, size_t data_size, ModelType model_type,
const std::shared_ptr<Context> &model_context, const Key &dec_key, const std::vector<char> &dec_mode);
Status Build(const std::vector<char> &model_path, ModelType model_type, const std::shared_ptr<Context> &model_context,
const Key &dec_key, const std::vector<char> &dec_mode);
std::shared_ptr<ModelImpl> impl_; std::shared_ptr<ModelImpl> impl_;
}; };
@ -231,5 +236,19 @@ MSTensor Model::GetOutputByTensorName(const std::string &tensor_name) {
std::vector<MSTensor> Model::GetOutputsByNodeName(const std::string &node_name) { std::vector<MSTensor> Model::GetOutputsByNodeName(const std::string &node_name) {
return GetOutputsByNodeName(StringToChar(node_name)); return GetOutputsByNodeName(StringToChar(node_name));
} }
Status Model::LoadConfig(const std::string &config_path) {
return LoadConfig(StringToChar(config_path));
}
Status Model::Build(const void *model_data, size_t data_size, ModelType model_type,
const std::shared_ptr<Context> &model_context, const Key &dec_key, const std::string &dec_mode) {
return Build(model_data, data_size, model_type, model_context, dec_key, StringToChar(dec_mode));
}
Status Model::Build(const std::string &model_path, ModelType model_type, const std::shared_ptr<Context> &model_context,
const Key &dec_key, const std::string &dec_mode) {
return Build(StringToChar(model_path), model_type, model_context, dec_key, StringToChar(dec_mode));
}
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_MODEL_H #endif // MINDSPORE_INCLUDE_API_MODEL_H

View File

@ -68,9 +68,9 @@ class MS_API Serialization {
const Key &dec_key = {}, const std::string &dec_mode = kDecModeAesGcm); const Key &dec_key = {}, const std::string &dec_mode = kDecModeAesGcm);
static Status SetParameters(const std::map<std::string, Buffer> &parameters, Model *model); static Status SetParameters(const std::map<std::string, Buffer> &parameters, Model *model);
static Status ExportModel(const Model &model, ModelType model_type, Buffer *model_data); static Status ExportModel(const Model &model, ModelType model_type, Buffer *model_data);
static Status ExportModel(const Model &model, ModelType model_type, const std::string &model_file, inline static Status ExportModel(const Model &model, ModelType model_type, const std::string &model_file,
QuantizationType quantization_type = kNoQuant, bool export_inference_only = true, QuantizationType quantization_type = kNoQuant, bool export_inference_only = true,
std::vector<std::string> output_tensor_name = {}); std::vector<std::string> output_tensor_name = {});
private: private:
static Status Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph, const Key &dec_key, static Status Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph, const Key &dec_key,
@ -80,6 +80,9 @@ class MS_API Serialization {
const std::vector<char> &dec_mode); const std::vector<char> &dec_mode);
static Status Load(const std::vector<std::vector<char>> &files, ModelType model_type, std::vector<Graph> *graphs, static Status Load(const std::vector<std::vector<char>> &files, ModelType model_type, std::vector<Graph> *graphs,
const Key &dec_key, const std::vector<char> &dec_mode); const Key &dec_key, const std::vector<char> &dec_mode);
static Status ExportModel(const Model &model, ModelType model_type, const std::vector<char> &model_file,
QuantizationType quantization_type, bool export_inference_only,
const std::vector<std::vector<char>> &output_tensor_name);
}; };
Status Serialization::Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph, Status Serialization::Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph,
@ -96,5 +99,13 @@ Status Serialization::Load(const std::vector<std::string> &files, ModelType mode
const Key &dec_key, const std::string &dec_mode) { const Key &dec_key, const std::string &dec_mode) {
return Load(VectorStringToChar(files), model_type, graphs, dec_key, StringToChar(dec_mode)); return Load(VectorStringToChar(files), model_type, graphs, dec_key, StringToChar(dec_mode));
} }
Status Serialization::ExportModel(const Model &model, ModelType model_type, const std::string &model_file,
QuantizationType quantization_type, bool export_inference_only,
std::vector<std::string> output_tensor_name) {
return ExportModel(model, model_type, StringToChar(model_file), quantization_type, export_inference_only,
VectorStringToChar(output_tensor_name));
}
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_SERIALIZATION_H #endif // MINDSPORE_INCLUDE_API_SERIALIZATION_H

View File

@ -221,7 +221,7 @@ class MS_API MSTensor {
/// \brief Set the name for the MSTensor. Only valid for Lite. /// \brief Set the name for the MSTensor. Only valid for Lite.
/// ///
/// \param[in] The name of the MSTensor. /// \param[in] The name of the MSTensor.
void SetTensorName(const std::string &name); inline void SetTensorName(const std::string &name);
/// \brief Set the Allocator for the MSTensor. Only valid for Lite. /// \brief Set the Allocator for the MSTensor. Only valid for Lite.
/// ///
@ -275,6 +275,7 @@ class MS_API MSTensor {
MSTensor(const std::vector<char> &name, enum DataType type, const std::vector<int64_t> &shape, const void *data, MSTensor(const std::vector<char> &name, enum DataType type, const std::vector<int64_t> &shape, const void *data,
size_t data_len); size_t data_len);
std::vector<char> CharName() const; std::vector<char> CharName() const;
void SetTensorName(const std::vector<char> &name);
friend class ModelImpl; friend class ModelImpl;
std::shared_ptr<Impl> impl_; std::shared_ptr<Impl> impl_;
@ -333,6 +334,10 @@ MSTensor::MSTensor(const std::string &name, enum DataType type, const std::vecto
std::string MSTensor::Name() const { return CharToString(CharName()); } std::string MSTensor::Name() const { return CharToString(CharName()); }
void MSTensor::SetTensorName(const std::string &name) {
return SetTensorName(StringToChar(name));
}
using Key = struct Key { using Key = struct Key {
const size_t max_key_len = 32; const size_t max_key_len = 32;
size_t len; size_t len;

View File

@ -66,13 +66,13 @@ Status Model::Build(GraphCell graph_cell, const std::shared_ptr<Context> &model_
} }
Status Model::Build(const void *, size_t, ModelType, const std::shared_ptr<Context> &, const Key &, Status Model::Build(const void *, size_t, ModelType, const std::shared_ptr<Context> &, const Key &,
const std::string &) { const std::vector<char> &) {
MS_LOG(ERROR) << "Unsupported Feature."; MS_LOG(ERROR) << "Unsupported Feature.";
return kMCFailed; return kMCFailed;
} }
Status Model::Build(const std::string &, ModelType, const std::shared_ptr<Context> &, const Key &, Status Model::Build(const std::vector<char> &, ModelType, const std::shared_ptr<Context> &, const Key &,
const std::string &) { const std::vector<char> &) {
MS_LOG(ERROR) << "Unsupported Feature."; MS_LOG(ERROR) << "Unsupported Feature.";
return kMCFailed; return kMCFailed;
} }

View File

@ -340,8 +340,8 @@ Status Serialization::ExportModel(const Model &, ModelType, Buffer *) {
return kMEFailed; return kMEFailed;
} }
Status Serialization::ExportModel(const Model &, ModelType, const std::string &, QuantizationType, bool, Status Serialization::ExportModel(const Model &, ModelType, const std::vector<char> &, QuantizationType, bool,
std::vector<std::string> output_tensor_name) { const std::vector<std::vector<char>> &output_tensor_name) {
MS_LOG(ERROR) << "Unsupported feature."; MS_LOG(ERROR) << "Unsupported feature.";
return kMEFailed; return kMEFailed;
} }

View File

@ -429,7 +429,7 @@ void MSTensor::SetShape(const std::vector<int64_t> &) { MS_LOG_EXCEPTION << "Inv
void MSTensor::SetDataType(enum DataType) { MS_LOG_EXCEPTION << "Invalid implement."; } void MSTensor::SetDataType(enum DataType) { MS_LOG_EXCEPTION << "Invalid implement."; }
void MSTensor::SetTensorName(const std::string &) { MS_LOG_EXCEPTION << "Invalid implement."; } void MSTensor::SetTensorName(const std::vector<char> &) { MS_LOG_EXCEPTION << "Invalid implement."; }
void MSTensor::SetAllocator(std::shared_ptr<Allocator>) { MS_LOG_EXCEPTION << "Invalid implement."; } void MSTensor::SetAllocator(std::shared_ptr<Allocator>) { MS_LOG_EXCEPTION << "Invalid implement."; }

View File

@ -18,7 +18,9 @@
#define MINDSPORE_LITE_INCLUDE_REGISTRY_NODE_PARSER_REGISTRY_H_ #define MINDSPORE_LITE_INCLUDE_REGISTRY_NODE_PARSER_REGISTRY_H_
#include <string> #include <string>
#include <vector>
#include "include/registry/node_parser.h" #include "include/registry/node_parser.h"
#include "include/api/dual_abi_helper.h"
namespace mindspore { namespace mindspore {
namespace registry { namespace registry {
@ -30,8 +32,8 @@ class MS_API NodeParserRegistry {
/// \param[in] fmk_type Define the framework. /// \param[in] fmk_type Define the framework.
/// \param[in] node_type Define the type of the node to be resolved. /// \param[in] node_type Define the type of the node to be resolved.
/// \param[in] node_parser Define the NodeParser instance to parse the node. /// \param[in] node_parser Define the NodeParser instance to parse the node.
NodeParserRegistry(converter::FmkType fmk_type, const std::string &node_type, inline NodeParserRegistry(converter::FmkType fmk_type, const std::string &node_type,
const converter::NodeParserPtr &node_parser); const converter::NodeParserPtr &node_parser);
/// \brief Destructor /// \brief Destructor
~NodeParserRegistry() = default; ~NodeParserRegistry() = default;
@ -42,9 +44,22 @@ class MS_API NodeParserRegistry {
/// \param[in] node_type Define the type of the node to be resolved. /// \param[in] node_type Define the type of the node to be resolved.
/// ///
/// \return NodeParser instance. /// \return NodeParser instance.
static converter::NodeParserPtr GetNodeParser(converter::FmkType fmk_type, const std::string &node_type); inline static converter::NodeParserPtr GetNodeParser(converter::FmkType fmk_type, const std::string &node_type);
private:
NodeParserRegistry(converter::FmkType fmk_type, const std::vector<char> &node_type,
const converter::NodeParserPtr &node_parser);
static converter::NodeParserPtr GetNodeParser(converter::FmkType fmk_type, const std::vector<char> &node_type);
}; };
NodeParserRegistry::NodeParserRegistry(converter::FmkType fmk_type, const std::string &node_type,
const converter::NodeParserPtr &node_parser)
: NodeParserRegistry(fmk_type, StringToChar(node_type), node_parser) {}
converter::NodeParserPtr NodeParserRegistry::GetNodeParser(converter::FmkType fmk_type, const std::string &node_type) {
return GetNodeParser(fmk_type, StringToChar(node_type));
}
/// \brief Defined registering macro to register NodeParser instance. /// \brief Defined registering macro to register NodeParser instance.
/// ///
/// \param[in] fmk_type Define the framework. /// \param[in] fmk_type Define the framework.

View File

@ -16,6 +16,7 @@
#ifndef MINDSPORE_LITE_INCLUDE_REGISTRY_OPENCL_RUNTIME_WRAPPER_H #ifndef MINDSPORE_LITE_INCLUDE_REGISTRY_OPENCL_RUNTIME_WRAPPER_H
#define MINDSPORE_LITE_INCLUDE_REGISTRY_OPENCL_RUNTIME_WRAPPER_H #define MINDSPORE_LITE_INCLUDE_REGISTRY_OPENCL_RUNTIME_WRAPPER_H
#include <vector> #include <vector>
#include <map> #include <map>
#include <memory> #include <memory>
@ -26,6 +27,7 @@
#include "CL/cl2.hpp" #include "CL/cl2.hpp"
#include "include/api/allocator.h" #include "include/api/allocator.h"
#include "include/api/status.h" #include "include/api/status.h"
#include "include/api/dual_abi_helper.h"
namespace mindspore::registry::opencl { namespace mindspore::registry::opencl {
class OpenCLRuntimeWrapper { class OpenCLRuntimeWrapper {
@ -39,7 +41,7 @@ class OpenCLRuntimeWrapper {
/// \param[in] source Define OpenCl source. /// \param[in] source Define OpenCl source.
/// ///
/// \return Status as a status identification of loading code. /// \return Status as a status identification of loading code.
Status LoadSource(const std::string &program_name, const std::string &source); inline Status LoadSource(const std::string &program_name, const std::string &source);
/// \brief Building OpenCL code. /// \brief Building OpenCL code.
/// ///
@ -49,8 +51,8 @@ class OpenCLRuntimeWrapper {
/// \param[in] build_options_ext Define OpenCl kernel build options. /// \param[in] build_options_ext Define OpenCl kernel build options.
/// ///
/// \return Status as a status identification of build Kernel /// \return Status as a status identification of build Kernel
Status BuildKernel(cl::Kernel *kernel, const std::string &program_name, const std::string &kernel_name, inline Status BuildKernel(cl::Kernel *kernel, const std::string &program_name, const std::string &kernel_name,
const std::vector<std::string> &build_options_ext = {}); const std::vector<std::string> &build_options_ext = {});
/// \brief Set kernel argument /// \brief Set kernel argument
/// ///
@ -114,6 +116,23 @@ class OpenCLRuntimeWrapper {
uint64_t GetMaxImage2DHeight(); uint64_t GetMaxImage2DHeight();
uint64_t GetImagePitchAlignment(); uint64_t GetImagePitchAlignment();
private:
Status LoadSource(const std::vector<char> &program_name, const std::vector<char> &source);
Status BuildKernel(cl::Kernel *kernel, const std::vector<char> &program_name, const std::vector<char> &kernel_name,
const std::vector<std::vector<char>> &build_options_ext);
}; };
Status OpenCLRuntimeWrapper::LoadSource(const std::string &program_name, const std::string &source) {
return LoadSource(StringToChar(program_name), StringToChar(source));
}
Status OpenCLRuntimeWrapper::BuildKernel(cl::Kernel *kernel, const std::string &program_name,
const std::string &kernel_name,
const std::vector<std::string> &build_options_ext) {
return BuildKernel(kernel, StringToChar(program_name), StringToChar(kernel_name),
VectorStringToChar(build_options_ext));
}
} // namespace mindspore::registry::opencl } // namespace mindspore::registry::opencl
#endif // MINDSPORE_LITE_INCLUDE_REGISTRY_OPENCL_RUNTIME_WRAPPER_H #endif // MINDSPORE_LITE_INCLUDE_REGISTRY_OPENCL_RUNTIME_WRAPPER_H

View File

@ -21,6 +21,7 @@
#include <string> #include <string>
#include <memory> #include <memory>
#include "include/lite_utils.h" #include "include/lite_utils.h"
#include "include/api/dual_abi_helper.h"
namespace mindspore { namespace mindspore {
namespace registry { namespace registry {
@ -36,13 +37,13 @@ class MS_API PassRegistry {
/// ///
/// \param[in] pass_name Define the name of the pass, a string which should guarantee uniqueness. /// \param[in] pass_name Define the name of the pass, a string which should guarantee uniqueness.
/// \param[in] pass Define pass instance. /// \param[in] pass Define pass instance.
PassRegistry(const std::string &pass_name, const PassBasePtr &pass); inline PassRegistry(const std::string &pass_name, const PassBasePtr &pass);
/// \brief Constructor of PassRegistry to assign which passes are required for external extension. /// \brief Constructor of PassRegistry to assign which passes are required for external extension.
/// ///
/// \param[in] position Define the place where assigned passes will run. /// \param[in] position Define the place where assigned passes will run.
/// \param[in] names Define the names of the passes. /// \param[in] names Define the names of the passes.
PassRegistry(PassPosition position, const std::vector<std::string> &names); inline PassRegistry(PassPosition position, const std::vector<std::string> &names);
/// \brief Destructor of PassRegistrar. /// \brief Destructor of PassRegistrar.
~PassRegistry() = default; ~PassRegistry() = default;
@ -52,16 +53,35 @@ class MS_API PassRegistry {
/// \param[in] position Define the place where assigned passes will run. /// \param[in] position Define the place where assigned passes will run.
/// ///
/// \return Passes' Name Vector. /// \return Passes' Name Vector.
static std::vector<std::string> GetOuterScheduleTask(PassPosition position); inline static std::vector<std::string> GetOuterScheduleTask(PassPosition position);
/// \brief Static method to obtain pass instance according to passes' name. /// \brief Static method to obtain pass instance according to passes' name.
/// ///
/// \param[in] pass_names Define the name of pass. /// \param[in] pass_names Define the name of pass.
/// ///
/// \return Pass Instance Vector. /// \return Pass Instance Vector.
static PassBasePtr GetPassFromStoreRoom(const std::string &pass_name); inline static PassBasePtr GetPassFromStoreRoom(const std::string &pass_name);
private:
PassRegistry(const std::vector<char> &pass_name, const PassBasePtr &pass);
PassRegistry(PassPosition position, const std::vector<std::vector<char>> &names);
static std::vector<std::vector<char>> GetOuterScheduleTaskInner(PassPosition position);
static PassBasePtr GetPassFromStoreRoom(const std::vector<char> &pass_name_char);
}; };
PassRegistry::PassRegistry(const std::string &pass_name, const PassBasePtr &pass)
: PassRegistry(StringToChar(pass_name), pass) {}
PassRegistry::PassRegistry(PassPosition position, const std::vector<std::string> &names)
: PassRegistry(position, VectorStringToChar(names)) {}
std::vector<std::string> PassRegistry::GetOuterScheduleTask(PassPosition position) {
return VectorCharToString(GetOuterScheduleTaskInner(position));
}
PassBasePtr PassRegistry::GetPassFromStoreRoom(const std::string &pass_name) {
return GetPassFromStoreRoom(StringToChar(pass_name));
}
/// \brief Defined registering macro to register Pass, which called by user directly. /// \brief Defined registering macro to register Pass, which called by user directly.
/// ///
/// \param[in] name Define the name of the pass, a string which should guarantee uniqueness. /// \param[in] name Define the name of the pass, a string which should guarantee uniqueness.

View File

@ -38,6 +38,14 @@ struct KernelDesc {
std::string provider; /**< user identification argument */ std::string provider; /**< user identification argument */
}; };
/// \brief KernelDesc defined kernel's basic attribute.
struct KernelDescHelper {
DataType data_type; /**< kernel data type argument */
int type; /**< op type argument */
std::vector<char> arch; /**< deviceType argument */
std::vector<char> provider; /**< user identification argument */
};
/// \brief CreateKernel Defined a functor to create a kernel. /// \brief CreateKernel Defined a functor to create a kernel.
/// ///
/// \param[in] inputs Define input tensors of kernel. /// \param[in] inputs Define input tensors of kernel.
@ -62,8 +70,8 @@ class MS_API RegisterKernel {
/// \param[in] creator Define a function pointer to create a kernel. /// \param[in] creator Define a function pointer to create a kernel.
/// ///
/// \return Status as a status identification of registering. /// \return Status as a status identification of registering.
static Status RegKernel(const std::string &arch, const std::string &provider, DataType data_type, int type, inline static Status RegKernel(const std::string &arch, const std::string &provider, DataType data_type, int type,
CreateKernel creator); CreateKernel creator);
/// \brief Static method to register kernel which is corresponding to custom op. /// \brief Static method to register kernel which is corresponding to custom op.
/// ///
@ -74,8 +82,8 @@ class MS_API RegisterKernel {
/// \param[in] creator Define a function pointer to create a kernel. /// \param[in] creator Define a function pointer to create a kernel.
/// ///
/// \return Status as a status identification of registering. /// \return Status as a status identification of registering.
static Status RegCustomKernel(const std::string &arch, const std::string &provider, DataType data_type, inline static Status RegCustomKernel(const std::string &arch, const std::string &provider, DataType data_type,
const std::string &type, CreateKernel creator); const std::string &type, CreateKernel creator);
/// \brief Static methon to get a kernel's create function. /// \brief Static methon to get a kernel's create function.
/// ///
@ -83,7 +91,14 @@ class MS_API RegisterKernel {
/// \param[in] primitive Define the primitive of kernel generated by flatbuffers. /// \param[in] primitive Define the primitive of kernel generated by flatbuffers.
/// ///
/// \return Function pointer to create a kernel. /// \return Function pointer to create a kernel.
static CreateKernel GetCreator(const schema::Primitive *primitive, KernelDesc *desc); inline static CreateKernel GetCreator(const schema::Primitive *primitive, KernelDesc *desc);
private:
static Status RegKernel(const std::vector<char> &arch, const std::vector<char> &provider, DataType data_type,
int type, CreateKernel creator);
static Status RegCustomKernel(const std::vector<char> &arch, const std::vector<char> &provider, DataType data_type,
const std::vector<char> &type, CreateKernel creator);
static CreateKernel GetCreator(const schema::Primitive *primitive, KernelDescHelper *desc);
}; };
/// \brief KernelReg Defined registration class of kernel. /// \brief KernelReg Defined registration class of kernel.
@ -117,6 +132,24 @@ class MS_API KernelReg {
} }
}; };
Status RegisterKernel::RegKernel(const std::string &arch, const std::string &provider, DataType data_type, int type,
CreateKernel creator) {
return RegKernel(StringToChar(arch), StringToChar(provider), data_type, type, creator);
}
Status RegisterKernel::RegCustomKernel(const std::string &arch, const std::string &provider, DataType data_type,
const std::string &type, CreateKernel creator) {
return RegCustomKernel(StringToChar(arch), StringToChar(provider), data_type, StringToChar(type), creator);
}
CreateKernel RegisterKernel::GetCreator(const schema::Primitive *primitive, KernelDesc *desc) {
if (desc == nullptr) {
return nullptr;
}
KernelDescHelper kernel_desc = {desc->data_type, desc->type, StringToChar(desc->arch), StringToChar(desc->provider)};
return GetCreator(primitive, &kernel_desc);
}
/// \brief Defined registering macro to register ordinary op kernel, which called by user directly. /// \brief Defined registering macro to register ordinary op kernel, which called by user directly.
/// ///
/// \param[in] arch Define deviceType, such as CPU. /// \param[in] arch Define deviceType, such as CPU.

View File

@ -39,7 +39,8 @@ class MS_API RegisterKernelInterface {
/// \param[in] creator Define the KernelInterface create function. /// \param[in] creator Define the KernelInterface create function.
/// ///
/// \return Status as a status identification of registering. /// \return Status as a status identification of registering.
static Status CustomReg(const std::string &provider, const std::string &op_type, KernelInterfaceCreator creator); inline static Status CustomReg(const std::string &provider, const std::string &op_type,
KernelInterfaceCreator creator);
/// \brief Static method to register op whose primitive type is ordinary. /// \brief Static method to register op whose primitive type is ordinary.
/// ///
@ -48,7 +49,7 @@ class MS_API RegisterKernelInterface {
/// \param[in] creator Define the KernelInterface create function. /// \param[in] creator Define the KernelInterface create function.
/// ///
/// \return Status as a status identification of registering. /// \return Status as a status identification of registering.
static Status Reg(const std::string &provider, int op_type, KernelInterfaceCreator creator); inline static Status Reg(const std::string &provider, int op_type, KernelInterfaceCreator creator);
/// \brief Static method to get registration of a certain op. /// \brief Static method to get registration of a certain op.
/// ///
@ -56,7 +57,14 @@ class MS_API RegisterKernelInterface {
/// \param[in] primitive Define the attributes of a certain op. /// \param[in] primitive Define the attributes of a certain op.
/// ///
/// \return Boolean value to represent registration of a certain op is existing or not. /// \return Boolean value to represent registration of a certain op is existing or not.
static std::shared_ptr<kernel::KernelInterface> GetKernelInterface(const std::string &provider, inline static std::shared_ptr<kernel::KernelInterface> GetKernelInterface(const std::string &provider,
const schema::Primitive *primitive);
private:
static Status CustomReg(const std::vector<char> &provider, const std::vector<char> &op_type,
KernelInterfaceCreator creator);
static Status Reg(const std::vector<char> &provider, int op_type, KernelInterfaceCreator creator);
static std::shared_ptr<kernel::KernelInterface> GetKernelInterface(const std::vector<char> &provider,
const schema::Primitive *primitive); const schema::Primitive *primitive);
}; };
@ -82,6 +90,20 @@ class MS_API KernelInterfaceReg {
} }
}; };
Status RegisterKernelInterface::CustomReg(const std::string &provider, const std::string &op_type,
KernelInterfaceCreator creator) {
return CustomReg(StringToChar(provider), StringToChar(op_type), creator);
}
Status RegisterKernelInterface::Reg(const std::string &provider, int op_type, KernelInterfaceCreator creator) {
return Reg(StringToChar(provider), op_type, creator);
}
std::shared_ptr<kernel::KernelInterface> RegisterKernelInterface::GetKernelInterface(
const std::string &provider, const schema::Primitive *primitive) {
return GetKernelInterface(StringToChar(provider), primitive);
}
/// \brief Defined registering macro to register ordinary op, which called by user directly. /// \brief Defined registering macro to register ordinary op, which called by user directly.
/// ///
/// \param[in] provider Define the identification of user. /// \param[in] provider Define the identification of user.

View File

@ -23,8 +23,9 @@
#include "src/common/log_adapter.h" #include "src/common/log_adapter.h"
namespace mindspore { namespace mindspore {
CkptSaver::CkptSaver(int save_every_n, const std::string &filename_prefix) { CkptSaver::CkptSaver(int save_every_n, const std::vector<char> &filename_prefix) {
callback_impl_ = new (std::nothrow) CallbackImpl(new (std::nothrow) lite::CkptSaver(save_every_n, filename_prefix)); callback_impl_ =
new (std::nothrow) CallbackImpl(new (std::nothrow) lite::CkptSaver(save_every_n, CharToString(filename_prefix)));
if (callback_impl_ == nullptr) { if (callback_impl_ == nullptr) {
MS_LOG(ERROR) << "Callback implement new failed"; MS_LOG(ERROR) << "Callback implement new failed";
} }

View File

@ -29,7 +29,8 @@ namespace mindspore {
std::mutex g_impl_init_lock; std::mutex g_impl_init_lock;
Status Model::Build(const void *model_data, size_t data_size, ModelType model_type, Status Model::Build(const void *model_data, size_t data_size, ModelType model_type,
const std::shared_ptr<Context> &model_context, const Key &dec_key, const std::string &dec_mode) { const std::shared_ptr<Context> &model_context, const Key &dec_key,
const std::vector<char> &dec_mode) {
if (impl_ == nullptr) { if (impl_ == nullptr) {
std::unique_lock<std::mutex> impl_lock(g_impl_init_lock); std::unique_lock<std::mutex> impl_lock(g_impl_init_lock);
impl_ = std::shared_ptr<ModelImpl>(new (std::nothrow) ModelImpl()); impl_ = std::shared_ptr<ModelImpl>(new (std::nothrow) ModelImpl());
@ -46,8 +47,9 @@ Status Model::Build(const void *model_data, size_t data_size, ModelType model_ty
return kSuccess; return kSuccess;
} }
Status Model::Build(const std::string &model_path, ModelType model_type, const std::shared_ptr<Context> &model_context, Status Model::Build(const std::vector<char> &model_path, ModelType model_type,
const Key &dec_key, const std::string &dec_mode) { const std::shared_ptr<Context> &model_context, const Key &dec_key,
const std::vector<char> &dec_mode) {
if (impl_ == nullptr) { if (impl_ == nullptr) {
std::unique_lock<std::mutex> impl_lock(g_impl_init_lock); std::unique_lock<std::mutex> impl_lock(g_impl_init_lock);
impl_ = std::shared_ptr<ModelImpl>(new (std::nothrow) ModelImpl()); impl_ = std::shared_ptr<ModelImpl>(new (std::nothrow) ModelImpl());
@ -57,7 +59,7 @@ Status Model::Build(const std::string &model_path, ModelType model_type, const s
} }
} }
Status ret = impl_->Build(model_path, model_type, model_context); Status ret = impl_->Build(CharToString(model_path), model_type, model_context);
if (ret != kSuccess) { if (ret != kSuccess) {
return ret; return ret;
} }
@ -186,7 +188,7 @@ std::vector<MSTensor> Model::GetOutputsByNodeName(const std::vector<char> &node_
return impl_->GetOutputsByNodeName(CharToString(node_name)); return impl_->GetOutputsByNodeName(CharToString(node_name));
} }
Status Model::LoadConfig(const std::string &config_path) { Status Model::LoadConfig(const std::vector<char> &config_path) {
std::unique_lock<std::mutex> impl_lock(g_impl_init_lock); std::unique_lock<std::mutex> impl_lock(g_impl_init_lock);
if (impl_ != nullptr) { if (impl_ != nullptr) {
MS_LOG(ERROR) << "impl_ illegal in LoadConfig."; MS_LOG(ERROR) << "impl_ illegal in LoadConfig.";
@ -199,7 +201,7 @@ Status Model::LoadConfig(const std::string &config_path) {
return Status(kLiteFileError, "Fail to load config file."); return Status(kLiteFileError, "Fail to load config file.");
} }
auto ret = impl_->LoadConfig(config_path); auto ret = impl_->LoadConfig(CharToString(config_path));
if (ret != kSuccess) { if (ret != kSuccess) {
MS_LOG(ERROR) << "impl_ LoadConfig failed,"; MS_LOG(ERROR) << "impl_ LoadConfig failed,";
return Status(kLiteFileError, "Invalid config file."); return Status(kLiteFileError, "Invalid config file.");

View File

@ -122,9 +122,9 @@ Status Serialization::ExportModel(const Model &model, ModelType model_type, Buff
return kMEFailed; return kMEFailed;
} }
Status Serialization::ExportModel(const Model &model, ModelType model_type, const std::string &model_file, Status Serialization::ExportModel(const Model &model, ModelType model_type, const std::vector<char> &model_file,
QuantizationType quantization_type, bool export_inference_only, QuantizationType quantization_type, bool export_inference_only,
std::vector<std::string> output_tensor_name) { const std::vector<std::vector<char>> &output_tensor_name) {
if (model.impl_ == nullptr) { if (model.impl_ == nullptr) {
MS_LOG(ERROR) << "Model implement is null."; MS_LOG(ERROR) << "Model implement is null.";
return kLiteUninitializedObj; return kLiteUninitializedObj;
@ -141,8 +141,9 @@ Status Serialization::ExportModel(const Model &model, ModelType model_type, cons
MS_LOG(ERROR) << "Model session is nullptr."; MS_LOG(ERROR) << "Model session is nullptr.";
return kLiteError; return kLiteError;
} }
auto ret = model.impl_->session_->Export(model_file, export_inference_only ? lite::MT_INFERENCE : lite::MT_TRAIN, auto ret = model.impl_->session_->Export(
A2L_ConvertQT(quantization_type), lite::FT_FLATBUFFERS, output_tensor_name); CharToString(model_file), export_inference_only ? lite::MT_INFERENCE : lite::MT_TRAIN,
A2L_ConvertQT(quantization_type), lite::FT_FLATBUFFERS, VectorCharToString(output_tensor_name));
return (ret == mindspore::lite::RET_OK) ? kSuccess : kLiteError; return (ret == mindspore::lite::RET_OK) ? kSuccess : kLiteError;
} }

View File

@ -302,12 +302,12 @@ void MSTensor::SetDataType(enum DataType data_type) {
impl_->SetDataType(data_type); impl_->SetDataType(data_type);
} }
void MSTensor::SetTensorName(const std::string &name) { void MSTensor::SetTensorName(const std::vector<char> &name) {
if (impl_ == nullptr) { if (impl_ == nullptr) {
MS_LOG(ERROR) << "Invalid tensor implement."; MS_LOG(ERROR) << "Invalid tensor implement.";
return; return;
} }
impl_->SetName(name); impl_->SetName(CharToString(name));
} }
void MSTensor::SetAllocator(std::shared_ptr<Allocator> allocator) { void MSTensor::SetAllocator(std::shared_ptr<Allocator> allocator) {

View File

@ -22,29 +22,35 @@
namespace mindspore { namespace mindspore {
namespace registry { namespace registry {
Status RegisterKernel::RegCustomKernel(const std::string &arch, const std::string &provider, DataType data_type, Status RegisterKernel::RegCustomKernel(const std::vector<char> &arch, const std::vector<char> &provider,
const std::string &type, CreateKernel creator) { DataType data_type, const std::vector<char> &type, CreateKernel creator) {
#ifndef CUSTOM_KERNEL_REGISTRY_CLIP #ifndef CUSTOM_KERNEL_REGISTRY_CLIP
return RegistryKernelImpl::GetInstance()->RegCustomKernel(arch, provider, data_type, type, creator); return RegistryKernelImpl::GetInstance()->RegCustomKernel(CharToString(arch), CharToString(provider), data_type,
CharToString(type), creator);
#else #else
MS_LOG(ERROR) << unsupport_custom_kernel_register_log; MS_LOG(ERROR) << unsupport_custom_kernel_register_log;
return kLiteNotSupport; return kLiteNotSupport;
#endif #endif
} }
Status RegisterKernel::RegKernel(const std::string &arch, const std::string &provider, DataType data_type, int op_type, Status RegisterKernel::RegKernel(const std::vector<char> &arch, const std::vector<char> &provider, DataType data_type,
CreateKernel creator) { int op_type, CreateKernel creator) {
#ifndef CUSTOM_KERNEL_REGISTRY_CLIP #ifndef CUSTOM_KERNEL_REGISTRY_CLIP
return RegistryKernelImpl::GetInstance()->RegKernel(arch, provider, data_type, op_type, creator); return RegistryKernelImpl::GetInstance()->RegKernel(CharToString(arch), CharToString(provider), data_type, op_type,
creator);
#else #else
MS_LOG(ERROR) << unsupport_custom_kernel_register_log; MS_LOG(ERROR) << unsupport_custom_kernel_register_log;
return kLiteNotSupport; return kLiteNotSupport;
#endif #endif
} }
CreateKernel RegisterKernel::GetCreator(const schema::Primitive *primitive, KernelDesc *desc) { CreateKernel RegisterKernel::GetCreator(const schema::Primitive *primitive, KernelDescHelper *desc) {
#ifndef CUSTOM_KERNEL_REGISTRY_CLIP #ifndef CUSTOM_KERNEL_REGISTRY_CLIP
return RegistryKernelImpl::GetInstance()->GetProviderCreator(primitive, desc); if (desc == nullptr) {
return nullptr;
}
KernelDesc kernel_desc = {desc->data_type, desc->type, CharToString(desc->arch), CharToString(desc->provider)};
return RegistryKernelImpl::GetInstance()->GetProviderCreator(primitive, &kernel_desc);
#else #else
MS_LOG(ERROR) << unsupport_custom_kernel_register_log; MS_LOG(ERROR) << unsupport_custom_kernel_register_log;
return nullptr; return nullptr;

View File

@ -22,19 +22,19 @@
namespace mindspore { namespace mindspore {
namespace registry { namespace registry {
Status RegisterKernelInterface::Reg(const std::string &provider, int op_type, KernelInterfaceCreator creator) { Status RegisterKernelInterface::Reg(const std::vector<char> &provider, int op_type, KernelInterfaceCreator creator) {
#ifndef CUSTOM_KERNEL_REGISTRY_CLIP #ifndef CUSTOM_KERNEL_REGISTRY_CLIP
return KernelInterfaceRegistry::Instance()->Reg(provider, op_type, creator); return KernelInterfaceRegistry::Instance()->Reg(CharToString(provider), op_type, creator);
#else #else
MS_LOG(ERROR) << unsupport_custom_kernel_register_log; MS_LOG(ERROR) << unsupport_custom_kernel_register_log;
return kLiteNotSupport; return kLiteNotSupport;
#endif #endif
} }
Status RegisterKernelInterface::CustomReg(const std::string &provider, const std::string &op_type, Status RegisterKernelInterface::CustomReg(const std::vector<char> &provider, const std::vector<char> &op_type,
KernelInterfaceCreator creator) { KernelInterfaceCreator creator) {
#ifndef CUSTOM_KERNEL_REGISTRY_CLIP #ifndef CUSTOM_KERNEL_REGISTRY_CLIP
return KernelInterfaceRegistry::Instance()->CustomReg(provider, op_type, creator); return KernelInterfaceRegistry::Instance()->CustomReg(CharToString(provider), CharToString(op_type), creator);
#else #else
MS_LOG(ERROR) << unsupport_custom_kernel_register_log; MS_LOG(ERROR) << unsupport_custom_kernel_register_log;
return kLiteNotSupport; return kLiteNotSupport;
@ -42,9 +42,9 @@ Status RegisterKernelInterface::CustomReg(const std::string &provider, const std
} }
std::shared_ptr<kernel::KernelInterface> RegisterKernelInterface::GetKernelInterface( std::shared_ptr<kernel::KernelInterface> RegisterKernelInterface::GetKernelInterface(
const std::string &provider, const schema::Primitive *primitive) { const std::vector<char> &provider, const schema::Primitive *primitive) {
#ifndef CUSTOM_KERNEL_REGISTRY_CLIP #ifndef CUSTOM_KERNEL_REGISTRY_CLIP
return KernelInterfaceRegistry::Instance()->GetKernelInterface(provider, primitive); return KernelInterfaceRegistry::Instance()->GetKernelInterface(CharToString(provider), primitive);
#else #else
MS_LOG(ERROR) << unsupport_custom_kernel_register_log; MS_LOG(ERROR) << unsupport_custom_kernel_register_log;
return nullptr; return nullptr;

View File

@ -31,24 +31,25 @@
using mindspore::kernel::CLErrorCode; using mindspore::kernel::CLErrorCode;
namespace mindspore::registry::opencl { namespace mindspore::registry::opencl {
Status OpenCLRuntimeWrapper::LoadSource(const std::string &program_name, const std::string &source) { Status OpenCLRuntimeWrapper::LoadSource(const std::vector<char> &program_name, const std::vector<char> &source) {
lite::opencl::OpenCLRuntimeInnerWrapper ocl_runtime_wrap; lite::opencl::OpenCLRuntimeInnerWrapper ocl_runtime_wrap;
lite::opencl::OpenCLRuntime *ocl_runtime = ocl_runtime_wrap.GetInstance(); lite::opencl::OpenCLRuntime *ocl_runtime = ocl_runtime_wrap.GetInstance();
const std::string program_name_ext = "provider_" + program_name; const std::string program_name_ext = "provider_" + CharToString(program_name);
if (ocl_runtime->LoadSource(program_name_ext, source)) { if (ocl_runtime->LoadSource(program_name_ext, CharToString(source))) {
return kSuccess; return kSuccess;
} else { } else {
return kLiteError; return kLiteError;
} }
} }
Status OpenCLRuntimeWrapper::BuildKernel(cl::Kernel *kernel, const std::string &program_name, Status OpenCLRuntimeWrapper::BuildKernel(cl::Kernel *kernel, const std::vector<char> &program_name,
const std::string &kernel_name, const std::vector<char> &kernel_name,
const std::vector<std::string> &build_options_ext) { const std::vector<std::vector<char>> &build_options_ext) {
lite::opencl::OpenCLRuntimeInnerWrapper ocl_runtime_wrap; lite::opencl::OpenCLRuntimeInnerWrapper ocl_runtime_wrap;
lite::opencl::OpenCLRuntime *ocl_runtime = ocl_runtime_wrap.GetInstance(); lite::opencl::OpenCLRuntime *ocl_runtime = ocl_runtime_wrap.GetInstance();
const std::string program_name_ext = "provider_" + program_name; const std::string program_name_ext = "provider_" + CharToString(program_name);
if (ocl_runtime->BuildKernel(*kernel, program_name_ext, kernel_name, build_options_ext, false) == RET_OK) { if (ocl_runtime->BuildKernel(*kernel, program_name_ext, CharToString(kernel_name),
VectorCharToString(build_options_ext), false) == RET_OK) {
return kSuccess; return kSuccess;
} else { } else {
return kLiteError; return kLiteError;

View File

@ -285,7 +285,7 @@ class CustomAddKernel : public kernel::Kernel {
return lite::RET_OK; return lite::RET_OK;
} }
auto status = auto status =
registry::RegisterKernelInterface::GetKernelInterface({}, primitive_)->Infer(&inputs_, &outputs_, primitive_); registry::RegisterKernelInterface::GetKernelInterface("", primitive_)->Infer(&inputs_, &outputs_, primitive_);
if (status != kSuccess) { if (status != kSuccess) {
std::cerr << "infer failed." << std::endl; std::cerr << "infer failed." << std::endl;
return lite::RET_ERROR; return lite::RET_ERROR;

View File

@ -25,13 +25,15 @@ namespace {
std::map<converter::FmkType, std::map<std::string, converter::NodeParserPtr>> node_parser_room; std::map<converter::FmkType, std::map<std::string, converter::NodeParserPtr>> node_parser_room;
std::mutex node_mutex; std::mutex node_mutex;
} // namespace } // namespace
NodeParserRegistry::NodeParserRegistry(converter::FmkType fmk_type, const std::string &node_type, NodeParserRegistry::NodeParserRegistry(converter::FmkType fmk_type, const std::vector<char> &node_type,
const converter::NodeParserPtr &node_parser) { const converter::NodeParserPtr &node_parser) {
std::unique_lock<std::mutex> lock(node_mutex); std::unique_lock<std::mutex> lock(node_mutex);
node_parser_room[fmk_type][node_type] = node_parser; std::string node_type_str = CharToString(node_type);
node_parser_room[fmk_type][node_type_str] = node_parser;
} }
converter::NodeParserPtr NodeParserRegistry::GetNodeParser(converter::FmkType fmk_type, const std::string &node_type) { converter::NodeParserPtr NodeParserRegistry::GetNodeParser(converter::FmkType fmk_type,
const std::vector<char> &node_type) {
auto iter_level1 = node_parser_room.find(fmk_type); auto iter_level1 = node_parser_room.find(fmk_type);
if (iter_level1 == node_parser_room.end()) { if (iter_level1 == node_parser_room.end()) {
return nullptr; return nullptr;
@ -39,7 +41,7 @@ converter::NodeParserPtr NodeParserRegistry::GetNodeParser(converter::FmkType fm
if (node_type.empty()) { if (node_type.empty()) {
return nullptr; return nullptr;
} }
auto iter_level2 = iter_level1->second.find(node_type); auto iter_level2 = iter_level1->second.find(CharToString(node_type));
if (iter_level2 == iter_level1->second.end()) { if (iter_level2 == iter_level1->second.end()) {
return nullptr; return nullptr;
} }

View File

@ -37,18 +37,21 @@ void RegPass(const std::string &pass_name, const PassBasePtr &pass) {
} }
} // namespace } // namespace
PassRegistry::PassRegistry(const std::string &pass_name, const PassBasePtr &pass) { RegPass(pass_name, pass); } PassRegistry::PassRegistry(const std::vector<char> &pass_name, const PassBasePtr &pass) {
RegPass(CharToString(pass_name), pass);
}
PassRegistry::PassRegistry(PassPosition position, const std::vector<std::string> &names) { PassRegistry::PassRegistry(PassPosition position, const std::vector<std::vector<char>> &names) {
std::unique_lock<std::mutex> lock(pass_mutex); std::unique_lock<std::mutex> lock(pass_mutex);
external_assigned_passes[position] = names; external_assigned_passes[position] = VectorCharToString(names);
} }
std::vector<std::string> PassRegistry::GetOuterScheduleTask(PassPosition position) { std::vector<std::vector<char>> PassRegistry::GetOuterScheduleTaskInner(PassPosition position) {
return external_assigned_passes[position]; return VectorStringToChar(external_assigned_passes[position]);
} }
PassBasePtr PassRegistry::GetPassFromStoreRoom(const std::string &pass_name) { PassBasePtr PassRegistry::GetPassFromStoreRoom(const std::vector<char> &pass_name_char) {
std::string pass_name = CharToString(pass_name_char);
return outer_pass_storage.find(pass_name) == outer_pass_storage.end() ? nullptr : outer_pass_storage[pass_name]; return outer_pass_storage.find(pass_name) == outer_pass_storage.end() ? nullptr : outer_pass_storage[pass_name];
} }
} // namespace registry } // namespace registry