forked from mindspore-Ecosystem/mindspore
copy from 1.6
This commit is contained in:
parent
c0d35aa950
commit
c301bf16a2
|
@ -0,0 +1,69 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_INCLUDE_API_MODEL_PARALLEL_RUNNER_H
|
||||
#define MINDSPORE_INCLUDE_API_MODEL_PARALLEL_RUNNER_H
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <string>
|
||||
#include "include/api/status.h"
|
||||
#include "include/api/context.h"
|
||||
namespace mindspore {
|
||||
struct RunnerConfig {
|
||||
std::shared_ptr<Context> context = nullptr;
|
||||
};
|
||||
|
||||
/// \brief The ModelParallelRunner class is used to define a MindSpore ModelParallelRunner, facilitating Model
|
||||
/// management.
|
||||
class MS_API ModelParallelRunner {
|
||||
public:
|
||||
ModelParallelRunner() = default;
|
||||
~ModelParallelRunner() = default;
|
||||
|
||||
/// \brief build a model parallel runner from model path so that it can run on a device. Only valid for Lite.
|
||||
///
|
||||
/// \param[in] model_path Define the model path.
|
||||
/// \param[in] runner_config Define the config used to store options during model pool init.
|
||||
/// \param[in] dec_key Define the key used to decrypt the ciphertext model. The key length is 16, 24, or 32.
|
||||
/// \param[in] dec_mode Define the decryption mode. Options: AES-GCM, AES-CBC.
|
||||
///
|
||||
/// \return Status.
|
||||
Status Init(const std::string &model_path, const std::shared_ptr<RunnerConfig> &runner_config = nullptr,
|
||||
const Key &dec_key = {}, const std::string &dec_mode = kDecModeAesGcm);
|
||||
|
||||
/// \brief Obtains all input tensors information of the model.
|
||||
///
|
||||
/// \return The vector that includes all input tensors.
|
||||
std::vector<MSTensor> GetInputs();
|
||||
|
||||
/// \brief Obtains all output tensors information of the model.
|
||||
///
|
||||
/// \return The vector that includes all output tensors.
|
||||
std::vector<MSTensor> GetOutputs();
|
||||
|
||||
/// \brief Inference ModelParallelRunner.
|
||||
///
|
||||
/// \param[in] inputs A vector where model inputs are arranged in sequence.
|
||||
/// \param[out] outputs Which is a pointer to a vector. The model outputs are filled in the container in sequence.
|
||||
/// \param[in] before CallBack before predict.
|
||||
/// \param[in] after CallBack after predict.
|
||||
///
|
||||
/// \return Status.
|
||||
Status Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
|
||||
const MSKernelCallBack &before = nullptr, const MSKernelCallBack &after = nullptr);
|
||||
};
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_INCLUDE_API_MODEL_PARALLEL_RUNNER_H
|
|
@ -69,7 +69,7 @@ if(MSLITE_ENABLE_SERVER_INFERENCE)
|
|||
set(CXX_API_SRCS
|
||||
${CXX_API_SRCS}
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/model_pool/predict_task_queue.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/model_pool/model_thread.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/model_pool/model_worker.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/model_pool/model_pool.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cxx_api/model_pool/model_parallel_runner.cc
|
||||
)
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "src/cxx_api/model_pool/model_parallel_runner.h"
|
||||
#include "include/api/model_parallel_runner.h"
|
||||
#include "src/cxx_api/model_pool/model_pool.h"
|
||||
#include "src/common/log.h"
|
||||
|
||||
|
|
|
@ -19,6 +19,9 @@
|
|||
#include "src/common/log.h"
|
||||
#include "include/lite_types.h"
|
||||
#include "src/common/config_file.h"
|
||||
#include "src/runtime/inner_allocator.h"
|
||||
#include "src/common//file_utils.h"
|
||||
#include "src/pack_weight_manager.h"
|
||||
namespace mindspore {
|
||||
namespace {
|
||||
constexpr int32_t kNumThreads = 4;
|
||||
|
@ -36,12 +39,11 @@ int GetCoreNum() {
|
|||
} // namespace
|
||||
|
||||
void ModelPool::SetBindStrategy(std::vector<std::vector<int>> *all_model_bind_list, int thread_num) {
|
||||
int core_num = GetCoreNum();
|
||||
if (thread_num == 0) {
|
||||
MS_LOG(ERROR) << "thread num is zero.";
|
||||
return;
|
||||
}
|
||||
num_models_ = core_num / thread_num;
|
||||
int core_num = GetCoreNum();
|
||||
int core_id = 0;
|
||||
for (size_t i = 0; i < num_models_; i++) {
|
||||
std::vector<int> bind_id;
|
||||
|
@ -68,16 +70,15 @@ std::shared_ptr<Context> ModelPool::InitContext(const std::shared_ptr<RunnerConf
|
|||
return nullptr;
|
||||
}
|
||||
if (runner_config != nullptr) {
|
||||
model_context = runner_config->model_ctx;
|
||||
num_models_ = runner_config->num_model;
|
||||
model_context = runner_config->context;
|
||||
auto device_list = model_context->MutableDeviceInfo();
|
||||
if (device_list.size() != 1) {
|
||||
MS_LOG(ERROR) << "model pool only support device num 1.";
|
||||
return nullptr;
|
||||
}
|
||||
auto device = device_list.front();
|
||||
if (device->GetDeviceType() != kCPU) {
|
||||
MS_LOG(ERROR) << "model pool only support cpu type.";
|
||||
if (device->GetDeviceType() != kCPU && device->GetDeviceType() != kGPU) {
|
||||
MS_LOG(ERROR) << "model pool only support cpu or gpu type.";
|
||||
return nullptr;
|
||||
}
|
||||
auto cpu_context = device->Cast<CPUDeviceInfo>();
|
||||
|
@ -86,13 +87,19 @@ std::shared_ptr<Context> ModelPool::InitContext(const std::shared_ptr<RunnerConf
|
|||
MS_LOG(ERROR) << "model pool not support enable fp16.";
|
||||
return nullptr;
|
||||
}
|
||||
if (device->GetDeviceType() == kGPU) {
|
||||
num_models_ = 1;
|
||||
} else {
|
||||
num_models_ = GetCoreNum() / static_cast<int>(model_context->GetThreadNum());
|
||||
}
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "use default config.";
|
||||
num_models_ = GetCoreNum() / static_cast<int>(model_context->GetThreadNum());
|
||||
model_context->SetThreadNum(kNumThreads);
|
||||
model_context->SetEnableParallel(false);
|
||||
model_context->SetThreadAffinity(lite::NO_BIND);
|
||||
model_context->SetEnableParallel(true);
|
||||
model_context->SetThreadAffinity(lite::HIGHER_CPU);
|
||||
auto &device_list = model_context->MutableDeviceInfo();
|
||||
auto device_info = std::shared_ptr<CPUDeviceInfo>();
|
||||
auto device_info = std::make_shared<CPUDeviceInfo>();
|
||||
device_info->SetEnableFP16(false);
|
||||
device_list.push_back(device_info);
|
||||
}
|
||||
|
@ -109,7 +116,6 @@ ModelPoolContex ModelPool::CreateModelContext(const std::shared_ptr<RunnerConfig
|
|||
MS_LOG(ERROR) << "thread num is zero.";
|
||||
return {};
|
||||
}
|
||||
num_models_ = GetCoreNum() / static_cast<int>(model_context->GetThreadNum());
|
||||
ModelPoolContex model_pool_context;
|
||||
std::vector<std::vector<int>> all_model_bind_list;
|
||||
if (model_context->GetThreadAffinityMode() == lite::HIGHER_CPU) {
|
||||
|
@ -165,10 +171,21 @@ Status ModelPool::Init(const std::string &model_path, const std::shared_ptr<Runn
|
|||
MS_LOG(ERROR) << "CreateModelContext failed, context is empty.";
|
||||
return kLiteError;
|
||||
}
|
||||
size_t size = 0;
|
||||
graph_buf_ = lite::ReadFile(model_path.c_str(), &size);
|
||||
if (graph_buf_ == nullptr) {
|
||||
MS_LOG(ERROR) << "read file failed.";
|
||||
return kLiteError;
|
||||
}
|
||||
lite::PackWeightManager::GetInstance()->InitWeightManagerByBuf(graph_buf_);
|
||||
std::shared_ptr<ModelThread> model_thread = nullptr;
|
||||
for (size_t i = 0; i < num_models_; i++) {
|
||||
model_thread = std::make_shared<ModelThread>();
|
||||
auto status = model_thread->Init(model_path, model_pool_context[i], dec_key, dec_mode);
|
||||
auto status = model_thread->Init(graph_buf_, size, model_pool_context[i], dec_key, dec_mode);
|
||||
if (status != kSuccess) {
|
||||
MS_LOG(ERROR) << " model thread init failed.";
|
||||
return kLiteError;
|
||||
}
|
||||
model_thread_vec_.push_back(std::thread(&ModelThread::Run, model_thread));
|
||||
}
|
||||
if (model_thread != nullptr) {
|
||||
|
@ -178,44 +195,65 @@ Status ModelPool::Init(const std::string &model_path, const std::shared_ptr<Runn
|
|||
return kSuccess;
|
||||
}
|
||||
|
||||
Status ModelPool::SplitTensorByBatch(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
|
||||
std::vector<std::vector<MSTensor>> *new_inputs) {
|
||||
auto batch = inputs[0].Shape()[0];
|
||||
if (batch % batch_split_num_ != 0) {
|
||||
MS_LOG(DEBUG) << "Can not split input tensor.";
|
||||
return kLiteSuccessExit;
|
||||
Status ModelPool::SplitInputTensorByBatch(const std::vector<MSTensor> &inputs,
|
||||
std::vector<std::vector<MSTensor>> *new_inputs, size_t batch_split_num) {
|
||||
if (batch_split_num == 0) {
|
||||
MS_LOG(ERROR) << "batch_split_num is zero.";
|
||||
return kLiteError;
|
||||
}
|
||||
auto batch = inputs[0].Shape()[0];
|
||||
std::vector<size_t> split_batch;
|
||||
size_t batch_sum = 0;
|
||||
size_t per_batch = batch / batch_split_num;
|
||||
for (size_t i = 0; i < batch_split_num - 1; i++) {
|
||||
split_batch.push_back(per_batch);
|
||||
batch_sum += per_batch;
|
||||
}
|
||||
split_batch.push_back(batch - batch_sum);
|
||||
std::vector<std::vector<std::vector<int64_t>>> all_input_shape;
|
||||
for (size_t k = 0; k < batch_split_num_; k++) { // do for batch
|
||||
std::vector<size_t> input_data_split_size(inputs.size(), 0);
|
||||
for (size_t k = 0; k < batch_split_num; k++) { // do for batch
|
||||
std::vector<std::vector<int64_t>> inputs_shape;
|
||||
std::vector<MSTensor> new_inputs_tensor;
|
||||
for (size_t i = 0; i < inputs.size(); i++) { // do for input
|
||||
std::vector<int64_t> shape;
|
||||
size_t input_size = batch / batch_split_num_;
|
||||
shape.push_back(batch / batch_split_num_);
|
||||
size_t input_size = split_batch[k];
|
||||
shape.push_back(split_batch[k]);
|
||||
for (size_t j = 1; j < inputs[i].Shape().size(); j++) { // do for dims
|
||||
shape.push_back(inputs[i].Shape()[j]);
|
||||
input_size *= inputs[i].Shape()[j];
|
||||
}
|
||||
inputs_shape.push_back(shape);
|
||||
if (inputs[i].DataType() == static_cast<enum DataType>(kNumberTypeFloat32)) {
|
||||
void *data = malloc(input_size * sizeof(float));
|
||||
memcpy(reinterpret_cast<float *>(data),
|
||||
reinterpret_cast<float *>(const_cast<MSTensor &>(inputs[i]).MutableData()) + input_size * k,
|
||||
input_size * sizeof(float));
|
||||
auto new_tensor = mindspore::MSTensor::CreateTensor(
|
||||
inputs[i].Name(), static_cast<enum DataType>(kNumberTypeFloat32), shape, data, input_size * sizeof(float));
|
||||
new_inputs_tensor.push_back(*new_tensor);
|
||||
free(data);
|
||||
if (input_size * sizeof(float) > MAX_MALLOC_SIZE) {
|
||||
MS_LOG(ERROR) << "malloc size is wrong.";
|
||||
return kLiteError;
|
||||
}
|
||||
auto data =
|
||||
reinterpret_cast<float *>(const_cast<MSTensor &>(inputs[i]).MutableData()) + input_data_split_size[i];
|
||||
auto new_tensor = MSTensor(inputs[i].Name(), static_cast<enum DataType>(kNumberTypeFloat32), shape, data,
|
||||
input_size * sizeof(float));
|
||||
if (new_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "create tensor failed.";
|
||||
return kLiteError;
|
||||
}
|
||||
new_inputs_tensor.push_back(new_tensor);
|
||||
input_data_split_size[i] += input_size;
|
||||
} else if (inputs[i].DataType() == static_cast<enum DataType>(kNumberTypeInt32)) {
|
||||
void *data = malloc(input_size * sizeof(int32_t));
|
||||
memcpy(reinterpret_cast<int32_t *>(data),
|
||||
reinterpret_cast<int32_t *>(const_cast<MSTensor &>(inputs[i]).MutableData()) + input_size * k,
|
||||
input_size * sizeof(int32_t));
|
||||
auto new_tensor = mindspore::MSTensor::CreateTensor(
|
||||
inputs[i].Name(), static_cast<enum DataType>(kNumberTypeInt32), shape, data, input_size * sizeof(int32_t));
|
||||
new_inputs_tensor.push_back(*new_tensor);
|
||||
free(data);
|
||||
if (input_size * sizeof(int32_t) > MAX_MALLOC_SIZE) {
|
||||
MS_LOG(ERROR) << "malloc size is wrong.";
|
||||
return kLiteError;
|
||||
}
|
||||
auto data =
|
||||
reinterpret_cast<int32_t *>(const_cast<MSTensor &>(inputs[i]).MutableData()) + input_data_split_size[i];
|
||||
auto new_tensor = MSTensor(inputs[i].Name(), static_cast<enum DataType>(kNumberTypeInt32), shape, data,
|
||||
input_size * sizeof(int32_t));
|
||||
if (new_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "create tensor failed.";
|
||||
return kLiteError;
|
||||
}
|
||||
new_inputs_tensor.push_back(new_tensor);
|
||||
input_data_split_size[i] += input_size;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "not support data type in split batch.";
|
||||
return kLiteError;
|
||||
|
@ -227,71 +265,166 @@ Status ModelPool::SplitTensorByBatch(const std::vector<MSTensor> &inputs, std::v
|
|||
return kSuccess;
|
||||
}
|
||||
|
||||
Status ModelPool::SplitOutputTensorByBatch(std::vector<std::vector<MSTensor>> *new_outputs,
|
||||
std::vector<MSTensor> *outputs, size_t batch_split_num) {
|
||||
if (batch_split_num == 0) {
|
||||
MS_LOG(ERROR) << "batch_split_num is zero.";
|
||||
return kLiteError;
|
||||
}
|
||||
for (size_t i = 0; i < batch_split_num; i++) {
|
||||
std::vector<MSTensor> new_output;
|
||||
for (size_t tensor_num_idx = 0; tensor_num_idx < outputs->size(); tensor_num_idx++) {
|
||||
if (outputs->at(tensor_num_idx).MutableData() != nullptr && outputs->at(tensor_num_idx).DataSize() != 0) {
|
||||
is_user_data_ = true;
|
||||
auto data = reinterpret_cast<float *>(outputs->at(tensor_num_idx).MutableData()) +
|
||||
outputs->at(tensor_num_idx).Shape().at(0) / batch_split_num * i;
|
||||
auto out_tensor =
|
||||
MSTensor(outputs->at(tensor_num_idx).Name(), outputs->at(tensor_num_idx).DataType(), {}, data, 0);
|
||||
new_output.push_back(out_tensor);
|
||||
}
|
||||
}
|
||||
new_outputs->push_back(new_output);
|
||||
}
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
Status ModelPool::ConcatPredictOutput(std::vector<std::vector<MSTensor>> *outputs, std::vector<MSTensor> *new_outputs) {
|
||||
if (outputs->empty()) {
|
||||
MS_LOG(ERROR) << "output is empty";
|
||||
return kLiteError;
|
||||
}
|
||||
for (size_t i = 0; i < outputs->at(0).size(); i++) {
|
||||
std::vector<int64_t> output_tensor_shape = outputs->at(0)[i].Shape();
|
||||
output_tensor_shape[0] *= batch_split_num_;
|
||||
if (output_tensor_shape.empty()) {
|
||||
MS_LOG(ERROR) << "output_tensor_shape is empty";
|
||||
return kLiteError;
|
||||
}
|
||||
size_t all_data_size = 0;
|
||||
size_t all_batch_size = 0;
|
||||
std::vector<size_t> per_bacth_data_size;
|
||||
for (size_t batch = 0; batch < outputs->size(); batch++) {
|
||||
per_bacth_data_size.push_back(all_data_size);
|
||||
all_data_size += outputs->at(batch).at(i).DataSize();
|
||||
all_batch_size += outputs->at(batch).at(i).Shape().front();
|
||||
}
|
||||
output_tensor_shape[0] = all_batch_size;
|
||||
if (is_user_data_) {
|
||||
new_outputs->at(i).SetShape(output_tensor_shape);
|
||||
continue;
|
||||
}
|
||||
auto all_out_data = malloc(all_data_size);
|
||||
if (all_out_data == nullptr) {
|
||||
MS_LOG(ERROR) << "all_out_data is nullptr.";
|
||||
return kLiteError;
|
||||
}
|
||||
for (size_t j = 0; j < outputs->size(); j++) {
|
||||
void *out_data = outputs->at(j).at(i).MutableData();
|
||||
if (out_data == nullptr) {
|
||||
free(all_out_data);
|
||||
all_out_data = nullptr;
|
||||
MS_LOG(ERROR) << "output data is nullptr.";
|
||||
return kLiteError;
|
||||
}
|
||||
memcpy(reinterpret_cast<float *>(all_out_data) + per_bacth_data_size[j] / sizeof(float),
|
||||
reinterpret_cast<float *>(out_data), outputs->at(j)[i].DataSize());
|
||||
}
|
||||
auto new_tensor = mindspore::MSTensor::CreateTensor(outputs->at(0)[i].Name(), outputs->at(0)[i].DataType(),
|
||||
output_tensor_shape, all_out_data, all_data_size);
|
||||
if (new_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "create tensor failed.";
|
||||
return kLiteError;
|
||||
}
|
||||
if (all_out_data != nullptr) {
|
||||
free(all_out_data);
|
||||
all_out_data = nullptr;
|
||||
}
|
||||
all_out_data = malloc(outputs->at(0).at(i).DataSize() * batch_split_num_);
|
||||
for (size_t j = 0; j < batch_split_num_; j++) {
|
||||
void *out_data = outputs->at(j)[i].MutableData();
|
||||
memcpy(reinterpret_cast<float *>(all_out_data) + outputs->at(j)[i].ElementNum() * j,
|
||||
reinterpret_cast<float *>(out_data), outputs->at(j)[i].DataSize());
|
||||
}
|
||||
auto new_tensor =
|
||||
mindspore::MSTensor::CreateTensor(outputs->at(0)[i].Name(), outputs->at(i)[0].DataType(), output_tensor_shape,
|
||||
all_out_data, outputs->at(0)[i].DataSize() * batch_split_num_);
|
||||
new_outputs->push_back(*new_tensor);
|
||||
delete new_tensor;
|
||||
}
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
Status ModelPool::FreeSplitTensor(std::vector<std::vector<MSTensor>> *new_inputs,
|
||||
std::vector<std::vector<MSTensor>> *new_outputs) {
|
||||
for (size_t i = 0; i < new_inputs->size(); i++) {
|
||||
for (size_t j = 0; j < new_inputs->at(i).size(); j++) {
|
||||
new_inputs->at(i).at(j).SetData(nullptr);
|
||||
}
|
||||
}
|
||||
new_inputs->clear();
|
||||
if (is_user_data_) {
|
||||
for (size_t i = 0; i < new_outputs->size(); i++) {
|
||||
for (size_t j = 0; j < new_outputs->at(i).size(); j++) {
|
||||
new_outputs->at(i).at(j).SetData(nullptr);
|
||||
}
|
||||
}
|
||||
new_outputs->clear();
|
||||
}
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
Status ModelPool::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
|
||||
const MSKernelCallBack &before, const MSKernelCallBack &after) {
|
||||
outputs->clear();
|
||||
if (PredictTaskQueue::GetInstance()->GetTaskNum() == 0 &&
|
||||
batch_split_num_ <= static_cast<size_t>(PredictTaskQueue::GetInstance()->GetWaitModelNum())) {
|
||||
mtx_split_task_.lock();
|
||||
auto wait_model_num = PredictTaskQueue::GetInstance()->GetWaitModelNum();
|
||||
auto batch = inputs[0].Shape()[0];
|
||||
if (PredictTaskQueue::GetInstance()->GetTaskNum() == 0 && wait_model_num > 1 && batch >= wait_model_num) {
|
||||
size_t batch_split_num = PredictTaskQueue::GetInstance()->GetWaitModelNum();
|
||||
PredictTaskQueue::GetInstance()->DecreaseWaitModelNum(batch_split_num);
|
||||
std::vector<std::vector<MSTensor>> new_inputs;
|
||||
std::vector<std::vector<MSTensor>> new_outputs;
|
||||
auto status = SplitTensorByBatch(inputs, outputs, &new_inputs);
|
||||
auto status = SplitInputTensorByBatch(inputs, &new_inputs, batch_split_num);
|
||||
if (status != kSuccess) {
|
||||
MS_LOG(ERROR) << "model pool predict failed.";
|
||||
MS_LOG(ERROR) << "model pool split input tensor by batch failed.";
|
||||
return kLiteError;
|
||||
}
|
||||
for (size_t i = 0; i < batch_split_num_; i++) {
|
||||
std::vector<MSTensor> new_output;
|
||||
new_outputs.push_back(new_output);
|
||||
status = SplitOutputTensorByBatch(&new_outputs, outputs, batch_split_num);
|
||||
if (status != kSuccess) {
|
||||
MS_LOG(ERROR) << "model pool split output tensor by batch failed.";
|
||||
return kLiteError;
|
||||
}
|
||||
for (size_t i = 0; i < batch_split_num_; i++) {
|
||||
auto predict_task = std::make_shared<PredictTask>(&new_inputs[i], &new_outputs[i], before, after);
|
||||
|
||||
std::vector<std::shared_ptr<PredictTask>> tasks;
|
||||
for (size_t i = 0; i < batch_split_num; i++) {
|
||||
auto predict_task = std::make_shared<PredictTask>(&new_inputs[i], &new_outputs.at(i), before, after);
|
||||
PredictTaskQueue::GetInstance()->PushPredictTask(predict_task);
|
||||
tasks.push_back(predict_task);
|
||||
}
|
||||
for (size_t i = 0; i < batch_split_num_; i++) {
|
||||
PredictTaskQueue::GetInstance()->WaitUntilPredictActive(&new_outputs[i]);
|
||||
mtx_split_task_.unlock();
|
||||
for (size_t i = 0; i < batch_split_num; i++) {
|
||||
PredictTaskQueue::GetInstance()->WaitUntilPredictActive(tasks[i]);
|
||||
}
|
||||
status = ConcatPredictOutput(&new_outputs, outputs);
|
||||
if (status != kSuccess) {
|
||||
MS_LOG(ERROR) << "ConcatPredictOutput failed.";
|
||||
return kLiteError;
|
||||
}
|
||||
status = FreeSplitTensor(&new_inputs, &new_outputs);
|
||||
if (status != kSuccess) {
|
||||
MS_LOG(ERROR) << "free split tensor failed.";
|
||||
return kLiteError;
|
||||
}
|
||||
} else {
|
||||
if (wait_model_num == 1) {
|
||||
PredictTaskQueue::GetInstance()->DecreaseWaitModelNum(1);
|
||||
}
|
||||
auto predict_task = std::make_shared<PredictTask>(&inputs, outputs, before, after);
|
||||
PredictTaskQueue::GetInstance()->PushPredictTask(predict_task);
|
||||
PredictTaskQueue::GetInstance()->WaitUntilPredictActive(outputs);
|
||||
mtx_split_task_.unlock();
|
||||
PredictTaskQueue::GetInstance()->WaitUntilPredictActive(predict_task);
|
||||
}
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
ModelPool::~ModelPool() {
|
||||
if (graph_buf_ != nullptr) {
|
||||
delete[] graph_buf_;
|
||||
graph_buf_ = nullptr;
|
||||
}
|
||||
for (auto &th : model_thread_vec_) {
|
||||
if (th.joinable()) {
|
||||
th.join();
|
||||
}
|
||||
}
|
||||
free(all_out_data);
|
||||
all_out_data = nullptr;
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -23,15 +23,10 @@
|
|||
#include <map>
|
||||
#include "include/api/status.h"
|
||||
#include "include/api/context.h"
|
||||
#include "src/cxx_api/model_pool/model_thread.h"
|
||||
#include "include/api/model_parallel_runner.h"
|
||||
#include "src/cxx_api/model_pool/model_worker.h"
|
||||
#include "src/cxx_api/model_pool/predict_task_queue.h"
|
||||
namespace mindspore {
|
||||
struct RunnerConfig {
|
||||
RunnerConfig(std::shared_ptr<Context> &ctx, int num) : model_ctx(ctx), num_model(num) {}
|
||||
std::shared_ptr<Context> model_ctx = nullptr;
|
||||
int num_model = 10;
|
||||
};
|
||||
|
||||
class ModelPool {
|
||||
public:
|
||||
static ModelPool *GetInstance();
|
||||
|
@ -52,16 +47,21 @@ class ModelPool {
|
|||
void SetBindStrategy(std::vector<std::vector<int>> *all_model_bind_list, int thread_num);
|
||||
ModelPoolContex CreateModelContext(const std::shared_ptr<RunnerConfig> &runner_config);
|
||||
std::shared_ptr<Context> InitContext(const std::shared_ptr<RunnerConfig> &runner_config);
|
||||
Status SplitTensorByBatch(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
|
||||
std::vector<std::vector<MSTensor>> *new_inputs);
|
||||
Status SplitInputTensorByBatch(const std::vector<MSTensor> &inputs, std::vector<std::vector<MSTensor>> *new_inputs,
|
||||
size_t batch_split_num);
|
||||
Status SplitOutputTensorByBatch(std::vector<std::vector<MSTensor>> *outputs, std::vector<MSTensor> *new_outputs,
|
||||
size_t batch_split_num);
|
||||
Status ConcatPredictOutput(std::vector<std::vector<MSTensor>> *outputs, std::vector<MSTensor> *new_outputs);
|
||||
Status FreeSplitTensor(std::vector<std::vector<MSTensor>> *new_inputs,
|
||||
std::vector<std::vector<MSTensor>> *new_outputs);
|
||||
|
||||
void *all_out_data = nullptr;
|
||||
std::vector<std::thread> model_thread_vec_;
|
||||
std::vector<MSTensor> model_inputs_;
|
||||
std::vector<MSTensor> model_outputs_;
|
||||
char *graph_buf_ = nullptr;
|
||||
size_t num_models_ = 10;
|
||||
size_t batch_split_num_ = 4;
|
||||
std::mutex mtx_split_task_;
|
||||
bool is_user_data_ = false;
|
||||
};
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_INCLUDE_API_MODEL_POOL_MODEL_POOL_H
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "src/cxx_api/model_pool/model_thread.h"
|
||||
#include "src/cxx_api/model_pool/model_worker.h"
|
||||
#include "src/common/log.h"
|
||||
#include "src/common/utils.h"
|
||||
namespace mindspore {
|
||||
|
@ -30,25 +30,39 @@ void ModelThread::Run() {
|
|||
auto status = Predict(*inputs, outputs, before, after);
|
||||
if (status != kSuccess) {
|
||||
MS_LOG(ERROR) << "model predict failed.";
|
||||
return;
|
||||
task->ready = true;
|
||||
PredictTaskQueue::GetInstance()->ActiveTask();
|
||||
continue;
|
||||
}
|
||||
auto output_size = outputs->size();
|
||||
for (size_t i = 0; i < output_size; i++) {
|
||||
auto copy_tensor =
|
||||
mindspore::MSTensor::CreateTensor(outputs->at(i).Name(), outputs->at(i).DataType(), outputs->at(i).Shape(),
|
||||
outputs->at(i).MutableData(), outputs->at(i).DataSize());
|
||||
outputs->erase(outputs->begin());
|
||||
outputs->push_back(*copy_tensor);
|
||||
if (is_copy_output_) {
|
||||
std::vector<MSTensor> new_outputs;
|
||||
auto output_size = outputs->size();
|
||||
for (size_t i = 0; i < output_size; i++) {
|
||||
auto copy_tensor =
|
||||
mindspore::MSTensor::CreateTensor(outputs->at(i).Name(), outputs->at(i).DataType(), outputs->at(i).Shape(),
|
||||
outputs->at(i).MutableData(), outputs->at(i).DataSize());
|
||||
if (copy_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "model thread copy output tensor failed.";
|
||||
task->ready = true;
|
||||
PredictTaskQueue::GetInstance()->ActiveTask();
|
||||
continue;
|
||||
}
|
||||
new_outputs.push_back(*copy_tensor);
|
||||
delete copy_tensor;
|
||||
}
|
||||
outputs->clear();
|
||||
outputs->insert(outputs->end(), new_outputs.begin(), new_outputs.end());
|
||||
}
|
||||
task->ready = true;
|
||||
PredictTaskQueue::GetInstance()->ActiveTask();
|
||||
}
|
||||
}
|
||||
|
||||
Status ModelThread::Init(const std::string &model_path, const std::shared_ptr<Context> &model_context,
|
||||
Status ModelThread::Init(const char *model_buf, size_t size, const std::shared_ptr<Context> &model_context,
|
||||
const Key &dec_key, const std::string &dec_mode) {
|
||||
model_ = std::make_shared<Model>();
|
||||
mindspore::ModelType model_type = kMindIR;
|
||||
auto status = model_->Build(model_path, model_type, model_context, dec_key, dec_mode);
|
||||
auto status = model_->Build(model_buf, size, model_type, model_context, dec_key, dec_mode);
|
||||
if (status != kSuccess) {
|
||||
MS_LOG(ERROR) << "model build failed in ModelPool Init";
|
||||
return status;
|
||||
|
@ -107,11 +121,31 @@ Status ModelThread::Predict(const std::vector<MSTensor> &inputs, std::vector<MST
|
|||
return kLiteError;
|
||||
}
|
||||
}
|
||||
auto status = model_->Predict(inputs, outputs, before, after);
|
||||
auto model_output = model_->GetOutputs();
|
||||
for (size_t i = 0; i < outputs->size(); i++) {
|
||||
if (outputs->at(i).MutableData() != nullptr) {
|
||||
/* user set graph-output-tensor from outside */
|
||||
model_output[i].SetData(outputs->at(i).MutableData());
|
||||
model_output[i].SetAllocator(nullptr);
|
||||
is_copy_output_ = false;
|
||||
}
|
||||
}
|
||||
auto status = model_->Predict(inputs, &model_output, before, after);
|
||||
if (status != kSuccess) {
|
||||
MS_LOG(ERROR) << "model predict failed.";
|
||||
return status;
|
||||
}
|
||||
if (is_copy_output_) {
|
||||
outputs->clear();
|
||||
outputs->insert(outputs->end(), model_output.begin(), model_output.end());
|
||||
} else {
|
||||
model_output = model_->GetOutputs();
|
||||
for (size_t i = 0; i < outputs->size(); i++) {
|
||||
outputs->at(i).SetShape(model_output[i].Shape());
|
||||
model_output[i].SetData(nullptr);
|
||||
model_output[i].SetAllocator(nullptr);
|
||||
}
|
||||
}
|
||||
return kSuccess;
|
||||
}
|
||||
} // namespace mindspore
|
|
@ -35,8 +35,8 @@ class ModelThread {
|
|||
~ModelThread() = default;
|
||||
|
||||
// the model pool is initialized once and can always accept model run requests
|
||||
Status Init(const std::string &model_path, const std::shared_ptr<Context> &model_context, const Key &dec_key = {},
|
||||
const std::string &dec_mode = kDecModeAesGcm);
|
||||
Status Init(const char *model_buf, size_t size, const std::shared_ptr<Context> &model_context,
|
||||
const Key &dec_key = {}, const std::string &dec_mode = kDecModeAesGcm);
|
||||
|
||||
std::vector<MSTensor> GetInputs();
|
||||
|
||||
|
@ -58,6 +58,7 @@ class ModelThread {
|
|||
|
||||
// num thread is configured according to the hardware
|
||||
int num_models_;
|
||||
bool is_copy_output_ = true;
|
||||
};
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_SRC_CXX_API_MODEL_POOL_MODEL_THREAD_H_
|
|
@ -21,9 +21,9 @@ PredictTaskQueue::~PredictTaskQueue() {
|
|||
task_push_cond_.notify_all();
|
||||
}
|
||||
|
||||
void PredictTaskQueue::WaitUntilPredictActive(std::vector<MSTensor> *outputs) {
|
||||
void PredictTaskQueue::WaitUntilPredictActive(std::shared_ptr<PredictTask> task) {
|
||||
std::unique_lock<std::mutex> result_lock(mtx_predict_task_);
|
||||
while (outputs->empty()) {
|
||||
while (!task->ready) {
|
||||
task_pop_cond_.wait(result_lock);
|
||||
}
|
||||
return;
|
||||
|
@ -48,7 +48,6 @@ std::shared_ptr<PredictTask> PredictTaskQueue::GetPredictTask() {
|
|||
waite_model_num_++;
|
||||
task_push_cond_.wait(task_lock);
|
||||
}
|
||||
waite_model_num_--;
|
||||
if (predict_task_done_) {
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
@ -26,12 +26,13 @@
|
|||
namespace mindspore {
|
||||
struct PredictTask {
|
||||
PredictTask(const std::vector<MSTensor> *in, std::vector<MSTensor> *out, MSKernelCallBack before,
|
||||
MSKernelCallBack after)
|
||||
: inputs(in), outputs(out), before(before), after(after) {}
|
||||
MSKernelCallBack after, bool ready = false)
|
||||
: inputs(in), outputs(out), before(before), after(after), ready(ready) {}
|
||||
const std::vector<MSTensor> *inputs;
|
||||
std::vector<MSTensor> *outputs;
|
||||
MSKernelCallBack before;
|
||||
MSKernelCallBack after;
|
||||
bool ready;
|
||||
};
|
||||
|
||||
class PredictTaskQueue {
|
||||
|
@ -40,12 +41,13 @@ class PredictTaskQueue {
|
|||
~PredictTaskQueue();
|
||||
|
||||
void PushPredictTask(std::shared_ptr<PredictTask> task);
|
||||
void WaitUntilPredictActive(std::vector<MSTensor> *outputs);
|
||||
void WaitUntilPredictActive(std::shared_ptr<PredictTask> task);
|
||||
std::shared_ptr<PredictTask> GetPredictTask();
|
||||
void ActiveTask();
|
||||
bool IsPredictTaskDone() { return predict_task_done_; }
|
||||
int GetTaskNum();
|
||||
int GetWaitModelNum() { return waite_model_num_; }
|
||||
void DecreaseWaitModelNum(int num) { waite_model_num_ -= num; }
|
||||
|
||||
private:
|
||||
PredictTaskQueue() = default;
|
||||
|
|
|
@ -26,6 +26,18 @@ PackWeightManager *PackWeightManager::GetInstance() {
|
|||
return &instance;
|
||||
}
|
||||
|
||||
void PackWeightManager::InitWeightManagerByBuf(const char *model_buf) {
|
||||
MS_CHECK_TRUE_RET_VOID(model_buf != nullptr);
|
||||
if (buf_model_weight_.find(model_buf) == buf_model_weight_.end()) {
|
||||
auto *model_const_weight = new (std::nothrow) ModelConstWeight();
|
||||
if (model_const_weight == nullptr) {
|
||||
MS_LOG(ERROR) << "model_const_weight is nullptr.";
|
||||
return;
|
||||
}
|
||||
buf_model_weight_[model_buf] = model_const_weight;
|
||||
}
|
||||
}
|
||||
|
||||
void PackWeightManager::InitWeightManagerByPath(const std::string &model_path, const char *model_buf) {
|
||||
MS_CHECK_TRUE_RET_VOID(model_buf != nullptr);
|
||||
if (path_model_buf_.find(model_path) == path_model_buf_.end()) {
|
||||
|
@ -49,7 +61,13 @@ STATUS PackWeightManager::StoreLiteModel(const char *model_buf, const Model *mod
|
|||
return RET_OK;
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
if (buf_model_weight_.find(model_buf) == buf_model_weight_.end()) {
|
||||
MS_LOG(ERROR) << "set model failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
buf_model_weight_[model_buf]->lite_models.push_back(model);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
@ -69,6 +87,19 @@ void *PackWeightManager::GetTensorData(const LiteModel *model, const SchemaTenso
|
|||
return nullptr;
|
||||
}
|
||||
}
|
||||
for (auto &item : buf_model_weight_) {
|
||||
auto &model_buf = item.first;
|
||||
auto &model_weight = item.second;
|
||||
auto &models = model_weight->lite_models;
|
||||
if (find(models.begin(), models.end(), model) != models.end()) {
|
||||
if (model_weight->packed_weight.find(tensor_index) != model_weight->packed_weight.end()) {
|
||||
return model_weight->packed_weight[tensor_index];
|
||||
}
|
||||
buf_model_weight_[model_buf]->origin_weight[tensor_index] = origin_tensor->data();
|
||||
buf_model_weight_[model_buf]->origin_data_index[origin_tensor->data()] = tensor_index;
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
MS_LOG(DEBUG) << "tensor data not packed.";
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -113,6 +144,13 @@ std::pair<PackStatus, void *> PackWeightManager::GetPackedTensor(const Tensor *t
|
|||
return packed_tensor_pair;
|
||||
}
|
||||
}
|
||||
for (auto &item : buf_model_weight_) {
|
||||
auto &model_weight = item.second;
|
||||
auto packed_tensor_pair = FindPackedTensor(model_weight, tensor, round_size);
|
||||
if (packed_tensor_pair.second != nullptr) {
|
||||
return packed_tensor_pair;
|
||||
}
|
||||
}
|
||||
MS_LOG(DEBUG) << "not const tensor, need pack in kernel.";
|
||||
return std::make_pair(MALLOC, nullptr);
|
||||
}
|
||||
|
@ -127,6 +165,13 @@ void PackWeightManager::DeleteSavedModelPtr(LiteModel *delete_model) {
|
|||
weight->lite_models.erase(it);
|
||||
}
|
||||
}
|
||||
for (auto &item : buf_model_weight_) {
|
||||
auto &weight = item.second;
|
||||
auto it = find(weight->lite_models.begin(), weight->lite_models.end(), delete_model);
|
||||
if (it != weight->lite_models.end()) {
|
||||
weight->lite_models.erase(it);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void PackWeightManager::FreePackedWeight(ModelConstWeight *weight) {
|
||||
|
@ -154,6 +199,10 @@ PackWeightManager::~PackWeightManager() {
|
|||
FreePackedWeight(item.second);
|
||||
path_model_weight_.erase(item.first);
|
||||
}
|
||||
for (auto &item : buf_model_weight_) {
|
||||
FreePackedWeight(item.second);
|
||||
buf_model_weight_.erase(item.first);
|
||||
}
|
||||
}
|
||||
} // namespace mindspore::lite
|
||||
#endif
|
||||
|
|
|
@ -45,6 +45,7 @@ class PackWeightManager {
|
|||
virtual ~PackWeightManager();
|
||||
|
||||
void InitWeightManagerByPath(const std::string &model_path, const char *model_buf);
|
||||
void InitWeightManagerByBuf(const char *model_buf);
|
||||
void DeleteSavedModelPtr(LiteModel *delete_model);
|
||||
STATUS StoreLiteModel(const char *model_buf, const Model *model);
|
||||
void *GetTensorData(const LiteModel *model, const SchemaTensorWrapper *origin_tensor, size_t tensor_index);
|
||||
|
@ -56,6 +57,7 @@ class PackWeightManager {
|
|||
void FreePackedWeight(ModelConstWeight *weight);
|
||||
|
||||
std::map<const std::string, ModelConstWeight *> path_model_weight_;
|
||||
std::map<const std::string, ModelConstWeight *> buf_model_weight_;
|
||||
std::map<const std::string, std::vector<const void *>> path_model_buf_;
|
||||
std::mutex mtx_weight_;
|
||||
};
|
||||
|
|
|
@ -164,11 +164,11 @@ int ArithmeticCPUKernel::InitIndexOffsetInfo() {
|
|||
delta = delta % batch_size[j];
|
||||
}
|
||||
if (j < last_batch_axis) {
|
||||
a_offset += (delta / batch_size[j + 1] * a_shape[j] / MSMAX(a_shape[j], b_shape[j])) * a_batch_size[j + 1];
|
||||
b_offset += (delta / batch_size[j + 1] * b_shape[j] / MSMAX(a_shape[j], b_shape[j])) * b_batch_size[j + 1];
|
||||
a_offset += (delta / batch_size[j + 1] * a_shape[j] / c_shape[j]) * a_batch_size[j + 1];
|
||||
b_offset += (delta / batch_size[j + 1] * b_shape[j] / c_shape[j]) * b_batch_size[j + 1];
|
||||
} else {
|
||||
a_offset += (delta * a_shape[j] / MSMAX(a_shape[j], b_shape[j]));
|
||||
b_offset += (delta * b_shape[j] / MSMAX(a_shape[j], b_shape[j]));
|
||||
a_offset += (delta * a_shape[j] / c_shape[j]);
|
||||
b_offset += (delta * b_shape[j] / c_shape[j]);
|
||||
}
|
||||
}
|
||||
a_offset_[i] = a_offset;
|
||||
|
@ -368,22 +368,22 @@ int ArithmeticCPUKernel::CalcArithmeticByBatch(int task_id) {
|
|||
int batch_per_thread = UP_DIV(out_batch_, op_parameter_->thread_num_);
|
||||
int start_batch = batch_per_thread * task_id;
|
||||
int end_batch = MSMIN(start_batch + batch_per_thread, out_batch_);
|
||||
int ret = RET_ERROR;
|
||||
for (int i = start_batch; i < end_batch; i++) {
|
||||
batch_a_ptr_ = static_cast<uint8_t *>(input0_ptr_) + a_offset_[i] * a_stride_size_ * data_type_len_;
|
||||
batch_b_ptr_ = static_cast<uint8_t *>(input1_ptr_) + b_offset_[i] * b_stride_size_ * data_type_len_;
|
||||
batch_c_ptr_ = static_cast<uint8_t *>(output_ptr_) + i * c_stride_size_ * data_type_len_;
|
||||
int ret = RET_ERROR;
|
||||
auto batch_a_ptr = static_cast<uint8_t *>(input0_ptr_) + a_offset_[i] * a_stride_size_ * data_type_len_;
|
||||
auto batch_b_ptr = static_cast<uint8_t *>(input1_ptr_) + b_offset_[i] * b_stride_size_ * data_type_len_;
|
||||
auto batch_c_ptr = static_cast<uint8_t *>(output_ptr_) + i * c_stride_size_ * data_type_len_;
|
||||
if (batch_scalar_) {
|
||||
ret = DoExecute(batch_a_ptr_, batch_b_ptr_, batch_c_ptr_, c_stride_size_, true);
|
||||
ret = DoExecute(batch_a_ptr, batch_b_ptr, batch_c_ptr, c_stride_size_, true);
|
||||
} else {
|
||||
ret = DoExecute(batch_a_ptr_, batch_b_ptr_, batch_c_ptr_, c_stride_size_, false);
|
||||
ret = DoExecute(batch_a_ptr, batch_b_ptr, batch_c_ptr, c_stride_size_, false);
|
||||
}
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "failed to calculate.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ArithmeticCPUKernel::DoArithmetic(int task_id) {
|
||||
|
|
|
@ -140,10 +140,9 @@ class MS_API BenchmarkFlags : public virtual FlagParser {
|
|||
AddFlag(&BenchmarkFlags::resize_dims_in_, "inputShapes",
|
||||
"Shape of input data, the format should be NHWC. e.g. 1,32,32,32:1,1,32,32,1", "");
|
||||
#ifdef SERVER_INFERENCE
|
||||
AddFlag(&BenchmarkFlags::model_pool_, "modelPool", "use model pool", false);
|
||||
AddFlag(&BenchmarkFlags::enable_parallel_predict_, "enableParallelPredict", "Enable model parallel : true | false",
|
||||
false);
|
||||
AddFlag(&BenchmarkFlags::num_require_, "numRequire", "require num", 1);
|
||||
AddFlag(&BenchmarkFlags::num_model_, "numModel", "build model num", 1);
|
||||
AddFlag(&BenchmarkFlags::num_split_, "numSplit", "split for batch", 1);
|
||||
#endif
|
||||
#ifdef ENABLE_OPENGL_TEXTURE
|
||||
AddFlag(&BenchmarkFlags::enable_gl_texture_, "enableGLTexture", "Enable GlTexture2D", false);
|
||||
|
@ -159,10 +158,8 @@ class MS_API BenchmarkFlags : public virtual FlagParser {
|
|||
public:
|
||||
// common
|
||||
#ifdef SERVER_INFERENCE
|
||||
bool model_pool_ = false;
|
||||
bool enable_parallel_predict_ = false;
|
||||
int num_require_ = 1;
|
||||
int num_model_ = 1;
|
||||
int num_split_ = 1;
|
||||
#endif
|
||||
std::string model_file_;
|
||||
std::string in_data_file_;
|
||||
|
|
|
@ -42,7 +42,7 @@
|
|||
#include "include/mpi_vb.h"
|
||||
#endif
|
||||
#ifdef SERVER_INFERENCE
|
||||
#include "src/cxx_api/model_pool/model_pool.h"
|
||||
#include <thread>
|
||||
#endif
|
||||
namespace mindspore {
|
||||
constexpr size_t kDataToStringMaxNum = 40;
|
||||
|
@ -220,7 +220,7 @@ int BenchmarkUnifiedApi::LoadInput() {
|
|||
|
||||
int BenchmarkUnifiedApi::GenerateInputData() {
|
||||
#ifdef SERVER_INFERENCE
|
||||
if (flags_->model_pool_) {
|
||||
if (flags_->enable_parallel_predict_) {
|
||||
std::vector<MSTensor> inputs;
|
||||
for (size_t i = 0; i < ms_inputs_for_api_.size(); i++) {
|
||||
auto tensor_name = ms_inputs_for_api_[i].Name();
|
||||
|
@ -247,6 +247,7 @@ int BenchmarkUnifiedApi::GenerateInputData() {
|
|||
auto new_tensor =
|
||||
mindspore::MSTensor::CreateTensor(tensor_name, ms_inputs_for_api_[i].DataType(), shape, input_data, size);
|
||||
inputs.push_back(*new_tensor);
|
||||
delete new_tensor;
|
||||
}
|
||||
all_inputs_.push_back(inputs);
|
||||
return RET_OK;
|
||||
|
@ -296,7 +297,7 @@ void BenchmarkUnifiedApi::UpdateConfigInfo() {
|
|||
|
||||
int BenchmarkUnifiedApi::ReadInputFile() {
|
||||
#ifdef SERVER_INFERENCE
|
||||
if (flags_->model_pool_) {
|
||||
if (flags_->enable_parallel_predict_) {
|
||||
std::vector<MSTensor> inputs;
|
||||
for (size_t i = 0; i < ms_inputs_for_api_.size(); i++) {
|
||||
size_t size;
|
||||
|
@ -324,6 +325,7 @@ int BenchmarkUnifiedApi::ReadInputFile() {
|
|||
auto new_tensor =
|
||||
mindspore::MSTensor::CreateTensor(tensor_name, ms_inputs_for_api_[i].DataType(), shape, input_data, size);
|
||||
inputs.push_back(*new_tensor);
|
||||
delete new_tensor;
|
||||
}
|
||||
all_inputs_.push_back(inputs);
|
||||
return RET_OK;
|
||||
|
@ -895,7 +897,7 @@ int BenchmarkUnifiedApi::PrintInputData() {
|
|||
for (size_t i = 0; i < ms_inputs_for_api_.size(); i++) {
|
||||
mindspore::MSTensor input;
|
||||
#ifdef SERVER_INFERENCE
|
||||
if (flags_->model_pool_) {
|
||||
if (flags_->enable_parallel_predict_) {
|
||||
input = all_inputs_[0][i];
|
||||
} else {
|
||||
input = ms_inputs_for_api_[i];
|
||||
|
@ -946,9 +948,14 @@ int BenchmarkUnifiedApi::PrintInputData() {
|
|||
}
|
||||
#ifdef SERVER_INFERENCE
|
||||
int BenchmarkUnifiedApi::RunModelPool(std::shared_ptr<mindspore::Context> context) {
|
||||
if (flags_->resize_dims_.empty()) {
|
||||
MS_LOG(ERROR) << "use parallel predict, inputShapes can not use empty.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
// model pool init
|
||||
ModelParallelRunner model_pool;
|
||||
auto runner_config = std::make_shared<RunnerConfig>(context, flags_->num_model_);
|
||||
auto runner_config = std::make_shared<RunnerConfig>();
|
||||
runner_config->context = context;
|
||||
auto model_init_start = GetTimeUs();
|
||||
auto ret = model_pool.Init(flags_->model_file_, runner_config);
|
||||
if (ret != kSuccess) {
|
||||
|
@ -958,6 +965,10 @@ int BenchmarkUnifiedApi::RunModelPool(std::shared_ptr<mindspore::Context> contex
|
|||
auto model_init_end = GetTimeUs();
|
||||
// load data
|
||||
ms_inputs_for_api_ = model_pool.GetInputs();
|
||||
if (ms_inputs_for_api_.empty()) {
|
||||
MS_LOG(ERROR) << "model pool input is empty.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
for (int i = 0; i < flags_->num_require_ + flags_->warm_up_loop_count_; i++) {
|
||||
auto status = LoadInput();
|
||||
if (status != RET_OK) {
|
||||
|
@ -989,7 +1000,7 @@ int BenchmarkUnifiedApi::RunModelPool(std::shared_ptr<mindspore::Context> contex
|
|||
MS_LOG(ERROR) << "model pool predict failed.";
|
||||
}
|
||||
auto predict_end = GetTimeUs();
|
||||
MS_LOG(ERROR) << "run predict time: " << (predict_end - predict_start) / kFloatMSEC << " ms";
|
||||
std::cout << "run predict time: " << (predict_end - predict_start) / kFloatMSEC << " ms\n";
|
||||
if (!flags_->benchmark_data_file_.empty()) {
|
||||
auto status = CompareOutputForModelPool(&output);
|
||||
if (status != RET_OK) {
|
||||
|
@ -1004,18 +1015,22 @@ int BenchmarkUnifiedApi::RunModelPool(std::shared_ptr<mindspore::Context> contex
|
|||
for (auto &warm_up_thread : model_thread_warm_up) {
|
||||
warm_up_thread.join();
|
||||
}
|
||||
MS_LOG(DEBUG) << "================ end warm up ================";
|
||||
std::cout << "================ end warm up ================";
|
||||
auto all_start = GetTimeUs();
|
||||
std::vector<std::thread> model_thread_run;
|
||||
for (int i = 0; i < flags_->num_require_; i++) {
|
||||
model_thread_run.push_back(std::thread(model_pool_run, i + flags_->warm_up_loop_count_));
|
||||
}
|
||||
for (auto &run_thread : model_thread_run) {
|
||||
run_thread.join();
|
||||
for (int loop_count_num = 0; loop_count_num < flags_->loop_count_; loop_count_num++) {
|
||||
std::vector<std::thread> model_thread_run;
|
||||
for (int i = 0; i < flags_->num_require_; i++) {
|
||||
model_thread_run.push_back(std::thread(model_pool_run, i + flags_->warm_up_loop_count_));
|
||||
}
|
||||
for (auto &run_thread : model_thread_run) {
|
||||
run_thread.join();
|
||||
}
|
||||
}
|
||||
auto all_end = GetTimeUs();
|
||||
std::cout << "=================================" << std::endl;
|
||||
std::cout << "model pool init time: " << (model_init_end - model_init_start) / kFloatMSEC << " ms\n";
|
||||
std::cout << "model pool all run time: " << (all_end - all_start) / kFloatMSEC << " ms\n";
|
||||
std::cout << "model pool all run time: " << (all_end - all_start) / kFloatMSEC / flags_->loop_count_ << " ms\n";
|
||||
std::cout << "=================================" << std::endl;
|
||||
return RET_OK;
|
||||
}
|
||||
#endif
|
||||
|
@ -1064,7 +1079,7 @@ int BenchmarkUnifiedApi::RunBenchmark() {
|
|||
|
||||
UpdateConfigInfo();
|
||||
#ifdef SERVER_INFERENCE
|
||||
if (flags_->model_pool_) {
|
||||
if (flags_->enable_parallel_predict_) {
|
||||
status = RunModelPool(context);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "run model pool failed.";
|
||||
|
|
|
@ -43,7 +43,7 @@
|
|||
#include "tools/common/opengl_util.h"
|
||||
#endif
|
||||
#ifdef SERVER_INFERENCE
|
||||
#include "src/cxx_api/model_pool/model_parallel_runner.h"
|
||||
#include "include/api/model_parallel_runner.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore::lite {
|
||||
|
|
Loading…
Reference in New Issue