forked from mindspore-Ecosystem/mindspore
add unique shred lock && check
This commit is contained in:
parent
de757f741c
commit
1137dbecef
|
@ -34,15 +34,43 @@ extern void mindspore_log_init();
|
|||
|
||||
RunnerConfig::RunnerConfig() : data_(std::make_shared<Data>()) {}
|
||||
|
||||
void RunnerConfig::SetWorkersNum(int32_t workers_num) { data_->workers_num = workers_num; }
|
||||
void RunnerConfig::SetWorkersNum(int32_t workers_num) {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Runner config data is nullptr.";
|
||||
return;
|
||||
}
|
||||
data_->workers_num = workers_num;
|
||||
}
|
||||
|
||||
void RunnerConfig::SetContext(const std::shared_ptr<Context> &context) { data_->context = context; }
|
||||
void RunnerConfig::SetContext(const std::shared_ptr<Context> &context) {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Runner config data is nullptr.";
|
||||
return;
|
||||
}
|
||||
data_->context = context;
|
||||
}
|
||||
|
||||
int32_t RunnerConfig::GetWorkersNum() const { return data_->workers_num; }
|
||||
int32_t RunnerConfig::GetWorkersNum() const {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Runner config data is nullptr.";
|
||||
return -1;
|
||||
}
|
||||
return data_->workers_num;
|
||||
}
|
||||
|
||||
std::shared_ptr<Context> RunnerConfig::GetContext() const { return data_->context; }
|
||||
std::shared_ptr<Context> RunnerConfig::GetContext() const {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Runner config data is nullptr.";
|
||||
return nullptr;
|
||||
}
|
||||
return data_->context;
|
||||
}
|
||||
|
||||
void RunnerConfig::SetConfigInfo(const std::string §ion, const std::map<std::string, std::string> &config) {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Runner config data is nullptr.";
|
||||
return;
|
||||
}
|
||||
if (data_->config_info.size() > kMaxSectionNum) {
|
||||
return;
|
||||
}
|
||||
|
@ -54,6 +82,11 @@ void RunnerConfig::SetConfigInfo(const std::string §ion, const std::map<std:
|
|||
}
|
||||
|
||||
std::map<std::string, std::map<std::string, std::string>> RunnerConfig::GetConfigInfo() const {
|
||||
if (data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Runner config data is nullptr.";
|
||||
std::map<std::string, std::map<std::string, std::string>> empty;
|
||||
return empty;
|
||||
}
|
||||
return data_->config_info;
|
||||
}
|
||||
|
||||
|
@ -61,38 +94,61 @@ Status ModelParallelRunner::Init(const std::string &model_path, const std::share
|
|||
#ifdef USE_GLOG
|
||||
mindspore::mindspore_log_init();
|
||||
#endif
|
||||
#ifdef CAPTURE_SIGNALS
|
||||
CaptureSignal();
|
||||
#endif
|
||||
if (model_pool_ != nullptr && model_pool_->IsInitialized()) {
|
||||
MS_LOG(WARNING) << "ModelParallelRunner is already initialized, not need to initialize it again";
|
||||
return kSuccess;
|
||||
}
|
||||
auto new_model_pool = std::make_shared<ModelPool>();
|
||||
if (new_model_pool == nullptr) {
|
||||
MS_LOG(ERROR) << "new model pool failed, model pool is nullptr.";
|
||||
return kLiteNullptr;
|
||||
}
|
||||
if (!PlatformInstructionSetSupportCheck()) {
|
||||
return kLiteNotSupport;
|
||||
}
|
||||
model_pool_ = std::make_shared<ModelPool>();
|
||||
if (model_pool_ == nullptr) {
|
||||
MS_LOG(ERROR) << "model pool is nullptr.";
|
||||
return kLiteNullptr;
|
||||
}
|
||||
auto status = model_pool_->Init(model_path, runner_config);
|
||||
auto status = new_model_pool->Init(model_path, runner_config);
|
||||
if (status != kSuccess) {
|
||||
MS_LOG(ERROR) << "model runner init failed.";
|
||||
MS_LOG(ERROR) << "ModelParallelRunner init failed.";
|
||||
return kLiteError;
|
||||
}
|
||||
if (model_pool_ != nullptr && model_pool_->IsInitialized()) {
|
||||
MS_LOG(WARNING) << "ModelParallelRunner is already initialized, not need to initialize it again";
|
||||
return kSuccess;
|
||||
}
|
||||
model_pool_ = new_model_pool;
|
||||
#ifdef CAPTURE_SIGNALS
|
||||
CaptureSignal();
|
||||
#endif
|
||||
return status;
|
||||
}
|
||||
|
||||
std::vector<MSTensor> ModelParallelRunner::GetInputs() { return model_pool_->GetInputs(); }
|
||||
std::vector<MSTensor> ModelParallelRunner::GetInputs() {
|
||||
if (model_pool_ == nullptr) {
|
||||
std::vector<MSTensor> empty;
|
||||
MS_LOG(ERROR) << "Please initialize ModelParallelRunner before calling GetInput API.";
|
||||
return empty;
|
||||
}
|
||||
return model_pool_->GetInputs();
|
||||
}
|
||||
|
||||
std::vector<MSTensor> ModelParallelRunner::GetOutputs() { return model_pool_->GetOutputs(); }
|
||||
std::vector<MSTensor> ModelParallelRunner::GetOutputs() {
|
||||
if (model_pool_ == nullptr) {
|
||||
std::vector<MSTensor> empty;
|
||||
MS_LOG(ERROR) << "Please initialize ModelParallelRunner before calling GetInput API.";
|
||||
return empty;
|
||||
}
|
||||
return model_pool_->GetOutputs();
|
||||
}
|
||||
|
||||
Status ModelParallelRunner::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
|
||||
const MSKernelCallBack &before, const MSKernelCallBack &after) {
|
||||
if (outputs == nullptr) {
|
||||
MS_LOG(ERROR) << "predict output is nullptr.";
|
||||
if (outputs == nullptr || model_pool_ == nullptr) {
|
||||
MS_LOG(ERROR) << "predict output is nullptr or ModelParallelRunner Not Initialize.";
|
||||
return kLiteNullptr;
|
||||
}
|
||||
auto status = model_pool_->Predict(inputs, outputs, before, after);
|
||||
if (status != kSuccess) {
|
||||
MS_LOG(ERROR) << "model runner predict failed.";
|
||||
MS_LOG(ERROR) << "ModelParallelRunner predict failed.";
|
||||
return status;
|
||||
}
|
||||
return kSuccess;
|
||||
|
|
|
@ -701,6 +701,7 @@ ModelPoolConfig ModelPool::CreateModelPoolConfig(const std::shared_ptr<RunnerCon
|
|||
}
|
||||
|
||||
std::vector<MSTensor> ModelPool::GetInputs() {
|
||||
std::shared_lock<std::shared_mutex> l(model_pool_mutex_);
|
||||
std::vector<MSTensor> inputs;
|
||||
if (inputs_info_.empty()) {
|
||||
MS_LOG(ERROR) << "model input is empty.";
|
||||
|
@ -723,6 +724,7 @@ std::vector<MSTensor> ModelPool::GetInputs() {
|
|||
}
|
||||
|
||||
std::vector<MSTensor> ModelPool::GetOutputs() {
|
||||
std::shared_lock<std::shared_mutex> l(model_pool_mutex_);
|
||||
std::vector<MSTensor> outputs;
|
||||
if (outputs_info_.empty()) {
|
||||
MS_LOG(ERROR) << "model output is empty.";
|
||||
|
@ -872,6 +874,7 @@ Status ModelPool::CanUseAllPhysicalResources(int *percentage) {
|
|||
}
|
||||
|
||||
Status ModelPool::Init(const std::string &model_path, const std::shared_ptr<RunnerConfig> &runner_config) {
|
||||
std::unique_lock<std::shared_mutex> l(model_pool_mutex_);
|
||||
int percentage;
|
||||
auto status = CanUseAllPhysicalResources(&percentage);
|
||||
if (status != kSuccess) {
|
||||
|
@ -934,6 +937,7 @@ Status ModelPool::Init(const std::string &model_path, const std::shared_ptr<Runn
|
|||
for (size_t i = 0; i < kNumMaxTaskQueueSize; i++) {
|
||||
free_tasks_id_.push(i);
|
||||
}
|
||||
is_initialized_ = true;
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
|
@ -1214,6 +1218,7 @@ void ModelPool::UpdateFreeTaskId(size_t id) {
|
|||
|
||||
Status ModelPool::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
|
||||
const MSKernelCallBack &before, const MSKernelCallBack &after) {
|
||||
std::shared_lock<std::shared_mutex> l(model_pool_mutex_);
|
||||
predict_task_mutex_.lock();
|
||||
int max_wait_worker_node_id = 0;
|
||||
int max_wait_worker_num = 0;
|
||||
|
@ -1262,6 +1267,8 @@ Status ModelPool::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTen
|
|||
}
|
||||
|
||||
ModelPool::~ModelPool() {
|
||||
std::unique_lock<std::shared_mutex> l(model_pool_mutex_);
|
||||
is_initialized_ = false;
|
||||
if (predict_task_queue_ != nullptr) {
|
||||
predict_task_queue_->SetPredictTaskDone();
|
||||
}
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include <queue>
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <shared_mutex>
|
||||
#include "src/runtime/dynamic_mem_allocator.h"
|
||||
#include "include/api/status.h"
|
||||
#include "include/api/context.h"
|
||||
|
@ -57,6 +58,8 @@ class ModelPool {
|
|||
Status Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
|
||||
const MSKernelCallBack &before = nullptr, const MSKernelCallBack &after = nullptr);
|
||||
|
||||
bool IsInitialized() { return is_initialized_; }
|
||||
|
||||
private:
|
||||
ModelPoolConfig CreateModelPoolConfig(const std::shared_ptr<RunnerConfig> &runner_config);
|
||||
std::shared_ptr<Context> GetInitContext(const std::shared_ptr<RunnerConfig> &runner_config);
|
||||
|
@ -164,6 +167,9 @@ class ModelPool {
|
|||
|
||||
std::vector<char *> model_bufs_;
|
||||
char *graph_buf_ = nullptr;
|
||||
|
||||
std::shared_mutex model_pool_mutex_;
|
||||
bool is_initialized_ = false;
|
||||
};
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_CXX_API_MODEL_POOL_MODEL_POOL_H_
|
||||
|
|
Loading…
Reference in New Issue