copy from 1.6

This commit is contained in:
yefeng 2022-02-22 15:18:25 +08:00
parent c0d35aa950
commit c301bf16a2
15 changed files with 429 additions and 128 deletions

View File

@ -0,0 +1,69 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_INCLUDE_API_MODEL_PARALLEL_RUNNER_H
#define MINDSPORE_INCLUDE_API_MODEL_PARALLEL_RUNNER_H
#include <vector>
#include <memory>
#include <utility>
#include <string>
#include "include/api/status.h"
#include "include/api/context.h"
namespace mindspore {
struct RunnerConfig {
std::shared_ptr<Context> context = nullptr;
};
/// \brief The ModelParallelRunner class is used to define a MindSpore ModelParallelRunner, facilitating Model
/// management.
class MS_API ModelParallelRunner {
public:
ModelParallelRunner() = default;
~ModelParallelRunner() = default;
/// \brief build a model parallel runner from model path so that it can run on a device. Only valid for Lite.
///
/// \param[in] model_path Define the model path.
/// \param[in] runner_config Define the config used to store options during model pool init.
/// \param[in] dec_key Define the key used to decrypt the ciphertext model. The key length is 16, 24, or 32.
/// \param[in] dec_mode Define the decryption mode. Options: AES-GCM, AES-CBC.
///
/// \return Status.
Status Init(const std::string &model_path, const std::shared_ptr<RunnerConfig> &runner_config = nullptr,
const Key &dec_key = {}, const std::string &dec_mode = kDecModeAesGcm);
/// \brief Obtains all input tensors information of the model.
///
/// \return The vector that includes all input tensors.
std::vector<MSTensor> GetInputs();
/// \brief Obtains all output tensors information of the model.
///
/// \return The vector that includes all output tensors.
std::vector<MSTensor> GetOutputs();
/// \brief Inference ModelParallelRunner.
///
/// \param[in] inputs A vector where model inputs are arranged in sequence.
/// \param[out] outputs Which is a pointer to a vector. The model outputs are filled in the container in sequence.
/// \param[in] before CallBack before predict.
/// \param[in] after CallBack after predict.
///
/// \return Status.
Status Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
const MSKernelCallBack &before = nullptr, const MSKernelCallBack &after = nullptr);
};
} // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_MODEL_PARALLEL_RUNNER_H

View File

@ -69,7 +69,7 @@ if(MSLITE_ENABLE_SERVER_INFERENCE)
set(CXX_API_SRCS set(CXX_API_SRCS
${CXX_API_SRCS} ${CXX_API_SRCS}
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/model_pool/predict_task_queue.cc ${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/model_pool/predict_task_queue.cc
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/model_pool/model_thread.cc ${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/model_pool/model_worker.cc
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/model_pool/model_pool.cc ${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/model_pool/model_pool.cc
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/model_pool/model_parallel_runner.cc ${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/model_pool/model_parallel_runner.cc
) )

View File

@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#include "src/cxx_api/model_pool/model_parallel_runner.h" #include "include/api/model_parallel_runner.h"
#include "src/cxx_api/model_pool/model_pool.h" #include "src/cxx_api/model_pool/model_pool.h"
#include "src/common/log.h" #include "src/common/log.h"

View File

@ -19,6 +19,9 @@
#include "src/common/log.h" #include "src/common/log.h"
#include "include/lite_types.h" #include "include/lite_types.h"
#include "src/common/config_file.h" #include "src/common/config_file.h"
#include "src/runtime/inner_allocator.h"
#include "src/common//file_utils.h"
#include "src/pack_weight_manager.h"
namespace mindspore { namespace mindspore {
namespace { namespace {
constexpr int32_t kNumThreads = 4; constexpr int32_t kNumThreads = 4;
@ -36,12 +39,11 @@ int GetCoreNum() {
} // namespace } // namespace
void ModelPool::SetBindStrategy(std::vector<std::vector<int>> *all_model_bind_list, int thread_num) { void ModelPool::SetBindStrategy(std::vector<std::vector<int>> *all_model_bind_list, int thread_num) {
int core_num = GetCoreNum();
if (thread_num == 0) { if (thread_num == 0) {
MS_LOG(ERROR) << "thread num is zero."; MS_LOG(ERROR) << "thread num is zero.";
return; return;
} }
num_models_ = core_num / thread_num; int core_num = GetCoreNum();
int core_id = 0; int core_id = 0;
for (size_t i = 0; i < num_models_; i++) { for (size_t i = 0; i < num_models_; i++) {
std::vector<int> bind_id; std::vector<int> bind_id;
@ -68,16 +70,15 @@ std::shared_ptr<Context> ModelPool::InitContext(const std::shared_ptr<RunnerConf
return nullptr; return nullptr;
} }
if (runner_config != nullptr) { if (runner_config != nullptr) {
model_context = runner_config->model_ctx; model_context = runner_config->context;
num_models_ = runner_config->num_model;
auto device_list = model_context->MutableDeviceInfo(); auto device_list = model_context->MutableDeviceInfo();
if (device_list.size() != 1) { if (device_list.size() != 1) {
MS_LOG(ERROR) << "model pool only support device num 1."; MS_LOG(ERROR) << "model pool only support device num 1.";
return nullptr; return nullptr;
} }
auto device = device_list.front(); auto device = device_list.front();
if (device->GetDeviceType() != kCPU) { if (device->GetDeviceType() != kCPU && device->GetDeviceType() != kGPU) {
MS_LOG(ERROR) << "model pool only support cpu type."; MS_LOG(ERROR) << "model pool only support cpu or gpu type.";
return nullptr; return nullptr;
} }
auto cpu_context = device->Cast<CPUDeviceInfo>(); auto cpu_context = device->Cast<CPUDeviceInfo>();
@ -86,13 +87,19 @@ std::shared_ptr<Context> ModelPool::InitContext(const std::shared_ptr<RunnerConf
MS_LOG(ERROR) << "model pool not support enable fp16."; MS_LOG(ERROR) << "model pool not support enable fp16.";
return nullptr; return nullptr;
} }
if (device->GetDeviceType() == kGPU) {
num_models_ = 1;
} else {
num_models_ = GetCoreNum() / static_cast<int>(model_context->GetThreadNum());
}
} else { } else {
MS_LOG(DEBUG) << "use default config."; MS_LOG(DEBUG) << "use default config.";
num_models_ = GetCoreNum() / static_cast<int>(model_context->GetThreadNum());
model_context->SetThreadNum(kNumThreads); model_context->SetThreadNum(kNumThreads);
model_context->SetEnableParallel(false); model_context->SetEnableParallel(true);
model_context->SetThreadAffinity(lite::NO_BIND); model_context->SetThreadAffinity(lite::HIGHER_CPU);
auto &device_list = model_context->MutableDeviceInfo(); auto &device_list = model_context->MutableDeviceInfo();
auto device_info = std::shared_ptr<CPUDeviceInfo>(); auto device_info = std::make_shared<CPUDeviceInfo>();
device_info->SetEnableFP16(false); device_info->SetEnableFP16(false);
device_list.push_back(device_info); device_list.push_back(device_info);
} }
@ -109,7 +116,6 @@ ModelPoolContex ModelPool::CreateModelContext(const std::shared_ptr<RunnerConfig
MS_LOG(ERROR) << "thread num is zero."; MS_LOG(ERROR) << "thread num is zero.";
return {}; return {};
} }
num_models_ = GetCoreNum() / static_cast<int>(model_context->GetThreadNum());
ModelPoolContex model_pool_context; ModelPoolContex model_pool_context;
std::vector<std::vector<int>> all_model_bind_list; std::vector<std::vector<int>> all_model_bind_list;
if (model_context->GetThreadAffinityMode() == lite::HIGHER_CPU) { if (model_context->GetThreadAffinityMode() == lite::HIGHER_CPU) {
@ -165,10 +171,21 @@ Status ModelPool::Init(const std::string &model_path, const std::shared_ptr<Runn
MS_LOG(ERROR) << "CreateModelContext failed, context is empty."; MS_LOG(ERROR) << "CreateModelContext failed, context is empty.";
return kLiteError; return kLiteError;
} }
size_t size = 0;
graph_buf_ = lite::ReadFile(model_path.c_str(), &size);
if (graph_buf_ == nullptr) {
MS_LOG(ERROR) << "read file failed.";
return kLiteError;
}
lite::PackWeightManager::GetInstance()->InitWeightManagerByBuf(graph_buf_);
std::shared_ptr<ModelThread> model_thread = nullptr; std::shared_ptr<ModelThread> model_thread = nullptr;
for (size_t i = 0; i < num_models_; i++) { for (size_t i = 0; i < num_models_; i++) {
model_thread = std::make_shared<ModelThread>(); model_thread = std::make_shared<ModelThread>();
auto status = model_thread->Init(model_path, model_pool_context[i], dec_key, dec_mode); auto status = model_thread->Init(graph_buf_, size, model_pool_context[i], dec_key, dec_mode);
if (status != kSuccess) {
MS_LOG(ERROR) << " model thread init failed.";
return kLiteError;
}
model_thread_vec_.push_back(std::thread(&ModelThread::Run, model_thread)); model_thread_vec_.push_back(std::thread(&ModelThread::Run, model_thread));
} }
if (model_thread != nullptr) { if (model_thread != nullptr) {
@ -178,44 +195,65 @@ Status ModelPool::Init(const std::string &model_path, const std::shared_ptr<Runn
return kSuccess; return kSuccess;
} }
Status ModelPool::SplitTensorByBatch(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs, Status ModelPool::SplitInputTensorByBatch(const std::vector<MSTensor> &inputs,
std::vector<std::vector<MSTensor>> *new_inputs) { std::vector<std::vector<MSTensor>> *new_inputs, size_t batch_split_num) {
auto batch = inputs[0].Shape()[0]; if (batch_split_num == 0) {
if (batch % batch_split_num_ != 0) { MS_LOG(ERROR) << "batch_split_num is zero.";
MS_LOG(DEBUG) << "Can not split input tensor."; return kLiteError;
return kLiteSuccessExit;
} }
auto batch = inputs[0].Shape()[0];
std::vector<size_t> split_batch;
size_t batch_sum = 0;
size_t per_batch = batch / batch_split_num;
for (size_t i = 0; i < batch_split_num - 1; i++) {
split_batch.push_back(per_batch);
batch_sum += per_batch;
}
split_batch.push_back(batch - batch_sum);
std::vector<std::vector<std::vector<int64_t>>> all_input_shape; std::vector<std::vector<std::vector<int64_t>>> all_input_shape;
for (size_t k = 0; k < batch_split_num_; k++) { // do for batch std::vector<size_t> input_data_split_size(inputs.size(), 0);
for (size_t k = 0; k < batch_split_num; k++) { // do for batch
std::vector<std::vector<int64_t>> inputs_shape; std::vector<std::vector<int64_t>> inputs_shape;
std::vector<MSTensor> new_inputs_tensor; std::vector<MSTensor> new_inputs_tensor;
for (size_t i = 0; i < inputs.size(); i++) { // do for input for (size_t i = 0; i < inputs.size(); i++) { // do for input
std::vector<int64_t> shape; std::vector<int64_t> shape;
size_t input_size = batch / batch_split_num_; size_t input_size = split_batch[k];
shape.push_back(batch / batch_split_num_); shape.push_back(split_batch[k]);
for (size_t j = 1; j < inputs[i].Shape().size(); j++) { // do for dims for (size_t j = 1; j < inputs[i].Shape().size(); j++) { // do for dims
shape.push_back(inputs[i].Shape()[j]); shape.push_back(inputs[i].Shape()[j]);
input_size *= inputs[i].Shape()[j]; input_size *= inputs[i].Shape()[j];
} }
inputs_shape.push_back(shape); inputs_shape.push_back(shape);
if (inputs[i].DataType() == static_cast<enum DataType>(kNumberTypeFloat32)) { if (inputs[i].DataType() == static_cast<enum DataType>(kNumberTypeFloat32)) {
void *data = malloc(input_size * sizeof(float)); if (input_size * sizeof(float) > MAX_MALLOC_SIZE) {
memcpy(reinterpret_cast<float *>(data), MS_LOG(ERROR) << "malloc size is wrong.";
reinterpret_cast<float *>(const_cast<MSTensor &>(inputs[i]).MutableData()) + input_size * k, return kLiteError;
input_size * sizeof(float)); }
auto new_tensor = mindspore::MSTensor::CreateTensor( auto data =
inputs[i].Name(), static_cast<enum DataType>(kNumberTypeFloat32), shape, data, input_size * sizeof(float)); reinterpret_cast<float *>(const_cast<MSTensor &>(inputs[i]).MutableData()) + input_data_split_size[i];
new_inputs_tensor.push_back(*new_tensor); auto new_tensor = MSTensor(inputs[i].Name(), static_cast<enum DataType>(kNumberTypeFloat32), shape, data,
free(data); input_size * sizeof(float));
if (new_tensor == nullptr) {
MS_LOG(ERROR) << "create tensor failed.";
return kLiteError;
}
new_inputs_tensor.push_back(new_tensor);
input_data_split_size[i] += input_size;
} else if (inputs[i].DataType() == static_cast<enum DataType>(kNumberTypeInt32)) { } else if (inputs[i].DataType() == static_cast<enum DataType>(kNumberTypeInt32)) {
void *data = malloc(input_size * sizeof(int32_t)); if (input_size * sizeof(int32_t) > MAX_MALLOC_SIZE) {
memcpy(reinterpret_cast<int32_t *>(data), MS_LOG(ERROR) << "malloc size is wrong.";
reinterpret_cast<int32_t *>(const_cast<MSTensor &>(inputs[i]).MutableData()) + input_size * k, return kLiteError;
input_size * sizeof(int32_t)); }
auto new_tensor = mindspore::MSTensor::CreateTensor( auto data =
inputs[i].Name(), static_cast<enum DataType>(kNumberTypeInt32), shape, data, input_size * sizeof(int32_t)); reinterpret_cast<int32_t *>(const_cast<MSTensor &>(inputs[i]).MutableData()) + input_data_split_size[i];
new_inputs_tensor.push_back(*new_tensor); auto new_tensor = MSTensor(inputs[i].Name(), static_cast<enum DataType>(kNumberTypeInt32), shape, data,
free(data); input_size * sizeof(int32_t));
if (new_tensor == nullptr) {
MS_LOG(ERROR) << "create tensor failed.";
return kLiteError;
}
new_inputs_tensor.push_back(new_tensor);
input_data_split_size[i] += input_size;
} else { } else {
MS_LOG(ERROR) << "not support data type in split batch."; MS_LOG(ERROR) << "not support data type in split batch.";
return kLiteError; return kLiteError;
@ -227,71 +265,166 @@ Status ModelPool::SplitTensorByBatch(const std::vector<MSTensor> &inputs, std::v
return kSuccess; return kSuccess;
} }
Status ModelPool::SplitOutputTensorByBatch(std::vector<std::vector<MSTensor>> *new_outputs,
std::vector<MSTensor> *outputs, size_t batch_split_num) {
if (batch_split_num == 0) {
MS_LOG(ERROR) << "batch_split_num is zero.";
return kLiteError;
}
for (size_t i = 0; i < batch_split_num; i++) {
std::vector<MSTensor> new_output;
for (size_t tensor_num_idx = 0; tensor_num_idx < outputs->size(); tensor_num_idx++) {
if (outputs->at(tensor_num_idx).MutableData() != nullptr && outputs->at(tensor_num_idx).DataSize() != 0) {
is_user_data_ = true;
auto data = reinterpret_cast<float *>(outputs->at(tensor_num_idx).MutableData()) +
outputs->at(tensor_num_idx).Shape().at(0) / batch_split_num * i;
auto out_tensor =
MSTensor(outputs->at(tensor_num_idx).Name(), outputs->at(tensor_num_idx).DataType(), {}, data, 0);
new_output.push_back(out_tensor);
}
}
new_outputs->push_back(new_output);
}
return kSuccess;
}
Status ModelPool::ConcatPredictOutput(std::vector<std::vector<MSTensor>> *outputs, std::vector<MSTensor> *new_outputs) { Status ModelPool::ConcatPredictOutput(std::vector<std::vector<MSTensor>> *outputs, std::vector<MSTensor> *new_outputs) {
if (outputs->empty()) {
MS_LOG(ERROR) << "output is empty";
return kLiteError;
}
for (size_t i = 0; i < outputs->at(0).size(); i++) { for (size_t i = 0; i < outputs->at(0).size(); i++) {
std::vector<int64_t> output_tensor_shape = outputs->at(0)[i].Shape(); std::vector<int64_t> output_tensor_shape = outputs->at(0)[i].Shape();
output_tensor_shape[0] *= batch_split_num_; if (output_tensor_shape.empty()) {
MS_LOG(ERROR) << "output_tensor_shape is empty";
return kLiteError;
}
size_t all_data_size = 0;
size_t all_batch_size = 0;
std::vector<size_t> per_bacth_data_size;
for (size_t batch = 0; batch < outputs->size(); batch++) {
per_bacth_data_size.push_back(all_data_size);
all_data_size += outputs->at(batch).at(i).DataSize();
all_batch_size += outputs->at(batch).at(i).Shape().front();
}
output_tensor_shape[0] = all_batch_size;
if (is_user_data_) {
new_outputs->at(i).SetShape(output_tensor_shape);
continue;
}
auto all_out_data = malloc(all_data_size);
if (all_out_data == nullptr) {
MS_LOG(ERROR) << "all_out_data is nullptr.";
return kLiteError;
}
for (size_t j = 0; j < outputs->size(); j++) {
void *out_data = outputs->at(j).at(i).MutableData();
if (out_data == nullptr) {
free(all_out_data);
all_out_data = nullptr;
MS_LOG(ERROR) << "output data is nullptr.";
return kLiteError;
}
memcpy(reinterpret_cast<float *>(all_out_data) + per_bacth_data_size[j] / sizeof(float),
reinterpret_cast<float *>(out_data), outputs->at(j)[i].DataSize());
}
auto new_tensor = mindspore::MSTensor::CreateTensor(outputs->at(0)[i].Name(), outputs->at(0)[i].DataType(),
output_tensor_shape, all_out_data, all_data_size);
if (new_tensor == nullptr) {
MS_LOG(ERROR) << "create tensor failed.";
return kLiteError;
}
if (all_out_data != nullptr) { if (all_out_data != nullptr) {
free(all_out_data); free(all_out_data);
all_out_data = nullptr; all_out_data = nullptr;
} }
all_out_data = malloc(outputs->at(0).at(i).DataSize() * batch_split_num_);
for (size_t j = 0; j < batch_split_num_; j++) {
void *out_data = outputs->at(j)[i].MutableData();
memcpy(reinterpret_cast<float *>(all_out_data) + outputs->at(j)[i].ElementNum() * j,
reinterpret_cast<float *>(out_data), outputs->at(j)[i].DataSize());
}
auto new_tensor =
mindspore::MSTensor::CreateTensor(outputs->at(0)[i].Name(), outputs->at(i)[0].DataType(), output_tensor_shape,
all_out_data, outputs->at(0)[i].DataSize() * batch_split_num_);
new_outputs->push_back(*new_tensor); new_outputs->push_back(*new_tensor);
delete new_tensor;
}
return kSuccess;
}
Status ModelPool::FreeSplitTensor(std::vector<std::vector<MSTensor>> *new_inputs,
std::vector<std::vector<MSTensor>> *new_outputs) {
for (size_t i = 0; i < new_inputs->size(); i++) {
for (size_t j = 0; j < new_inputs->at(i).size(); j++) {
new_inputs->at(i).at(j).SetData(nullptr);
}
}
new_inputs->clear();
if (is_user_data_) {
for (size_t i = 0; i < new_outputs->size(); i++) {
for (size_t j = 0; j < new_outputs->at(i).size(); j++) {
new_outputs->at(i).at(j).SetData(nullptr);
}
}
new_outputs->clear();
} }
return kSuccess; return kSuccess;
} }
Status ModelPool::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs, Status ModelPool::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
const MSKernelCallBack &before, const MSKernelCallBack &after) { const MSKernelCallBack &before, const MSKernelCallBack &after) {
outputs->clear(); mtx_split_task_.lock();
if (PredictTaskQueue::GetInstance()->GetTaskNum() == 0 && auto wait_model_num = PredictTaskQueue::GetInstance()->GetWaitModelNum();
batch_split_num_ <= static_cast<size_t>(PredictTaskQueue::GetInstance()->GetWaitModelNum())) { auto batch = inputs[0].Shape()[0];
if (PredictTaskQueue::GetInstance()->GetTaskNum() == 0 && wait_model_num > 1 && batch >= wait_model_num) {
size_t batch_split_num = PredictTaskQueue::GetInstance()->GetWaitModelNum();
PredictTaskQueue::GetInstance()->DecreaseWaitModelNum(batch_split_num);
std::vector<std::vector<MSTensor>> new_inputs; std::vector<std::vector<MSTensor>> new_inputs;
std::vector<std::vector<MSTensor>> new_outputs; std::vector<std::vector<MSTensor>> new_outputs;
auto status = SplitTensorByBatch(inputs, outputs, &new_inputs); auto status = SplitInputTensorByBatch(inputs, &new_inputs, batch_split_num);
if (status != kSuccess) { if (status != kSuccess) {
MS_LOG(ERROR) << "model pool predict failed."; MS_LOG(ERROR) << "model pool split input tensor by batch failed.";
return kLiteError; return kLiteError;
} }
for (size_t i = 0; i < batch_split_num_; i++) { status = SplitOutputTensorByBatch(&new_outputs, outputs, batch_split_num);
std::vector<MSTensor> new_output; if (status != kSuccess) {
new_outputs.push_back(new_output); MS_LOG(ERROR) << "model pool split output tensor by batch failed.";
return kLiteError;
} }
for (size_t i = 0; i < batch_split_num_; i++) {
auto predict_task = std::make_shared<PredictTask>(&new_inputs[i], &new_outputs[i], before, after); std::vector<std::shared_ptr<PredictTask>> tasks;
for (size_t i = 0; i < batch_split_num; i++) {
auto predict_task = std::make_shared<PredictTask>(&new_inputs[i], &new_outputs.at(i), before, after);
PredictTaskQueue::GetInstance()->PushPredictTask(predict_task); PredictTaskQueue::GetInstance()->PushPredictTask(predict_task);
tasks.push_back(predict_task);
} }
for (size_t i = 0; i < batch_split_num_; i++) { mtx_split_task_.unlock();
PredictTaskQueue::GetInstance()->WaitUntilPredictActive(&new_outputs[i]); for (size_t i = 0; i < batch_split_num; i++) {
PredictTaskQueue::GetInstance()->WaitUntilPredictActive(tasks[i]);
} }
status = ConcatPredictOutput(&new_outputs, outputs); status = ConcatPredictOutput(&new_outputs, outputs);
if (status != kSuccess) { if (status != kSuccess) {
MS_LOG(ERROR) << "ConcatPredictOutput failed."; MS_LOG(ERROR) << "ConcatPredictOutput failed.";
return kLiteError; return kLiteError;
} }
status = FreeSplitTensor(&new_inputs, &new_outputs);
if (status != kSuccess) {
MS_LOG(ERROR) << "free split tensor failed.";
return kLiteError;
}
} else { } else {
if (wait_model_num == 1) {
PredictTaskQueue::GetInstance()->DecreaseWaitModelNum(1);
}
auto predict_task = std::make_shared<PredictTask>(&inputs, outputs, before, after); auto predict_task = std::make_shared<PredictTask>(&inputs, outputs, before, after);
PredictTaskQueue::GetInstance()->PushPredictTask(predict_task); PredictTaskQueue::GetInstance()->PushPredictTask(predict_task);
PredictTaskQueue::GetInstance()->WaitUntilPredictActive(outputs); mtx_split_task_.unlock();
PredictTaskQueue::GetInstance()->WaitUntilPredictActive(predict_task);
} }
return kSuccess; return kSuccess;
} }
ModelPool::~ModelPool() { ModelPool::~ModelPool() {
if (graph_buf_ != nullptr) {
delete[] graph_buf_;
graph_buf_ = nullptr;
}
for (auto &th : model_thread_vec_) { for (auto &th : model_thread_vec_) {
if (th.joinable()) { if (th.joinable()) {
th.join(); th.join();
} }
} }
free(all_out_data);
all_out_data = nullptr;
} }
} // namespace mindspore } // namespace mindspore

View File

@ -23,15 +23,10 @@
#include <map> #include <map>
#include "include/api/status.h" #include "include/api/status.h"
#include "include/api/context.h" #include "include/api/context.h"
#include "src/cxx_api/model_pool/model_thread.h" #include "include/api/model_parallel_runner.h"
#include "src/cxx_api/model_pool/model_worker.h"
#include "src/cxx_api/model_pool/predict_task_queue.h" #include "src/cxx_api/model_pool/predict_task_queue.h"
namespace mindspore { namespace mindspore {
struct RunnerConfig {
RunnerConfig(std::shared_ptr<Context> &ctx, int num) : model_ctx(ctx), num_model(num) {}
std::shared_ptr<Context> model_ctx = nullptr;
int num_model = 10;
};
class ModelPool { class ModelPool {
public: public:
static ModelPool *GetInstance(); static ModelPool *GetInstance();
@ -52,16 +47,21 @@ class ModelPool {
void SetBindStrategy(std::vector<std::vector<int>> *all_model_bind_list, int thread_num); void SetBindStrategy(std::vector<std::vector<int>> *all_model_bind_list, int thread_num);
ModelPoolContex CreateModelContext(const std::shared_ptr<RunnerConfig> &runner_config); ModelPoolContex CreateModelContext(const std::shared_ptr<RunnerConfig> &runner_config);
std::shared_ptr<Context> InitContext(const std::shared_ptr<RunnerConfig> &runner_config); std::shared_ptr<Context> InitContext(const std::shared_ptr<RunnerConfig> &runner_config);
Status SplitTensorByBatch(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs, Status SplitInputTensorByBatch(const std::vector<MSTensor> &inputs, std::vector<std::vector<MSTensor>> *new_inputs,
std::vector<std::vector<MSTensor>> *new_inputs); size_t batch_split_num);
Status SplitOutputTensorByBatch(std::vector<std::vector<MSTensor>> *outputs, std::vector<MSTensor> *new_outputs,
size_t batch_split_num);
Status ConcatPredictOutput(std::vector<std::vector<MSTensor>> *outputs, std::vector<MSTensor> *new_outputs); Status ConcatPredictOutput(std::vector<std::vector<MSTensor>> *outputs, std::vector<MSTensor> *new_outputs);
Status FreeSplitTensor(std::vector<std::vector<MSTensor>> *new_inputs,
std::vector<std::vector<MSTensor>> *new_outputs);
void *all_out_data = nullptr;
std::vector<std::thread> model_thread_vec_; std::vector<std::thread> model_thread_vec_;
std::vector<MSTensor> model_inputs_; std::vector<MSTensor> model_inputs_;
std::vector<MSTensor> model_outputs_; std::vector<MSTensor> model_outputs_;
char *graph_buf_ = nullptr;
size_t num_models_ = 10; size_t num_models_ = 10;
size_t batch_split_num_ = 4; std::mutex mtx_split_task_;
bool is_user_data_ = false;
}; };
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_MODEL_POOL_MODEL_POOL_H #endif // MINDSPORE_INCLUDE_API_MODEL_POOL_MODEL_POOL_H

View File

@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#include "src/cxx_api/model_pool/model_thread.h" #include "src/cxx_api/model_pool/model_worker.h"
#include "src/common/log.h" #include "src/common/log.h"
#include "src/common/utils.h" #include "src/common/utils.h"
namespace mindspore { namespace mindspore {
@ -30,25 +30,39 @@ void ModelThread::Run() {
auto status = Predict(*inputs, outputs, before, after); auto status = Predict(*inputs, outputs, before, after);
if (status != kSuccess) { if (status != kSuccess) {
MS_LOG(ERROR) << "model predict failed."; MS_LOG(ERROR) << "model predict failed.";
return; task->ready = true;
PredictTaskQueue::GetInstance()->ActiveTask();
continue;
} }
auto output_size = outputs->size(); if (is_copy_output_) {
for (size_t i = 0; i < output_size; i++) { std::vector<MSTensor> new_outputs;
auto copy_tensor = auto output_size = outputs->size();
mindspore::MSTensor::CreateTensor(outputs->at(i).Name(), outputs->at(i).DataType(), outputs->at(i).Shape(), for (size_t i = 0; i < output_size; i++) {
outputs->at(i).MutableData(), outputs->at(i).DataSize()); auto copy_tensor =
outputs->erase(outputs->begin()); mindspore::MSTensor::CreateTensor(outputs->at(i).Name(), outputs->at(i).DataType(), outputs->at(i).Shape(),
outputs->push_back(*copy_tensor); outputs->at(i).MutableData(), outputs->at(i).DataSize());
if (copy_tensor == nullptr) {
MS_LOG(ERROR) << "model thread copy output tensor failed.";
task->ready = true;
PredictTaskQueue::GetInstance()->ActiveTask();
continue;
}
new_outputs.push_back(*copy_tensor);
delete copy_tensor;
}
outputs->clear();
outputs->insert(outputs->end(), new_outputs.begin(), new_outputs.end());
} }
task->ready = true;
PredictTaskQueue::GetInstance()->ActiveTask(); PredictTaskQueue::GetInstance()->ActiveTask();
} }
} }
Status ModelThread::Init(const std::string &model_path, const std::shared_ptr<Context> &model_context, Status ModelThread::Init(const char *model_buf, size_t size, const std::shared_ptr<Context> &model_context,
const Key &dec_key, const std::string &dec_mode) { const Key &dec_key, const std::string &dec_mode) {
model_ = std::make_shared<Model>(); model_ = std::make_shared<Model>();
mindspore::ModelType model_type = kMindIR; mindspore::ModelType model_type = kMindIR;
auto status = model_->Build(model_path, model_type, model_context, dec_key, dec_mode); auto status = model_->Build(model_buf, size, model_type, model_context, dec_key, dec_mode);
if (status != kSuccess) { if (status != kSuccess) {
MS_LOG(ERROR) << "model build failed in ModelPool Init"; MS_LOG(ERROR) << "model build failed in ModelPool Init";
return status; return status;
@ -107,11 +121,31 @@ Status ModelThread::Predict(const std::vector<MSTensor> &inputs, std::vector<MST
return kLiteError; return kLiteError;
} }
} }
auto status = model_->Predict(inputs, outputs, before, after); auto model_output = model_->GetOutputs();
for (size_t i = 0; i < outputs->size(); i++) {
if (outputs->at(i).MutableData() != nullptr) {
/* user set graph-output-tensor from outside */
model_output[i].SetData(outputs->at(i).MutableData());
model_output[i].SetAllocator(nullptr);
is_copy_output_ = false;
}
}
auto status = model_->Predict(inputs, &model_output, before, after);
if (status != kSuccess) { if (status != kSuccess) {
MS_LOG(ERROR) << "model predict failed."; MS_LOG(ERROR) << "model predict failed.";
return status; return status;
} }
if (is_copy_output_) {
outputs->clear();
outputs->insert(outputs->end(), model_output.begin(), model_output.end());
} else {
model_output = model_->GetOutputs();
for (size_t i = 0; i < outputs->size(); i++) {
outputs->at(i).SetShape(model_output[i].Shape());
model_output[i].SetData(nullptr);
model_output[i].SetAllocator(nullptr);
}
}
return kSuccess; return kSuccess;
} }
} // namespace mindspore } // namespace mindspore

View File

@ -35,8 +35,8 @@ class ModelThread {
~ModelThread() = default; ~ModelThread() = default;
// the model pool is initialized once and can always accept model run requests // the model pool is initialized once and can always accept model run requests
Status Init(const std::string &model_path, const std::shared_ptr<Context> &model_context, const Key &dec_key = {}, Status Init(const char *model_buf, size_t size, const std::shared_ptr<Context> &model_context,
const std::string &dec_mode = kDecModeAesGcm); const Key &dec_key = {}, const std::string &dec_mode = kDecModeAesGcm);
std::vector<MSTensor> GetInputs(); std::vector<MSTensor> GetInputs();
@ -58,6 +58,7 @@ class ModelThread {
// num thread is configured according to the hardware // num thread is configured according to the hardware
int num_models_; int num_models_;
bool is_copy_output_ = true;
}; };
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_LITE_SRC_CXX_API_MODEL_POOL_MODEL_THREAD_H_ #endif // MINDSPORE_LITE_SRC_CXX_API_MODEL_POOL_MODEL_THREAD_H_

View File

@ -21,9 +21,9 @@ PredictTaskQueue::~PredictTaskQueue() {
task_push_cond_.notify_all(); task_push_cond_.notify_all();
} }
void PredictTaskQueue::WaitUntilPredictActive(std::vector<MSTensor> *outputs) { void PredictTaskQueue::WaitUntilPredictActive(std::shared_ptr<PredictTask> task) {
std::unique_lock<std::mutex> result_lock(mtx_predict_task_); std::unique_lock<std::mutex> result_lock(mtx_predict_task_);
while (outputs->empty()) { while (!task->ready) {
task_pop_cond_.wait(result_lock); task_pop_cond_.wait(result_lock);
} }
return; return;
@ -48,7 +48,6 @@ std::shared_ptr<PredictTask> PredictTaskQueue::GetPredictTask() {
waite_model_num_++; waite_model_num_++;
task_push_cond_.wait(task_lock); task_push_cond_.wait(task_lock);
} }
waite_model_num_--;
if (predict_task_done_) { if (predict_task_done_) {
return nullptr; return nullptr;
} }

View File

@ -26,12 +26,13 @@
namespace mindspore { namespace mindspore {
struct PredictTask { struct PredictTask {
PredictTask(const std::vector<MSTensor> *in, std::vector<MSTensor> *out, MSKernelCallBack before, PredictTask(const std::vector<MSTensor> *in, std::vector<MSTensor> *out, MSKernelCallBack before,
MSKernelCallBack after) MSKernelCallBack after, bool ready = false)
: inputs(in), outputs(out), before(before), after(after) {} : inputs(in), outputs(out), before(before), after(after), ready(ready) {}
const std::vector<MSTensor> *inputs; const std::vector<MSTensor> *inputs;
std::vector<MSTensor> *outputs; std::vector<MSTensor> *outputs;
MSKernelCallBack before; MSKernelCallBack before;
MSKernelCallBack after; MSKernelCallBack after;
bool ready;
}; };
class PredictTaskQueue { class PredictTaskQueue {
@ -40,12 +41,13 @@ class PredictTaskQueue {
~PredictTaskQueue(); ~PredictTaskQueue();
void PushPredictTask(std::shared_ptr<PredictTask> task); void PushPredictTask(std::shared_ptr<PredictTask> task);
void WaitUntilPredictActive(std::vector<MSTensor> *outputs); void WaitUntilPredictActive(std::shared_ptr<PredictTask> task);
std::shared_ptr<PredictTask> GetPredictTask(); std::shared_ptr<PredictTask> GetPredictTask();
void ActiveTask(); void ActiveTask();
bool IsPredictTaskDone() { return predict_task_done_; } bool IsPredictTaskDone() { return predict_task_done_; }
int GetTaskNum(); int GetTaskNum();
int GetWaitModelNum() { return waite_model_num_; } int GetWaitModelNum() { return waite_model_num_; }
void DecreaseWaitModelNum(int num) { waite_model_num_ -= num; }
private: private:
PredictTaskQueue() = default; PredictTaskQueue() = default;

View File

@ -26,6 +26,18 @@ PackWeightManager *PackWeightManager::GetInstance() {
return &instance; return &instance;
} }
void PackWeightManager::InitWeightManagerByBuf(const char *model_buf) {
MS_CHECK_TRUE_RET_VOID(model_buf != nullptr);
if (buf_model_weight_.find(model_buf) == buf_model_weight_.end()) {
auto *model_const_weight = new (std::nothrow) ModelConstWeight();
if (model_const_weight == nullptr) {
MS_LOG(ERROR) << "model_const_weight is nullptr.";
return;
}
buf_model_weight_[model_buf] = model_const_weight;
}
}
void PackWeightManager::InitWeightManagerByPath(const std::string &model_path, const char *model_buf) { void PackWeightManager::InitWeightManagerByPath(const std::string &model_path, const char *model_buf) {
MS_CHECK_TRUE_RET_VOID(model_buf != nullptr); MS_CHECK_TRUE_RET_VOID(model_buf != nullptr);
if (path_model_buf_.find(model_path) == path_model_buf_.end()) { if (path_model_buf_.find(model_path) == path_model_buf_.end()) {
@ -49,7 +61,13 @@ STATUS PackWeightManager::StoreLiteModel(const char *model_buf, const Model *mod
return RET_OK; return RET_OK;
} }
} }
{
if (buf_model_weight_.find(model_buf) == buf_model_weight_.end()) {
MS_LOG(ERROR) << "set model failed.";
return RET_ERROR;
}
buf_model_weight_[model_buf]->lite_models.push_back(model);
}
return RET_OK; return RET_OK;
} }
@ -69,6 +87,19 @@ void *PackWeightManager::GetTensorData(const LiteModel *model, const SchemaTenso
return nullptr; return nullptr;
} }
} }
for (auto &item : buf_model_weight_) {
auto &model_buf = item.first;
auto &model_weight = item.second;
auto &models = model_weight->lite_models;
if (find(models.begin(), models.end(), model) != models.end()) {
if (model_weight->packed_weight.find(tensor_index) != model_weight->packed_weight.end()) {
return model_weight->packed_weight[tensor_index];
}
buf_model_weight_[model_buf]->origin_weight[tensor_index] = origin_tensor->data();
buf_model_weight_[model_buf]->origin_data_index[origin_tensor->data()] = tensor_index;
return nullptr;
}
}
MS_LOG(DEBUG) << "tensor data not packed."; MS_LOG(DEBUG) << "tensor data not packed.";
return nullptr; return nullptr;
} }
@ -113,6 +144,13 @@ std::pair<PackStatus, void *> PackWeightManager::GetPackedTensor(const Tensor *t
return packed_tensor_pair; return packed_tensor_pair;
} }
} }
for (auto &item : buf_model_weight_) {
auto &model_weight = item.second;
auto packed_tensor_pair = FindPackedTensor(model_weight, tensor, round_size);
if (packed_tensor_pair.second != nullptr) {
return packed_tensor_pair;
}
}
MS_LOG(DEBUG) << "not const tensor, need pack in kernel."; MS_LOG(DEBUG) << "not const tensor, need pack in kernel.";
return std::make_pair(MALLOC, nullptr); return std::make_pair(MALLOC, nullptr);
} }
@ -127,6 +165,13 @@ void PackWeightManager::DeleteSavedModelPtr(LiteModel *delete_model) {
weight->lite_models.erase(it); weight->lite_models.erase(it);
} }
} }
for (auto &item : buf_model_weight_) {
auto &weight = item.second;
auto it = find(weight->lite_models.begin(), weight->lite_models.end(), delete_model);
if (it != weight->lite_models.end()) {
weight->lite_models.erase(it);
}
}
} }
void PackWeightManager::FreePackedWeight(ModelConstWeight *weight) { void PackWeightManager::FreePackedWeight(ModelConstWeight *weight) {
@ -154,6 +199,10 @@ PackWeightManager::~PackWeightManager() {
FreePackedWeight(item.second); FreePackedWeight(item.second);
path_model_weight_.erase(item.first); path_model_weight_.erase(item.first);
} }
for (auto &item : buf_model_weight_) {
FreePackedWeight(item.second);
buf_model_weight_.erase(item.first);
}
} }
} // namespace mindspore::lite } // namespace mindspore::lite
#endif #endif

View File

@ -45,6 +45,7 @@ class PackWeightManager {
virtual ~PackWeightManager(); virtual ~PackWeightManager();
void InitWeightManagerByPath(const std::string &model_path, const char *model_buf); void InitWeightManagerByPath(const std::string &model_path, const char *model_buf);
void InitWeightManagerByBuf(const char *model_buf);
void DeleteSavedModelPtr(LiteModel *delete_model); void DeleteSavedModelPtr(LiteModel *delete_model);
STATUS StoreLiteModel(const char *model_buf, const Model *model); STATUS StoreLiteModel(const char *model_buf, const Model *model);
void *GetTensorData(const LiteModel *model, const SchemaTensorWrapper *origin_tensor, size_t tensor_index); void *GetTensorData(const LiteModel *model, const SchemaTensorWrapper *origin_tensor, size_t tensor_index);
@ -56,6 +57,7 @@ class PackWeightManager {
void FreePackedWeight(ModelConstWeight *weight); void FreePackedWeight(ModelConstWeight *weight);
std::map<const std::string, ModelConstWeight *> path_model_weight_; std::map<const std::string, ModelConstWeight *> path_model_weight_;
std::map<const std::string, ModelConstWeight *> buf_model_weight_;
std::map<const std::string, std::vector<const void *>> path_model_buf_; std::map<const std::string, std::vector<const void *>> path_model_buf_;
std::mutex mtx_weight_; std::mutex mtx_weight_;
}; };

View File

@ -164,11 +164,11 @@ int ArithmeticCPUKernel::InitIndexOffsetInfo() {
delta = delta % batch_size[j]; delta = delta % batch_size[j];
} }
if (j < last_batch_axis) { if (j < last_batch_axis) {
a_offset += (delta / batch_size[j + 1] * a_shape[j] / MSMAX(a_shape[j], b_shape[j])) * a_batch_size[j + 1]; a_offset += (delta / batch_size[j + 1] * a_shape[j] / c_shape[j]) * a_batch_size[j + 1];
b_offset += (delta / batch_size[j + 1] * b_shape[j] / MSMAX(a_shape[j], b_shape[j])) * b_batch_size[j + 1]; b_offset += (delta / batch_size[j + 1] * b_shape[j] / c_shape[j]) * b_batch_size[j + 1];
} else { } else {
a_offset += (delta * a_shape[j] / MSMAX(a_shape[j], b_shape[j])); a_offset += (delta * a_shape[j] / c_shape[j]);
b_offset += (delta * b_shape[j] / MSMAX(a_shape[j], b_shape[j])); b_offset += (delta * b_shape[j] / c_shape[j]);
} }
} }
a_offset_[i] = a_offset; a_offset_[i] = a_offset;
@ -368,22 +368,22 @@ int ArithmeticCPUKernel::CalcArithmeticByBatch(int task_id) {
int batch_per_thread = UP_DIV(out_batch_, op_parameter_->thread_num_); int batch_per_thread = UP_DIV(out_batch_, op_parameter_->thread_num_);
int start_batch = batch_per_thread * task_id; int start_batch = batch_per_thread * task_id;
int end_batch = MSMIN(start_batch + batch_per_thread, out_batch_); int end_batch = MSMIN(start_batch + batch_per_thread, out_batch_);
int ret = RET_ERROR;
for (int i = start_batch; i < end_batch; i++) { for (int i = start_batch; i < end_batch; i++) {
batch_a_ptr_ = static_cast<uint8_t *>(input0_ptr_) + a_offset_[i] * a_stride_size_ * data_type_len_; int ret = RET_ERROR;
batch_b_ptr_ = static_cast<uint8_t *>(input1_ptr_) + b_offset_[i] * b_stride_size_ * data_type_len_; auto batch_a_ptr = static_cast<uint8_t *>(input0_ptr_) + a_offset_[i] * a_stride_size_ * data_type_len_;
batch_c_ptr_ = static_cast<uint8_t *>(output_ptr_) + i * c_stride_size_ * data_type_len_; auto batch_b_ptr = static_cast<uint8_t *>(input1_ptr_) + b_offset_[i] * b_stride_size_ * data_type_len_;
auto batch_c_ptr = static_cast<uint8_t *>(output_ptr_) + i * c_stride_size_ * data_type_len_;
if (batch_scalar_) { if (batch_scalar_) {
ret = DoExecute(batch_a_ptr_, batch_b_ptr_, batch_c_ptr_, c_stride_size_, true); ret = DoExecute(batch_a_ptr, batch_b_ptr, batch_c_ptr, c_stride_size_, true);
} else { } else {
ret = DoExecute(batch_a_ptr_, batch_b_ptr_, batch_c_ptr_, c_stride_size_, false); ret = DoExecute(batch_a_ptr, batch_b_ptr, batch_c_ptr, c_stride_size_, false);
} }
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << "failed to calculate."; MS_LOG(ERROR) << "failed to calculate.";
return RET_ERROR; return RET_ERROR;
} }
} }
return ret; return RET_OK;
} }
int ArithmeticCPUKernel::DoArithmetic(int task_id) { int ArithmeticCPUKernel::DoArithmetic(int task_id) {

View File

@ -140,10 +140,9 @@ class MS_API BenchmarkFlags : public virtual FlagParser {
AddFlag(&BenchmarkFlags::resize_dims_in_, "inputShapes", AddFlag(&BenchmarkFlags::resize_dims_in_, "inputShapes",
"Shape of input data, the format should be NHWC. e.g. 1,32,32,32:1,1,32,32,1", ""); "Shape of input data, the format should be NHWC. e.g. 1,32,32,32:1,1,32,32,1", "");
#ifdef SERVER_INFERENCE #ifdef SERVER_INFERENCE
AddFlag(&BenchmarkFlags::model_pool_, "modelPool", "use model pool", false); AddFlag(&BenchmarkFlags::enable_parallel_predict_, "enableParallelPredict", "Enable model parallel : true | false",
false);
AddFlag(&BenchmarkFlags::num_require_, "numRequire", "require num", 1); AddFlag(&BenchmarkFlags::num_require_, "numRequire", "require num", 1);
AddFlag(&BenchmarkFlags::num_model_, "numModel", "build model num", 1);
AddFlag(&BenchmarkFlags::num_split_, "numSplit", "split for batch", 1);
#endif #endif
#ifdef ENABLE_OPENGL_TEXTURE #ifdef ENABLE_OPENGL_TEXTURE
AddFlag(&BenchmarkFlags::enable_gl_texture_, "enableGLTexture", "Enable GlTexture2D", false); AddFlag(&BenchmarkFlags::enable_gl_texture_, "enableGLTexture", "Enable GlTexture2D", false);
@ -159,10 +158,8 @@ class MS_API BenchmarkFlags : public virtual FlagParser {
public: public:
// common // common
#ifdef SERVER_INFERENCE #ifdef SERVER_INFERENCE
bool model_pool_ = false; bool enable_parallel_predict_ = false;
int num_require_ = 1; int num_require_ = 1;
int num_model_ = 1;
int num_split_ = 1;
#endif #endif
std::string model_file_; std::string model_file_;
std::string in_data_file_; std::string in_data_file_;

View File

@ -42,7 +42,7 @@
#include "include/mpi_vb.h" #include "include/mpi_vb.h"
#endif #endif
#ifdef SERVER_INFERENCE #ifdef SERVER_INFERENCE
#include "src/cxx_api/model_pool/model_pool.h" #include <thread>
#endif #endif
namespace mindspore { namespace mindspore {
constexpr size_t kDataToStringMaxNum = 40; constexpr size_t kDataToStringMaxNum = 40;
@ -220,7 +220,7 @@ int BenchmarkUnifiedApi::LoadInput() {
int BenchmarkUnifiedApi::GenerateInputData() { int BenchmarkUnifiedApi::GenerateInputData() {
#ifdef SERVER_INFERENCE #ifdef SERVER_INFERENCE
if (flags_->model_pool_) { if (flags_->enable_parallel_predict_) {
std::vector<MSTensor> inputs; std::vector<MSTensor> inputs;
for (size_t i = 0; i < ms_inputs_for_api_.size(); i++) { for (size_t i = 0; i < ms_inputs_for_api_.size(); i++) {
auto tensor_name = ms_inputs_for_api_[i].Name(); auto tensor_name = ms_inputs_for_api_[i].Name();
@ -247,6 +247,7 @@ int BenchmarkUnifiedApi::GenerateInputData() {
auto new_tensor = auto new_tensor =
mindspore::MSTensor::CreateTensor(tensor_name, ms_inputs_for_api_[i].DataType(), shape, input_data, size); mindspore::MSTensor::CreateTensor(tensor_name, ms_inputs_for_api_[i].DataType(), shape, input_data, size);
inputs.push_back(*new_tensor); inputs.push_back(*new_tensor);
delete new_tensor;
} }
all_inputs_.push_back(inputs); all_inputs_.push_back(inputs);
return RET_OK; return RET_OK;
@ -296,7 +297,7 @@ void BenchmarkUnifiedApi::UpdateConfigInfo() {
int BenchmarkUnifiedApi::ReadInputFile() { int BenchmarkUnifiedApi::ReadInputFile() {
#ifdef SERVER_INFERENCE #ifdef SERVER_INFERENCE
if (flags_->model_pool_) { if (flags_->enable_parallel_predict_) {
std::vector<MSTensor> inputs; std::vector<MSTensor> inputs;
for (size_t i = 0; i < ms_inputs_for_api_.size(); i++) { for (size_t i = 0; i < ms_inputs_for_api_.size(); i++) {
size_t size; size_t size;
@ -324,6 +325,7 @@ int BenchmarkUnifiedApi::ReadInputFile() {
auto new_tensor = auto new_tensor =
mindspore::MSTensor::CreateTensor(tensor_name, ms_inputs_for_api_[i].DataType(), shape, input_data, size); mindspore::MSTensor::CreateTensor(tensor_name, ms_inputs_for_api_[i].DataType(), shape, input_data, size);
inputs.push_back(*new_tensor); inputs.push_back(*new_tensor);
delete new_tensor;
} }
all_inputs_.push_back(inputs); all_inputs_.push_back(inputs);
return RET_OK; return RET_OK;
@ -895,7 +897,7 @@ int BenchmarkUnifiedApi::PrintInputData() {
for (size_t i = 0; i < ms_inputs_for_api_.size(); i++) { for (size_t i = 0; i < ms_inputs_for_api_.size(); i++) {
mindspore::MSTensor input; mindspore::MSTensor input;
#ifdef SERVER_INFERENCE #ifdef SERVER_INFERENCE
if (flags_->model_pool_) { if (flags_->enable_parallel_predict_) {
input = all_inputs_[0][i]; input = all_inputs_[0][i];
} else { } else {
input = ms_inputs_for_api_[i]; input = ms_inputs_for_api_[i];
@ -946,9 +948,14 @@ int BenchmarkUnifiedApi::PrintInputData() {
} }
#ifdef SERVER_INFERENCE #ifdef SERVER_INFERENCE
int BenchmarkUnifiedApi::RunModelPool(std::shared_ptr<mindspore::Context> context) { int BenchmarkUnifiedApi::RunModelPool(std::shared_ptr<mindspore::Context> context) {
if (flags_->resize_dims_.empty()) {
MS_LOG(ERROR) << "use parallel predict, inputShapes can not use empty.";
return RET_ERROR;
}
// model pool init // model pool init
ModelParallelRunner model_pool; ModelParallelRunner model_pool;
auto runner_config = std::make_shared<RunnerConfig>(context, flags_->num_model_); auto runner_config = std::make_shared<RunnerConfig>();
runner_config->context = context;
auto model_init_start = GetTimeUs(); auto model_init_start = GetTimeUs();
auto ret = model_pool.Init(flags_->model_file_, runner_config); auto ret = model_pool.Init(flags_->model_file_, runner_config);
if (ret != kSuccess) { if (ret != kSuccess) {
@ -958,6 +965,10 @@ int BenchmarkUnifiedApi::RunModelPool(std::shared_ptr<mindspore::Context> contex
auto model_init_end = GetTimeUs(); auto model_init_end = GetTimeUs();
// load data // load data
ms_inputs_for_api_ = model_pool.GetInputs(); ms_inputs_for_api_ = model_pool.GetInputs();
if (ms_inputs_for_api_.empty()) {
MS_LOG(ERROR) << "model pool input is empty.";
return RET_ERROR;
}
for (int i = 0; i < flags_->num_require_ + flags_->warm_up_loop_count_; i++) { for (int i = 0; i < flags_->num_require_ + flags_->warm_up_loop_count_; i++) {
auto status = LoadInput(); auto status = LoadInput();
if (status != RET_OK) { if (status != RET_OK) {
@ -989,7 +1000,7 @@ int BenchmarkUnifiedApi::RunModelPool(std::shared_ptr<mindspore::Context> contex
MS_LOG(ERROR) << "model pool predict failed."; MS_LOG(ERROR) << "model pool predict failed.";
} }
auto predict_end = GetTimeUs(); auto predict_end = GetTimeUs();
MS_LOG(ERROR) << "run predict time: " << (predict_end - predict_start) / kFloatMSEC << " ms"; std::cout << "run predict time: " << (predict_end - predict_start) / kFloatMSEC << " ms\n";
if (!flags_->benchmark_data_file_.empty()) { if (!flags_->benchmark_data_file_.empty()) {
auto status = CompareOutputForModelPool(&output); auto status = CompareOutputForModelPool(&output);
if (status != RET_OK) { if (status != RET_OK) {
@ -1004,18 +1015,22 @@ int BenchmarkUnifiedApi::RunModelPool(std::shared_ptr<mindspore::Context> contex
for (auto &warm_up_thread : model_thread_warm_up) { for (auto &warm_up_thread : model_thread_warm_up) {
warm_up_thread.join(); warm_up_thread.join();
} }
MS_LOG(DEBUG) << "================ end warm up ================"; std::cout << "================ end warm up ================";
auto all_start = GetTimeUs(); auto all_start = GetTimeUs();
std::vector<std::thread> model_thread_run; for (int loop_count_num = 0; loop_count_num < flags_->loop_count_; loop_count_num++) {
for (int i = 0; i < flags_->num_require_; i++) { std::vector<std::thread> model_thread_run;
model_thread_run.push_back(std::thread(model_pool_run, i + flags_->warm_up_loop_count_)); for (int i = 0; i < flags_->num_require_; i++) {
} model_thread_run.push_back(std::thread(model_pool_run, i + flags_->warm_up_loop_count_));
for (auto &run_thread : model_thread_run) { }
run_thread.join(); for (auto &run_thread : model_thread_run) {
run_thread.join();
}
} }
auto all_end = GetTimeUs(); auto all_end = GetTimeUs();
std::cout << "=================================" << std::endl;
std::cout << "model pool init time: " << (model_init_end - model_init_start) / kFloatMSEC << " ms\n"; std::cout << "model pool init time: " << (model_init_end - model_init_start) / kFloatMSEC << " ms\n";
std::cout << "model pool all run time: " << (all_end - all_start) / kFloatMSEC << " ms\n"; std::cout << "model pool all run time: " << (all_end - all_start) / kFloatMSEC / flags_->loop_count_ << " ms\n";
std::cout << "=================================" << std::endl;
return RET_OK; return RET_OK;
} }
#endif #endif
@ -1064,7 +1079,7 @@ int BenchmarkUnifiedApi::RunBenchmark() {
UpdateConfigInfo(); UpdateConfigInfo();
#ifdef SERVER_INFERENCE #ifdef SERVER_INFERENCE
if (flags_->model_pool_) { if (flags_->enable_parallel_predict_) {
status = RunModelPool(context); status = RunModelPool(context);
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "run model pool failed."; MS_LOG(ERROR) << "run model pool failed.";

View File

@ -43,7 +43,7 @@
#include "tools/common/opengl_util.h" #include "tools/common/opengl_util.h"
#endif #endif
#ifdef SERVER_INFERENCE #ifdef SERVER_INFERENCE
#include "src/cxx_api/model_pool/model_parallel_runner.h" #include "include/api/model_parallel_runner.h"
#endif #endif
namespace mindspore::lite { namespace mindspore::lite {