forked from mindspore-Ecosystem/mindspore
external interface modify
This commit is contained in:
parent
41788a88d2
commit
5440f9fbc1
|
@ -21,13 +21,21 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include "include/api/callback/callback.h"
|
#include "include/api/callback/callback.h"
|
||||||
|
#include "include/api/dual_abi_helper.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
|
||||||
class CkptSaver: public TrainCallBack {
|
class CkptSaver: public TrainCallBack {
|
||||||
public:
|
public:
|
||||||
explicit CkptSaver(int save_every_n, const std::string &filename_prefix);
|
inline CkptSaver(int save_every_n, const std::string &filename_prefix);
|
||||||
virtual ~CkptSaver();
|
virtual ~CkptSaver();
|
||||||
|
|
||||||
|
private:
|
||||||
|
CkptSaver(int save_every_n, const std::vector<char> &filename_prefix);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
CkptSaver::CkptSaver(int save_every_n, const std::string &filename_prefix)
|
||||||
|
: CkptSaver(save_every_n, StringToChar(filename_prefix)) {}
|
||||||
|
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_INCLUDE_API_CALLBACK_CKPT_SAVER_H
|
#endif // MINDSPORE_INCLUDE_API_CALLBACK_CKPT_SAVER_H
|
||||||
|
|
|
@ -104,7 +104,7 @@ class MS_API Model {
|
||||||
/// \param[in] config_path config file path.
|
/// \param[in] config_path config file path.
|
||||||
///
|
///
|
||||||
/// \return Status.
|
/// \return Status.
|
||||||
Status LoadConfig(const std::string &config_path);
|
inline Status LoadConfig(const std::string &config_path);
|
||||||
|
|
||||||
/// \brief Obtains all input tensors of the model.
|
/// \brief Obtains all input tensors of the model.
|
||||||
///
|
///
|
||||||
|
@ -189,9 +189,9 @@ class MS_API Model {
|
||||||
/// \param[in] dec_mode Define the decryption mode. Options: AES-GCM, AES-CBC.
|
/// \param[in] dec_mode Define the decryption mode. Options: AES-GCM, AES-CBC.
|
||||||
///
|
///
|
||||||
/// \return Status.
|
/// \return Status.
|
||||||
Status Build(const void *model_data, size_t data_size, ModelType model_type,
|
inline Status Build(const void *model_data, size_t data_size, ModelType model_type,
|
||||||
const std::shared_ptr<Context> &model_context = nullptr, const Key &dec_key = {},
|
const std::shared_ptr<Context> &model_context = nullptr, const Key &dec_key = {},
|
||||||
const std::string &dec_mode = kDecModeAesGcm);
|
const std::string &dec_mode = kDecModeAesGcm);
|
||||||
|
|
||||||
/// \brief Load and build a model from model buffer so that it can run on a device. Only valid for Lite.
|
/// \brief Load and build a model from model buffer so that it can run on a device. Only valid for Lite.
|
||||||
///
|
///
|
||||||
|
@ -203,9 +203,9 @@ class MS_API Model {
|
||||||
/// \param[in] dec_mode Define the decryption mode. Options: AES-GCM, AES-CBC.
|
/// \param[in] dec_mode Define the decryption mode. Options: AES-GCM, AES-CBC.
|
||||||
///
|
///
|
||||||
/// \return Status.
|
/// \return Status.
|
||||||
Status Build(const std::string &model_path, ModelType model_type,
|
inline Status Build(const std::string &model_path, ModelType model_type,
|
||||||
const std::shared_ptr<Context> &model_context = nullptr, const Key &dec_key = {},
|
const std::shared_ptr<Context> &model_context = nullptr, const Key &dec_key = {},
|
||||||
const std::string &dec_mode = kDecModeAesGcm);
|
const std::string &dec_mode = kDecModeAesGcm);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
friend class Serialization;
|
friend class Serialization;
|
||||||
|
@ -214,6 +214,11 @@ class MS_API Model {
|
||||||
std::vector<std::vector<char>> GetOutputTensorNamesChar();
|
std::vector<std::vector<char>> GetOutputTensorNamesChar();
|
||||||
MSTensor GetOutputByTensorName(const std::vector<char> &tensor_name);
|
MSTensor GetOutputByTensorName(const std::vector<char> &tensor_name);
|
||||||
std::vector<MSTensor> GetOutputsByNodeName(const std::vector<char> &node_name);
|
std::vector<MSTensor> GetOutputsByNodeName(const std::vector<char> &node_name);
|
||||||
|
Status LoadConfig(const std::vector<char> &config_path);
|
||||||
|
Status Build(const void *model_data, size_t data_size, ModelType model_type,
|
||||||
|
const std::shared_ptr<Context> &model_context, const Key &dec_key, const std::vector<char> &dec_mode);
|
||||||
|
Status Build(const std::vector<char> &model_path, ModelType model_type, const std::shared_ptr<Context> &model_context,
|
||||||
|
const Key &dec_key, const std::vector<char> &dec_mode);
|
||||||
|
|
||||||
std::shared_ptr<ModelImpl> impl_;
|
std::shared_ptr<ModelImpl> impl_;
|
||||||
};
|
};
|
||||||
|
@ -231,5 +236,19 @@ MSTensor Model::GetOutputByTensorName(const std::string &tensor_name) {
|
||||||
std::vector<MSTensor> Model::GetOutputsByNodeName(const std::string &node_name) {
|
std::vector<MSTensor> Model::GetOutputsByNodeName(const std::string &node_name) {
|
||||||
return GetOutputsByNodeName(StringToChar(node_name));
|
return GetOutputsByNodeName(StringToChar(node_name));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status Model::LoadConfig(const std::string &config_path) {
|
||||||
|
return LoadConfig(StringToChar(config_path));
|
||||||
|
}
|
||||||
|
|
||||||
|
Status Model::Build(const void *model_data, size_t data_size, ModelType model_type,
|
||||||
|
const std::shared_ptr<Context> &model_context, const Key &dec_key, const std::string &dec_mode) {
|
||||||
|
return Build(model_data, data_size, model_type, model_context, dec_key, StringToChar(dec_mode));
|
||||||
|
}
|
||||||
|
|
||||||
|
Status Model::Build(const std::string &model_path, ModelType model_type, const std::shared_ptr<Context> &model_context,
|
||||||
|
const Key &dec_key, const std::string &dec_mode) {
|
||||||
|
return Build(StringToChar(model_path), model_type, model_context, dec_key, StringToChar(dec_mode));
|
||||||
|
}
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_INCLUDE_API_MODEL_H
|
#endif // MINDSPORE_INCLUDE_API_MODEL_H
|
||||||
|
|
|
@ -68,9 +68,9 @@ class MS_API Serialization {
|
||||||
const Key &dec_key = {}, const std::string &dec_mode = kDecModeAesGcm);
|
const Key &dec_key = {}, const std::string &dec_mode = kDecModeAesGcm);
|
||||||
static Status SetParameters(const std::map<std::string, Buffer> ¶meters, Model *model);
|
static Status SetParameters(const std::map<std::string, Buffer> ¶meters, Model *model);
|
||||||
static Status ExportModel(const Model &model, ModelType model_type, Buffer *model_data);
|
static Status ExportModel(const Model &model, ModelType model_type, Buffer *model_data);
|
||||||
static Status ExportModel(const Model &model, ModelType model_type, const std::string &model_file,
|
inline static Status ExportModel(const Model &model, ModelType model_type, const std::string &model_file,
|
||||||
QuantizationType quantization_type = kNoQuant, bool export_inference_only = true,
|
QuantizationType quantization_type = kNoQuant, bool export_inference_only = true,
|
||||||
std::vector<std::string> output_tensor_name = {});
|
std::vector<std::string> output_tensor_name = {});
|
||||||
|
|
||||||
private:
|
private:
|
||||||
static Status Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph, const Key &dec_key,
|
static Status Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph, const Key &dec_key,
|
||||||
|
@ -80,6 +80,9 @@ class MS_API Serialization {
|
||||||
const std::vector<char> &dec_mode);
|
const std::vector<char> &dec_mode);
|
||||||
static Status Load(const std::vector<std::vector<char>> &files, ModelType model_type, std::vector<Graph> *graphs,
|
static Status Load(const std::vector<std::vector<char>> &files, ModelType model_type, std::vector<Graph> *graphs,
|
||||||
const Key &dec_key, const std::vector<char> &dec_mode);
|
const Key &dec_key, const std::vector<char> &dec_mode);
|
||||||
|
static Status ExportModel(const Model &model, ModelType model_type, const std::vector<char> &model_file,
|
||||||
|
QuantizationType quantization_type, bool export_inference_only,
|
||||||
|
const std::vector<std::vector<char>> &output_tensor_name);
|
||||||
};
|
};
|
||||||
|
|
||||||
Status Serialization::Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph,
|
Status Serialization::Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph,
|
||||||
|
@ -96,5 +99,13 @@ Status Serialization::Load(const std::vector<std::string> &files, ModelType mode
|
||||||
const Key &dec_key, const std::string &dec_mode) {
|
const Key &dec_key, const std::string &dec_mode) {
|
||||||
return Load(VectorStringToChar(files), model_type, graphs, dec_key, StringToChar(dec_mode));
|
return Load(VectorStringToChar(files), model_type, graphs, dec_key, StringToChar(dec_mode));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status Serialization::ExportModel(const Model &model, ModelType model_type, const std::string &model_file,
|
||||||
|
QuantizationType quantization_type, bool export_inference_only,
|
||||||
|
std::vector<std::string> output_tensor_name) {
|
||||||
|
return ExportModel(model, model_type, StringToChar(model_file), quantization_type, export_inference_only,
|
||||||
|
VectorStringToChar(output_tensor_name));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_INCLUDE_API_SERIALIZATION_H
|
#endif // MINDSPORE_INCLUDE_API_SERIALIZATION_H
|
||||||
|
|
|
@ -221,7 +221,7 @@ class MS_API MSTensor {
|
||||||
/// \brief Set the name for the MSTensor. Only valid for Lite.
|
/// \brief Set the name for the MSTensor. Only valid for Lite.
|
||||||
///
|
///
|
||||||
/// \param[in] The name of the MSTensor.
|
/// \param[in] The name of the MSTensor.
|
||||||
void SetTensorName(const std::string &name);
|
inline void SetTensorName(const std::string &name);
|
||||||
|
|
||||||
/// \brief Set the Allocator for the MSTensor. Only valid for Lite.
|
/// \brief Set the Allocator for the MSTensor. Only valid for Lite.
|
||||||
///
|
///
|
||||||
|
@ -275,6 +275,7 @@ class MS_API MSTensor {
|
||||||
MSTensor(const std::vector<char> &name, enum DataType type, const std::vector<int64_t> &shape, const void *data,
|
MSTensor(const std::vector<char> &name, enum DataType type, const std::vector<int64_t> &shape, const void *data,
|
||||||
size_t data_len);
|
size_t data_len);
|
||||||
std::vector<char> CharName() const;
|
std::vector<char> CharName() const;
|
||||||
|
void SetTensorName(const std::vector<char> &name);
|
||||||
|
|
||||||
friend class ModelImpl;
|
friend class ModelImpl;
|
||||||
std::shared_ptr<Impl> impl_;
|
std::shared_ptr<Impl> impl_;
|
||||||
|
@ -333,6 +334,10 @@ MSTensor::MSTensor(const std::string &name, enum DataType type, const std::vecto
|
||||||
|
|
||||||
std::string MSTensor::Name() const { return CharToString(CharName()); }
|
std::string MSTensor::Name() const { return CharToString(CharName()); }
|
||||||
|
|
||||||
|
void MSTensor::SetTensorName(const std::string &name) {
|
||||||
|
return SetTensorName(StringToChar(name));
|
||||||
|
}
|
||||||
|
|
||||||
using Key = struct Key {
|
using Key = struct Key {
|
||||||
const size_t max_key_len = 32;
|
const size_t max_key_len = 32;
|
||||||
size_t len;
|
size_t len;
|
||||||
|
|
|
@ -66,13 +66,13 @@ Status Model::Build(GraphCell graph_cell, const std::shared_ptr<Context> &model_
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Model::Build(const void *, size_t, ModelType, const std::shared_ptr<Context> &, const Key &,
|
Status Model::Build(const void *, size_t, ModelType, const std::shared_ptr<Context> &, const Key &,
|
||||||
const std::string &) {
|
const std::vector<char> &) {
|
||||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||||
return kMCFailed;
|
return kMCFailed;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Model::Build(const std::string &, ModelType, const std::shared_ptr<Context> &, const Key &,
|
Status Model::Build(const std::vector<char> &, ModelType, const std::shared_ptr<Context> &, const Key &,
|
||||||
const std::string &) {
|
const std::vector<char> &) {
|
||||||
MS_LOG(ERROR) << "Unsupported Feature.";
|
MS_LOG(ERROR) << "Unsupported Feature.";
|
||||||
return kMCFailed;
|
return kMCFailed;
|
||||||
}
|
}
|
||||||
|
|
|
@ -340,8 +340,8 @@ Status Serialization::ExportModel(const Model &, ModelType, Buffer *) {
|
||||||
return kMEFailed;
|
return kMEFailed;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Serialization::ExportModel(const Model &, ModelType, const std::string &, QuantizationType, bool,
|
Status Serialization::ExportModel(const Model &, ModelType, const std::vector<char> &, QuantizationType, bool,
|
||||||
std::vector<std::string> output_tensor_name) {
|
const std::vector<std::vector<char>> &output_tensor_name) {
|
||||||
MS_LOG(ERROR) << "Unsupported feature.";
|
MS_LOG(ERROR) << "Unsupported feature.";
|
||||||
return kMEFailed;
|
return kMEFailed;
|
||||||
}
|
}
|
||||||
|
|
|
@ -429,7 +429,7 @@ void MSTensor::SetShape(const std::vector<int64_t> &) { MS_LOG_EXCEPTION << "Inv
|
||||||
|
|
||||||
void MSTensor::SetDataType(enum DataType) { MS_LOG_EXCEPTION << "Invalid implement."; }
|
void MSTensor::SetDataType(enum DataType) { MS_LOG_EXCEPTION << "Invalid implement."; }
|
||||||
|
|
||||||
void MSTensor::SetTensorName(const std::string &) { MS_LOG_EXCEPTION << "Invalid implement."; }
|
void MSTensor::SetTensorName(const std::vector<char> &) { MS_LOG_EXCEPTION << "Invalid implement."; }
|
||||||
|
|
||||||
void MSTensor::SetAllocator(std::shared_ptr<Allocator>) { MS_LOG_EXCEPTION << "Invalid implement."; }
|
void MSTensor::SetAllocator(std::shared_ptr<Allocator>) { MS_LOG_EXCEPTION << "Invalid implement."; }
|
||||||
|
|
||||||
|
|
|
@ -18,7 +18,9 @@
|
||||||
#define MINDSPORE_LITE_INCLUDE_REGISTRY_NODE_PARSER_REGISTRY_H_
|
#define MINDSPORE_LITE_INCLUDE_REGISTRY_NODE_PARSER_REGISTRY_H_
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
#include "include/registry/node_parser.h"
|
#include "include/registry/node_parser.h"
|
||||||
|
#include "include/api/dual_abi_helper.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace registry {
|
namespace registry {
|
||||||
|
@ -30,8 +32,8 @@ class MS_API NodeParserRegistry {
|
||||||
/// \param[in] fmk_type Define the framework.
|
/// \param[in] fmk_type Define the framework.
|
||||||
/// \param[in] node_type Define the type of the node to be resolved.
|
/// \param[in] node_type Define the type of the node to be resolved.
|
||||||
/// \param[in] node_parser Define the NodeParser instance to parse the node.
|
/// \param[in] node_parser Define the NodeParser instance to parse the node.
|
||||||
NodeParserRegistry(converter::FmkType fmk_type, const std::string &node_type,
|
inline NodeParserRegistry(converter::FmkType fmk_type, const std::string &node_type,
|
||||||
const converter::NodeParserPtr &node_parser);
|
const converter::NodeParserPtr &node_parser);
|
||||||
|
|
||||||
/// \brief Destructor
|
/// \brief Destructor
|
||||||
~NodeParserRegistry() = default;
|
~NodeParserRegistry() = default;
|
||||||
|
@ -42,9 +44,22 @@ class MS_API NodeParserRegistry {
|
||||||
/// \param[in] node_type Define the type of the node to be resolved.
|
/// \param[in] node_type Define the type of the node to be resolved.
|
||||||
///
|
///
|
||||||
/// \return NodeParser instance.
|
/// \return NodeParser instance.
|
||||||
static converter::NodeParserPtr GetNodeParser(converter::FmkType fmk_type, const std::string &node_type);
|
inline static converter::NodeParserPtr GetNodeParser(converter::FmkType fmk_type, const std::string &node_type);
|
||||||
|
|
||||||
|
private:
|
||||||
|
NodeParserRegistry(converter::FmkType fmk_type, const std::vector<char> &node_type,
|
||||||
|
const converter::NodeParserPtr &node_parser);
|
||||||
|
static converter::NodeParserPtr GetNodeParser(converter::FmkType fmk_type, const std::vector<char> &node_type);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
NodeParserRegistry::NodeParserRegistry(converter::FmkType fmk_type, const std::string &node_type,
|
||||||
|
const converter::NodeParserPtr &node_parser)
|
||||||
|
: NodeParserRegistry(fmk_type, StringToChar(node_type), node_parser) {}
|
||||||
|
|
||||||
|
converter::NodeParserPtr NodeParserRegistry::GetNodeParser(converter::FmkType fmk_type, const std::string &node_type) {
|
||||||
|
return GetNodeParser(fmk_type, StringToChar(node_type));
|
||||||
|
}
|
||||||
|
|
||||||
/// \brief Defined registering macro to register NodeParser instance.
|
/// \brief Defined registering macro to register NodeParser instance.
|
||||||
///
|
///
|
||||||
/// \param[in] fmk_type Define the framework.
|
/// \param[in] fmk_type Define the framework.
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
#ifndef MINDSPORE_LITE_INCLUDE_REGISTRY_OPENCL_RUNTIME_WRAPPER_H
|
#ifndef MINDSPORE_LITE_INCLUDE_REGISTRY_OPENCL_RUNTIME_WRAPPER_H
|
||||||
#define MINDSPORE_LITE_INCLUDE_REGISTRY_OPENCL_RUNTIME_WRAPPER_H
|
#define MINDSPORE_LITE_INCLUDE_REGISTRY_OPENCL_RUNTIME_WRAPPER_H
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
@ -26,6 +27,7 @@
|
||||||
#include "CL/cl2.hpp"
|
#include "CL/cl2.hpp"
|
||||||
#include "include/api/allocator.h"
|
#include "include/api/allocator.h"
|
||||||
#include "include/api/status.h"
|
#include "include/api/status.h"
|
||||||
|
#include "include/api/dual_abi_helper.h"
|
||||||
|
|
||||||
namespace mindspore::registry::opencl {
|
namespace mindspore::registry::opencl {
|
||||||
class OpenCLRuntimeWrapper {
|
class OpenCLRuntimeWrapper {
|
||||||
|
@ -39,7 +41,7 @@ class OpenCLRuntimeWrapper {
|
||||||
/// \param[in] source Define OpenCl source.
|
/// \param[in] source Define OpenCl source.
|
||||||
///
|
///
|
||||||
/// \return Status as a status identification of loading code.
|
/// \return Status as a status identification of loading code.
|
||||||
Status LoadSource(const std::string &program_name, const std::string &source);
|
inline Status LoadSource(const std::string &program_name, const std::string &source);
|
||||||
|
|
||||||
/// \brief Building OpenCL code.
|
/// \brief Building OpenCL code.
|
||||||
///
|
///
|
||||||
|
@ -49,8 +51,8 @@ class OpenCLRuntimeWrapper {
|
||||||
/// \param[in] build_options_ext Define OpenCl kernel build options.
|
/// \param[in] build_options_ext Define OpenCl kernel build options.
|
||||||
///
|
///
|
||||||
/// \return Status as a status identification of build Kernel
|
/// \return Status as a status identification of build Kernel
|
||||||
Status BuildKernel(cl::Kernel *kernel, const std::string &program_name, const std::string &kernel_name,
|
inline Status BuildKernel(cl::Kernel *kernel, const std::string &program_name, const std::string &kernel_name,
|
||||||
const std::vector<std::string> &build_options_ext = {});
|
const std::vector<std::string> &build_options_ext = {});
|
||||||
|
|
||||||
/// \brief Set kernel argument
|
/// \brief Set kernel argument
|
||||||
///
|
///
|
||||||
|
@ -114,6 +116,23 @@ class OpenCLRuntimeWrapper {
|
||||||
uint64_t GetMaxImage2DHeight();
|
uint64_t GetMaxImage2DHeight();
|
||||||
|
|
||||||
uint64_t GetImagePitchAlignment();
|
uint64_t GetImagePitchAlignment();
|
||||||
|
|
||||||
|
private:
|
||||||
|
Status LoadSource(const std::vector<char> &program_name, const std::vector<char> &source);
|
||||||
|
|
||||||
|
Status BuildKernel(cl::Kernel *kernel, const std::vector<char> &program_name, const std::vector<char> &kernel_name,
|
||||||
|
const std::vector<std::vector<char>> &build_options_ext);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Status OpenCLRuntimeWrapper::LoadSource(const std::string &program_name, const std::string &source) {
|
||||||
|
return LoadSource(StringToChar(program_name), StringToChar(source));
|
||||||
|
}
|
||||||
|
|
||||||
|
Status OpenCLRuntimeWrapper::BuildKernel(cl::Kernel *kernel, const std::string &program_name,
|
||||||
|
const std::string &kernel_name,
|
||||||
|
const std::vector<std::string> &build_options_ext) {
|
||||||
|
return BuildKernel(kernel, StringToChar(program_name), StringToChar(kernel_name),
|
||||||
|
VectorStringToChar(build_options_ext));
|
||||||
|
}
|
||||||
} // namespace mindspore::registry::opencl
|
} // namespace mindspore::registry::opencl
|
||||||
#endif // MINDSPORE_LITE_INCLUDE_REGISTRY_OPENCL_RUNTIME_WRAPPER_H
|
#endif // MINDSPORE_LITE_INCLUDE_REGISTRY_OPENCL_RUNTIME_WRAPPER_H
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include "include/lite_utils.h"
|
#include "include/lite_utils.h"
|
||||||
|
#include "include/api/dual_abi_helper.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace registry {
|
namespace registry {
|
||||||
|
@ -36,13 +37,13 @@ class MS_API PassRegistry {
|
||||||
///
|
///
|
||||||
/// \param[in] pass_name Define the name of the pass, a string which should guarantee uniqueness.
|
/// \param[in] pass_name Define the name of the pass, a string which should guarantee uniqueness.
|
||||||
/// \param[in] pass Define pass instance.
|
/// \param[in] pass Define pass instance.
|
||||||
PassRegistry(const std::string &pass_name, const PassBasePtr &pass);
|
inline PassRegistry(const std::string &pass_name, const PassBasePtr &pass);
|
||||||
|
|
||||||
/// \brief Constructor of PassRegistry to assign which passes are required for external extension.
|
/// \brief Constructor of PassRegistry to assign which passes are required for external extension.
|
||||||
///
|
///
|
||||||
/// \param[in] position Define the place where assigned passes will run.
|
/// \param[in] position Define the place where assigned passes will run.
|
||||||
/// \param[in] names Define the names of the passes.
|
/// \param[in] names Define the names of the passes.
|
||||||
PassRegistry(PassPosition position, const std::vector<std::string> &names);
|
inline PassRegistry(PassPosition position, const std::vector<std::string> &names);
|
||||||
|
|
||||||
/// \brief Destructor of PassRegistrar.
|
/// \brief Destructor of PassRegistrar.
|
||||||
~PassRegistry() = default;
|
~PassRegistry() = default;
|
||||||
|
@ -52,16 +53,35 @@ class MS_API PassRegistry {
|
||||||
/// \param[in] position Define the place where assigned passes will run.
|
/// \param[in] position Define the place where assigned passes will run.
|
||||||
///
|
///
|
||||||
/// \return Passes' Name Vector.
|
/// \return Passes' Name Vector.
|
||||||
static std::vector<std::string> GetOuterScheduleTask(PassPosition position);
|
inline static std::vector<std::string> GetOuterScheduleTask(PassPosition position);
|
||||||
|
|
||||||
/// \brief Static method to obtain pass instance according to passes' name.
|
/// \brief Static method to obtain pass instance according to passes' name.
|
||||||
///
|
///
|
||||||
/// \param[in] pass_names Define the name of pass.
|
/// \param[in] pass_names Define the name of pass.
|
||||||
///
|
///
|
||||||
/// \return Pass Instance Vector.
|
/// \return Pass Instance Vector.
|
||||||
static PassBasePtr GetPassFromStoreRoom(const std::string &pass_name);
|
inline static PassBasePtr GetPassFromStoreRoom(const std::string &pass_name);
|
||||||
|
|
||||||
|
private:
|
||||||
|
PassRegistry(const std::vector<char> &pass_name, const PassBasePtr &pass);
|
||||||
|
PassRegistry(PassPosition position, const std::vector<std::vector<char>> &names);
|
||||||
|
static std::vector<std::vector<char>> GetOuterScheduleTaskInner(PassPosition position);
|
||||||
|
static PassBasePtr GetPassFromStoreRoom(const std::vector<char> &pass_name_char);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
PassRegistry::PassRegistry(const std::string &pass_name, const PassBasePtr &pass)
|
||||||
|
: PassRegistry(StringToChar(pass_name), pass) {}
|
||||||
|
|
||||||
|
PassRegistry::PassRegistry(PassPosition position, const std::vector<std::string> &names)
|
||||||
|
: PassRegistry(position, VectorStringToChar(names)) {}
|
||||||
|
|
||||||
|
std::vector<std::string> PassRegistry::GetOuterScheduleTask(PassPosition position) {
|
||||||
|
return VectorCharToString(GetOuterScheduleTaskInner(position));
|
||||||
|
}
|
||||||
|
|
||||||
|
PassBasePtr PassRegistry::GetPassFromStoreRoom(const std::string &pass_name) {
|
||||||
|
return GetPassFromStoreRoom(StringToChar(pass_name));
|
||||||
|
}
|
||||||
/// \brief Defined registering macro to register Pass, which called by user directly.
|
/// \brief Defined registering macro to register Pass, which called by user directly.
|
||||||
///
|
///
|
||||||
/// \param[in] name Define the name of the pass, a string which should guarantee uniqueness.
|
/// \param[in] name Define the name of the pass, a string which should guarantee uniqueness.
|
||||||
|
|
|
@ -38,6 +38,14 @@ struct KernelDesc {
|
||||||
std::string provider; /**< user identification argument */
|
std::string provider; /**< user identification argument */
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// \brief KernelDesc defined kernel's basic attribute.
|
||||||
|
struct KernelDescHelper {
|
||||||
|
DataType data_type; /**< kernel data type argument */
|
||||||
|
int type; /**< op type argument */
|
||||||
|
std::vector<char> arch; /**< deviceType argument */
|
||||||
|
std::vector<char> provider; /**< user identification argument */
|
||||||
|
};
|
||||||
|
|
||||||
/// \brief CreateKernel Defined a functor to create a kernel.
|
/// \brief CreateKernel Defined a functor to create a kernel.
|
||||||
///
|
///
|
||||||
/// \param[in] inputs Define input tensors of kernel.
|
/// \param[in] inputs Define input tensors of kernel.
|
||||||
|
@ -62,8 +70,8 @@ class MS_API RegisterKernel {
|
||||||
/// \param[in] creator Define a function pointer to create a kernel.
|
/// \param[in] creator Define a function pointer to create a kernel.
|
||||||
///
|
///
|
||||||
/// \return Status as a status identification of registering.
|
/// \return Status as a status identification of registering.
|
||||||
static Status RegKernel(const std::string &arch, const std::string &provider, DataType data_type, int type,
|
inline static Status RegKernel(const std::string &arch, const std::string &provider, DataType data_type, int type,
|
||||||
CreateKernel creator);
|
CreateKernel creator);
|
||||||
|
|
||||||
/// \brief Static method to register kernel which is corresponding to custom op.
|
/// \brief Static method to register kernel which is corresponding to custom op.
|
||||||
///
|
///
|
||||||
|
@ -74,8 +82,8 @@ class MS_API RegisterKernel {
|
||||||
/// \param[in] creator Define a function pointer to create a kernel.
|
/// \param[in] creator Define a function pointer to create a kernel.
|
||||||
///
|
///
|
||||||
/// \return Status as a status identification of registering.
|
/// \return Status as a status identification of registering.
|
||||||
static Status RegCustomKernel(const std::string &arch, const std::string &provider, DataType data_type,
|
inline static Status RegCustomKernel(const std::string &arch, const std::string &provider, DataType data_type,
|
||||||
const std::string &type, CreateKernel creator);
|
const std::string &type, CreateKernel creator);
|
||||||
|
|
||||||
/// \brief Static methon to get a kernel's create function.
|
/// \brief Static methon to get a kernel's create function.
|
||||||
///
|
///
|
||||||
|
@ -83,7 +91,14 @@ class MS_API RegisterKernel {
|
||||||
/// \param[in] primitive Define the primitive of kernel generated by flatbuffers.
|
/// \param[in] primitive Define the primitive of kernel generated by flatbuffers.
|
||||||
///
|
///
|
||||||
/// \return Function pointer to create a kernel.
|
/// \return Function pointer to create a kernel.
|
||||||
static CreateKernel GetCreator(const schema::Primitive *primitive, KernelDesc *desc);
|
inline static CreateKernel GetCreator(const schema::Primitive *primitive, KernelDesc *desc);
|
||||||
|
|
||||||
|
private:
|
||||||
|
static Status RegKernel(const std::vector<char> &arch, const std::vector<char> &provider, DataType data_type,
|
||||||
|
int type, CreateKernel creator);
|
||||||
|
static Status RegCustomKernel(const std::vector<char> &arch, const std::vector<char> &provider, DataType data_type,
|
||||||
|
const std::vector<char> &type, CreateKernel creator);
|
||||||
|
static CreateKernel GetCreator(const schema::Primitive *primitive, KernelDescHelper *desc);
|
||||||
};
|
};
|
||||||
|
|
||||||
/// \brief KernelReg Defined registration class of kernel.
|
/// \brief KernelReg Defined registration class of kernel.
|
||||||
|
@ -117,6 +132,24 @@ class MS_API KernelReg {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Status RegisterKernel::RegKernel(const std::string &arch, const std::string &provider, DataType data_type, int type,
|
||||||
|
CreateKernel creator) {
|
||||||
|
return RegKernel(StringToChar(arch), StringToChar(provider), data_type, type, creator);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status RegisterKernel::RegCustomKernel(const std::string &arch, const std::string &provider, DataType data_type,
|
||||||
|
const std::string &type, CreateKernel creator) {
|
||||||
|
return RegCustomKernel(StringToChar(arch), StringToChar(provider), data_type, StringToChar(type), creator);
|
||||||
|
}
|
||||||
|
|
||||||
|
CreateKernel RegisterKernel::GetCreator(const schema::Primitive *primitive, KernelDesc *desc) {
|
||||||
|
if (desc == nullptr) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
KernelDescHelper kernel_desc = {desc->data_type, desc->type, StringToChar(desc->arch), StringToChar(desc->provider)};
|
||||||
|
return GetCreator(primitive, &kernel_desc);
|
||||||
|
}
|
||||||
|
|
||||||
/// \brief Defined registering macro to register ordinary op kernel, which called by user directly.
|
/// \brief Defined registering macro to register ordinary op kernel, which called by user directly.
|
||||||
///
|
///
|
||||||
/// \param[in] arch Define deviceType, such as CPU.
|
/// \param[in] arch Define deviceType, such as CPU.
|
||||||
|
|
|
@ -39,7 +39,8 @@ class MS_API RegisterKernelInterface {
|
||||||
/// \param[in] creator Define the KernelInterface create function.
|
/// \param[in] creator Define the KernelInterface create function.
|
||||||
///
|
///
|
||||||
/// \return Status as a status identification of registering.
|
/// \return Status as a status identification of registering.
|
||||||
static Status CustomReg(const std::string &provider, const std::string &op_type, KernelInterfaceCreator creator);
|
inline static Status CustomReg(const std::string &provider, const std::string &op_type,
|
||||||
|
KernelInterfaceCreator creator);
|
||||||
|
|
||||||
/// \brief Static method to register op whose primitive type is ordinary.
|
/// \brief Static method to register op whose primitive type is ordinary.
|
||||||
///
|
///
|
||||||
|
@ -48,7 +49,7 @@ class MS_API RegisterKernelInterface {
|
||||||
/// \param[in] creator Define the KernelInterface create function.
|
/// \param[in] creator Define the KernelInterface create function.
|
||||||
///
|
///
|
||||||
/// \return Status as a status identification of registering.
|
/// \return Status as a status identification of registering.
|
||||||
static Status Reg(const std::string &provider, int op_type, KernelInterfaceCreator creator);
|
inline static Status Reg(const std::string &provider, int op_type, KernelInterfaceCreator creator);
|
||||||
|
|
||||||
/// \brief Static method to get registration of a certain op.
|
/// \brief Static method to get registration of a certain op.
|
||||||
///
|
///
|
||||||
|
@ -56,7 +57,14 @@ class MS_API RegisterKernelInterface {
|
||||||
/// \param[in] primitive Define the attributes of a certain op.
|
/// \param[in] primitive Define the attributes of a certain op.
|
||||||
///
|
///
|
||||||
/// \return Boolean value to represent registration of a certain op is existing or not.
|
/// \return Boolean value to represent registration of a certain op is existing or not.
|
||||||
static std::shared_ptr<kernel::KernelInterface> GetKernelInterface(const std::string &provider,
|
inline static std::shared_ptr<kernel::KernelInterface> GetKernelInterface(const std::string &provider,
|
||||||
|
const schema::Primitive *primitive);
|
||||||
|
|
||||||
|
private:
|
||||||
|
static Status CustomReg(const std::vector<char> &provider, const std::vector<char> &op_type,
|
||||||
|
KernelInterfaceCreator creator);
|
||||||
|
static Status Reg(const std::vector<char> &provider, int op_type, KernelInterfaceCreator creator);
|
||||||
|
static std::shared_ptr<kernel::KernelInterface> GetKernelInterface(const std::vector<char> &provider,
|
||||||
const schema::Primitive *primitive);
|
const schema::Primitive *primitive);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -82,6 +90,20 @@ class MS_API KernelInterfaceReg {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Status RegisterKernelInterface::CustomReg(const std::string &provider, const std::string &op_type,
|
||||||
|
KernelInterfaceCreator creator) {
|
||||||
|
return CustomReg(StringToChar(provider), StringToChar(op_type), creator);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status RegisterKernelInterface::Reg(const std::string &provider, int op_type, KernelInterfaceCreator creator) {
|
||||||
|
return Reg(StringToChar(provider), op_type, creator);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<kernel::KernelInterface> RegisterKernelInterface::GetKernelInterface(
|
||||||
|
const std::string &provider, const schema::Primitive *primitive) {
|
||||||
|
return GetKernelInterface(StringToChar(provider), primitive);
|
||||||
|
}
|
||||||
|
|
||||||
/// \brief Defined registering macro to register ordinary op, which called by user directly.
|
/// \brief Defined registering macro to register ordinary op, which called by user directly.
|
||||||
///
|
///
|
||||||
/// \param[in] provider Define the identification of user.
|
/// \param[in] provider Define the identification of user.
|
||||||
|
|
|
@ -23,8 +23,9 @@
|
||||||
#include "src/common/log_adapter.h"
|
#include "src/common/log_adapter.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
CkptSaver::CkptSaver(int save_every_n, const std::string &filename_prefix) {
|
CkptSaver::CkptSaver(int save_every_n, const std::vector<char> &filename_prefix) {
|
||||||
callback_impl_ = new (std::nothrow) CallbackImpl(new (std::nothrow) lite::CkptSaver(save_every_n, filename_prefix));
|
callback_impl_ =
|
||||||
|
new (std::nothrow) CallbackImpl(new (std::nothrow) lite::CkptSaver(save_every_n, CharToString(filename_prefix)));
|
||||||
if (callback_impl_ == nullptr) {
|
if (callback_impl_ == nullptr) {
|
||||||
MS_LOG(ERROR) << "Callback implement new failed";
|
MS_LOG(ERROR) << "Callback implement new failed";
|
||||||
}
|
}
|
||||||
|
|
|
@ -29,7 +29,8 @@ namespace mindspore {
|
||||||
std::mutex g_impl_init_lock;
|
std::mutex g_impl_init_lock;
|
||||||
|
|
||||||
Status Model::Build(const void *model_data, size_t data_size, ModelType model_type,
|
Status Model::Build(const void *model_data, size_t data_size, ModelType model_type,
|
||||||
const std::shared_ptr<Context> &model_context, const Key &dec_key, const std::string &dec_mode) {
|
const std::shared_ptr<Context> &model_context, const Key &dec_key,
|
||||||
|
const std::vector<char> &dec_mode) {
|
||||||
if (impl_ == nullptr) {
|
if (impl_ == nullptr) {
|
||||||
std::unique_lock<std::mutex> impl_lock(g_impl_init_lock);
|
std::unique_lock<std::mutex> impl_lock(g_impl_init_lock);
|
||||||
impl_ = std::shared_ptr<ModelImpl>(new (std::nothrow) ModelImpl());
|
impl_ = std::shared_ptr<ModelImpl>(new (std::nothrow) ModelImpl());
|
||||||
|
@ -46,8 +47,9 @@ Status Model::Build(const void *model_data, size_t data_size, ModelType model_ty
|
||||||
return kSuccess;
|
return kSuccess;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Model::Build(const std::string &model_path, ModelType model_type, const std::shared_ptr<Context> &model_context,
|
Status Model::Build(const std::vector<char> &model_path, ModelType model_type,
|
||||||
const Key &dec_key, const std::string &dec_mode) {
|
const std::shared_ptr<Context> &model_context, const Key &dec_key,
|
||||||
|
const std::vector<char> &dec_mode) {
|
||||||
if (impl_ == nullptr) {
|
if (impl_ == nullptr) {
|
||||||
std::unique_lock<std::mutex> impl_lock(g_impl_init_lock);
|
std::unique_lock<std::mutex> impl_lock(g_impl_init_lock);
|
||||||
impl_ = std::shared_ptr<ModelImpl>(new (std::nothrow) ModelImpl());
|
impl_ = std::shared_ptr<ModelImpl>(new (std::nothrow) ModelImpl());
|
||||||
|
@ -57,7 +59,7 @@ Status Model::Build(const std::string &model_path, ModelType model_type, const s
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ret = impl_->Build(model_path, model_type, model_context);
|
Status ret = impl_->Build(CharToString(model_path), model_type, model_context);
|
||||||
if (ret != kSuccess) {
|
if (ret != kSuccess) {
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
@ -186,7 +188,7 @@ std::vector<MSTensor> Model::GetOutputsByNodeName(const std::vector<char> &node_
|
||||||
return impl_->GetOutputsByNodeName(CharToString(node_name));
|
return impl_->GetOutputsByNodeName(CharToString(node_name));
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Model::LoadConfig(const std::string &config_path) {
|
Status Model::LoadConfig(const std::vector<char> &config_path) {
|
||||||
std::unique_lock<std::mutex> impl_lock(g_impl_init_lock);
|
std::unique_lock<std::mutex> impl_lock(g_impl_init_lock);
|
||||||
if (impl_ != nullptr) {
|
if (impl_ != nullptr) {
|
||||||
MS_LOG(ERROR) << "impl_ illegal in LoadConfig.";
|
MS_LOG(ERROR) << "impl_ illegal in LoadConfig.";
|
||||||
|
@ -199,7 +201,7 @@ Status Model::LoadConfig(const std::string &config_path) {
|
||||||
return Status(kLiteFileError, "Fail to load config file.");
|
return Status(kLiteFileError, "Fail to load config file.");
|
||||||
}
|
}
|
||||||
|
|
||||||
auto ret = impl_->LoadConfig(config_path);
|
auto ret = impl_->LoadConfig(CharToString(config_path));
|
||||||
if (ret != kSuccess) {
|
if (ret != kSuccess) {
|
||||||
MS_LOG(ERROR) << "impl_ LoadConfig failed,";
|
MS_LOG(ERROR) << "impl_ LoadConfig failed,";
|
||||||
return Status(kLiteFileError, "Invalid config file.");
|
return Status(kLiteFileError, "Invalid config file.");
|
||||||
|
|
|
@ -122,9 +122,9 @@ Status Serialization::ExportModel(const Model &model, ModelType model_type, Buff
|
||||||
return kMEFailed;
|
return kMEFailed;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Serialization::ExportModel(const Model &model, ModelType model_type, const std::string &model_file,
|
Status Serialization::ExportModel(const Model &model, ModelType model_type, const std::vector<char> &model_file,
|
||||||
QuantizationType quantization_type, bool export_inference_only,
|
QuantizationType quantization_type, bool export_inference_only,
|
||||||
std::vector<std::string> output_tensor_name) {
|
const std::vector<std::vector<char>> &output_tensor_name) {
|
||||||
if (model.impl_ == nullptr) {
|
if (model.impl_ == nullptr) {
|
||||||
MS_LOG(ERROR) << "Model implement is null.";
|
MS_LOG(ERROR) << "Model implement is null.";
|
||||||
return kLiteUninitializedObj;
|
return kLiteUninitializedObj;
|
||||||
|
@ -141,8 +141,9 @@ Status Serialization::ExportModel(const Model &model, ModelType model_type, cons
|
||||||
MS_LOG(ERROR) << "Model session is nullptr.";
|
MS_LOG(ERROR) << "Model session is nullptr.";
|
||||||
return kLiteError;
|
return kLiteError;
|
||||||
}
|
}
|
||||||
auto ret = model.impl_->session_->Export(model_file, export_inference_only ? lite::MT_INFERENCE : lite::MT_TRAIN,
|
auto ret = model.impl_->session_->Export(
|
||||||
A2L_ConvertQT(quantization_type), lite::FT_FLATBUFFERS, output_tensor_name);
|
CharToString(model_file), export_inference_only ? lite::MT_INFERENCE : lite::MT_TRAIN,
|
||||||
|
A2L_ConvertQT(quantization_type), lite::FT_FLATBUFFERS, VectorCharToString(output_tensor_name));
|
||||||
|
|
||||||
return (ret == mindspore::lite::RET_OK) ? kSuccess : kLiteError;
|
return (ret == mindspore::lite::RET_OK) ? kSuccess : kLiteError;
|
||||||
}
|
}
|
||||||
|
|
|
@ -302,12 +302,12 @@ void MSTensor::SetDataType(enum DataType data_type) {
|
||||||
impl_->SetDataType(data_type);
|
impl_->SetDataType(data_type);
|
||||||
}
|
}
|
||||||
|
|
||||||
void MSTensor::SetTensorName(const std::string &name) {
|
void MSTensor::SetTensorName(const std::vector<char> &name) {
|
||||||
if (impl_ == nullptr) {
|
if (impl_ == nullptr) {
|
||||||
MS_LOG(ERROR) << "Invalid tensor implement.";
|
MS_LOG(ERROR) << "Invalid tensor implement.";
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
impl_->SetName(name);
|
impl_->SetName(CharToString(name));
|
||||||
}
|
}
|
||||||
|
|
||||||
void MSTensor::SetAllocator(std::shared_ptr<Allocator> allocator) {
|
void MSTensor::SetAllocator(std::shared_ptr<Allocator> allocator) {
|
||||||
|
|
|
@ -22,29 +22,35 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace registry {
|
namespace registry {
|
||||||
Status RegisterKernel::RegCustomKernel(const std::string &arch, const std::string &provider, DataType data_type,
|
Status RegisterKernel::RegCustomKernel(const std::vector<char> &arch, const std::vector<char> &provider,
|
||||||
const std::string &type, CreateKernel creator) {
|
DataType data_type, const std::vector<char> &type, CreateKernel creator) {
|
||||||
#ifndef CUSTOM_KERNEL_REGISTRY_CLIP
|
#ifndef CUSTOM_KERNEL_REGISTRY_CLIP
|
||||||
return RegistryKernelImpl::GetInstance()->RegCustomKernel(arch, provider, data_type, type, creator);
|
return RegistryKernelImpl::GetInstance()->RegCustomKernel(CharToString(arch), CharToString(provider), data_type,
|
||||||
|
CharToString(type), creator);
|
||||||
#else
|
#else
|
||||||
MS_LOG(ERROR) << unsupport_custom_kernel_register_log;
|
MS_LOG(ERROR) << unsupport_custom_kernel_register_log;
|
||||||
return kLiteNotSupport;
|
return kLiteNotSupport;
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
Status RegisterKernel::RegKernel(const std::string &arch, const std::string &provider, DataType data_type, int op_type,
|
Status RegisterKernel::RegKernel(const std::vector<char> &arch, const std::vector<char> &provider, DataType data_type,
|
||||||
CreateKernel creator) {
|
int op_type, CreateKernel creator) {
|
||||||
#ifndef CUSTOM_KERNEL_REGISTRY_CLIP
|
#ifndef CUSTOM_KERNEL_REGISTRY_CLIP
|
||||||
return RegistryKernelImpl::GetInstance()->RegKernel(arch, provider, data_type, op_type, creator);
|
return RegistryKernelImpl::GetInstance()->RegKernel(CharToString(arch), CharToString(provider), data_type, op_type,
|
||||||
|
creator);
|
||||||
#else
|
#else
|
||||||
MS_LOG(ERROR) << unsupport_custom_kernel_register_log;
|
MS_LOG(ERROR) << unsupport_custom_kernel_register_log;
|
||||||
return kLiteNotSupport;
|
return kLiteNotSupport;
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
CreateKernel RegisterKernel::GetCreator(const schema::Primitive *primitive, KernelDesc *desc) {
|
CreateKernel RegisterKernel::GetCreator(const schema::Primitive *primitive, KernelDescHelper *desc) {
|
||||||
#ifndef CUSTOM_KERNEL_REGISTRY_CLIP
|
#ifndef CUSTOM_KERNEL_REGISTRY_CLIP
|
||||||
return RegistryKernelImpl::GetInstance()->GetProviderCreator(primitive, desc);
|
if (desc == nullptr) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
KernelDesc kernel_desc = {desc->data_type, desc->type, CharToString(desc->arch), CharToString(desc->provider)};
|
||||||
|
return RegistryKernelImpl::GetInstance()->GetProviderCreator(primitive, &kernel_desc);
|
||||||
#else
|
#else
|
||||||
MS_LOG(ERROR) << unsupport_custom_kernel_register_log;
|
MS_LOG(ERROR) << unsupport_custom_kernel_register_log;
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
|
@ -22,19 +22,19 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace registry {
|
namespace registry {
|
||||||
Status RegisterKernelInterface::Reg(const std::string &provider, int op_type, KernelInterfaceCreator creator) {
|
Status RegisterKernelInterface::Reg(const std::vector<char> &provider, int op_type, KernelInterfaceCreator creator) {
|
||||||
#ifndef CUSTOM_KERNEL_REGISTRY_CLIP
|
#ifndef CUSTOM_KERNEL_REGISTRY_CLIP
|
||||||
return KernelInterfaceRegistry::Instance()->Reg(provider, op_type, creator);
|
return KernelInterfaceRegistry::Instance()->Reg(CharToString(provider), op_type, creator);
|
||||||
#else
|
#else
|
||||||
MS_LOG(ERROR) << unsupport_custom_kernel_register_log;
|
MS_LOG(ERROR) << unsupport_custom_kernel_register_log;
|
||||||
return kLiteNotSupport;
|
return kLiteNotSupport;
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
Status RegisterKernelInterface::CustomReg(const std::string &provider, const std::string &op_type,
|
Status RegisterKernelInterface::CustomReg(const std::vector<char> &provider, const std::vector<char> &op_type,
|
||||||
KernelInterfaceCreator creator) {
|
KernelInterfaceCreator creator) {
|
||||||
#ifndef CUSTOM_KERNEL_REGISTRY_CLIP
|
#ifndef CUSTOM_KERNEL_REGISTRY_CLIP
|
||||||
return KernelInterfaceRegistry::Instance()->CustomReg(provider, op_type, creator);
|
return KernelInterfaceRegistry::Instance()->CustomReg(CharToString(provider), CharToString(op_type), creator);
|
||||||
#else
|
#else
|
||||||
MS_LOG(ERROR) << unsupport_custom_kernel_register_log;
|
MS_LOG(ERROR) << unsupport_custom_kernel_register_log;
|
||||||
return kLiteNotSupport;
|
return kLiteNotSupport;
|
||||||
|
@ -42,9 +42,9 @@ Status RegisterKernelInterface::CustomReg(const std::string &provider, const std
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<kernel::KernelInterface> RegisterKernelInterface::GetKernelInterface(
|
std::shared_ptr<kernel::KernelInterface> RegisterKernelInterface::GetKernelInterface(
|
||||||
const std::string &provider, const schema::Primitive *primitive) {
|
const std::vector<char> &provider, const schema::Primitive *primitive) {
|
||||||
#ifndef CUSTOM_KERNEL_REGISTRY_CLIP
|
#ifndef CUSTOM_KERNEL_REGISTRY_CLIP
|
||||||
return KernelInterfaceRegistry::Instance()->GetKernelInterface(provider, primitive);
|
return KernelInterfaceRegistry::Instance()->GetKernelInterface(CharToString(provider), primitive);
|
||||||
#else
|
#else
|
||||||
MS_LOG(ERROR) << unsupport_custom_kernel_register_log;
|
MS_LOG(ERROR) << unsupport_custom_kernel_register_log;
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
|
@ -31,24 +31,25 @@
|
||||||
using mindspore::kernel::CLErrorCode;
|
using mindspore::kernel::CLErrorCode;
|
||||||
|
|
||||||
namespace mindspore::registry::opencl {
|
namespace mindspore::registry::opencl {
|
||||||
Status OpenCLRuntimeWrapper::LoadSource(const std::string &program_name, const std::string &source) {
|
Status OpenCLRuntimeWrapper::LoadSource(const std::vector<char> &program_name, const std::vector<char> &source) {
|
||||||
lite::opencl::OpenCLRuntimeInnerWrapper ocl_runtime_wrap;
|
lite::opencl::OpenCLRuntimeInnerWrapper ocl_runtime_wrap;
|
||||||
lite::opencl::OpenCLRuntime *ocl_runtime = ocl_runtime_wrap.GetInstance();
|
lite::opencl::OpenCLRuntime *ocl_runtime = ocl_runtime_wrap.GetInstance();
|
||||||
const std::string program_name_ext = "provider_" + program_name;
|
const std::string program_name_ext = "provider_" + CharToString(program_name);
|
||||||
if (ocl_runtime->LoadSource(program_name_ext, source)) {
|
if (ocl_runtime->LoadSource(program_name_ext, CharToString(source))) {
|
||||||
return kSuccess;
|
return kSuccess;
|
||||||
} else {
|
} else {
|
||||||
return kLiteError;
|
return kLiteError;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Status OpenCLRuntimeWrapper::BuildKernel(cl::Kernel *kernel, const std::string &program_name,
|
Status OpenCLRuntimeWrapper::BuildKernel(cl::Kernel *kernel, const std::vector<char> &program_name,
|
||||||
const std::string &kernel_name,
|
const std::vector<char> &kernel_name,
|
||||||
const std::vector<std::string> &build_options_ext) {
|
const std::vector<std::vector<char>> &build_options_ext) {
|
||||||
lite::opencl::OpenCLRuntimeInnerWrapper ocl_runtime_wrap;
|
lite::opencl::OpenCLRuntimeInnerWrapper ocl_runtime_wrap;
|
||||||
lite::opencl::OpenCLRuntime *ocl_runtime = ocl_runtime_wrap.GetInstance();
|
lite::opencl::OpenCLRuntime *ocl_runtime = ocl_runtime_wrap.GetInstance();
|
||||||
const std::string program_name_ext = "provider_" + program_name;
|
const std::string program_name_ext = "provider_" + CharToString(program_name);
|
||||||
if (ocl_runtime->BuildKernel(*kernel, program_name_ext, kernel_name, build_options_ext, false) == RET_OK) {
|
if (ocl_runtime->BuildKernel(*kernel, program_name_ext, CharToString(kernel_name),
|
||||||
|
VectorCharToString(build_options_ext), false) == RET_OK) {
|
||||||
return kSuccess;
|
return kSuccess;
|
||||||
} else {
|
} else {
|
||||||
return kLiteError;
|
return kLiteError;
|
||||||
|
|
|
@ -285,7 +285,7 @@ class CustomAddKernel : public kernel::Kernel {
|
||||||
return lite::RET_OK;
|
return lite::RET_OK;
|
||||||
}
|
}
|
||||||
auto status =
|
auto status =
|
||||||
registry::RegisterKernelInterface::GetKernelInterface({}, primitive_)->Infer(&inputs_, &outputs_, primitive_);
|
registry::RegisterKernelInterface::GetKernelInterface("", primitive_)->Infer(&inputs_, &outputs_, primitive_);
|
||||||
if (status != kSuccess) {
|
if (status != kSuccess) {
|
||||||
std::cerr << "infer failed." << std::endl;
|
std::cerr << "infer failed." << std::endl;
|
||||||
return lite::RET_ERROR;
|
return lite::RET_ERROR;
|
||||||
|
|
|
@ -25,13 +25,15 @@ namespace {
|
||||||
std::map<converter::FmkType, std::map<std::string, converter::NodeParserPtr>> node_parser_room;
|
std::map<converter::FmkType, std::map<std::string, converter::NodeParserPtr>> node_parser_room;
|
||||||
std::mutex node_mutex;
|
std::mutex node_mutex;
|
||||||
} // namespace
|
} // namespace
|
||||||
NodeParserRegistry::NodeParserRegistry(converter::FmkType fmk_type, const std::string &node_type,
|
NodeParserRegistry::NodeParserRegistry(converter::FmkType fmk_type, const std::vector<char> &node_type,
|
||||||
const converter::NodeParserPtr &node_parser) {
|
const converter::NodeParserPtr &node_parser) {
|
||||||
std::unique_lock<std::mutex> lock(node_mutex);
|
std::unique_lock<std::mutex> lock(node_mutex);
|
||||||
node_parser_room[fmk_type][node_type] = node_parser;
|
std::string node_type_str = CharToString(node_type);
|
||||||
|
node_parser_room[fmk_type][node_type_str] = node_parser;
|
||||||
}
|
}
|
||||||
|
|
||||||
converter::NodeParserPtr NodeParserRegistry::GetNodeParser(converter::FmkType fmk_type, const std::string &node_type) {
|
converter::NodeParserPtr NodeParserRegistry::GetNodeParser(converter::FmkType fmk_type,
|
||||||
|
const std::vector<char> &node_type) {
|
||||||
auto iter_level1 = node_parser_room.find(fmk_type);
|
auto iter_level1 = node_parser_room.find(fmk_type);
|
||||||
if (iter_level1 == node_parser_room.end()) {
|
if (iter_level1 == node_parser_room.end()) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -39,7 +41,7 @@ converter::NodeParserPtr NodeParserRegistry::GetNodeParser(converter::FmkType fm
|
||||||
if (node_type.empty()) {
|
if (node_type.empty()) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
auto iter_level2 = iter_level1->second.find(node_type);
|
auto iter_level2 = iter_level1->second.find(CharToString(node_type));
|
||||||
if (iter_level2 == iter_level1->second.end()) {
|
if (iter_level2 == iter_level1->second.end()) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
|
@ -37,18 +37,21 @@ void RegPass(const std::string &pass_name, const PassBasePtr &pass) {
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
PassRegistry::PassRegistry(const std::string &pass_name, const PassBasePtr &pass) { RegPass(pass_name, pass); }
|
PassRegistry::PassRegistry(const std::vector<char> &pass_name, const PassBasePtr &pass) {
|
||||||
|
RegPass(CharToString(pass_name), pass);
|
||||||
|
}
|
||||||
|
|
||||||
PassRegistry::PassRegistry(PassPosition position, const std::vector<std::string> &names) {
|
PassRegistry::PassRegistry(PassPosition position, const std::vector<std::vector<char>> &names) {
|
||||||
std::unique_lock<std::mutex> lock(pass_mutex);
|
std::unique_lock<std::mutex> lock(pass_mutex);
|
||||||
external_assigned_passes[position] = names;
|
external_assigned_passes[position] = VectorCharToString(names);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::string> PassRegistry::GetOuterScheduleTask(PassPosition position) {
|
std::vector<std::vector<char>> PassRegistry::GetOuterScheduleTaskInner(PassPosition position) {
|
||||||
return external_assigned_passes[position];
|
return VectorStringToChar(external_assigned_passes[position]);
|
||||||
}
|
}
|
||||||
|
|
||||||
PassBasePtr PassRegistry::GetPassFromStoreRoom(const std::string &pass_name) {
|
PassBasePtr PassRegistry::GetPassFromStoreRoom(const std::vector<char> &pass_name_char) {
|
||||||
|
std::string pass_name = CharToString(pass_name_char);
|
||||||
return outer_pass_storage.find(pass_name) == outer_pass_storage.end() ? nullptr : outer_pass_storage[pass_name];
|
return outer_pass_storage.find(pass_name) == outer_pass_storage.end() ? nullptr : outer_pass_storage[pass_name];
|
||||||
}
|
}
|
||||||
} // namespace registry
|
} // namespace registry
|
||||||
|
|
Loading…
Reference in New Issue