!29297 [MS][LITE] model pool: GetInput && input queue

Merge pull request !29297 from yefeng/201-fix_model_pool_input_queue
This commit is contained in:
i-robot 2022-01-20 02:21:14 +00:00 committed by Gitee
commit 5a8d837ff8
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 57 additions and 38 deletions

View File

@ -133,35 +133,40 @@ ModelPoolContex ModelPool::CreateModelContext(const std::string &config_path) {
return model_pool_context;
}
Status ModelPool::Run() {
std::unique_lock<std::mutex> model_lock(mtx_model_queue_);
while (model_pool_queue_.empty()) {
cv_model_.wait(model_lock);
void ModelPool::Run(std::shared_ptr<ModelThread> model) {
while (!model_pool_task_done_) {
std::unique_lock<std::mutex> data_lock(mtx_model_queue_);
while (model_data_queue_.empty() && !model_pool_task_done_) {
cv_in_data_.wait(data_lock);
}
if (model_pool_task_done_) {
cv_in_data_.notify_all();
break;
}
auto &model_data = model_data_queue_.front();
model_data_queue_.pop();
auto inputs = model_data->inputs;
auto *outputs = model_data->outputs;
auto before = model_data->before;
auto after = model_data->after;
cv_in_data_.notify_one();
data_lock.unlock();
auto status = model->Predict(*inputs, outputs, before, after);
if (status != kSuccess) {
MS_LOG(ERROR) << "model predict failed.";
return;
}
auto output_size = outputs->size();
for (size_t i = 0; i < output_size; i++) {
auto copy_tensor =
mindspore::MSTensor::CreateTensor(outputs->at(i).Name(), outputs->at(i).DataType(), outputs->at(i).Shape(),
outputs->at(i).MutableData(), outputs->at(i).DataSize());
outputs->erase(outputs->begin());
outputs->push_back(*copy_tensor);
}
cv_in_data_.notify_one();
cv_out_data_.notify_all();
}
auto model = model_pool_queue_.front();
model_pool_queue_.pop();
model_lock.unlock();
std::unique_lock<std::mutex> data_lock(mtx_data_queue_);
if (model_data_queue_.empty()) {
MS_LOG(ERROR) << "model data queue is empty";
return kLiteError;
}
auto model_data = model_data_queue_.front();
model_data_queue_.pop();
auto inputs = model_data->inputs;
auto outputs = model_data->outputs;
auto before = model_data->before;
auto after = model_data->after;
auto status = model->Predict(*inputs, outputs, before, after);
if (status != kSuccess) {
MS_LOG(ERROR) << "model predict failed.";
return status;
}
mtx_model_queue_.lock();
model_pool_queue_.push(model);
cv_model_.notify_one();
mtx_model_queue_.unlock();
return kSuccess;
}
Status ModelPool::Init(const std::string &model_path, const std::string &config_path, const Key &dec_key,
@ -170,7 +175,7 @@ Status ModelPool::Init(const std::string &model_path, const std::string &config_
for (size_t i = 0; i < num_models_; i++) {
auto model = std::make_shared<ModelThread>();
auto status = model->Init(model_path, model_pool_context[i], dec_key, dec_mode);
model_pool_queue_.push(model);
model_thread_vec_.push_back(std::thread(&ModelPool::Run, this, model));
}
return kSuccess;
}
@ -185,14 +190,26 @@ Status ModelPool::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTen
model_data->before = before;
model_data->after = after;
model_data_queue_.push(model_data);
cv_in_data_.notify_one();
}
auto future_status = std::async(std::launch::async, &ModelPool::Run, this);
auto status = future_status.get();
if (status != kSuccess) {
MS_LOG(ERROR) << "model run failed in model pool predict.";
return status;
{
std::unique_lock<std::mutex> result_loack(mtx_data_queue_);
while (outputs->empty()) {
cv_out_data_.wait(result_loack);
}
}
return kSuccess;
}
ModelPool::~ModelPool() {
model_pool_task_done_ = true;
cv_in_data_.notify_all();
for (auto &th : model_thread_vec_) {
if (th.joinable()) {
th.join();
}
}
cv_in_data_.notify_one();
}
} // namespace mindspore
#endif

View File

@ -29,7 +29,7 @@ namespace mindspore {
class ModelPool {
public:
static ModelPool *GetInstance();
virtual ~ModelPool() = default;
~ModelPool();
Status Init(const std::string &model_path, const std::string &config_path, const Key &dec_key = {},
const std::string &dec_mode = kDecModeAesGcm);
@ -41,17 +41,19 @@ class ModelPool {
ModelPool() = default;
Status InitContext(const std::shared_ptr<mindspore::Context> &context,
std::map<std::string, std::map<std::string, std::string>> *all_config_info);
Status Run();
void Run(std::shared_ptr<ModelThread> model);
void SetBindStrategy(std::vector<std::vector<int>> *all_model_bind_list, int thread_num);
ModelPoolContex CreateModelContext(const std::string &config_path);
std::mutex mtx_data_queue_;
std::mutex mtx_model_queue_;
std::condition_variable cv_data_;
std::condition_variable cv_out_data_;
std::condition_variable cv_in_data_;
std::condition_variable cv_model_;
std::queue<std::shared_ptr<ModelThread>> model_pool_queue_;
std::vector<std::thread> model_thread_vec_;
std::queue<std::shared_ptr<ModelData>> model_data_queue_;
bool model_pool_task_done_ = false;
size_t num_models_ = 5;
};
} // namespace mindspore