!27889 [MS][LITE]support load cache data from device model

Merge pull request !27889 from 张学同/distribution_cache_dev
This commit is contained in:
i-robot 2021-12-20 07:03:01 +00:00 committed by Gitee
commit 39c4d6aa42
10 changed files with 180 additions and 45 deletions

View File

@ -55,6 +55,7 @@ static const int32_t DIM_DEFAULT_SIZE = 4;
static const char *const kMSCache = "ms_cache"; static const char *const kMSCache = "ms_cache";
static const char *const kMSCacheModelPath = "cache_model_path"; static const char *const kMSCacheModelPath = "cache_model_path";
static const char *const kMSCacheVocabSize = "vocab_size"; static const char *const kMSCacheVocabSize = "vocab_size";
static const char *const kMSCacheDeviceSize = "device_cache_size";
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

View File

@ -29,11 +29,13 @@ namespace mindspore {
namespace cache { namespace cache {
class EmbeddingCache { class EmbeddingCache {
public: public:
EmbeddingCache(size_t vocab_size, size_t host_cache_size, size_t device_cache_size, size_t embedding_size, EmbeddingCache(size_t vocab_size, size_t host_cache_size, size_t device_cache_size, size_t device_start_index,
size_t batch_elements, DataType data_type, void *host_addr, int rank_id, int rank_group_size) size_t embedding_size, size_t batch_elements, DataType data_type, void *host_addr, int rank_id,
int rank_group_size)
: vocab_size_(vocab_size), : vocab_size_(vocab_size),
host_cache_size_(host_cache_size), host_cache_size_(host_cache_size),
device_cache_size_(device_cache_size), device_cache_size_(device_cache_size),
device_start_index_(device_start_index),
embedding_size_(embedding_size), embedding_size_(embedding_size),
batch_elements_(batch_elements), batch_elements_(batch_elements),
data_type_(data_type), data_type_(data_type),
@ -51,7 +53,6 @@ class EmbeddingCache {
default: default:
sizeof_data_type_ = sizeof(float); sizeof_data_type_ = sizeof(float);
} }
device_start_index_ = device_cache_size_ * rank_id_;
MS_LOG(INFO) << "rank_group_size_ num:" << rank_group_size_ << ", rank id:" << rank_id_ MS_LOG(INFO) << "rank_group_size_ num:" << rank_group_size_ << ", rank id:" << rank_id_
<< ", vocab_size_:" << vocab_size_ << ", host_cache_size_:" << host_cache_size_ << ", vocab_size_:" << vocab_size_ << ", host_cache_size_:" << host_cache_size_
<< ", device_cache_size_:" << device_cache_size_ << ", embedding_size_:" << embedding_size_ << ", device_cache_size_:" << device_cache_size_ << ", embedding_size_:" << embedding_size_

View File

@ -22,9 +22,9 @@
namespace mindspore { namespace mindspore {
namespace cache { namespace cache {
Status EmbeddingCacheManager::Init(const std::string &cache_model_path, size_t vocab_size) { Status EmbeddingCacheManager::Init(const std::string &cache_model_path, size_t vocab_size, size_t device_cache_size) {
if (cache_model_path.empty()) { if (cache_model_path.empty() || vocab_size == 0 || device_cache_size >= vocab_size) {
MS_LOG(INFO) << "no cache model "; MS_LOG(INFO) << "no cache model , vocab_size " << vocab_size << ", device_cache_size " << device_cache_size;
return kSuccess; return kSuccess;
} }
@ -34,8 +34,39 @@ Status EmbeddingCacheManager::Init(const std::string &cache_model_path, size_t v
return kLiteMemoryFailed; return kLiteMemoryFailed;
} }
auto ret = host_cache_model_->LoadCache(cache_model_path); auto ret = host_cache_model_->LoadCache(cache_model_path);
if (ret != kSuccess) {
MS_LOG(ERROR) << "load cache failed";
return ret;
}
vocab_size_ = vocab_size; vocab_size_ = vocab_size;
MS_LOG(INFO) << "cache manager init end, ret " << ret.ToString(); device_cache_size_ = device_cache_size;
MS_LOG(INFO) << "cache manager init succ, cache model" << cache_model_path << " , vocab_size " << vocab_size
<< ", device_cache_size " << device_cache_size;
return ret;
}
Status EmbeddingCacheManager::Init(DelegateModel<schema::Primitive> *model, size_t vocab_size,
size_t device_cache_size) {
if (model == nullptr || vocab_size == 0 || device_cache_size >= vocab_size) {
MS_LOG(INFO) << "no cache model , vocab_size " << vocab_size << ", device_cache_size " << device_cache_size;
return kSuccess;
}
host_cache_model_ = std::make_shared<HostCacheModel>();
if (host_cache_model_ == nullptr) {
MS_LOG(ERROR) << "HostCacheModel malloc failed";
return kLiteMemoryFailed;
}
auto ret = host_cache_model_->LoadCache(model);
if (ret != kSuccess) {
MS_LOG(ERROR) << "load cache failed";
return ret;
}
vocab_size_ = vocab_size;
device_cache_size_ = device_cache_size;
MS_LOG(INFO) << "cache manager init succ, vocab_size " << vocab_size << ", device_cache_size " << device_cache_size;
return ret; return ret;
} }
@ -70,14 +101,17 @@ Status EmbeddingCacheManager::InitCacheKernel(kernel::Kernel *kernel, uint32_t d
return kLiteError; return kLiteError;
} }
size_t vocab_size = vocab_size_; size_t vocab_size = vocab_size_;
size_t host_cache_size = host_cache_tensor.ElementNum(); size_t host_cache_size = host_cache_tensor.Shape()[0];
size_t device_cache_size = tensor.Shape()[0];
size_t embedding_size_ = tensor.Shape()[1]; size_t embedding_size_ = tensor.Shape()[1];
DataType data_type = tensor.DataType(); DataType data_type = tensor.DataType();
size_t batch_elements = kernel->inputs()[1].ElementNum(); size_t batch_elements = kernel->inputs()[1].ElementNum();
auto cache = size_t device_start_index = device_cache_size_ * rank_id_;
std::make_shared<EmbeddingCache>(vocab_size, host_cache_size, device_cache_size, embedding_size_, batch_elements, if (tensor.Shape()[0] == host_cache_tensor.Shape()[0]) {
data_type, host_cache_tensor.MutableData(), rank_id_, rank_group_size_); device_start_index = rank_id_ * static_cast<int>(std::ceil(static_cast<float>(vocab_size_) / rank_group_size_));
}
auto cache = std::make_shared<EmbeddingCache>(vocab_size, host_cache_size, device_cache_size_, device_start_index,
embedding_size_, batch_elements, data_type,
host_cache_tensor.MutableData(), rank_id_, rank_group_size_);
if (cache == nullptr) { if (cache == nullptr) {
MS_LOG(ERROR) << kernel->name() << ": malloc EmbeddingCache failed"; MS_LOG(ERROR) << kernel->name() << ": malloc EmbeddingCache failed";
return kLiteError; return kLiteError;
@ -107,6 +141,23 @@ bool EmbeddingCacheManager::IsCacheTensor(mindspore::MSTensor tensor) {
return false; return false;
} }
std::vector<int64_t> EmbeddingCacheManager::GetCacheShape(mindspore::MSTensor tensor) {
std::vector<int64_t> shape = tensor.Shape();
if (shape.size() > 0 && IsCacheTensor(tensor)) {
shape[0] = device_cache_size_;
}
return shape;
}
size_t EmbeddingCacheManager::GetCacheDataSize(mindspore::MSTensor tensor) {
auto data_size = tensor.DataSize();
auto &shape = tensor.Shape();
if (shape.size() > 0 && IsCacheTensor(tensor) && shape[0] > 0) {
data_size = data_size * device_cache_size_ / shape[0];
}
return data_size;
}
Status EmbeddingCacheManager::SetDeviceCacheAddr(const std::string &tensor_name, void *device_mem_addr, size_t size) { Status EmbeddingCacheManager::SetDeviceCacheAddr(const std::string &tensor_name, void *device_mem_addr, size_t size) {
auto cache_iter = caches_.find(tensor_name); auto cache_iter = caches_.find(tensor_name);
if (cache_iter == caches_.end() || cache_iter->second == nullptr) { if (cache_iter == caches_.end() || cache_iter->second == nullptr) {

View File

@ -35,12 +35,15 @@ class EmbeddingCacheManager {
rank_id_ = lite::GetRankID(); rank_id_ = lite::GetRankID();
rank_group_size_ = lite::GetGPUGroupSize(); rank_group_size_ = lite::GetGPUGroupSize();
} }
Status Init(const std::string &cache_model_path, size_t vocab_size); Status Init(const std::string &cache_model_path, size_t vocab_size, size_t device_cache_size);
Status Init(DelegateModel<schema::Primitive> *model, size_t vocab_size, size_t device_cache_size);
bool CheckIsCacheKernel(kernel::Kernel *kernel); bool CheckIsCacheKernel(kernel::Kernel *kernel);
Status InitCacheKernel(kernel::Kernel *kernel, uint32_t device_id, const void *context); Status InitCacheKernel(kernel::Kernel *kernel, uint32_t device_id, const void *context);
bool IsCacheTensor(mindspore::MSTensor tensor); bool IsCacheTensor(mindspore::MSTensor tensor);
int CacheHandle(const std::string &tensor_name, mindspore::MSTensor model_input_tensor, void *device_addr); int CacheHandle(const std::string &tensor_name, mindspore::MSTensor model_input_tensor, void *device_addr);
Status SetDeviceCacheAddr(const std::string &tensor_name, void *device_mem_addr, size_t size); Status SetDeviceCacheAddr(const std::string &tensor_name, void *device_mem_addr, size_t size);
std::vector<int64_t> GetCacheShape(mindspore::MSTensor tensor);
size_t GetCacheDataSize(mindspore::MSTensor tensor);
private: private:
std::map<std::string, std::shared_ptr<EmbeddingCache>> caches_; std::map<std::string, std::shared_ptr<EmbeddingCache>> caches_;
@ -50,6 +53,7 @@ class EmbeddingCacheManager {
std::shared_ptr<HostCacheModel> host_cache_model_; std::shared_ptr<HostCacheModel> host_cache_model_;
size_t vocab_size_; size_t vocab_size_;
size_t device_cache_size_;
}; };
} // namespace cache } // namespace cache
} // namespace mindspore } // namespace mindspore

View File

@ -20,6 +20,7 @@
#include <vector> #include <vector>
#include "src/delegate/parameter_cache/load_host_cache_model.h" #include "src/delegate/parameter_cache/load_host_cache_model.h"
#include "src/common/log_adapter.h" #include "src/common/log_adapter.h"
#include "src/common/common.h"
#include "include/errorcode.h" #include "include/errorcode.h"
#include "src/common/file_utils.h" #include "src/common/file_utils.h"
@ -79,6 +80,45 @@ Status HostCacheModel::LoadCache(const std::string &model_path) {
return kSuccess; return kSuccess;
} }
size_t GetVocabSize(kernel::Kernel *kernel) {
size_t vocab_size = 0;
auto cache_config = kernel->GetConfig(lite::kMSCache);
auto vocab_size_iter = cache_config.find(lite::kMSCacheVocabSize);
if (vocab_size_iter == cache_config.end()) {
return vocab_size;
}
auto vocab_size_opt = lite::GenericParseValue<size_t>(vocab_size_iter->second);
if (!vocab_size_opt.IsNone()) {
vocab_size = vocab_size_opt.Get();
}
return vocab_size;
}
Status HostCacheModel::LoadCache(DelegateModel<schema::Primitive> *model) {
KernelIter from, end;
for (KernelIter iter = model->BeginKernelIterator(); iter != model->EndKernelIterator(); iter++) {
kernel::Kernel *kernel = *iter;
// only support embedding cache
if (kernel->type() != schema::PrimitiveType_Gather) {
continue;
}
MS_ASSERT(kernel->inputs.size() == 3);
auto tensor = kernel->inputs()[0];
if (tensor.Data() == nullptr) {
continue;
}
size_t vocab_size = GetVocabSize(kernel);
if (vocab_size == 0) {
continue;
}
cache_tensor_[tensor.Name()] = tensor;
}
return mindspore::kSuccess;
}
bool HostCacheModel::CheckIsCacheKernel(kernel::Kernel *kernel) { bool HostCacheModel::CheckIsCacheKernel(kernel::Kernel *kernel) {
if (GetHostCacheTensor(kernel) == nullptr) { if (GetHostCacheTensor(kernel) == nullptr) {
return false; return false;

View File

@ -23,6 +23,7 @@
#include "include/api/data_type.h" #include "include/api/data_type.h"
#include "include/api/types.h" #include "include/api/types.h"
#include "include/api/kernel.h" #include "include/api/kernel.h"
#include "include/api/delegate.h"
#include "src/lite_model.h" #include "src/lite_model.h"
namespace mindspore { namespace mindspore {
@ -31,6 +32,7 @@ class HostCacheModel {
public: public:
HostCacheModel() {} HostCacheModel() {}
Status LoadCache(const std::string &model_path); Status LoadCache(const std::string &model_path);
Status LoadCache(DelegateModel<schema::Primitive> *model);
bool CheckIsCacheKernel(kernel::Kernel *kernel); bool CheckIsCacheKernel(kernel::Kernel *kernel);
MSTensor GetHostCacheTensor(kernel::Kernel *kernel); MSTensor GetHostCacheTensor(kernel::Kernel *kernel);

View File

@ -146,7 +146,7 @@ Status TensorRTDelegate::Init() {
MS_LOG(ERROR) << "malloc EmbeddingCacheManager failed."; MS_LOG(ERROR) << "malloc EmbeddingCacheManager failed.";
return kLiteMemoryFailed; return kLiteMemoryFailed;
} }
auto cache_ret = cache_mgr_->Init(cache_model_path_, vocab_size_); auto cache_ret = cache_mgr_->Init(cache_model_path_, vocab_size_, device_cache_size_);
if (cache_ret != mindspore::kSuccess) { if (cache_ret != mindspore::kSuccess) {
MS_LOG(ERROR) << "cache_mgr_ init failed."; MS_LOG(ERROR) << "cache_mgr_ init failed.";
return cache_ret; return cache_ret;
@ -155,40 +155,37 @@ Status TensorRTDelegate::Init() {
return mindspore::kSuccess; return mindspore::kSuccess;
} }
Status TensorRTDelegate::Build(DelegateModel<schema::Primitive> *model) { void TensorRTDelegate::CheckSupportResize(schema::PrimitiveType type) {
int ret = lite::SetCudaDevice(device_info_); if (support_resize_) {
if (ret != RET_OK) { auto opt_iter = unsupport_resize_op_list_.find(type);
return mindspore::kLiteError; if (opt_iter != unsupport_resize_op_list_.end()) {
support_resize_ = false;
support_hw_resize_ = false;
MS_LOG(INFO) << "network has op don't support resize, op: " << type;
return;
}
} }
if (support_hw_resize_) {
auto opt_iter = unsupport_hw_op_lists_.find(type);
if (opt_iter != unsupport_hw_op_lists_.end()) {
support_hw_resize_ = false;
MS_LOG(INFO) << "network has op don't support hw resize, op: " << type;
}
}
}
Status TensorRTDelegate::BuildSubGraph(DelegateModel<schema::Primitive> *model) {
KernelIter from, end; KernelIter from, end;
std::vector<TensorRTOp *> tensorrt_ops; std::vector<TensorRTOp *> tensorrt_ops;
for (KernelIter iter = model->BeginKernelIterator(); iter != model->EndKernelIterator(); iter++) { for (KernelIter iter = model->BeginKernelIterator(); iter != model->EndKernelIterator(); iter++) {
kernel::Kernel *kernel = *iter; kernel::Kernel *kernel = *iter;
if (support_resize_) { CheckSupportResize(model->GetPrimitive(kernel)->value_type());
for (auto no_type : unsupport_resize_op_list_) {
if (model->GetPrimitive(kernel)->value_type() == no_type) {
support_resize_ = false;
support_hw_resize_ = false;
MS_LOG(INFO) << "network has op don't support resize.";
continue;
}
}
}
if (support_hw_resize_) {
for (auto no_type : unsupport_hw_op_lists_) {
if (model->GetPrimitive(kernel)->value_type() == no_type) {
support_hw_resize_ = false;
MS_LOG(INFO) << "network has op don't support hw resize.";
continue;
}
}
}
auto tensorrt_op = FindTensorRTOp(kernel, model->GetPrimitive(kernel)); auto tensorrt_op = FindTensorRTOp(kernel, model->GetPrimitive(kernel));
if (tensorrt_op != nullptr) { if (tensorrt_op != nullptr) {
if (cache_mgr_->CheckIsCacheKernel(kernel)) { if (cache_mgr_->CheckIsCacheKernel(kernel)) {
auto cache_ret = cache_mgr_->InitCacheKernel(kernel, device_info_->GetDeviceID(), &stream_); auto cache_ret = cache_mgr_->InitCacheKernel(kernel, device_info_->GetDeviceID(), &stream_);
if (cache_ret != kSuccess) { if (cache_ret != kSuccess) {
MS_LOG(INFO) << "InitCacheKernel failed " << kernel->name(); MS_LOG(ERROR) << "InitCacheKernel failed " << kernel->name();
return cache_ret; return cache_ret;
} }
} }
@ -215,7 +212,7 @@ Status TensorRTDelegate::Build(DelegateModel<schema::Primitive> *model) {
if (tensorrt_ops.size() > 0) { if (tensorrt_ops.size() > 0) {
auto tensorrt_subgraph = CreateTensorRTGraph(tensorrt_ops, model, from, end); auto tensorrt_subgraph = CreateTensorRTGraph(tensorrt_ops, model, from, end);
if (tensorrt_subgraph == nullptr) { if (tensorrt_subgraph == nullptr) {
MS_LOG(DEBUG) << "Create TensorRT Graph failed."; MS_LOG(ERROR) << "Create TensorRT Graph failed.";
return mindspore::kLiteNullptr; return mindspore::kLiteNullptr;
} }
model->Replace(from, end + 1, tensorrt_subgraph); model->Replace(from, end + 1, tensorrt_subgraph);
@ -224,6 +221,28 @@ Status TensorRTDelegate::Build(DelegateModel<schema::Primitive> *model) {
return mindspore::kSuccess; return mindspore::kSuccess;
} }
Status TensorRTDelegate::Build(DelegateModel<schema::Primitive> *model) {
int ret = lite::SetCudaDevice(device_info_);
if (ret != RET_OK) {
return mindspore::kLiteError;
}
if (cache_model_path_.empty() && vocab_size_ > 0) {
auto cache_ret = cache_mgr_->Init(model, vocab_size_, device_cache_size_);
if (cache_ret != mindspore::kSuccess) {
MS_LOG(ERROR) << "cache_mgr_ init failed.";
return cache_ret;
}
}
auto build_ret = BuildSubGraph(model);
if (build_ret != kSuccess) {
MS_LOG(INFO) << "BuildSubGraph failed";
return build_ret;
}
return mindspore::kSuccess;
}
TensorRTOp *TensorRTDelegate::FindTensorRTOp(kernel::Kernel *kernel, const schema::Primitive *primitive) { TensorRTOp *TensorRTDelegate::FindTensorRTOp(kernel::Kernel *kernel, const schema::Primitive *primitive) {
auto in_tensors = kernel->inputs(); auto in_tensors = kernel->inputs();
auto out_tensors = kernel->outputs(); auto out_tensors = kernel->outputs();

View File

@ -35,8 +35,12 @@ typedef TensorRTOp *(*TensorRTGetOp)(const schema::Primitive *primitive,
class TensorRTDelegate : public Delegate { class TensorRTDelegate : public Delegate {
public: public:
explicit TensorRTDelegate(mindspore::Context *context, const std::string &cache_model_path, size_t vocab_size) explicit TensorRTDelegate(mindspore::Context *context, const std::string &cache_model_path, size_t vocab_size,
: context_(context), cache_model_path_(cache_model_path), vocab_size_(vocab_size) {} size_t device_cache_size)
: context_(context),
cache_model_path_(cache_model_path),
vocab_size_(vocab_size),
device_cache_size_(device_cache_size) {}
~TensorRTDelegate() override; ~TensorRTDelegate() override;
@ -45,6 +49,8 @@ class TensorRTDelegate : public Delegate {
Status Build(DelegateModel<schema::Primitive> *model) override; Status Build(DelegateModel<schema::Primitive> *model) override;
private: private:
Status BuildSubGraph(DelegateModel<schema::Primitive> *model);
void CheckSupportResize(schema::PrimitiveType type);
TensorRTOp *FindTensorRTOp(kernel::Kernel *kernel, const schema::Primitive *primitive); TensorRTOp *FindTensorRTOp(kernel::Kernel *kernel, const schema::Primitive *primitive);
TensorRTSubGraph *CreateTensorRTGraph(const std::vector<TensorRTOp *> &ops, DelegateModel<schema::Primitive> *model, TensorRTSubGraph *CreateTensorRTGraph(const std::vector<TensorRTOp *> &ops, DelegateModel<schema::Primitive> *model,
@ -67,6 +73,7 @@ class TensorRTDelegate : public Delegate {
bool support_resize_{true}; bool support_resize_{true};
const std::string cache_model_path_; const std::string cache_model_path_;
size_t vocab_size_; size_t vocab_size_;
size_t device_cache_size_;
std::shared_ptr<cache::EmbeddingCacheManager> cache_mgr_{nullptr}; std::shared_ptr<cache::EmbeddingCacheManager> cache_mgr_{nullptr};
cudaStream_t stream_; cudaStream_t stream_;

View File

@ -359,8 +359,8 @@ int TensorRTSubGraph::Prepare() {
// malloc for cache weight tensor // malloc for cache weight tensor
for (auto cache_tensor : cache_const_inputs_) { for (auto cache_tensor : cache_const_inputs_) {
auto data_size = cache_tensor.DataSize(); size_t data_size = cache_mgr_->GetCacheDataSize(cache_tensor);
auto device_ptr = runtime_->GetAllocator()->MallocDeviceMem(cache_tensor, cache_tensor.DataSize()); auto device_ptr = runtime_->GetAllocator()->MallocDeviceMem(cache_tensor, data_size);
runtime_->GetAllocator()->SyncMemInHostAndDevice(cache_tensor, cache_tensor.Name().c_str(), true); runtime_->GetAllocator()->SyncMemInHostAndDevice(cache_tensor, cache_tensor.Name().c_str(), true);
runtime_->GetAllocator()->MarkMemValid(cache_tensor.Name().c_str(), true); runtime_->GetAllocator()->MarkMemValid(cache_tensor.Name().c_str(), true);
int index = this->engine_->getBindingIndex(cache_tensor.Name().c_str()); int index = this->engine_->getBindingIndex(cache_tensor.Name().c_str());
@ -596,9 +596,10 @@ int TensorRTSubGraph::HandleCacheTensor(TensorRTOp *cur_op, const mindspore::MST
FindCacheTensorInfo(cur_op, in_tensor); FindCacheTensorInfo(cur_op, in_tensor);
// cache kernel weight tensor // cache kernel weight tensor
cache_const_inputs_.push_back(in_tensor); cache_const_inputs_.push_back(in_tensor);
auto shape = cache_mgr_->GetCacheShape(in_tensor);
MS_LOG(INFO) << "auto add cache constant tensor for: " << in_tensor.Name(); MS_LOG(INFO) << "auto add cache constant tensor for: " << in_tensor.Name();
auto cuda_dtype = ConvertDataType(in_tensor.DataType()); auto cuda_dtype = ConvertDataType(in_tensor.DataType());
nvinfer1::Dims input_dims = ConvertCudaDims(in_tensor.Shape()); nvinfer1::Dims input_dims = ConvertCudaDims(shape);
nvinfer1::ITensor *cache_input = network_->addInput(in_tensor.Name().c_str(), cuda_dtype, input_dims); nvinfer1::ITensor *cache_input = network_->addInput(in_tensor.Name().c_str(), cuda_dtype, input_dims);
if (cache_input == nullptr) { if (cache_input == nullptr) {
MS_LOG(ERROR) << "add cache Weight Tensor data is nullptr."; MS_LOG(ERROR) << "add cache Weight Tensor data is nullptr.";

View File

@ -742,6 +742,7 @@ int LiteSession::CreateTensorRTDelegate() {
#if GPU_TENSORRT #if GPU_TENSORRT
std::string cache_model_path; std::string cache_model_path;
size_t vocab_size = 0; size_t vocab_size = 0;
size_t device_cache_size = 0;
if (config_info_ != nullptr) { if (config_info_ != nullptr) {
auto ms_cache_iter = config_info_->find(kMSCache); auto ms_cache_iter = config_info_->find(kMSCache);
if (ms_cache_iter != config_info_->end()) { if (ms_cache_iter != config_info_->end()) {
@ -758,10 +759,18 @@ int LiteSession::CreateTensorRTDelegate() {
vocab_size = vocab_size_opt.Get(); vocab_size = vocab_size_opt.Get();
} }
} }
auto device_cache_size_iter = ms_cache.find(kMSCacheDeviceSize);
if (device_cache_size_iter != ms_cache.end()) {
auto device_cache_size_opt = GenericParseValue<size_t>(device_cache_size_iter->second);
if (!device_cache_size_opt.IsNone()) {
device_cache_size = device_cache_size_opt.Get();
}
}
} }
} }
delegate_ = std::make_shared<TensorRTDelegate>(ms_context_, cache_model_path, vocab_size); delegate_ = std::make_shared<TensorRTDelegate>(ms_context_, cache_model_path, vocab_size, device_cache_size);
if (delegate_ == nullptr) { if (delegate_ == nullptr) {
MS_LOG(ERROR) << "New tensorrt delegate_ failed"; MS_LOG(ERROR) << "New tensorrt delegate_ failed";
return RET_ERROR; return RET_ERROR;