diff --git a/include/api/dual_abi_helper.h b/include/api/dual_abi_helper.h index 1804ad8a9e8..4c35255d5ac 100644 --- a/include/api/dual_abi_helper.h +++ b/include/api/dual_abi_helper.h @@ -67,18 +67,20 @@ inline std::set SetCharToString(const std::set> & return ret; } -inline std::map, int32_t> MapStringToChar(const std::map &s) { - std::map, int32_t> ret; +template +inline std::map, T> MapStringToChar(const std::map &s) { + std::map, T> ret; std::transform(s.begin(), s.end(), std::inserter(ret, ret.begin()), [](auto str) { - return std::pair, int32_t>(std::vector(str.first.begin(), str.first.end()), str.second); + return std::pair, T>(std::vector(str.first.begin(), str.first.end()), str.second); }); return ret; } -inline std::map MapCharToString(const std::map, int32_t> &c) { - std::map ret; +template +inline std::map MapCharToString(const std::map, T> &c) { + std::map ret; std::transform(c.begin(), c.end(), std::inserter(ret, ret.begin()), [](auto ch) { - return std::pair(std::string(ch.first.begin(), ch.first.end()), ch.second); + return std::pair(std::string(ch.first.begin(), ch.first.end()), ch.second); }); return ret; } @@ -151,24 +153,6 @@ inline std::vector, int64_t>> PairStringInt64ToPairC return ret; } -template -inline std::map, T> PadInfoStringToChar(const std::map &s_pad_info) { - std::map, T> ret; - std::transform(s_pad_info.begin(), s_pad_info.end(), std::inserter(ret, ret.begin()), [](auto str) { - return std::pair, T>(std::vector(str.first.begin(), str.first.end()), str.second); - }); - return ret; -} - -template -inline std::map PadInfoCharToString(const std::map, T> &c_pad_info) { - std::map ret; - std::transform(c_pad_info.begin(), c_pad_info.end(), std::inserter(ret, ret.begin()), [](auto ch) { - return std::pair(std::string(ch.first.begin(), ch.first.end()), ch.second); - }); - return ret; -} - template inline void TensorMapCharToString(const std::map, T> *c, std::unordered_map *s) { if (c == nullptr || s == nullptr) { diff --git a/include/api/serialization.h b/include/api/serialization.h index 27a1ede0b1c..0532583243a 100644 --- a/include/api/serialization.h +++ b/include/api/serialization.h @@ -129,7 +129,7 @@ Status Serialization::Load(const std::vector &files, ModelType mode } Status Serialization::SetParameters(const std::map ¶meters, Model *model) { - return SetParameters(PadInfoStringToChar(parameters), model); + return SetParameters(MapStringToChar(parameters), model); } Status Serialization::ExportModel(const Model &model, ModelType model_type, const std::string &model_file, diff --git a/mindspore/ccsrc/minddata/dataset/api/datasets.cc b/mindspore/ccsrc/minddata/dataset/api/datasets.cc index e77701c0310..080b86b40d2 100644 --- a/mindspore/ccsrc/minddata/dataset/api/datasets.cc +++ b/mindspore/ccsrc/minddata/dataset/api/datasets.cc @@ -446,9 +446,9 @@ BucketBatchByLengthDataset::BucketBatchByLengthDataset( if (input == nullptr) { ir_node_ = nullptr; } else { - auto ds = std::make_shared( - input->IRNode(), VectorCharToString(column_names), bucket_boundaries, bucket_batch_sizes, c_func, - PadInfoCharToString(map), pad_to_bucket_boundary, drop_remainder); + auto ds = std::make_shared(input->IRNode(), VectorCharToString(column_names), + bucket_boundaries, bucket_batch_sizes, c_func, + MapCharToString(map), pad_to_bucket_boundary, drop_remainder); ir_node_ = std::static_pointer_cast(ds); } diff --git a/mindspore/ccsrc/minddata/dataset/include/dataset/datasets.h b/mindspore/ccsrc/minddata/dataset/include/dataset/datasets.h index 6381ea4f358..42aaf86b834 100644 --- a/mindspore/ccsrc/minddata/dataset/include/dataset/datasets.h +++ b/mindspore/ccsrc/minddata/dataset/include/dataset/datasets.h @@ -262,7 +262,7 @@ class DATASET_API Dataset : public std::enable_shared_from_this { bool pad_to_bucket_boundary = false, bool drop_remainder = false) { return std::make_shared( shared_from_this(), VectorStringToChar(column_names), bucket_boundaries, bucket_batch_sizes, - element_length_function, PadInfoStringToChar(pad_info), pad_to_bucket_boundary, drop_remainder); + element_length_function, MapStringToChar(pad_info), pad_to_bucket_boundary, drop_remainder); } /// \brief Function to create a SentencePieceVocab from source dataset. diff --git a/mindspore/lite/include/converter.h b/mindspore/lite/include/converter.h index b5eb2d713ef..dcb2fd6a031 100644 --- a/mindspore/lite/include/converter.h +++ b/mindspore/lite/include/converter.h @@ -23,26 +23,27 @@ #include "include/api/format.h" #include "include/api/status.h" #include "include/registry/converter_context.h" +#include "include/api/dual_abi_helper.h" namespace mindspore { struct ConverterPara; class MS_API Converter { public: - Converter(converter::FmkType fmk_type, const std::string &model_file, const std::string &output_file = "", - const std::string &weight_file = ""); + inline Converter(converter::FmkType fmk_type, const std::string &model_file, const std::string &output_file = "", + const std::string &weight_file = ""); ~Converter() = default; - void SetConfigFile(const std::string &config_file); - std::string GetConfigFile() const; + inline void SetConfigFile(const std::string &config_file); + inline std::string GetConfigFile() const; - void SetConfigInfo(const std::string §ion, const std::map &config); - std::map> GetConfigInfo() const; + inline void SetConfigInfo(const std::string §ion, const std::map &config); + inline std::map> GetConfigInfo() const; void SetWeightFp16(bool weight_fp16); bool GetWeightFp16() const; - void SetInputShape(const std::map> &input_shape); - std::map> GetInputShape() const; + inline void SetInputShape(const std::map> &input_shape); + inline std::map> GetInputShape() const; void SetInputFormat(Format format); Format GetInputFormat() const; @@ -56,17 +57,17 @@ class MS_API Converter { void SetExportMindIR(ModelType export_mindir); ModelType GetExportMindIR() const; - void SetDecryptKey(const std::string &key); - std::string GetDecryptKey() const; + inline void SetDecryptKey(const std::string &key); + inline std::string GetDecryptKey() const; - void SetDecryptMode(const std::string &mode); - std::string GetDecryptMode() const; + inline void SetDecryptMode(const std::string &mode); + inline std::string GetDecryptMode() const; void SetEnableEncryption(bool encryption); bool GetEnableEncryption() const; - void SetEncryptKey(const std::string &key); - std::string GetEncryptKey() const; + inline void SetEncryptKey(const std::string &key); + inline std::string GetEncryptKey() const; void SetInfer(bool infer); bool GetInfer() const; @@ -77,14 +78,70 @@ class MS_API Converter { void SetNoFusion(bool no_fusion); bool GetNoFusion(); - void SetDevice(const std::string &device); - std::string GetDevice(); + inline void SetDevice(const std::string &device); + inline std::string GetDevice(); Status Convert(); void *Convert(size_t *data_size); private: + Converter(converter::FmkType fmk_type, const std::vector &model_file, const std::vector &output_file, + const std::vector &weight_file); + void SetConfigFile(const std::vector &config_file); + std::vector GetConfigFileChar() const; + void SetConfigInfo(const std::vector §ion, const std::map, std::vector> &config); + std::map, std::map, std::vector>> GetConfigInfoChar() const; + void SetInputShape(const std::map, std::vector> &input_shape); + std::map, std::vector> GetInputShapeChar() const; + void SetDecryptKey(const std::vector &key); + std::vector GetDecryptKeyChar() const; + void SetDecryptMode(const std::vector &mode); + std::vector GetDecryptModeChar() const; + void SetEncryptKey(const std::vector &key); + std::vector GetEncryptKeyChar() const; + void SetDevice(const std::vector &device); + std::vector GetDeviceChar(); std::shared_ptr data_; }; + +Converter::Converter(converter::FmkType fmk_type, const std::string &model_file, const std::string &output_file, + const std::string &weight_file) + : Converter(fmk_type, StringToChar(model_file), StringToChar(output_file), StringToChar(weight_file)) {} + +void Converter::SetConfigFile(const std::string &config_file) { SetConfigFile(StringToChar(config_file)); } + +std::string Converter::GetConfigFile() const { return CharToString(GetConfigFileChar()); } + +void Converter::SetConfigInfo(const std::string §ion, const std::map &config) { + SetConfigInfo(StringToChar(section), MapStringToVectorChar(config)); +} + +std::map> Converter::GetConfigInfo() const { + return MapMapCharToString(GetConfigInfoChar()); +} + +void Converter::SetInputShape(const std::map> &input_shape) { + SetInputShape(MapStringToChar(input_shape)); +} + +std::map> Converter::GetInputShape() const { + return MapCharToString(GetInputShapeChar()); +} + +void Converter::SetDecryptKey(const std::string &key) { SetDecryptKey(StringToChar(key)); } + +std::string Converter::GetDecryptKey() const { return CharToString(GetDecryptKeyChar()); } + +void Converter::SetDecryptMode(const std::string &mode) { SetDecryptMode(StringToChar(mode)); } + +std::string Converter::GetDecryptMode() const { return CharToString(GetDecryptModeChar()); } + +void Converter::SetEncryptKey(const std::string &key) { SetEncryptKey(StringToChar(key)); } + +std::string Converter::GetEncryptKey() const { return CharToString(GetEncryptKeyChar()); } + +void Converter::SetDevice(const std::string &device) { SetDevice(StringToChar(device)); } + +std::string Converter::GetDevice() { return CharToString(GetDeviceChar()); } } // namespace mindspore #endif // MINDSPORE_LITE_INCLUDE_CONVERTER_H_ diff --git a/mindspore/lite/python/src/converter_pybind.cc b/mindspore/lite/python/src/converter_pybind.cc index 56f1b35fc3b..6d29c72f2d1 100644 --- a/mindspore/lite/python/src/converter_pybind.cc +++ b/mindspore/lite/python/src/converter_pybind.cc @@ -32,13 +32,15 @@ void ConverterPyBind(const py::module &m) { py::class_>(m, "ConverterBind") .def(py::init()) - .def("set_config_file", &Converter::SetConfigFile) + .def("set_config_file", py::overload_cast(&Converter::SetConfigFile)) .def("get_config_file", &Converter::GetConfigFile) - .def("set_config_info", &Converter::SetConfigInfo) + .def("set_config_info", + py::overload_cast &>(&Converter::SetConfigInfo)) .def("get_config_info", &Converter::GetConfigInfo) .def("set_weight_fp16", &Converter::SetWeightFp16) .def("get_weight_fp16", &Converter::GetWeightFp16) - .def("set_input_shape", &Converter::SetInputShape) + .def("set_input_shape", + py::overload_cast> &>(&Converter::SetInputShape)) .def("get_input_shape", &Converter::GetInputShape) .def("set_input_format", &Converter::SetInputFormat) .def("get_input_format", &Converter::GetInputFormat) @@ -48,13 +50,13 @@ void ConverterPyBind(const py::module &m) { .def("get_output_data_type", &Converter::GetOutputDataType) .def("set_export_mindir", &Converter::SetExportMindIR) .def("get_export_mindir", &Converter::GetExportMindIR) - .def("set_decrypt_key", &Converter::SetDecryptKey) + .def("set_decrypt_key", py::overload_cast(&Converter::SetDecryptKey)) .def("get_decrypt_key", &Converter::GetDecryptKey) - .def("set_decrypt_mode", &Converter::SetDecryptMode) + .def("set_decrypt_mode", py::overload_cast(&Converter::SetDecryptMode)) .def("get_decrypt_mode", &Converter::GetDecryptMode) .def("set_enable_encryption", &Converter::SetEnableEncryption) .def("get_enable_encryption", &Converter::GetEnableEncryption) - .def("set_encrypt_key", &Converter::SetEncryptKey) + .def("set_encrypt_key", py::overload_cast(&Converter::SetEncryptKey)) .def("get_encrypt_key", &Converter::GetEncryptKey) .def("set_infer", &Converter::SetInfer) .def("get_infer", &Converter::GetInfer) diff --git a/mindspore/lite/tools/converter/cxx_api/converter.cc b/mindspore/lite/tools/converter/cxx_api/converter.cc index 62520bc09dd..ba168675eb2 100644 --- a/mindspore/lite/tools/converter/cxx_api/converter.cc +++ b/mindspore/lite/tools/converter/cxx_api/converter.cc @@ -28,54 +28,57 @@ constexpr size_t kMaxConfigNumPerSection = 1000; namespace lite { int RunConverter(const std::shared_ptr &data_); } -Converter::Converter(converter::FmkType fmk_type, const std::string &model_file, const std::string &output_file, - const std::string &weight_file) { +Converter::Converter(converter::FmkType fmk_type, const std::vector &model_file, + const std::vector &output_file, const std::vector &weight_file) { data_ = std::make_shared(); if (data_ != nullptr) { data_->fmk_type = fmk_type; - data_->model_file = model_file; - data_->output_file = output_file; - data_->weight_file = weight_file; + data_->model_file = CharToString(model_file); + data_->output_file = CharToString(output_file); + data_->weight_file = CharToString(weight_file); } else { MS_LOG(ERROR) << "Create ConverterPara failed"; } } -void Converter::SetConfigFile(const std::string &config_file) { +void Converter::SetConfigFile(const std::vector &config_file) { if (data_ != nullptr) { - data_->config_file = config_file; + data_->config_file = CharToString(config_file); } } -std::string Converter::GetConfigFile() const { +std::vector Converter::GetConfigFileChar() const { + std::string cfg_file = ""; if (data_ != nullptr) { - return data_->config_file; - } else { - return ""; + cfg_file = data_->config_file; } + return StringToChar(cfg_file); } -void Converter::SetConfigInfo(const std::string §ion, const std::map &config) { +void Converter::SetConfigInfo(const std::vector §ion, + const std::map, std::vector> &config) { + auto section_str = CharToString(section); + auto config_str = MapVectorCharToString(config); if (data_ != nullptr) { if (data_->config_param.size() > kMaxSectionNum) { MS_LOG(ERROR) << "Section num " << data_->config_param.size() << "exceeds max num " << kMaxSectionNum; return; } - if (data_->config_param.find(section) != data_->config_param.end()) { - MS_LOG(WARNING) << "Section " << section << "already exists, " + if (data_->config_param.find(section_str) != data_->config_param.end()) { + MS_LOG(WARNING) << "Section " << section_str << "already exists, " << "value will be overwrite."; } if (config.size() > kMaxConfigNumPerSection) { MS_LOG(ERROR) << "Config num " << config.size() << " exceeds max num " << kMaxConfigNumPerSection << " in " - << section; + << section_str; return; } - data_->config_param[section] = config; + data_->config_param[section_str] = config_str; } } -std::map> Converter::GetConfigInfo() const { - return data_->config_param; +std::map, std::map, std::vector>> Converter::GetConfigInfoChar() const { + return MapMapStringToChar(data_->config_param); } void Converter::SetWeightFp16(bool weight_fp16) { @@ -92,21 +95,22 @@ bool Converter::GetWeightFp16() const { } } -void Converter::SetInputShape(const std::map> &input_shape) { +void Converter::SetInputShape(const std::map, std::vector> &input_shape) { + auto input_shape_str = MapCharToString(input_shape); if (data_ != nullptr) { - for (auto &it : input_shape) { + for (auto &it : input_shape_str) { lite::ConverterInnerContext::GetInstance()->UpdateGraphInputTensorShape(it.first, it.second); } - data_->input_shape = input_shape; + data_->input_shape = input_shape_str; } } -std::map> Converter::GetInputShape() const { +std::map, std::vector> Converter::GetInputShapeChar() const { + std::map> input_shape = {}; if (data_ != nullptr) { - return data_->input_shape; - } else { - return {}; + input_shape = data_->input_shape; } + return MapStringToChar(input_shape); } void Converter::SetInputFormat(Format format) { @@ -168,32 +172,32 @@ ModelType Converter::GetExportMindIR() const { } } -void Converter::SetDecryptKey(const std::string &key) { +void Converter::SetDecryptKey(const std::vector &key) { if (data_ != nullptr) { - data_->decrypt_key = key; + data_->decrypt_key = CharToString(key); } } -std::string Converter::GetDecryptKey() const { +std::vector Converter::GetDecryptKeyChar() const { + std::string decrypt_key = ""; if (data_ != nullptr) { - return data_->decrypt_key; - } else { - return ""; + decrypt_key = data_->decrypt_key; + } + return StringToChar(decrypt_key); +} + +void Converter::SetDecryptMode(const std::vector &mode) { + if (data_ != nullptr) { + data_->decrypt_mode = CharToString(mode); } } -void Converter::SetDecryptMode(const std::string &mode) { +std::vector Converter::GetDecryptModeChar() const { + std::string decrypt_mode = ""; if (data_ != nullptr) { - data_->decrypt_mode = mode; - } -} - -std::string Converter::GetDecryptMode() const { - if (data_ != nullptr) { - return data_->decrypt_mode; - } else { - return ""; + decrypt_mode = data_->decrypt_mode; } + return StringToChar(decrypt_mode); } void Converter::SetEnableEncryption(bool encryption) { @@ -210,18 +214,18 @@ bool Converter::GetEnableEncryption() const { } } -void Converter::SetEncryptKey(const std::string &key) { +void Converter::SetEncryptKey(const std::vector &key) { if (data_ != nullptr) { - data_->encrypt_key = key; + data_->encrypt_key = CharToString(key); } } -std::string Converter::GetEncryptKey() const { +std::vector Converter::GetEncryptKeyChar() const { + std::string encrypt_key = ""; if (data_ != nullptr) { - return data_->encrypt_key; - } else { - return ""; + encrypt_key = data_->encrypt_key; } + return StringToChar(encrypt_key); } void Converter::SetInfer(bool infer) { @@ -266,18 +270,18 @@ bool Converter::GetNoFusion() { } } -void Converter::SetDevice(const std::string &device) { +void Converter::SetDevice(const std::vector &device) { if (data_ != nullptr) { - data_->device = device; + data_->device = CharToString(device); } } -std::string Converter::GetDevice() { +std::vector Converter::GetDeviceChar() { + std::string device = ""; if (data_ != nullptr) { - return data_->device; - } else { - return ""; + device = data_->device; } + return StringToChar(device); } Status Converter::Convert() {