support big model: read tensor data for build by buf
This commit is contained in:
parent
df9b900c04
commit
368aa3a3d0
|
@ -65,6 +65,9 @@ static const char *const kMSCacheModelPath = "cache_model_path";
|
||||||
static const char *const kMSCacheVocabSize = "vocab_size";
|
static const char *const kMSCacheVocabSize = "vocab_size";
|
||||||
static const char *const kMSCacheDeviceSize = "device_cache_size";
|
static const char *const kMSCacheDeviceSize = "device_cache_size";
|
||||||
static const char *const kMSCacheSerializePath = "serialize_path";
|
static const char *const kMSCacheSerializePath = "serialize_path";
|
||||||
|
// weight path
|
||||||
|
static const char *const kWeight = "weight";
|
||||||
|
static const char *const kWeightPath = "weight_path";
|
||||||
} // namespace lite
|
} // namespace lite
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -83,6 +83,11 @@ Status ModelWorker::ResizeInit() {
|
||||||
std::vector<std::vector<int64_t>> new_input_shape;
|
std::vector<std::vector<int64_t>> new_input_shape;
|
||||||
for (size_t input_idx = 0; input_idx < inputs.size(); input_idx++) {
|
for (size_t input_idx = 0; input_idx < inputs.size(); input_idx++) {
|
||||||
new_input_shape.push_back(inputs[input_idx].Shape());
|
new_input_shape.push_back(inputs[input_idx].Shape());
|
||||||
|
for (size_t i = 1; i < new_input_shape.size(); i++) {
|
||||||
|
if (new_input_shape[input_idx][i] == -1) {
|
||||||
|
return kSuccess;
|
||||||
|
}
|
||||||
|
}
|
||||||
if (new_input_shape[input_idx][0] == -1) {
|
if (new_input_shape[input_idx][0] == -1) {
|
||||||
// only support resize for batch dim
|
// only support resize for batch dim
|
||||||
new_input_shape[input_idx][0] = kNumInitBatch;
|
new_input_shape[input_idx][0] = kNumInitBatch;
|
||||||
|
|
|
@ -504,7 +504,8 @@ bool LiteModel::CheckQuantAllInit(
|
||||||
|
|
||||||
Model *ImportFromPath(const char *model_path) { return LiteImportFromPath(model_path); }
|
Model *ImportFromPath(const char *model_path) { return LiteImportFromPath(model_path); }
|
||||||
|
|
||||||
Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf, mindspore::ModelType model_type) {
|
Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf, mindspore::ModelType model_type,
|
||||||
|
const std::string &path) {
|
||||||
auto model_loader = mindspore::infer::ModelLoaderRegistry::GetInstance()->GetModelLoader(model_type);
|
auto model_loader = mindspore::infer::ModelLoaderRegistry::GetInstance()->GetModelLoader(model_type);
|
||||||
if (model_loader != nullptr) {
|
if (model_loader != nullptr) {
|
||||||
MS_LOG(INFO) << "import model from model loader";
|
MS_LOG(INFO) << "import model from model loader";
|
||||||
|
@ -516,7 +517,7 @@ Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf, minds
|
||||||
}
|
}
|
||||||
|
|
||||||
MS_LOG(INFO) << "import model from lite model";
|
MS_LOG(INFO) << "import model from lite model";
|
||||||
auto *model = new (std::nothrow) LiteModel();
|
auto *model = new (std::nothrow) LiteModel(path);
|
||||||
if (model == nullptr) {
|
if (model == nullptr) {
|
||||||
MS_LOG(ERROR) << "new model fail!";
|
MS_LOG(ERROR) << "new model fail!";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
|
@ -322,7 +322,8 @@ class LiteModel : public Model {
|
||||||
};
|
};
|
||||||
|
|
||||||
Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf,
|
Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf,
|
||||||
mindspore::ModelType model_type = mindspore::ModelType::kMindIR_Lite);
|
mindspore::ModelType model_type = mindspore::ModelType::kMindIR_Lite,
|
||||||
|
const std::string &path = "");
|
||||||
LiteModel *LiteImportFromPath(const char *model_path);
|
LiteModel *LiteImportFromPath(const char *model_path);
|
||||||
Model *ImportFromPath(const char *model_path);
|
Model *ImportFromPath(const char *model_path);
|
||||||
} // namespace lite
|
} // namespace lite
|
||||||
|
|
|
@ -1706,6 +1706,20 @@ int lite::LiteSession::LoadModelAndCompileByBuf(const char *model_buf, mindspore
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string lite::LiteSession::ParseWeightPath() {
|
||||||
|
std::string weight_path = "";
|
||||||
|
if (config_info_ != nullptr) {
|
||||||
|
auto ms_weight = config_info_->find(kWeight);
|
||||||
|
if (ms_weight != config_info_->end()) {
|
||||||
|
auto ms_weight_iter = ms_weight->second;
|
||||||
|
if (ms_weight_iter.find(kWeightPath) != ms_weight_iter.end()) {
|
||||||
|
weight_path = ms_weight_iter[kWeightPath];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return weight_path;
|
||||||
|
}
|
||||||
|
|
||||||
int lite::LiteSession::LoadModelAndCompileByBuf(const char *model_buf, mindspore::ModelType model_type,
|
int lite::LiteSession::LoadModelAndCompileByBuf(const char *model_buf, mindspore::ModelType model_type,
|
||||||
const size_t &buf_size,
|
const size_t &buf_size,
|
||||||
const std::shared_ptr<mindspore::Context> &ms_context) {
|
const std::shared_ptr<mindspore::Context> &ms_context) {
|
||||||
|
@ -1716,7 +1730,8 @@ int lite::LiteSession::LoadModelAndCompileByBuf(const char *model_buf, mindspore
|
||||||
MS_LOG(ERROR) << "Invalid model_buf";
|
MS_LOG(ERROR) << "Invalid model_buf";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
auto *model = lite::ImportFromBuffer(lite_buf, lite_buf_size, true, model_type);
|
auto weight_path = ParseWeightPath();
|
||||||
|
auto *model = lite::ImportFromBuffer(lite_buf, lite_buf_size, true, model_type, weight_path);
|
||||||
if (model == nullptr) {
|
if (model == nullptr) {
|
||||||
MS_LOG(ERROR) << "Import model failed";
|
MS_LOG(ERROR) << "Import model failed";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
|
@ -1743,7 +1758,7 @@ int lite::LiteSession::LoadModelAndCompileByPath(const std::string &model_path,
|
||||||
MS_LOG(ERROR) << "Read model file failed";
|
MS_LOG(ERROR) << "Read model file failed";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
auto *model = lite::ImportFromBuffer(model_buf, model_size, true, model_type);
|
auto *model = lite::ImportFromBuffer(model_buf, model_size, true, model_type, model_path);
|
||||||
if (model == nullptr) {
|
if (model == nullptr) {
|
||||||
MS_LOG(ERROR) << "Import model failed";
|
MS_LOG(ERROR) << "Import model failed";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
|
@ -1768,7 +1783,7 @@ int lite::LiteSession::LoadModelAndCompileByPath(const std::string &model_path,
|
||||||
MS_LOG(ERROR) << "Read model file failed";
|
MS_LOG(ERROR) << "Read model file failed";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
auto *model = lite::ImportFromBuffer(model_buf, model_size, true, model_type);
|
auto *model = lite::ImportFromBuffer(model_buf, model_size, true, model_type, model_path);
|
||||||
if (model == nullptr) {
|
if (model == nullptr) {
|
||||||
MS_LOG(ERROR) << "Import model failed";
|
MS_LOG(ERROR) << "Import model failed";
|
||||||
delete[] model_buf;
|
delete[] model_buf;
|
||||||
|
|
|
@ -143,6 +143,7 @@ class LiteSession {
|
||||||
const std::vector<kernel::KernelExec *> &kernels,
|
const std::vector<kernel::KernelExec *> &kernels,
|
||||||
const std::unordered_map<Tensor *, Tensor *> &isolate_input_map = std::unordered_map<Tensor *, Tensor *>());
|
const std::unordered_map<Tensor *, Tensor *> &isolate_input_map = std::unordered_map<Tensor *, Tensor *>());
|
||||||
static void FreePackOpWeight(const std::vector<kernel::KernelExec *> &kernels);
|
static void FreePackOpWeight(const std::vector<kernel::KernelExec *> &kernels);
|
||||||
|
std::string ParseWeightPath();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int PreCheck(Model *model);
|
int PreCheck(Model *model);
|
||||||
|
|
|
@ -953,6 +953,20 @@ void BenchmarkUnifiedApi::ModelParallelRunnerRun(int task_num, int parallel_idx)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int BenchmarkUnifiedApi::AddConfigInfo(const std::shared_ptr<RunnerConfig> &runner_config) {
|
||||||
|
auto env = std::getenv("BENCHMARK_WEIGHT_PATH");
|
||||||
|
if (env == nullptr) {
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
auto weight_path = std::string(env);
|
||||||
|
if (weight_path != "") {
|
||||||
|
std::map<std::string, std::string> config;
|
||||||
|
config[kWeightPath] = weight_path;
|
||||||
|
runner_config->SetConfigInfo(kWeight, config);
|
||||||
|
}
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
int BenchmarkUnifiedApi::ParallelInference(std::shared_ptr<mindspore::Context> context) {
|
int BenchmarkUnifiedApi::ParallelInference(std::shared_ptr<mindspore::Context> context) {
|
||||||
if (flags_->warm_up_loop_count_ > kMaxRequestNum || flags_->parallel_num_ > kMaxRequestNum) {
|
if (flags_->warm_up_loop_count_ > kMaxRequestNum || flags_->parallel_num_ > kMaxRequestNum) {
|
||||||
MS_LOG(WARNING) << "in parallel predict warm up loop count should less than" << kMaxRequestNum;
|
MS_LOG(WARNING) << "in parallel predict warm up loop count should less than" << kMaxRequestNum;
|
||||||
|
@ -965,6 +979,8 @@ int BenchmarkUnifiedApi::ParallelInference(std::shared_ptr<mindspore::Context> c
|
||||||
auto runner_config = std::make_shared<RunnerConfig>();
|
auto runner_config = std::make_shared<RunnerConfig>();
|
||||||
runner_config->SetContext(context);
|
runner_config->SetContext(context);
|
||||||
runner_config->SetWorkersNum(flags_->workers_num_);
|
runner_config->SetWorkersNum(flags_->workers_num_);
|
||||||
|
auto status = AddConfigInfo(runner_config);
|
||||||
|
MS_CHECK_FALSE_MSG(status != kSuccess, RET_ERROR, "add config info for parallel predict failed.");
|
||||||
auto model_init_start = GetTimeUs();
|
auto model_init_start = GetTimeUs();
|
||||||
auto ret = model_runner_.Init(flags_->model_file_, runner_config);
|
auto ret = model_runner_.Init(flags_->model_file_, runner_config);
|
||||||
MS_CHECK_FALSE_MSG(ret != kSuccess, RET_ERROR, "model pool init failed.");
|
MS_CHECK_FALSE_MSG(ret != kSuccess, RET_ERROR, "model pool init failed.");
|
||||||
|
@ -974,13 +990,13 @@ int BenchmarkUnifiedApi::ParallelInference(std::shared_ptr<mindspore::Context> c
|
||||||
ms_inputs_for_api_ = model_runner_.GetInputs();
|
ms_inputs_for_api_ = model_runner_.GetInputs();
|
||||||
MS_CHECK_FALSE_MSG(ms_inputs_for_api_.empty(), RET_ERROR, "model pool input is empty.");
|
MS_CHECK_FALSE_MSG(ms_inputs_for_api_.empty(), RET_ERROR, "model pool input is empty.");
|
||||||
for (int i = 0; i < flags_->parallel_num_ + flags_->warm_up_loop_count_; i++) {
|
for (int i = 0; i < flags_->parallel_num_ + flags_->warm_up_loop_count_; i++) {
|
||||||
auto status = LoadInput();
|
status = LoadInput();
|
||||||
MS_CHECK_FALSE_MSG(status != RET_OK, status, "Generate input data error");
|
MS_CHECK_FALSE_MSG(status != RET_OK, status, "Generate input data error");
|
||||||
std::vector<MSTensor> output;
|
std::vector<MSTensor> output;
|
||||||
all_outputs_.push_back(output);
|
all_outputs_.push_back(output);
|
||||||
}
|
}
|
||||||
if (!flags_->benchmark_data_file_.empty()) {
|
if (!flags_->benchmark_data_file_.empty()) {
|
||||||
auto status = PrintInputData();
|
status = PrintInputData();
|
||||||
MS_CHECK_FALSE_MSG(status != RET_OK, status, "PrintInputData error ");
|
MS_CHECK_FALSE_MSG(status != RET_OK, status, "PrintInputData error ");
|
||||||
status = ReadCalibData();
|
status = ReadCalibData();
|
||||||
MS_CHECK_FALSE_MSG(status != RET_OK, status, "ReadCalibData error ");
|
MS_CHECK_FALSE_MSG(status != RET_OK, status, "ReadCalibData error ");
|
||||||
|
|
|
@ -99,6 +99,7 @@ class MS_API BenchmarkUnifiedApi : public BenchmarkBase {
|
||||||
void ModelParallelRunnerWarmUp(int index);
|
void ModelParallelRunnerWarmUp(int index);
|
||||||
void ModelParallelRunnerRun(int task_num, int parallel_idx);
|
void ModelParallelRunnerRun(int task_num, int parallel_idx);
|
||||||
int ParallelInference(std::shared_ptr<mindspore::Context> context);
|
int ParallelInference(std::shared_ptr<mindspore::Context> context);
|
||||||
|
int AddConfigInfo(const std::shared_ptr<RunnerConfig> &runner_config);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
|
Loading…
Reference in New Issue