parall predict remove param

This commit is contained in:
yefeng 2022-03-02 09:43:18 +08:00
parent 5b1772d2b3
commit 4b06963e1f
6 changed files with 9 additions and 16 deletions

View File

@ -38,12 +38,9 @@ class MS_API ModelParallelRunner {
///
/// \param[in] model_path Define the model path.
/// \param[in] runner_config Define the config used to store options during model pool init.
/// \param[in] dec_key Define the key used to decrypt the ciphertext model. The key length is 16, 24, or 32.
/// \param[in] dec_mode Define the decryption mode. Options: AES-GCM, AES-CBC.
///
/// \return Status.
Status Init(const std::string &model_path, const std::shared_ptr<RunnerConfig> &runner_config = nullptr,
const Key &dec_key = {}, const std::string &dec_mode = kDecModeAesGcm);
Status Init(const std::string &model_path, const std::shared_ptr<RunnerConfig> &runner_config = nullptr);
/// \brief Obtains all input tensors information of the model.
///

View File

@ -18,9 +18,8 @@
#include "src/common/log.h"
namespace mindspore {
Status ModelParallelRunner::Init(const std::string &model_path, const std::shared_ptr<RunnerConfig> &runner_config,
const Key &dec_key, const std::string &dec_mode) {
auto status = ModelPool::GetInstance()->Init(model_path, runner_config, dec_key, dec_mode);
Status ModelParallelRunner::Init(const std::string &model_path, const std::shared_ptr<RunnerConfig> &runner_config) {
auto status = ModelPool::GetInstance()->Init(model_path, runner_config);
if (status != kSuccess) {
MS_LOG(ERROR) << "model runner init failed.";
return kLiteError;

View File

@ -271,8 +271,7 @@ std::vector<MSTensor> ModelPool::GetOutputs() {
return model_outputs_;
}
Status ModelPool::Init(const std::string &model_path, const std::shared_ptr<RunnerConfig> &runner_config,
const Key &dec_key, const std::string &dec_mode) {
Status ModelPool::Init(const std::string &model_path, const std::shared_ptr<RunnerConfig> &runner_config) {
auto model_pool_context = CreateModelContext(runner_config);
if (model_pool_context.empty()) {
MS_LOG(ERROR) << "CreateModelContext failed, context is empty.";
@ -309,7 +308,7 @@ Status ModelPool::Init(const std::string &model_path, const std::shared_ptr<Runn
numa_node_id = 0;
}
model_thread = std::make_shared<ModelThread>();
auto status = model_thread->Init(graph_buf_, size, model_pool_context[i], dec_key, dec_mode, numa_node_id);
auto status = model_thread->Init(graph_buf_, size, model_pool_context[i], numa_node_id);
if (status != kSuccess) {
MS_LOG(ERROR) << " model thread init failed.";
return kLiteError;

View File

@ -34,8 +34,7 @@ class ModelPool {
static ModelPool *GetInstance();
~ModelPool();
Status Init(const std::string &model_path, const std::shared_ptr<RunnerConfig> &runner_config = nullptr,
const Key &dec_key = {}, const std::string &dec_mode = kDecModeAesGcm);
Status Init(const std::string &model_path, const std::shared_ptr<RunnerConfig> &runner_config = nullptr);
std::vector<MSTensor> GetInputs();

View File

@ -60,13 +60,13 @@ void ModelThread::Run(int node_id) {
}
Status ModelThread::Init(const char *model_buf, size_t size, const std::shared_ptr<Context> &model_context,
const Key &dec_key, const std::string &dec_mode, int node_id) {
int node_id) {
model_ = std::make_shared<Model>();
mindspore::ModelType model_type = kMindIR;
if (node_id != -1) {
model_->UpdateConfig(lite::kConfigServerInference, {lite::kConfigNUMANodeId, std::to_string(node_id)});
}
auto status = model_->Build(model_buf, size, model_type, model_context, dec_key, dec_mode);
auto status = model_->Build(model_buf, size, model_type, model_context);
if (status != kSuccess) {
MS_LOG(ERROR) << "model build failed in ModelPool Init";
return status;

View File

@ -33,8 +33,7 @@ class ModelThread {
~ModelThread() = default;
// the model pool is initialized once and can always accept model run requests
Status Init(const char *model_buf, size_t size, const std::shared_ptr<Context> &model_context,
const Key &dec_key = {}, const std::string &dec_mode = kDecModeAesGcm, int node_id = -1);
Status Init(const char *model_buf, size_t size, const std::shared_ptr<Context> &model_context, int node_id = -1);
std::vector<MSTensor> GetInputs();