!28896 [MS][LITE] model pool

Merge pull request !28896 from yefeng/196-model_pool
This commit is contained in:
i-robot 2022-01-19 06:53:29 +00:00 committed by Gitee
commit 133f248d53
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
12 changed files with 577 additions and 167 deletions

View File

@ -163,6 +163,9 @@ elseif(TOOLCHAIN_NAME STREQUAL "ohos-lite")
set(TARGET_OHOS_LITE on)
SET_PROPERTY(GLOBAL PROPERTY TARGET_SUPPORTS_SHARED_LIBS TRUE)
endif()
if(MSLITE_ENABLE_SERVING)
add_compile_definitions(USING_SERVING)
endif()
if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7.3.0
AND NOT TARGET_HIMIX AND NOT TARGET_MIX210)

View File

@ -30,9 +30,6 @@ endif()
if(NOT MSLITE_ENABLE_INT8)
add_compile_definitions(OP_INT8_CLIP)
endif()
if(MSLITE_ENABLE_SERVING)
add_compile_definitions(USING_SERVING)
endif()
if(APPLE OR PLATFORM_ARM32 OR PLATFORM_ARM64)
#for performance

View File

@ -0,0 +1,51 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifdef USING_SERVING
#include "src/cxx_api/model/model_parallel_runner.h"
#include "src/cxx_api/model/model_pool.h"
#include "src/common/log.h"
namespace mindspore {
Status ModelParallelRunner::Init(const std::string &model_path, const std::string &config_path, const Key &dec_key,
const std::string &dec_mode) {
auto status = ModelPool::GetInstance()->Init(model_path, config_path, dec_key, dec_mode);
if (status != kSuccess) {
MS_LOG(ERROR) << "model runner init failed.";
return kLiteError;
}
return status;
}
std::vector<MSTensor> ModelParallelRunner::GetInputs() {
std::vector<MSTensor> model_inputs = {};
return model_inputs;
}
Status ModelParallelRunner::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
const MSKernelCallBack &before, const MSKernelCallBack &after) {
if (outputs == nullptr) {
MS_LOG(ERROR) << "predict output is nullptr.";
return kLiteNullptr;
}
auto status = ModelPool::GetInstance()->Predict(inputs, outputs, before, after);
if (status != kSuccess) {
MS_LOG(ERROR) << "model runner predict failed.";
return kLiteError;
}
return kSuccess;
}
} // namespace mindspore
#endif

View File

@ -0,0 +1,66 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_INCLUDE_API_MODEL_RUNNER_H
#define MINDSPORE_INCLUDE_API_MODEL_RUNNER_H
#ifdef USING_SERVING
#include <vector>
#include <memory>
#include <utility>
#include <string>
#include "include/api/status.h"
#include "include/api/context.h"
namespace mindspore {
class ModelPool;
/// \brief The ModelRunner class is used to define a MindSpore ModelPoolManager, facilitating Model management.
class MS_API ModelParallelRunner {
public:
ModelParallelRunner() = default;
~ModelParallelRunner() = default;
/// \brief build a model runner from model path so that it can run on a device. Only valid for Lite.
///
/// \param[in] model_path Define the model path.
/// \param[in] model_type Define The type of model file. Options: ModelType::kMindIR, ModelType::kOM. Only
/// ModelType::kMindIR is valid for Lite.
/// \param[in] model_context Define the context used to store options during execution.
/// \param[in] dec_key Define the key used to decrypt the ciphertext model. The key length is 16, 24, or 32.
/// \param[in] dec_mode Define the decryption mode. Options: AES-GCM, AES-CBC.
///
/// \return Status.
Status Init(const std::string &model_path, const std::string &config_path, const Key &dec_key = {},
const std::string &dec_mode = kDecModeAesGcm);
/// \brief Obtains all input tensors of the model.
///
/// \return The vector that includes all input tensors.
std::vector<MSTensor> GetInputs();
/// \brief Inference ModelPoolManager.
///
/// \param[in] inputs A vector where model inputs are arranged in sequence.
/// \param[out] outputs Which is a pointer to a vector. The model outputs are filled in the container in sequence.
/// \param[in] before CallBack before predict.
/// \param[in] after CallBack after predict.
///
/// \return Status.
Status Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
const MSKernelCallBack &before = nullptr, const MSKernelCallBack &after = nullptr);
};
} // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_MODEL_RUNNER_H
#endif

View File

@ -0,0 +1,198 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifdef USING_SERVING
#include "src/cxx_api/model/model_pool.h"
#include <unistd.h>
#include <future>
#include "src/common/log.h"
#include "include/lite_types.h"
#include "src/common/config_file.h"
namespace mindspore {
void ModelPool::SetBindStrategy(std::vector<std::vector<int>> *all_model_bind_list, int thread_num) {
int core_num = sysconf(_SC_NPROCESSORS_CONF);
if (thread_num == 0) {
MS_LOG(ERROR) << "thread num is zero.";
return;
}
num_models_ = core_num / thread_num;
int core_id = 0;
for (size_t i = 0; i < num_models_; i++) {
std::vector<int> bind_id;
for (int j = 0; j < thread_num; j++) {
if (core_id >= core_num) {
core_id = 0;
}
bind_id.push_back(core_id);
core_id++;
}
all_model_bind_list->push_back(bind_id);
}
}
ModelPool *ModelPool::GetInstance() {
static ModelPool instance;
return &instance;
}
Status ModelPool::InitContext(const std::shared_ptr<mindspore::Context> &context,
std::map<std::string, std::map<std::string, std::string>> *all_config_info) {
if (all_config_info->size() != 1) {
MS_LOG(ERROR) << "all_config_info size should be 1";
return kLiteError;
}
for (auto &item : *all_config_info) {
auto config = item.second;
auto num_thread = atoi(config["num_thread"].c_str());
auto bind_mode = atoi(config["bind_mode"].c_str());
context->SetThreadNum(num_thread);
context->SetThreadAffinity(bind_mode);
}
context->SetEnableParallel(false);
auto &device_list = context->MutableDeviceInfo();
std::shared_ptr<CPUDeviceInfo> device_info = std::make_shared<CPUDeviceInfo>();
device_info->SetEnableFP16(false);
device_list.push_back(device_info);
return kSuccess;
}
ModelPoolContex ModelPool::CreateModelContext(const std::string &config_path) {
std::map<std::string, std::map<std::string, std::string>> all_config_info;
auto ret = lite::GetAllSectionInfoFromConfigFile(config_path, &all_config_info);
if (ret != 0) {
MS_LOG(ERROR) << "GetAllSectionInfoFromConfigFile failed.";
return {};
}
auto model_context = std::make_shared<mindspore::Context>();
if (model_context == nullptr) {
MS_LOG(ERROR) << "model context is nullptr.";
return {};
}
auto status = InitContext(model_context, &all_config_info);
if (status != kSuccess) {
MS_LOG(ERROR) << "InitMSContext failed.";
return {};
}
auto device_list = model_context->MutableDeviceInfo();
if (device_list.size() != 1) {
MS_LOG(ERROR) << "model pool only support device num 1.";
return {};
}
auto device = device_list.front();
if (device->GetDeviceType() != kCPU) {
MS_LOG(ERROR) << "model pool only support cpu type.";
return {};
}
auto cpu_context = device->Cast<CPUDeviceInfo>();
auto enable_fp16 = cpu_context->GetEnableFP16();
if (enable_fp16) {
MS_LOG(ERROR) << "model pool not support enable fp16.";
return {};
}
ModelPoolContex model_pool_context;
std::vector<std::vector<int>> all_model_bind_list;
if (model_context->GetThreadAffinityMode() == lite::HIGHER_CPU) {
SetBindStrategy(&all_model_bind_list, static_cast<int>(model_context->GetThreadNum()));
} else if (model_context->GetThreadAffinityMode() == lite::MID_CPU) {
MS_LOG(ERROR) << "not support bind MID_CPU.";
return {};
}
for (size_t i = 0; i < num_models_; i++) {
auto context = std::make_shared<Context>();
if (context == nullptr) {
MS_LOG(ERROR) << "New Context failed.";
return {};
}
context->SetThreadNum(model_context->GetThreadNum());
context->SetEnableParallel(model_context->GetEnableParallel());
if (model_context->GetThreadAffinityMode() != lite::NO_BIND) {
// bind by core id
context->SetThreadAffinity(all_model_bind_list[i]);
} else {
// not bind core
context->SetThreadAffinity(model_context->GetThreadAffinityMode());
}
auto &new_device_list = context->MutableDeviceInfo();
std::shared_ptr<CPUDeviceInfo> device_info = std::make_shared<CPUDeviceInfo>();
device_info->SetEnableFP16(false);
new_device_list.push_back(device_info);
model_pool_context.push_back(context);
}
return model_pool_context;
}
Status ModelPool::Run() {
std::unique_lock<std::mutex> model_lock(mtx_model_queue_);
while (model_pool_queue_.empty()) {
cv_model_.wait(model_lock);
}
auto model = model_pool_queue_.front();
model_pool_queue_.pop();
model_lock.unlock();
std::unique_lock<std::mutex> data_lock(mtx_data_queue_);
if (model_data_queue_.empty()) {
MS_LOG(ERROR) << "model data queue is empty";
return kLiteError;
}
auto model_data = model_data_queue_.front();
model_data_queue_.pop();
auto inputs = model_data->inputs;
auto outputs = model_data->outputs;
auto before = model_data->before;
auto after = model_data->after;
auto status = model->Predict(*inputs, outputs, before, after);
if (status != kSuccess) {
MS_LOG(ERROR) << "model predict failed.";
return status;
}
mtx_model_queue_.lock();
model_pool_queue_.push(model);
cv_model_.notify_one();
mtx_model_queue_.unlock();
return kSuccess;
}
Status ModelPool::Init(const std::string &model_path, const std::string &config_path, const Key &dec_key,
const std::string &dec_mode) {
auto model_pool_context = CreateModelContext(config_path);
for (size_t i = 0; i < num_models_; i++) {
auto model = std::make_shared<ModelThread>();
auto status = model->Init(model_path, model_pool_context[i], dec_key, dec_mode);
model_pool_queue_.push(model);
}
return kSuccess;
}
Status ModelPool::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
const MSKernelCallBack &before, const MSKernelCallBack &after) {
{
std::unique_lock<std::mutex> data_lock(mtx_data_queue_);
auto model_data = std::make_shared<ModelData>();
model_data->inputs = &inputs;
model_data->outputs = outputs;
model_data->before = before;
model_data->after = after;
model_data_queue_.push(model_data);
}
auto future_status = std::async(std::launch::async, &ModelPool::Run, this);
auto status = future_status.get();
if (status != kSuccess) {
MS_LOG(ERROR) << "model run failed in model pool predict.";
return status;
}
return kSuccess;
}
} // namespace mindspore
#endif

View File

@ -0,0 +1,59 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_INCLUDE_API_MODEL_MODEL_POOL_H
#define MINDSPORE_INCLUDE_API_MODEL_MODEL_POOL_H
#ifdef USING_SERVING
#include <vector>
#include <memory>
#include <utility>
#include <string>
#include <queue>
#include <map>
#include "include/api/status.h"
#include "include/api/context.h"
#include "src/cxx_api/model/model_thread.h"
namespace mindspore {
class ModelPool {
public:
static ModelPool *GetInstance();
virtual ~ModelPool() = default;
Status Init(const std::string &model_path, const std::string &config_path, const Key &dec_key = {},
const std::string &dec_mode = kDecModeAesGcm);
Status Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
const MSKernelCallBack &before = nullptr, const MSKernelCallBack &after = nullptr);
private:
ModelPool() = default;
Status InitContext(const std::shared_ptr<mindspore::Context> &context,
std::map<std::string, std::map<std::string, std::string>> *all_config_info);
Status Run();
void SetBindStrategy(std::vector<std::vector<int>> *all_model_bind_list, int thread_num);
ModelPoolContex CreateModelContext(const std::string &config_path);
std::mutex mtx_data_queue_;
std::mutex mtx_model_queue_;
std::condition_variable cv_data_;
std::condition_variable cv_model_;
std::queue<std::shared_ptr<ModelThread>> model_pool_queue_;
std::queue<std::shared_ptr<ModelData>> model_data_queue_;
size_t num_models_ = 5;
};
} // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_MODEL_MODEL_POOL_H
#endif

View File

@ -0,0 +1,85 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifdef USING_SERVING
#include "src/cxx_api/model/model_thread.h"
#include "src/common/log.h"
#include "src/common/utils.h"
namespace mindspore {
Status ModelThread::Init(const std::string &model_path, const std::shared_ptr<Context> &model_context,
const Key &dec_key, const std::string &dec_mode) {
model_ = std::make_shared<Model>();
mindspore::ModelType model_type = kMindIR;
auto status = model_->Build(model_path, model_type, model_context, dec_key, dec_mode);
if (status != kSuccess) {
MS_LOG(ERROR) << "model build failed in ModelPool Init";
return status;
}
return kSuccess;
}
Status ModelThread::ModelRun(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
const MSKernelCallBack &before, const MSKernelCallBack &after) {
auto status = model_->Predict(inputs, outputs, before, after);
if (status != kSuccess) {
MS_LOG(ERROR) << "model predict failed.";
return status;
}
return kSuccess;
}
std::pair<std::vector<std::vector<int64_t>>, bool> ModelThread::GetModelResize(
const std::vector<MSTensor> &model_inputs, const std::vector<MSTensor> &inputs) {
std::unique_lock<std::mutex> model_lock(mtx_model_);
std::vector<std::vector<int64_t>> dims;
bool need_resize = false;
for (size_t i = 0; i < model_inputs.size(); i++) {
for (size_t j = 0; j < model_inputs[i].Shape().size(); j++) {
if (model_inputs[i].Shape()[j] != inputs[i].Shape()[j]) {
need_resize = true;
}
}
dims.push_back(inputs[i].Shape());
}
return std::make_pair(dims, need_resize);
}
Status ModelThread::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
const MSKernelCallBack &before, const MSKernelCallBack &after) {
// model
auto model_input = model_->GetInputs();
if (model_input.size() != inputs.size()) {
MS_LOG(ERROR) << "model input size is: " << model_input.size() << ", but get input size is: " << inputs.size();
return kLiteError;
}
auto resize_pair = GetModelResize(model_input, inputs);
if (resize_pair.second) {
auto dims = resize_pair.first;
auto status = model_->Resize(model_->GetInputs(), dims);
if (status != kSuccess) {
MS_LOG(ERROR) << "model pool resize failed.";
return kLiteError;
}
}
auto status = ModelRun(inputs, outputs, before, after);
if (status != kSuccess) {
MS_LOG(ERROR) << "model predict failed in ModelPool.";
return status;
}
return kSuccess;
}
} // namespace mindspore
#endif

View File

@ -0,0 +1,67 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_CXX_API_MODEL_THREAD_H_
#define MINDSPORE_LITE_SRC_CXX_API_MODEL_THREAD_H_
#ifdef USING_SERVING
#include <queue>
#include <string>
#include <mutex>
#include <future>
#include <vector>
#include <utility>
#include <memory>
#include "include/api/model.h"
namespace mindspore {
using ModelPoolContex = std::vector<std::shared_ptr<Context>>;
struct ModelData {
const std::vector<MSTensor> *inputs;
std::vector<MSTensor> *outputs;
MSKernelCallBack before;
MSKernelCallBack after;
};
class ModelThread {
public:
ModelThread() = default;
~ModelThread() = default;
// the model pool is initialized once and can always accept model run requests
Status Init(const std::string &model_path, const std::shared_ptr<Context> &model_context, const Key &dec_key = {},
const std::string &dec_mode = kDecModeAesGcm);
Status Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
const MSKernelCallBack &before = nullptr, const MSKernelCallBack &after = nullptr);
private:
std::pair<std::vector<std::vector<int64_t>>, bool> GetModelResize(const std::vector<MSTensor> &model_inputs,
const std::vector<MSTensor> &inputs);
Status ModelRun(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
const MSKernelCallBack &before = nullptr, const MSKernelCallBack &after = nullptr);
private:
std::shared_ptr<mindspore::Model> model_ = nullptr;
std::mutex mtx_model_;
std::condition_variable model_cond_;
// num thread is configured according to the hardware
int num_models_;
};
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_CXX_API_MODEL_THREAD_H_
#endif

View File

@ -28,6 +28,9 @@
#include "src/common/graph_util.h"
#include "src/common/file_utils.h"
#include "src/tensor.h"
#ifdef USING_SERVING
#include "src/pack_weight_manager.h"
#endif
#ifdef ENABLE_V0
#include "src/ops/compat/compat_register.h"
#endif
@ -105,6 +108,9 @@ int LiteModel::ConvertAttrToTensors() {
#endif
void LiteModel::Free() {
#ifdef USING_SERVING
lite::PackWeightManager::GetInstance()->DeleteSavedModelPtr(this);
#endif
if (this->buf != nullptr) {
delete[](this->buf);
this->buf = nullptr;
@ -588,7 +594,9 @@ Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf) {
MS_LOG(ERROR) << "new model fail!";
return nullptr;
}
#ifdef USING_SERVING
lite::PackWeightManager::GetInstance()->StoreLiteModel(model_buf, model);
#endif
auto status = model->ConstructModel(model_buf, size, take_buf);
if (status != RET_OK) {
MS_LOG(ERROR) << "construct model failed.";

View File

@ -567,8 +567,7 @@ int LiteSession::IniPackWeightData(Model *model) {
src_tensor->length() == 0) {
continue;
}
lite::PackWeightManager::GetInstance()->StoreOriginTensor(lite_model, src_tensor, tensor_index);
auto data = lite::PackWeightManager::GetInstance()->GetTensorData(lite_model, tensor_index);
auto data = lite::PackWeightManager::GetInstance()->GetTensorData(lite_model, src_tensor, tensor_index);
if (data == nullptr) {
MS_LOG(DEBUG) << "data not packed.";
continue;
@ -982,9 +981,6 @@ LiteSession::~LiteSession() {
MS_LOG(ERROR) << "Not support multi-threading";
return;
}
#ifdef USING_SERVING
lite::PackWeightManager::GetInstance()->DeleteSavedSessionPtr(this);
#endif
for (auto *kernel : kernels_) {
delete kernel;
kernel = nullptr;
@ -1684,7 +1680,7 @@ const char *lite::LiteSession::LoadModelByPath(const std::string &file, mindspor
model_buf = nullptr;
}
#ifdef USING_SERVING
lite::PackWeightManager::GetInstance()->InitWeightManagerByPath(file, model_buf, nullptr);
lite::PackWeightManager::GetInstance()->InitWeightManagerByPath(file, model_buf);
#endif
return lite_buf;
}
@ -1708,7 +1704,7 @@ const char *lite::LiteSession::LoadModelByPath(const std::string &file, mindspor
model_buf = nullptr;
}
#ifdef USING_SERVING
lite::PackWeightManager::GetInstance()->InitWeightManagerByPath(file, model_buf, nullptr);
lite::PackWeightManager::GetInstance()->InitWeightManagerByPath(file, model_buf);
#endif
return lite_buf;
}
@ -1722,9 +1718,6 @@ int lite::LiteSession::LoadModelAndCompileByBuf(const char *model_buf, mindspore
MS_LOG(ERROR) << "Invalid model_buf";
return RET_ERROR;
}
#ifdef USING_SERVING
lite::PackWeightManager::GetInstance()->InitWeightManagerByBuf(model_buf, this);
#endif
auto *model = lite::ImportFromBuffer(lite_buf, lite_buf_size, true);
if (model == nullptr) {
MS_LOG(ERROR) << "Import model failed";
@ -1755,9 +1748,6 @@ int lite::LiteSession::LoadModelAndCompileByBuf(const char *model_buf, mindspore
MS_LOG(ERROR) << "Invalid model_buf";
return RET_ERROR;
}
#ifdef USING_SERVING
lite::PackWeightManager::GetInstance()->InitWeightManagerByBuf(model_buf, this);
#endif
auto *model = lite::ImportFromBuffer(lite_buf, lite_buf_size, true);
if (model == nullptr) {
MS_LOG(ERROR) << "Import model failed";

View File

@ -21,21 +21,7 @@ PackWeightManager *PackWeightManager::GetInstance() {
return &instance;
}
void PackWeightManager::InitWeightManagerByBuf(const char *model_buf, const LiteSession *lite_session) {
MS_CHECK_TRUE_RET_VOID(model_buf != nullptr);
MS_CHECK_TRUE_RET_VOID(lite_session != nullptr);
if (buf_model_weight_.find(model_buf) == buf_model_weight_.end()) {
auto *model_const_weight = new (std::nothrow) ModelConstWeight();
if (model_const_weight == nullptr) {
return;
}
buf_model_weight_[model_buf] = model_const_weight;
}
buf_model_weight_[model_buf]->lite_sessions.push_back(lite_session);
}
void PackWeightManager::InitWeightManagerByPath(const std::string &model_path, const char *model_buf,
const LiteSession *session) {
void PackWeightManager::InitWeightManagerByPath(const std::string &model_path, const char *model_buf) {
MS_CHECK_TRUE_RET_VOID(model_buf != nullptr);
if (path_model_buf_.find(model_path) == path_model_buf_.end()) {
auto *model_const_weight = new (std::nothrow) ModelConstWeight();
@ -44,7 +30,6 @@ void PackWeightManager::InitWeightManagerByPath(const std::string &model_path, c
}
path_model_weight_[model_path] = model_const_weight;
}
path_model_weight_[model_path]->lite_sessions.push_back(session);
path_model_buf_[model_path].push_back(model_buf);
}
@ -59,94 +44,51 @@ STATUS PackWeightManager::StoreLiteModel(const char *model_buf, const Model *mod
return RET_OK;
}
}
if (buf_model_weight_.find(model_buf) == buf_model_weight_.end()) {
MS_LOG(ERROR) << "Set model failed.";
return RET_ERROR;
}
buf_model_weight_[model_buf]->lite_models.push_back(model);
return RET_OK;
}
void PackWeightManager::StoreOriginTensor(const LiteModel *model, const SchemaTensorWrapper *origin_tensor,
size_t tensor_index) {
MS_CHECK_TRUE_RET_VOID(model != nullptr);
MS_CHECK_TRUE_RET_VOID(origin_tensor != nullptr);
for (auto &item : buf_model_weight_) {
auto &model_buf = item.first;
auto &model_weight = item.second;
for (auto &lite_model : model_weight->lite_models) {
if (model == lite_model) {
if (model_weight->origin_weight.find(tensor_index) == model_weight->origin_weight.end()) {
buf_model_weight_[model_buf]->origin_weight[tensor_index] = origin_tensor->data();
}
}
}
}
void *PackWeightManager::GetTensorData(const LiteModel *model, const SchemaTensorWrapper *origin_tensor,
size_t tensor_index) {
MS_CHECK_TRUE_RET(model != nullptr, nullptr);
for (auto &item : path_model_weight_) {
auto &path = item.first;
auto &model_weight = item.second;
for (auto &lite_model : model_weight->lite_models) {
if (model == lite_model) {
if (model_weight->origin_weight.find(tensor_index) == model_weight->origin_weight.end()) {
path_model_weight_[path]->origin_weight[tensor_index] = origin_tensor->data();
}
}
}
}
}
void *PackWeightManager::GetTensorData(const LiteModel *model, size_t tensor_index) {
MS_CHECK_TRUE_RET(model != nullptr, nullptr);
for (auto &item : buf_model_weight_) {
auto &model_weight = item.second;
auto &models = model_weight->lite_models;
if (find(models.begin(), models.end(), model) != models.end()) {
if (model_weight->packed_weight.find(tensor_index) != model_weight->packed_weight.end()) {
return model_weight->packed_weight[tensor_index];
}
}
}
for (auto &item : path_model_weight_) {
auto &model_weight = item.second;
auto &models = model_weight->lite_models;
if (find(models.begin(), models.end(), model) != models.end()) {
if (model_weight->packed_weight.find(tensor_index) != model_weight->packed_weight.end()) {
return model_weight->packed_weight[tensor_index];
}
path_model_weight_[path]->origin_weight[tensor_index] = origin_tensor->data();
path_model_weight_[path]->origin_data_index[origin_tensor->data()] = tensor_index;
return nullptr;
}
}
MS_LOG(DEBUG) << "tensor data not packed.";
return nullptr;
}
std::pair<PackStatus, void *> PackWeightManager::FindPackedTensor(PackedWeight *packed_weights,
const OriginWeight &origin_weithts,
const Tensor *tensor, const size_t size) {
std::pair<PackStatus, void *> PackWeightManager::FindPackedTensor(ModelConstWeight *weight, const Tensor *tensor,
const size_t size) {
std::unique_lock<std::mutex> weight_lock(mtx_weight_);
MS_CHECK_TRUE_RET(packed_weights != nullptr, std::make_pair(MALLOC, nullptr));
MS_CHECK_TRUE_RET(tensor != nullptr, std::make_pair(MALLOC, nullptr));
auto &packed_weights = weight->packed_weight;
if (size > MAX_MALLOC_SIZE) {
MS_LOG(ERROR) << "malloc size more than MAX_MALLOC_SIZE";
return std::make_pair(MALLOC, nullptr);
}
for (auto &packed_weight : *packed_weights) {
auto &packed_tensor = packed_weight.second;
if (packed_tensor == tensor->data()) {
return std::make_pair(PACKED, packed_tensor);
}
}
for (auto &origin_weight : origin_weithts) {
auto &origin_tensor = origin_weight.second;
auto &origin_index = origin_weight.first;
if (origin_tensor == tensor->data()) {
void *data = malloc(size);
if (data == nullptr) {
MS_LOG(ERROR) << "malloc failed.";
return std::make_pair(MALLOC, nullptr);
}
memset(data, 0, size);
packed_weights->insert(std::make_pair(origin_index, data));
return std::make_pair(NOTPACK, packed_weights->at(origin_index));
if (weight->packed_data.find(tensor->data()) != weight->packed_data.end()) {
return std::make_pair(PACKED, tensor->data());
} else if (weight->origin_data_index.find(tensor->data()) != weight->origin_data_index.end()) {
auto origin_index = weight->origin_data_index[tensor->data()];
void *data = malloc(size);
if (data == nullptr) {
MS_LOG(ERROR) << "malloc failed.";
return std::make_pair(MALLOC, nullptr);
}
weight->packed_data.insert(data);
packed_weights.insert(std::make_pair(origin_index, data));
return std::make_pair(NOTPACK, packed_weights.at(origin_index));
}
return std::make_pair(MALLOC, nullptr);
}
@ -154,25 +96,14 @@ std::pair<PackStatus, void *> PackWeightManager::FindPackedTensor(PackedWeight *
std::pair<PackStatus, void *> PackWeightManager::GetPackedTensor(const Tensor *tensor, const size_t size) {
MS_CHECK_TRUE_RET(tensor != nullptr, std::make_pair(MALLOC, nullptr));
std::pair<PackStatus, void *> packed_tensor_pair;
for (auto &item : buf_model_weight_) {
auto &model_weight = item.second;
auto &origin_weithts = model_weight->origin_weight;
auto &packed_weights = model_weight->packed_weight;
packed_tensor_pair = FindPackedTensor(&packed_weights, origin_weithts, tensor, size);
if (packed_tensor_pair.second != nullptr) {
return packed_tensor_pair;
}
}
for (auto &item : path_model_weight_) {
auto &model_weight = item.second;
auto &origin_weithts = model_weight->origin_weight;
auto &packed_weights = model_weight->packed_weight;
packed_tensor_pair = FindPackedTensor(&packed_weights, origin_weithts, tensor, size);
packed_tensor_pair = FindPackedTensor(model_weight, tensor, size);
if (packed_tensor_pair.second != nullptr) {
return packed_tensor_pair;
}
}
MS_LOG(DEBUG) << "not const tensor, need pack in kernel.";
return std::make_pair(MALLOC, nullptr);
}
@ -186,32 +117,6 @@ void PackWeightManager::DeleteSavedModelPtr(LiteModel *delete_model) {
weight->lite_models.erase(it);
}
}
for (auto &item : buf_model_weight_) {
auto &weight = item.second;
auto it = find(weight->lite_models.begin(), weight->lite_models.end(), delete_model);
if (it != weight->lite_models.end()) {
weight->lite_models.erase(it);
}
}
}
void PackWeightManager::DeleteSavedSessionPtr(LiteSession *delete_session) {
std::unique_lock<std::mutex> weight_lock(mtx_weight_);
MS_CHECK_TRUE_RET_VOID(delete_session != nullptr);
for (auto &item : path_model_weight_) {
auto &weight = item.second;
auto it = find(weight->lite_sessions.begin(), weight->lite_sessions.end(), delete_session);
if (it != weight->lite_sessions.end()) {
weight->lite_sessions.erase(it);
}
}
for (auto &item : buf_model_weight_) {
auto &weight = item.second;
auto it = find(weight->lite_sessions.begin(), weight->lite_sessions.end(), delete_session);
if (it != weight->lite_sessions.end()) {
weight->lite_sessions.erase(it);
}
}
}
void PackWeightManager::FreePackedWeight(ModelConstWeight *weight) {
@ -228,23 +133,11 @@ void PackWeightManager::FreePackedWeight(ModelConstWeight *weight) {
}
}
void PackWeightManager::FreeBufModelWeight() {
for (auto &item : buf_model_weight_) {
FreePackedWeight(item.second);
buf_model_weight_.erase(item.first);
}
}
void PackWeightManager::FreePathModelWeight() {
PackWeightManager::~PackWeightManager() {
for (auto &item : path_model_weight_) {
FreePackedWeight(item.second);
path_model_weight_.erase(item.first);
}
}
PackWeightManager::~PackWeightManager() {
FreePathModelWeight();
FreeBufModelWeight();
}
} // namespace mindspore::lite
#endif

View File

@ -22,6 +22,7 @@
#include <algorithm>
#include <utility>
#include <vector>
#include <set>
#include <mutex>
#include "src/tensor.h"
#include "src/lite_session.h"
@ -33,8 +34,10 @@ struct ModelConstWeight {
PackedWeight packed_weight;
OriginWeight origin_weight;
std::vector<const Model *> lite_models;
std::vector<const LiteSession *> lite_sessions;
std::map<const void *, size_t> origin_data_index;
std::set<const void *> packed_data;
};
enum PackStatus : int8_t { NOTPACK = 1, PACKED = 2, MALLOC = 3 };
class PackWeightManager {
@ -42,29 +45,19 @@ class PackWeightManager {
static PackWeightManager *GetInstance();
virtual ~PackWeightManager();
void InitWeightManagerByPath(const std::string &model_path, const char *model_buf);
void DeleteSavedModelPtr(LiteModel *delete_model);
void DeleteSavedSessionPtr(LiteSession *delete_session);
void FreePathModelWeight();
void FreeBufModelWeight();
void InitWeightManagerByBuf(const char *model_buf, const LiteSession *lite_session);
void InitWeightManagerByPath(const std::string &model_path, const char *model_buf,
const LiteSession *session = nullptr);
STATUS StoreLiteModel(const char *model_buf, const Model *model);
void StoreOriginTensor(const LiteModel *model, const SchemaTensorWrapper *origin_tensor, size_t tensor_index);
void *GetTensorData(const LiteModel *model, size_t tensor_index);
void *GetTensorData(const LiteModel *model, const SchemaTensorWrapper *origin_tensor, size_t tensor_index);
std::pair<PackStatus, void *> GetPackedTensor(const Tensor *tensor, const size_t size);
void FreePackedWeight(ModelConstWeight *weight);
private:
PackWeightManager() = default;
std::pair<PackStatus, void *> FindPackedTensor(PackedWeight *packed_weights, const OriginWeight &origin_weithts,
const Tensor *tensor, const size_t size);
std::map<const char *, ModelConstWeight *> buf_model_weight_;
std::map<const std::string, std::vector<const void *>> path_model_buf_;
// path: model_buf
std::pair<PackStatus, void *> FindPackedTensor(ModelConstWeight *weight, const Tensor *tensor, const size_t size);
void FreePackedWeight(ModelConstWeight *weight);
std::map<const std::string, ModelConstWeight *> path_model_weight_;
std::map<const std::string, std::vector<const void *>> path_model_buf_;
std::mutex mtx_weight_;
};
} // namespace mindspore::lite