external interface modify

This commit is contained in:
zhengyuanhua 2021-09-18 10:46:46 +08:00
parent 41788a88d2
commit 5440f9fbc1
22 changed files with 251 additions and 83 deletions

View File

@ -21,13 +21,21 @@
#include <vector>
#include <memory>
#include "include/api/callback/callback.h"
#include "include/api/dual_abi_helper.h"
namespace mindspore {
class CkptSaver: public TrainCallBack {
public:
explicit CkptSaver(int save_every_n, const std::string &filename_prefix);
inline CkptSaver(int save_every_n, const std::string &filename_prefix);
virtual ~CkptSaver();
private:
CkptSaver(int save_every_n, const std::vector<char> &filename_prefix);
};
CkptSaver::CkptSaver(int save_every_n, const std::string &filename_prefix)
: CkptSaver(save_every_n, StringToChar(filename_prefix)) {}
} // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_CALLBACK_CKPT_SAVER_H

View File

@ -104,7 +104,7 @@ class MS_API Model {
/// \param[in] config_path config file path.
///
/// \return Status.
Status LoadConfig(const std::string &config_path);
inline Status LoadConfig(const std::string &config_path);
/// \brief Obtains all input tensors of the model.
///
@ -189,9 +189,9 @@ class MS_API Model {
/// \param[in] dec_mode Define the decryption mode. Options: AES-GCM, AES-CBC.
///
/// \return Status.
Status Build(const void *model_data, size_t data_size, ModelType model_type,
const std::shared_ptr<Context> &model_context = nullptr, const Key &dec_key = {},
const std::string &dec_mode = kDecModeAesGcm);
inline Status Build(const void *model_data, size_t data_size, ModelType model_type,
const std::shared_ptr<Context> &model_context = nullptr, const Key &dec_key = {},
const std::string &dec_mode = kDecModeAesGcm);
/// \brief Load and build a model from model buffer so that it can run on a device. Only valid for Lite.
///
@ -203,9 +203,9 @@ class MS_API Model {
/// \param[in] dec_mode Define the decryption mode. Options: AES-GCM, AES-CBC.
///
/// \return Status.
Status Build(const std::string &model_path, ModelType model_type,
const std::shared_ptr<Context> &model_context = nullptr, const Key &dec_key = {},
const std::string &dec_mode = kDecModeAesGcm);
inline Status Build(const std::string &model_path, ModelType model_type,
const std::shared_ptr<Context> &model_context = nullptr, const Key &dec_key = {},
const std::string &dec_mode = kDecModeAesGcm);
private:
friend class Serialization;
@ -214,6 +214,11 @@ class MS_API Model {
std::vector<std::vector<char>> GetOutputTensorNamesChar();
MSTensor GetOutputByTensorName(const std::vector<char> &tensor_name);
std::vector<MSTensor> GetOutputsByNodeName(const std::vector<char> &node_name);
Status LoadConfig(const std::vector<char> &config_path);
Status Build(const void *model_data, size_t data_size, ModelType model_type,
const std::shared_ptr<Context> &model_context, const Key &dec_key, const std::vector<char> &dec_mode);
Status Build(const std::vector<char> &model_path, ModelType model_type, const std::shared_ptr<Context> &model_context,
const Key &dec_key, const std::vector<char> &dec_mode);
std::shared_ptr<ModelImpl> impl_;
};
@ -231,5 +236,19 @@ MSTensor Model::GetOutputByTensorName(const std::string &tensor_name) {
std::vector<MSTensor> Model::GetOutputsByNodeName(const std::string &node_name) {
return GetOutputsByNodeName(StringToChar(node_name));
}
Status Model::LoadConfig(const std::string &config_path) {
return LoadConfig(StringToChar(config_path));
}
Status Model::Build(const void *model_data, size_t data_size, ModelType model_type,
const std::shared_ptr<Context> &model_context, const Key &dec_key, const std::string &dec_mode) {
return Build(model_data, data_size, model_type, model_context, dec_key, StringToChar(dec_mode));
}
Status Model::Build(const std::string &model_path, ModelType model_type, const std::shared_ptr<Context> &model_context,
const Key &dec_key, const std::string &dec_mode) {
return Build(StringToChar(model_path), model_type, model_context, dec_key, StringToChar(dec_mode));
}
} // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_MODEL_H

View File

@ -68,9 +68,9 @@ class MS_API Serialization {
const Key &dec_key = {}, const std::string &dec_mode = kDecModeAesGcm);
static Status SetParameters(const std::map<std::string, Buffer> &parameters, Model *model);
static Status ExportModel(const Model &model, ModelType model_type, Buffer *model_data);
static Status ExportModel(const Model &model, ModelType model_type, const std::string &model_file,
QuantizationType quantization_type = kNoQuant, bool export_inference_only = true,
std::vector<std::string> output_tensor_name = {});
inline static Status ExportModel(const Model &model, ModelType model_type, const std::string &model_file,
QuantizationType quantization_type = kNoQuant, bool export_inference_only = true,
std::vector<std::string> output_tensor_name = {});
private:
static Status Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph, const Key &dec_key,
@ -80,6 +80,9 @@ class MS_API Serialization {
const std::vector<char> &dec_mode);
static Status Load(const std::vector<std::vector<char>> &files, ModelType model_type, std::vector<Graph> *graphs,
const Key &dec_key, const std::vector<char> &dec_mode);
static Status ExportModel(const Model &model, ModelType model_type, const std::vector<char> &model_file,
QuantizationType quantization_type, bool export_inference_only,
const std::vector<std::vector<char>> &output_tensor_name);
};
Status Serialization::Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph,
@ -96,5 +99,13 @@ Status Serialization::Load(const std::vector<std::string> &files, ModelType mode
const Key &dec_key, const std::string &dec_mode) {
return Load(VectorStringToChar(files), model_type, graphs, dec_key, StringToChar(dec_mode));
}
Status Serialization::ExportModel(const Model &model, ModelType model_type, const std::string &model_file,
QuantizationType quantization_type, bool export_inference_only,
std::vector<std::string> output_tensor_name) {
return ExportModel(model, model_type, StringToChar(model_file), quantization_type, export_inference_only,
VectorStringToChar(output_tensor_name));
}
} // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_SERIALIZATION_H

View File

@ -221,7 +221,7 @@ class MS_API MSTensor {
/// \brief Set the name for the MSTensor. Only valid for Lite.
///
/// \param[in] The name of the MSTensor.
void SetTensorName(const std::string &name);
inline void SetTensorName(const std::string &name);
/// \brief Set the Allocator for the MSTensor. Only valid for Lite.
///
@ -275,6 +275,7 @@ class MS_API MSTensor {
MSTensor(const std::vector<char> &name, enum DataType type, const std::vector<int64_t> &shape, const void *data,
size_t data_len);
std::vector<char> CharName() const;
void SetTensorName(const std::vector<char> &name);
friend class ModelImpl;
std::shared_ptr<Impl> impl_;
@ -333,6 +334,10 @@ MSTensor::MSTensor(const std::string &name, enum DataType type, const std::vecto
std::string MSTensor::Name() const { return CharToString(CharName()); }
void MSTensor::SetTensorName(const std::string &name) {
return SetTensorName(StringToChar(name));
}
using Key = struct Key {
const size_t max_key_len = 32;
size_t len;

View File

@ -66,13 +66,13 @@ Status Model::Build(GraphCell graph_cell, const std::shared_ptr<Context> &model_
}
Status Model::Build(const void *, size_t, ModelType, const std::shared_ptr<Context> &, const Key &,
const std::string &) {
const std::vector<char> &) {
MS_LOG(ERROR) << "Unsupported Feature.";
return kMCFailed;
}
Status Model::Build(const std::string &, ModelType, const std::shared_ptr<Context> &, const Key &,
const std::string &) {
Status Model::Build(const std::vector<char> &, ModelType, const std::shared_ptr<Context> &, const Key &,
const std::vector<char> &) {
MS_LOG(ERROR) << "Unsupported Feature.";
return kMCFailed;
}

View File

@ -340,8 +340,8 @@ Status Serialization::ExportModel(const Model &, ModelType, Buffer *) {
return kMEFailed;
}
Status Serialization::ExportModel(const Model &, ModelType, const std::string &, QuantizationType, bool,
std::vector<std::string> output_tensor_name) {
Status Serialization::ExportModel(const Model &, ModelType, const std::vector<char> &, QuantizationType, bool,
const std::vector<std::vector<char>> &output_tensor_name) {
MS_LOG(ERROR) << "Unsupported feature.";
return kMEFailed;
}

View File

@ -429,7 +429,7 @@ void MSTensor::SetShape(const std::vector<int64_t> &) { MS_LOG_EXCEPTION << "Inv
void MSTensor::SetDataType(enum DataType) { MS_LOG_EXCEPTION << "Invalid implement."; }
void MSTensor::SetTensorName(const std::string &) { MS_LOG_EXCEPTION << "Invalid implement."; }
void MSTensor::SetTensorName(const std::vector<char> &) { MS_LOG_EXCEPTION << "Invalid implement."; }
void MSTensor::SetAllocator(std::shared_ptr<Allocator>) { MS_LOG_EXCEPTION << "Invalid implement."; }

View File

@ -18,7 +18,9 @@
#define MINDSPORE_LITE_INCLUDE_REGISTRY_NODE_PARSER_REGISTRY_H_
#include <string>
#include <vector>
#include "include/registry/node_parser.h"
#include "include/api/dual_abi_helper.h"
namespace mindspore {
namespace registry {
@ -30,8 +32,8 @@ class MS_API NodeParserRegistry {
/// \param[in] fmk_type Define the framework.
/// \param[in] node_type Define the type of the node to be resolved.
/// \param[in] node_parser Define the NodeParser instance to parse the node.
NodeParserRegistry(converter::FmkType fmk_type, const std::string &node_type,
const converter::NodeParserPtr &node_parser);
inline NodeParserRegistry(converter::FmkType fmk_type, const std::string &node_type,
const converter::NodeParserPtr &node_parser);
/// \brief Destructor
~NodeParserRegistry() = default;
@ -42,9 +44,22 @@ class MS_API NodeParserRegistry {
/// \param[in] node_type Define the type of the node to be resolved.
///
/// \return NodeParser instance.
static converter::NodeParserPtr GetNodeParser(converter::FmkType fmk_type, const std::string &node_type);
inline static converter::NodeParserPtr GetNodeParser(converter::FmkType fmk_type, const std::string &node_type);
private:
NodeParserRegistry(converter::FmkType fmk_type, const std::vector<char> &node_type,
const converter::NodeParserPtr &node_parser);
static converter::NodeParserPtr GetNodeParser(converter::FmkType fmk_type, const std::vector<char> &node_type);
};
NodeParserRegistry::NodeParserRegistry(converter::FmkType fmk_type, const std::string &node_type,
const converter::NodeParserPtr &node_parser)
: NodeParserRegistry(fmk_type, StringToChar(node_type), node_parser) {}
converter::NodeParserPtr NodeParserRegistry::GetNodeParser(converter::FmkType fmk_type, const std::string &node_type) {
return GetNodeParser(fmk_type, StringToChar(node_type));
}
/// \brief Defined registering macro to register NodeParser instance.
///
/// \param[in] fmk_type Define the framework.

View File

@ -16,6 +16,7 @@
#ifndef MINDSPORE_LITE_INCLUDE_REGISTRY_OPENCL_RUNTIME_WRAPPER_H
#define MINDSPORE_LITE_INCLUDE_REGISTRY_OPENCL_RUNTIME_WRAPPER_H
#include <vector>
#include <map>
#include <memory>
@ -26,6 +27,7 @@
#include "CL/cl2.hpp"
#include "include/api/allocator.h"
#include "include/api/status.h"
#include "include/api/dual_abi_helper.h"
namespace mindspore::registry::opencl {
class OpenCLRuntimeWrapper {
@ -39,7 +41,7 @@ class OpenCLRuntimeWrapper {
/// \param[in] source Define OpenCl source.
///
/// \return Status as a status identification of loading code.
Status LoadSource(const std::string &program_name, const std::string &source);
inline Status LoadSource(const std::string &program_name, const std::string &source);
/// \brief Building OpenCL code.
///
@ -49,8 +51,8 @@ class OpenCLRuntimeWrapper {
/// \param[in] build_options_ext Define OpenCl kernel build options.
///
/// \return Status as a status identification of build Kernel
Status BuildKernel(cl::Kernel *kernel, const std::string &program_name, const std::string &kernel_name,
const std::vector<std::string> &build_options_ext = {});
inline Status BuildKernel(cl::Kernel *kernel, const std::string &program_name, const std::string &kernel_name,
const std::vector<std::string> &build_options_ext = {});
/// \brief Set kernel argument
///
@ -114,6 +116,23 @@ class OpenCLRuntimeWrapper {
uint64_t GetMaxImage2DHeight();
uint64_t GetImagePitchAlignment();
private:
Status LoadSource(const std::vector<char> &program_name, const std::vector<char> &source);
Status BuildKernel(cl::Kernel *kernel, const std::vector<char> &program_name, const std::vector<char> &kernel_name,
const std::vector<std::vector<char>> &build_options_ext);
};
Status OpenCLRuntimeWrapper::LoadSource(const std::string &program_name, const std::string &source) {
return LoadSource(StringToChar(program_name), StringToChar(source));
}
Status OpenCLRuntimeWrapper::BuildKernel(cl::Kernel *kernel, const std::string &program_name,
const std::string &kernel_name,
const std::vector<std::string> &build_options_ext) {
return BuildKernel(kernel, StringToChar(program_name), StringToChar(kernel_name),
VectorStringToChar(build_options_ext));
}
} // namespace mindspore::registry::opencl
#endif // MINDSPORE_LITE_INCLUDE_REGISTRY_OPENCL_RUNTIME_WRAPPER_H

View File

@ -21,6 +21,7 @@
#include <string>
#include <memory>
#include "include/lite_utils.h"
#include "include/api/dual_abi_helper.h"
namespace mindspore {
namespace registry {
@ -36,13 +37,13 @@ class MS_API PassRegistry {
///
/// \param[in] pass_name Define the name of the pass, a string which should guarantee uniqueness.
/// \param[in] pass Define pass instance.
PassRegistry(const std::string &pass_name, const PassBasePtr &pass);
inline PassRegistry(const std::string &pass_name, const PassBasePtr &pass);
/// \brief Constructor of PassRegistry to assign which passes are required for external extension.
///
/// \param[in] position Define the place where assigned passes will run.
/// \param[in] names Define the names of the passes.
PassRegistry(PassPosition position, const std::vector<std::string> &names);
inline PassRegistry(PassPosition position, const std::vector<std::string> &names);
/// \brief Destructor of PassRegistrar.
~PassRegistry() = default;
@ -52,16 +53,35 @@ class MS_API PassRegistry {
/// \param[in] position Define the place where assigned passes will run.
///
/// \return Passes' Name Vector.
static std::vector<std::string> GetOuterScheduleTask(PassPosition position);
inline static std::vector<std::string> GetOuterScheduleTask(PassPosition position);
/// \brief Static method to obtain pass instance according to passes' name.
///
/// \param[in] pass_names Define the name of pass.
///
/// \return Pass Instance Vector.
static PassBasePtr GetPassFromStoreRoom(const std::string &pass_name);
inline static PassBasePtr GetPassFromStoreRoom(const std::string &pass_name);
private:
PassRegistry(const std::vector<char> &pass_name, const PassBasePtr &pass);
PassRegistry(PassPosition position, const std::vector<std::vector<char>> &names);
static std::vector<std::vector<char>> GetOuterScheduleTaskInner(PassPosition position);
static PassBasePtr GetPassFromStoreRoom(const std::vector<char> &pass_name_char);
};
PassRegistry::PassRegistry(const std::string &pass_name, const PassBasePtr &pass)
: PassRegistry(StringToChar(pass_name), pass) {}
PassRegistry::PassRegistry(PassPosition position, const std::vector<std::string> &names)
: PassRegistry(position, VectorStringToChar(names)) {}
std::vector<std::string> PassRegistry::GetOuterScheduleTask(PassPosition position) {
return VectorCharToString(GetOuterScheduleTaskInner(position));
}
PassBasePtr PassRegistry::GetPassFromStoreRoom(const std::string &pass_name) {
return GetPassFromStoreRoom(StringToChar(pass_name));
}
/// \brief Defined registering macro to register Pass, which called by user directly.
///
/// \param[in] name Define the name of the pass, a string which should guarantee uniqueness.

View File

@ -38,6 +38,14 @@ struct KernelDesc {
std::string provider; /**< user identification argument */
};
/// \brief KernelDesc defined kernel's basic attribute.
struct KernelDescHelper {
DataType data_type; /**< kernel data type argument */
int type; /**< op type argument */
std::vector<char> arch; /**< deviceType argument */
std::vector<char> provider; /**< user identification argument */
};
/// \brief CreateKernel Defined a functor to create a kernel.
///
/// \param[in] inputs Define input tensors of kernel.
@ -62,8 +70,8 @@ class MS_API RegisterKernel {
/// \param[in] creator Define a function pointer to create a kernel.
///
/// \return Status as a status identification of registering.
static Status RegKernel(const std::string &arch, const std::string &provider, DataType data_type, int type,
CreateKernel creator);
inline static Status RegKernel(const std::string &arch, const std::string &provider, DataType data_type, int type,
CreateKernel creator);
/// \brief Static method to register kernel which is corresponding to custom op.
///
@ -74,8 +82,8 @@ class MS_API RegisterKernel {
/// \param[in] creator Define a function pointer to create a kernel.
///
/// \return Status as a status identification of registering.
static Status RegCustomKernel(const std::string &arch, const std::string &provider, DataType data_type,
const std::string &type, CreateKernel creator);
inline static Status RegCustomKernel(const std::string &arch, const std::string &provider, DataType data_type,
const std::string &type, CreateKernel creator);
/// \brief Static methon to get a kernel's create function.
///
@ -83,7 +91,14 @@ class MS_API RegisterKernel {
/// \param[in] primitive Define the primitive of kernel generated by flatbuffers.
///
/// \return Function pointer to create a kernel.
static CreateKernel GetCreator(const schema::Primitive *primitive, KernelDesc *desc);
inline static CreateKernel GetCreator(const schema::Primitive *primitive, KernelDesc *desc);
private:
static Status RegKernel(const std::vector<char> &arch, const std::vector<char> &provider, DataType data_type,
int type, CreateKernel creator);
static Status RegCustomKernel(const std::vector<char> &arch, const std::vector<char> &provider, DataType data_type,
const std::vector<char> &type, CreateKernel creator);
static CreateKernel GetCreator(const schema::Primitive *primitive, KernelDescHelper *desc);
};
/// \brief KernelReg Defined registration class of kernel.
@ -117,6 +132,24 @@ class MS_API KernelReg {
}
};
Status RegisterKernel::RegKernel(const std::string &arch, const std::string &provider, DataType data_type, int type,
CreateKernel creator) {
return RegKernel(StringToChar(arch), StringToChar(provider), data_type, type, creator);
}
Status RegisterKernel::RegCustomKernel(const std::string &arch, const std::string &provider, DataType data_type,
const std::string &type, CreateKernel creator) {
return RegCustomKernel(StringToChar(arch), StringToChar(provider), data_type, StringToChar(type), creator);
}
CreateKernel RegisterKernel::GetCreator(const schema::Primitive *primitive, KernelDesc *desc) {
if (desc == nullptr) {
return nullptr;
}
KernelDescHelper kernel_desc = {desc->data_type, desc->type, StringToChar(desc->arch), StringToChar(desc->provider)};
return GetCreator(primitive, &kernel_desc);
}
/// \brief Defined registering macro to register ordinary op kernel, which called by user directly.
///
/// \param[in] arch Define deviceType, such as CPU.

View File

@ -39,7 +39,8 @@ class MS_API RegisterKernelInterface {
/// \param[in] creator Define the KernelInterface create function.
///
/// \return Status as a status identification of registering.
static Status CustomReg(const std::string &provider, const std::string &op_type, KernelInterfaceCreator creator);
inline static Status CustomReg(const std::string &provider, const std::string &op_type,
KernelInterfaceCreator creator);
/// \brief Static method to register op whose primitive type is ordinary.
///
@ -48,7 +49,7 @@ class MS_API RegisterKernelInterface {
/// \param[in] creator Define the KernelInterface create function.
///
/// \return Status as a status identification of registering.
static Status Reg(const std::string &provider, int op_type, KernelInterfaceCreator creator);
inline static Status Reg(const std::string &provider, int op_type, KernelInterfaceCreator creator);
/// \brief Static method to get registration of a certain op.
///
@ -56,7 +57,14 @@ class MS_API RegisterKernelInterface {
/// \param[in] primitive Define the attributes of a certain op.
///
/// \return Boolean value to represent registration of a certain op is existing or not.
static std::shared_ptr<kernel::KernelInterface> GetKernelInterface(const std::string &provider,
inline static std::shared_ptr<kernel::KernelInterface> GetKernelInterface(const std::string &provider,
const schema::Primitive *primitive);
private:
static Status CustomReg(const std::vector<char> &provider, const std::vector<char> &op_type,
KernelInterfaceCreator creator);
static Status Reg(const std::vector<char> &provider, int op_type, KernelInterfaceCreator creator);
static std::shared_ptr<kernel::KernelInterface> GetKernelInterface(const std::vector<char> &provider,
const schema::Primitive *primitive);
};
@ -82,6 +90,20 @@ class MS_API KernelInterfaceReg {
}
};
Status RegisterKernelInterface::CustomReg(const std::string &provider, const std::string &op_type,
KernelInterfaceCreator creator) {
return CustomReg(StringToChar(provider), StringToChar(op_type), creator);
}
Status RegisterKernelInterface::Reg(const std::string &provider, int op_type, KernelInterfaceCreator creator) {
return Reg(StringToChar(provider), op_type, creator);
}
std::shared_ptr<kernel::KernelInterface> RegisterKernelInterface::GetKernelInterface(
const std::string &provider, const schema::Primitive *primitive) {
return GetKernelInterface(StringToChar(provider), primitive);
}
/// \brief Defined registering macro to register ordinary op, which called by user directly.
///
/// \param[in] provider Define the identification of user.

View File

@ -23,8 +23,9 @@
#include "src/common/log_adapter.h"
namespace mindspore {
CkptSaver::CkptSaver(int save_every_n, const std::string &filename_prefix) {
callback_impl_ = new (std::nothrow) CallbackImpl(new (std::nothrow) lite::CkptSaver(save_every_n, filename_prefix));
CkptSaver::CkptSaver(int save_every_n, const std::vector<char> &filename_prefix) {
callback_impl_ =
new (std::nothrow) CallbackImpl(new (std::nothrow) lite::CkptSaver(save_every_n, CharToString(filename_prefix)));
if (callback_impl_ == nullptr) {
MS_LOG(ERROR) << "Callback implement new failed";
}

View File

@ -29,7 +29,8 @@ namespace mindspore {
std::mutex g_impl_init_lock;
Status Model::Build(const void *model_data, size_t data_size, ModelType model_type,
const std::shared_ptr<Context> &model_context, const Key &dec_key, const std::string &dec_mode) {
const std::shared_ptr<Context> &model_context, const Key &dec_key,
const std::vector<char> &dec_mode) {
if (impl_ == nullptr) {
std::unique_lock<std::mutex> impl_lock(g_impl_init_lock);
impl_ = std::shared_ptr<ModelImpl>(new (std::nothrow) ModelImpl());
@ -46,8 +47,9 @@ Status Model::Build(const void *model_data, size_t data_size, ModelType model_ty
return kSuccess;
}
Status Model::Build(const std::string &model_path, ModelType model_type, const std::shared_ptr<Context> &model_context,
const Key &dec_key, const std::string &dec_mode) {
Status Model::Build(const std::vector<char> &model_path, ModelType model_type,
const std::shared_ptr<Context> &model_context, const Key &dec_key,
const std::vector<char> &dec_mode) {
if (impl_ == nullptr) {
std::unique_lock<std::mutex> impl_lock(g_impl_init_lock);
impl_ = std::shared_ptr<ModelImpl>(new (std::nothrow) ModelImpl());
@ -57,7 +59,7 @@ Status Model::Build(const std::string &model_path, ModelType model_type, const s
}
}
Status ret = impl_->Build(model_path, model_type, model_context);
Status ret = impl_->Build(CharToString(model_path), model_type, model_context);
if (ret != kSuccess) {
return ret;
}
@ -186,7 +188,7 @@ std::vector<MSTensor> Model::GetOutputsByNodeName(const std::vector<char> &node_
return impl_->GetOutputsByNodeName(CharToString(node_name));
}
Status Model::LoadConfig(const std::string &config_path) {
Status Model::LoadConfig(const std::vector<char> &config_path) {
std::unique_lock<std::mutex> impl_lock(g_impl_init_lock);
if (impl_ != nullptr) {
MS_LOG(ERROR) << "impl_ illegal in LoadConfig.";
@ -199,7 +201,7 @@ Status Model::LoadConfig(const std::string &config_path) {
return Status(kLiteFileError, "Fail to load config file.");
}
auto ret = impl_->LoadConfig(config_path);
auto ret = impl_->LoadConfig(CharToString(config_path));
if (ret != kSuccess) {
MS_LOG(ERROR) << "impl_ LoadConfig failed,";
return Status(kLiteFileError, "Invalid config file.");

View File

@ -122,9 +122,9 @@ Status Serialization::ExportModel(const Model &model, ModelType model_type, Buff
return kMEFailed;
}
Status Serialization::ExportModel(const Model &model, ModelType model_type, const std::string &model_file,
Status Serialization::ExportModel(const Model &model, ModelType model_type, const std::vector<char> &model_file,
QuantizationType quantization_type, bool export_inference_only,
std::vector<std::string> output_tensor_name) {
const std::vector<std::vector<char>> &output_tensor_name) {
if (model.impl_ == nullptr) {
MS_LOG(ERROR) << "Model implement is null.";
return kLiteUninitializedObj;
@ -141,8 +141,9 @@ Status Serialization::ExportModel(const Model &model, ModelType model_type, cons
MS_LOG(ERROR) << "Model session is nullptr.";
return kLiteError;
}
auto ret = model.impl_->session_->Export(model_file, export_inference_only ? lite::MT_INFERENCE : lite::MT_TRAIN,
A2L_ConvertQT(quantization_type), lite::FT_FLATBUFFERS, output_tensor_name);
auto ret = model.impl_->session_->Export(
CharToString(model_file), export_inference_only ? lite::MT_INFERENCE : lite::MT_TRAIN,
A2L_ConvertQT(quantization_type), lite::FT_FLATBUFFERS, VectorCharToString(output_tensor_name));
return (ret == mindspore::lite::RET_OK) ? kSuccess : kLiteError;
}

View File

@ -302,12 +302,12 @@ void MSTensor::SetDataType(enum DataType data_type) {
impl_->SetDataType(data_type);
}
void MSTensor::SetTensorName(const std::string &name) {
void MSTensor::SetTensorName(const std::vector<char> &name) {
if (impl_ == nullptr) {
MS_LOG(ERROR) << "Invalid tensor implement.";
return;
}
impl_->SetName(name);
impl_->SetName(CharToString(name));
}
void MSTensor::SetAllocator(std::shared_ptr<Allocator> allocator) {

View File

@ -22,29 +22,35 @@
namespace mindspore {
namespace registry {
Status RegisterKernel::RegCustomKernel(const std::string &arch, const std::string &provider, DataType data_type,
const std::string &type, CreateKernel creator) {
Status RegisterKernel::RegCustomKernel(const std::vector<char> &arch, const std::vector<char> &provider,
DataType data_type, const std::vector<char> &type, CreateKernel creator) {
#ifndef CUSTOM_KERNEL_REGISTRY_CLIP
return RegistryKernelImpl::GetInstance()->RegCustomKernel(arch, provider, data_type, type, creator);
return RegistryKernelImpl::GetInstance()->RegCustomKernel(CharToString(arch), CharToString(provider), data_type,
CharToString(type), creator);
#else
MS_LOG(ERROR) << unsupport_custom_kernel_register_log;
return kLiteNotSupport;
#endif
}
Status RegisterKernel::RegKernel(const std::string &arch, const std::string &provider, DataType data_type, int op_type,
CreateKernel creator) {
Status RegisterKernel::RegKernel(const std::vector<char> &arch, const std::vector<char> &provider, DataType data_type,
int op_type, CreateKernel creator) {
#ifndef CUSTOM_KERNEL_REGISTRY_CLIP
return RegistryKernelImpl::GetInstance()->RegKernel(arch, provider, data_type, op_type, creator);
return RegistryKernelImpl::GetInstance()->RegKernel(CharToString(arch), CharToString(provider), data_type, op_type,
creator);
#else
MS_LOG(ERROR) << unsupport_custom_kernel_register_log;
return kLiteNotSupport;
#endif
}
CreateKernel RegisterKernel::GetCreator(const schema::Primitive *primitive, KernelDesc *desc) {
CreateKernel RegisterKernel::GetCreator(const schema::Primitive *primitive, KernelDescHelper *desc) {
#ifndef CUSTOM_KERNEL_REGISTRY_CLIP
return RegistryKernelImpl::GetInstance()->GetProviderCreator(primitive, desc);
if (desc == nullptr) {
return nullptr;
}
KernelDesc kernel_desc = {desc->data_type, desc->type, CharToString(desc->arch), CharToString(desc->provider)};
return RegistryKernelImpl::GetInstance()->GetProviderCreator(primitive, &kernel_desc);
#else
MS_LOG(ERROR) << unsupport_custom_kernel_register_log;
return nullptr;

View File

@ -22,19 +22,19 @@
namespace mindspore {
namespace registry {
Status RegisterKernelInterface::Reg(const std::string &provider, int op_type, KernelInterfaceCreator creator) {
Status RegisterKernelInterface::Reg(const std::vector<char> &provider, int op_type, KernelInterfaceCreator creator) {
#ifndef CUSTOM_KERNEL_REGISTRY_CLIP
return KernelInterfaceRegistry::Instance()->Reg(provider, op_type, creator);
return KernelInterfaceRegistry::Instance()->Reg(CharToString(provider), op_type, creator);
#else
MS_LOG(ERROR) << unsupport_custom_kernel_register_log;
return kLiteNotSupport;
#endif
}
Status RegisterKernelInterface::CustomReg(const std::string &provider, const std::string &op_type,
Status RegisterKernelInterface::CustomReg(const std::vector<char> &provider, const std::vector<char> &op_type,
KernelInterfaceCreator creator) {
#ifndef CUSTOM_KERNEL_REGISTRY_CLIP
return KernelInterfaceRegistry::Instance()->CustomReg(provider, op_type, creator);
return KernelInterfaceRegistry::Instance()->CustomReg(CharToString(provider), CharToString(op_type), creator);
#else
MS_LOG(ERROR) << unsupport_custom_kernel_register_log;
return kLiteNotSupport;
@ -42,9 +42,9 @@ Status RegisterKernelInterface::CustomReg(const std::string &provider, const std
}
std::shared_ptr<kernel::KernelInterface> RegisterKernelInterface::GetKernelInterface(
const std::string &provider, const schema::Primitive *primitive) {
const std::vector<char> &provider, const schema::Primitive *primitive) {
#ifndef CUSTOM_KERNEL_REGISTRY_CLIP
return KernelInterfaceRegistry::Instance()->GetKernelInterface(provider, primitive);
return KernelInterfaceRegistry::Instance()->GetKernelInterface(CharToString(provider), primitive);
#else
MS_LOG(ERROR) << unsupport_custom_kernel_register_log;
return nullptr;

View File

@ -31,24 +31,25 @@
using mindspore::kernel::CLErrorCode;
namespace mindspore::registry::opencl {
Status OpenCLRuntimeWrapper::LoadSource(const std::string &program_name, const std::string &source) {
Status OpenCLRuntimeWrapper::LoadSource(const std::vector<char> &program_name, const std::vector<char> &source) {
lite::opencl::OpenCLRuntimeInnerWrapper ocl_runtime_wrap;
lite::opencl::OpenCLRuntime *ocl_runtime = ocl_runtime_wrap.GetInstance();
const std::string program_name_ext = "provider_" + program_name;
if (ocl_runtime->LoadSource(program_name_ext, source)) {
const std::string program_name_ext = "provider_" + CharToString(program_name);
if (ocl_runtime->LoadSource(program_name_ext, CharToString(source))) {
return kSuccess;
} else {
return kLiteError;
}
}
Status OpenCLRuntimeWrapper::BuildKernel(cl::Kernel *kernel, const std::string &program_name,
const std::string &kernel_name,
const std::vector<std::string> &build_options_ext) {
Status OpenCLRuntimeWrapper::BuildKernel(cl::Kernel *kernel, const std::vector<char> &program_name,
const std::vector<char> &kernel_name,
const std::vector<std::vector<char>> &build_options_ext) {
lite::opencl::OpenCLRuntimeInnerWrapper ocl_runtime_wrap;
lite::opencl::OpenCLRuntime *ocl_runtime = ocl_runtime_wrap.GetInstance();
const std::string program_name_ext = "provider_" + program_name;
if (ocl_runtime->BuildKernel(*kernel, program_name_ext, kernel_name, build_options_ext, false) == RET_OK) {
const std::string program_name_ext = "provider_" + CharToString(program_name);
if (ocl_runtime->BuildKernel(*kernel, program_name_ext, CharToString(kernel_name),
VectorCharToString(build_options_ext), false) == RET_OK) {
return kSuccess;
} else {
return kLiteError;

View File

@ -285,7 +285,7 @@ class CustomAddKernel : public kernel::Kernel {
return lite::RET_OK;
}
auto status =
registry::RegisterKernelInterface::GetKernelInterface({}, primitive_)->Infer(&inputs_, &outputs_, primitive_);
registry::RegisterKernelInterface::GetKernelInterface("", primitive_)->Infer(&inputs_, &outputs_, primitive_);
if (status != kSuccess) {
std::cerr << "infer failed." << std::endl;
return lite::RET_ERROR;

View File

@ -25,13 +25,15 @@ namespace {
std::map<converter::FmkType, std::map<std::string, converter::NodeParserPtr>> node_parser_room;
std::mutex node_mutex;
} // namespace
NodeParserRegistry::NodeParserRegistry(converter::FmkType fmk_type, const std::string &node_type,
NodeParserRegistry::NodeParserRegistry(converter::FmkType fmk_type, const std::vector<char> &node_type,
const converter::NodeParserPtr &node_parser) {
std::unique_lock<std::mutex> lock(node_mutex);
node_parser_room[fmk_type][node_type] = node_parser;
std::string node_type_str = CharToString(node_type);
node_parser_room[fmk_type][node_type_str] = node_parser;
}
converter::NodeParserPtr NodeParserRegistry::GetNodeParser(converter::FmkType fmk_type, const std::string &node_type) {
converter::NodeParserPtr NodeParserRegistry::GetNodeParser(converter::FmkType fmk_type,
const std::vector<char> &node_type) {
auto iter_level1 = node_parser_room.find(fmk_type);
if (iter_level1 == node_parser_room.end()) {
return nullptr;
@ -39,7 +41,7 @@ converter::NodeParserPtr NodeParserRegistry::GetNodeParser(converter::FmkType fm
if (node_type.empty()) {
return nullptr;
}
auto iter_level2 = iter_level1->second.find(node_type);
auto iter_level2 = iter_level1->second.find(CharToString(node_type));
if (iter_level2 == iter_level1->second.end()) {
return nullptr;
}

View File

@ -37,18 +37,21 @@ void RegPass(const std::string &pass_name, const PassBasePtr &pass) {
}
} // namespace
PassRegistry::PassRegistry(const std::string &pass_name, const PassBasePtr &pass) { RegPass(pass_name, pass); }
PassRegistry::PassRegistry(const std::vector<char> &pass_name, const PassBasePtr &pass) {
RegPass(CharToString(pass_name), pass);
}
PassRegistry::PassRegistry(PassPosition position, const std::vector<std::string> &names) {
PassRegistry::PassRegistry(PassPosition position, const std::vector<std::vector<char>> &names) {
std::unique_lock<std::mutex> lock(pass_mutex);
external_assigned_passes[position] = names;
external_assigned_passes[position] = VectorCharToString(names);
}
std::vector<std::string> PassRegistry::GetOuterScheduleTask(PassPosition position) {
return external_assigned_passes[position];
std::vector<std::vector<char>> PassRegistry::GetOuterScheduleTaskInner(PassPosition position) {
return VectorStringToChar(external_assigned_passes[position]);
}
PassBasePtr PassRegistry::GetPassFromStoreRoom(const std::string &pass_name) {
PassBasePtr PassRegistry::GetPassFromStoreRoom(const std::vector<char> &pass_name_char) {
std::string pass_name = CharToString(pass_name_char);
return outer_pass_storage.find(pass_name) == outer_pass_storage.end() ? nullptr : outer_pass_storage[pass_name];
}
} // namespace registry