forked from mindspore-Ecosystem/mindspore
!29297 [MS][LITE] model pool: GetInput && input queue
Merge pull request !29297 from yefeng/201-fix_model_pool_input_queue
This commit is contained in:
commit
5a8d837ff8
|
@ -133,35 +133,40 @@ ModelPoolContex ModelPool::CreateModelContext(const std::string &config_path) {
|
|||
return model_pool_context;
|
||||
}
|
||||
|
||||
Status ModelPool::Run() {
|
||||
std::unique_lock<std::mutex> model_lock(mtx_model_queue_);
|
||||
while (model_pool_queue_.empty()) {
|
||||
cv_model_.wait(model_lock);
|
||||
void ModelPool::Run(std::shared_ptr<ModelThread> model) {
|
||||
while (!model_pool_task_done_) {
|
||||
std::unique_lock<std::mutex> data_lock(mtx_model_queue_);
|
||||
while (model_data_queue_.empty() && !model_pool_task_done_) {
|
||||
cv_in_data_.wait(data_lock);
|
||||
}
|
||||
if (model_pool_task_done_) {
|
||||
cv_in_data_.notify_all();
|
||||
break;
|
||||
}
|
||||
auto &model_data = model_data_queue_.front();
|
||||
model_data_queue_.pop();
|
||||
auto inputs = model_data->inputs;
|
||||
auto *outputs = model_data->outputs;
|
||||
auto before = model_data->before;
|
||||
auto after = model_data->after;
|
||||
cv_in_data_.notify_one();
|
||||
data_lock.unlock();
|
||||
auto status = model->Predict(*inputs, outputs, before, after);
|
||||
if (status != kSuccess) {
|
||||
MS_LOG(ERROR) << "model predict failed.";
|
||||
return;
|
||||
}
|
||||
auto output_size = outputs->size();
|
||||
for (size_t i = 0; i < output_size; i++) {
|
||||
auto copy_tensor =
|
||||
mindspore::MSTensor::CreateTensor(outputs->at(i).Name(), outputs->at(i).DataType(), outputs->at(i).Shape(),
|
||||
outputs->at(i).MutableData(), outputs->at(i).DataSize());
|
||||
outputs->erase(outputs->begin());
|
||||
outputs->push_back(*copy_tensor);
|
||||
}
|
||||
cv_in_data_.notify_one();
|
||||
cv_out_data_.notify_all();
|
||||
}
|
||||
auto model = model_pool_queue_.front();
|
||||
model_pool_queue_.pop();
|
||||
model_lock.unlock();
|
||||
std::unique_lock<std::mutex> data_lock(mtx_data_queue_);
|
||||
if (model_data_queue_.empty()) {
|
||||
MS_LOG(ERROR) << "model data queue is empty";
|
||||
return kLiteError;
|
||||
}
|
||||
auto model_data = model_data_queue_.front();
|
||||
model_data_queue_.pop();
|
||||
auto inputs = model_data->inputs;
|
||||
auto outputs = model_data->outputs;
|
||||
auto before = model_data->before;
|
||||
auto after = model_data->after;
|
||||
auto status = model->Predict(*inputs, outputs, before, after);
|
||||
if (status != kSuccess) {
|
||||
MS_LOG(ERROR) << "model predict failed.";
|
||||
return status;
|
||||
}
|
||||
mtx_model_queue_.lock();
|
||||
model_pool_queue_.push(model);
|
||||
cv_model_.notify_one();
|
||||
mtx_model_queue_.unlock();
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
Status ModelPool::Init(const std::string &model_path, const std::string &config_path, const Key &dec_key,
|
||||
|
@ -170,7 +175,7 @@ Status ModelPool::Init(const std::string &model_path, const std::string &config_
|
|||
for (size_t i = 0; i < num_models_; i++) {
|
||||
auto model = std::make_shared<ModelThread>();
|
||||
auto status = model->Init(model_path, model_pool_context[i], dec_key, dec_mode);
|
||||
model_pool_queue_.push(model);
|
||||
model_thread_vec_.push_back(std::thread(&ModelPool::Run, this, model));
|
||||
}
|
||||
return kSuccess;
|
||||
}
|
||||
|
@ -185,14 +190,26 @@ Status ModelPool::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTen
|
|||
model_data->before = before;
|
||||
model_data->after = after;
|
||||
model_data_queue_.push(model_data);
|
||||
cv_in_data_.notify_one();
|
||||
}
|
||||
auto future_status = std::async(std::launch::async, &ModelPool::Run, this);
|
||||
auto status = future_status.get();
|
||||
if (status != kSuccess) {
|
||||
MS_LOG(ERROR) << "model run failed in model pool predict.";
|
||||
return status;
|
||||
{
|
||||
std::unique_lock<std::mutex> result_loack(mtx_data_queue_);
|
||||
while (outputs->empty()) {
|
||||
cv_out_data_.wait(result_loack);
|
||||
}
|
||||
}
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
ModelPool::~ModelPool() {
|
||||
model_pool_task_done_ = true;
|
||||
cv_in_data_.notify_all();
|
||||
for (auto &th : model_thread_vec_) {
|
||||
if (th.joinable()) {
|
||||
th.join();
|
||||
}
|
||||
}
|
||||
cv_in_data_.notify_one();
|
||||
}
|
||||
} // namespace mindspore
|
||||
#endif
|
||||
|
|
|
@ -29,7 +29,7 @@ namespace mindspore {
|
|||
class ModelPool {
|
||||
public:
|
||||
static ModelPool *GetInstance();
|
||||
virtual ~ModelPool() = default;
|
||||
~ModelPool();
|
||||
|
||||
Status Init(const std::string &model_path, const std::string &config_path, const Key &dec_key = {},
|
||||
const std::string &dec_mode = kDecModeAesGcm);
|
||||
|
@ -41,17 +41,19 @@ class ModelPool {
|
|||
ModelPool() = default;
|
||||
Status InitContext(const std::shared_ptr<mindspore::Context> &context,
|
||||
std::map<std::string, std::map<std::string, std::string>> *all_config_info);
|
||||
Status Run();
|
||||
void Run(std::shared_ptr<ModelThread> model);
|
||||
void SetBindStrategy(std::vector<std::vector<int>> *all_model_bind_list, int thread_num);
|
||||
ModelPoolContex CreateModelContext(const std::string &config_path);
|
||||
|
||||
std::mutex mtx_data_queue_;
|
||||
std::mutex mtx_model_queue_;
|
||||
std::condition_variable cv_data_;
|
||||
std::condition_variable cv_out_data_;
|
||||
std::condition_variable cv_in_data_;
|
||||
std::condition_variable cv_model_;
|
||||
|
||||
std::queue<std::shared_ptr<ModelThread>> model_pool_queue_;
|
||||
std::vector<std::thread> model_thread_vec_;
|
||||
std::queue<std::shared_ptr<ModelData>> model_data_queue_;
|
||||
bool model_pool_task_done_ = false;
|
||||
size_t num_models_ = 5;
|
||||
};
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue