forked from mindspore-Ecosystem/mindspore
!27889 [MS][LITE]support load cache data from device model
Merge pull request !27889 from 张学同/distribution_cache_dev
This commit is contained in:
commit
39c4d6aa42
|
@ -55,6 +55,7 @@ static const int32_t DIM_DEFAULT_SIZE = 4;
|
||||||
static const char *const kMSCache = "ms_cache";
|
static const char *const kMSCache = "ms_cache";
|
||||||
static const char *const kMSCacheModelPath = "cache_model_path";
|
static const char *const kMSCacheModelPath = "cache_model_path";
|
||||||
static const char *const kMSCacheVocabSize = "vocab_size";
|
static const char *const kMSCacheVocabSize = "vocab_size";
|
||||||
|
static const char *const kMSCacheDeviceSize = "device_cache_size";
|
||||||
} // namespace lite
|
} // namespace lite
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -29,11 +29,13 @@ namespace mindspore {
|
||||||
namespace cache {
|
namespace cache {
|
||||||
class EmbeddingCache {
|
class EmbeddingCache {
|
||||||
public:
|
public:
|
||||||
EmbeddingCache(size_t vocab_size, size_t host_cache_size, size_t device_cache_size, size_t embedding_size,
|
EmbeddingCache(size_t vocab_size, size_t host_cache_size, size_t device_cache_size, size_t device_start_index,
|
||||||
size_t batch_elements, DataType data_type, void *host_addr, int rank_id, int rank_group_size)
|
size_t embedding_size, size_t batch_elements, DataType data_type, void *host_addr, int rank_id,
|
||||||
|
int rank_group_size)
|
||||||
: vocab_size_(vocab_size),
|
: vocab_size_(vocab_size),
|
||||||
host_cache_size_(host_cache_size),
|
host_cache_size_(host_cache_size),
|
||||||
device_cache_size_(device_cache_size),
|
device_cache_size_(device_cache_size),
|
||||||
|
device_start_index_(device_start_index),
|
||||||
embedding_size_(embedding_size),
|
embedding_size_(embedding_size),
|
||||||
batch_elements_(batch_elements),
|
batch_elements_(batch_elements),
|
||||||
data_type_(data_type),
|
data_type_(data_type),
|
||||||
|
@ -51,7 +53,6 @@ class EmbeddingCache {
|
||||||
default:
|
default:
|
||||||
sizeof_data_type_ = sizeof(float);
|
sizeof_data_type_ = sizeof(float);
|
||||||
}
|
}
|
||||||
device_start_index_ = device_cache_size_ * rank_id_;
|
|
||||||
MS_LOG(INFO) << "rank_group_size_ num:" << rank_group_size_ << ", rank id:" << rank_id_
|
MS_LOG(INFO) << "rank_group_size_ num:" << rank_group_size_ << ", rank id:" << rank_id_
|
||||||
<< ", vocab_size_:" << vocab_size_ << ", host_cache_size_:" << host_cache_size_
|
<< ", vocab_size_:" << vocab_size_ << ", host_cache_size_:" << host_cache_size_
|
||||||
<< ", device_cache_size_:" << device_cache_size_ << ", embedding_size_:" << embedding_size_
|
<< ", device_cache_size_:" << device_cache_size_ << ", embedding_size_:" << embedding_size_
|
||||||
|
|
|
@ -22,9 +22,9 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace cache {
|
namespace cache {
|
||||||
Status EmbeddingCacheManager::Init(const std::string &cache_model_path, size_t vocab_size) {
|
Status EmbeddingCacheManager::Init(const std::string &cache_model_path, size_t vocab_size, size_t device_cache_size) {
|
||||||
if (cache_model_path.empty()) {
|
if (cache_model_path.empty() || vocab_size == 0 || device_cache_size >= vocab_size) {
|
||||||
MS_LOG(INFO) << "no cache model ";
|
MS_LOG(INFO) << "no cache model , vocab_size " << vocab_size << ", device_cache_size " << device_cache_size;
|
||||||
return kSuccess;
|
return kSuccess;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -34,8 +34,39 @@ Status EmbeddingCacheManager::Init(const std::string &cache_model_path, size_t v
|
||||||
return kLiteMemoryFailed;
|
return kLiteMemoryFailed;
|
||||||
}
|
}
|
||||||
auto ret = host_cache_model_->LoadCache(cache_model_path);
|
auto ret = host_cache_model_->LoadCache(cache_model_path);
|
||||||
|
if (ret != kSuccess) {
|
||||||
|
MS_LOG(ERROR) << "load cache failed";
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
vocab_size_ = vocab_size;
|
vocab_size_ = vocab_size;
|
||||||
MS_LOG(INFO) << "cache manager init end, ret " << ret.ToString();
|
device_cache_size_ = device_cache_size;
|
||||||
|
|
||||||
|
MS_LOG(INFO) << "cache manager init succ, cache model" << cache_model_path << " , vocab_size " << vocab_size
|
||||||
|
<< ", device_cache_size " << device_cache_size;
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status EmbeddingCacheManager::Init(DelegateModel<schema::Primitive> *model, size_t vocab_size,
|
||||||
|
size_t device_cache_size) {
|
||||||
|
if (model == nullptr || vocab_size == 0 || device_cache_size >= vocab_size) {
|
||||||
|
MS_LOG(INFO) << "no cache model , vocab_size " << vocab_size << ", device_cache_size " << device_cache_size;
|
||||||
|
return kSuccess;
|
||||||
|
}
|
||||||
|
|
||||||
|
host_cache_model_ = std::make_shared<HostCacheModel>();
|
||||||
|
if (host_cache_model_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "HostCacheModel malloc failed";
|
||||||
|
return kLiteMemoryFailed;
|
||||||
|
}
|
||||||
|
auto ret = host_cache_model_->LoadCache(model);
|
||||||
|
if (ret != kSuccess) {
|
||||||
|
MS_LOG(ERROR) << "load cache failed";
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
vocab_size_ = vocab_size;
|
||||||
|
device_cache_size_ = device_cache_size;
|
||||||
|
|
||||||
|
MS_LOG(INFO) << "cache manager init succ, vocab_size " << vocab_size << ", device_cache_size " << device_cache_size;
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -70,14 +101,17 @@ Status EmbeddingCacheManager::InitCacheKernel(kernel::Kernel *kernel, uint32_t d
|
||||||
return kLiteError;
|
return kLiteError;
|
||||||
}
|
}
|
||||||
size_t vocab_size = vocab_size_;
|
size_t vocab_size = vocab_size_;
|
||||||
size_t host_cache_size = host_cache_tensor.ElementNum();
|
size_t host_cache_size = host_cache_tensor.Shape()[0];
|
||||||
size_t device_cache_size = tensor.Shape()[0];
|
|
||||||
size_t embedding_size_ = tensor.Shape()[1];
|
size_t embedding_size_ = tensor.Shape()[1];
|
||||||
DataType data_type = tensor.DataType();
|
DataType data_type = tensor.DataType();
|
||||||
size_t batch_elements = kernel->inputs()[1].ElementNum();
|
size_t batch_elements = kernel->inputs()[1].ElementNum();
|
||||||
auto cache =
|
size_t device_start_index = device_cache_size_ * rank_id_;
|
||||||
std::make_shared<EmbeddingCache>(vocab_size, host_cache_size, device_cache_size, embedding_size_, batch_elements,
|
if (tensor.Shape()[0] == host_cache_tensor.Shape()[0]) {
|
||||||
data_type, host_cache_tensor.MutableData(), rank_id_, rank_group_size_);
|
device_start_index = rank_id_ * static_cast<int>(std::ceil(static_cast<float>(vocab_size_) / rank_group_size_));
|
||||||
|
}
|
||||||
|
auto cache = std::make_shared<EmbeddingCache>(vocab_size, host_cache_size, device_cache_size_, device_start_index,
|
||||||
|
embedding_size_, batch_elements, data_type,
|
||||||
|
host_cache_tensor.MutableData(), rank_id_, rank_group_size_);
|
||||||
if (cache == nullptr) {
|
if (cache == nullptr) {
|
||||||
MS_LOG(ERROR) << kernel->name() << ": malloc EmbeddingCache failed";
|
MS_LOG(ERROR) << kernel->name() << ": malloc EmbeddingCache failed";
|
||||||
return kLiteError;
|
return kLiteError;
|
||||||
|
@ -107,6 +141,23 @@ bool EmbeddingCacheManager::IsCacheTensor(mindspore::MSTensor tensor) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<int64_t> EmbeddingCacheManager::GetCacheShape(mindspore::MSTensor tensor) {
|
||||||
|
std::vector<int64_t> shape = tensor.Shape();
|
||||||
|
if (shape.size() > 0 && IsCacheTensor(tensor)) {
|
||||||
|
shape[0] = device_cache_size_;
|
||||||
|
}
|
||||||
|
return shape;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t EmbeddingCacheManager::GetCacheDataSize(mindspore::MSTensor tensor) {
|
||||||
|
auto data_size = tensor.DataSize();
|
||||||
|
auto &shape = tensor.Shape();
|
||||||
|
if (shape.size() > 0 && IsCacheTensor(tensor) && shape[0] > 0) {
|
||||||
|
data_size = data_size * device_cache_size_ / shape[0];
|
||||||
|
}
|
||||||
|
return data_size;
|
||||||
|
}
|
||||||
|
|
||||||
Status EmbeddingCacheManager::SetDeviceCacheAddr(const std::string &tensor_name, void *device_mem_addr, size_t size) {
|
Status EmbeddingCacheManager::SetDeviceCacheAddr(const std::string &tensor_name, void *device_mem_addr, size_t size) {
|
||||||
auto cache_iter = caches_.find(tensor_name);
|
auto cache_iter = caches_.find(tensor_name);
|
||||||
if (cache_iter == caches_.end() || cache_iter->second == nullptr) {
|
if (cache_iter == caches_.end() || cache_iter->second == nullptr) {
|
||||||
|
|
|
@ -35,12 +35,15 @@ class EmbeddingCacheManager {
|
||||||
rank_id_ = lite::GetRankID();
|
rank_id_ = lite::GetRankID();
|
||||||
rank_group_size_ = lite::GetGPUGroupSize();
|
rank_group_size_ = lite::GetGPUGroupSize();
|
||||||
}
|
}
|
||||||
Status Init(const std::string &cache_model_path, size_t vocab_size);
|
Status Init(const std::string &cache_model_path, size_t vocab_size, size_t device_cache_size);
|
||||||
|
Status Init(DelegateModel<schema::Primitive> *model, size_t vocab_size, size_t device_cache_size);
|
||||||
bool CheckIsCacheKernel(kernel::Kernel *kernel);
|
bool CheckIsCacheKernel(kernel::Kernel *kernel);
|
||||||
Status InitCacheKernel(kernel::Kernel *kernel, uint32_t device_id, const void *context);
|
Status InitCacheKernel(kernel::Kernel *kernel, uint32_t device_id, const void *context);
|
||||||
bool IsCacheTensor(mindspore::MSTensor tensor);
|
bool IsCacheTensor(mindspore::MSTensor tensor);
|
||||||
int CacheHandle(const std::string &tensor_name, mindspore::MSTensor model_input_tensor, void *device_addr);
|
int CacheHandle(const std::string &tensor_name, mindspore::MSTensor model_input_tensor, void *device_addr);
|
||||||
Status SetDeviceCacheAddr(const std::string &tensor_name, void *device_mem_addr, size_t size);
|
Status SetDeviceCacheAddr(const std::string &tensor_name, void *device_mem_addr, size_t size);
|
||||||
|
std::vector<int64_t> GetCacheShape(mindspore::MSTensor tensor);
|
||||||
|
size_t GetCacheDataSize(mindspore::MSTensor tensor);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::map<std::string, std::shared_ptr<EmbeddingCache>> caches_;
|
std::map<std::string, std::shared_ptr<EmbeddingCache>> caches_;
|
||||||
|
@ -50,6 +53,7 @@ class EmbeddingCacheManager {
|
||||||
|
|
||||||
std::shared_ptr<HostCacheModel> host_cache_model_;
|
std::shared_ptr<HostCacheModel> host_cache_model_;
|
||||||
size_t vocab_size_;
|
size_t vocab_size_;
|
||||||
|
size_t device_cache_size_;
|
||||||
};
|
};
|
||||||
} // namespace cache
|
} // namespace cache
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "src/delegate/parameter_cache/load_host_cache_model.h"
|
#include "src/delegate/parameter_cache/load_host_cache_model.h"
|
||||||
#include "src/common/log_adapter.h"
|
#include "src/common/log_adapter.h"
|
||||||
|
#include "src/common/common.h"
|
||||||
#include "include/errorcode.h"
|
#include "include/errorcode.h"
|
||||||
#include "src/common/file_utils.h"
|
#include "src/common/file_utils.h"
|
||||||
|
|
||||||
|
@ -79,6 +80,45 @@ Status HostCacheModel::LoadCache(const std::string &model_path) {
|
||||||
return kSuccess;
|
return kSuccess;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
size_t GetVocabSize(kernel::Kernel *kernel) {
|
||||||
|
size_t vocab_size = 0;
|
||||||
|
auto cache_config = kernel->GetConfig(lite::kMSCache);
|
||||||
|
auto vocab_size_iter = cache_config.find(lite::kMSCacheVocabSize);
|
||||||
|
if (vocab_size_iter == cache_config.end()) {
|
||||||
|
return vocab_size;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto vocab_size_opt = lite::GenericParseValue<size_t>(vocab_size_iter->second);
|
||||||
|
if (!vocab_size_opt.IsNone()) {
|
||||||
|
vocab_size = vocab_size_opt.Get();
|
||||||
|
}
|
||||||
|
return vocab_size;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status HostCacheModel::LoadCache(DelegateModel<schema::Primitive> *model) {
|
||||||
|
KernelIter from, end;
|
||||||
|
for (KernelIter iter = model->BeginKernelIterator(); iter != model->EndKernelIterator(); iter++) {
|
||||||
|
kernel::Kernel *kernel = *iter;
|
||||||
|
// only support embedding cache
|
||||||
|
if (kernel->type() != schema::PrimitiveType_Gather) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
MS_ASSERT(kernel->inputs.size() == 3);
|
||||||
|
auto tensor = kernel->inputs()[0];
|
||||||
|
if (tensor.Data() == nullptr) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t vocab_size = GetVocabSize(kernel);
|
||||||
|
if (vocab_size == 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
cache_tensor_[tensor.Name()] = tensor;
|
||||||
|
}
|
||||||
|
return mindspore::kSuccess;
|
||||||
|
}
|
||||||
|
|
||||||
bool HostCacheModel::CheckIsCacheKernel(kernel::Kernel *kernel) {
|
bool HostCacheModel::CheckIsCacheKernel(kernel::Kernel *kernel) {
|
||||||
if (GetHostCacheTensor(kernel) == nullptr) {
|
if (GetHostCacheTensor(kernel) == nullptr) {
|
||||||
return false;
|
return false;
|
||||||
|
|
|
@ -23,6 +23,7 @@
|
||||||
#include "include/api/data_type.h"
|
#include "include/api/data_type.h"
|
||||||
#include "include/api/types.h"
|
#include "include/api/types.h"
|
||||||
#include "include/api/kernel.h"
|
#include "include/api/kernel.h"
|
||||||
|
#include "include/api/delegate.h"
|
||||||
#include "src/lite_model.h"
|
#include "src/lite_model.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
@ -31,6 +32,7 @@ class HostCacheModel {
|
||||||
public:
|
public:
|
||||||
HostCacheModel() {}
|
HostCacheModel() {}
|
||||||
Status LoadCache(const std::string &model_path);
|
Status LoadCache(const std::string &model_path);
|
||||||
|
Status LoadCache(DelegateModel<schema::Primitive> *model);
|
||||||
bool CheckIsCacheKernel(kernel::Kernel *kernel);
|
bool CheckIsCacheKernel(kernel::Kernel *kernel);
|
||||||
MSTensor GetHostCacheTensor(kernel::Kernel *kernel);
|
MSTensor GetHostCacheTensor(kernel::Kernel *kernel);
|
||||||
|
|
||||||
|
|
|
@ -146,7 +146,7 @@ Status TensorRTDelegate::Init() {
|
||||||
MS_LOG(ERROR) << "malloc EmbeddingCacheManager failed.";
|
MS_LOG(ERROR) << "malloc EmbeddingCacheManager failed.";
|
||||||
return kLiteMemoryFailed;
|
return kLiteMemoryFailed;
|
||||||
}
|
}
|
||||||
auto cache_ret = cache_mgr_->Init(cache_model_path_, vocab_size_);
|
auto cache_ret = cache_mgr_->Init(cache_model_path_, vocab_size_, device_cache_size_);
|
||||||
if (cache_ret != mindspore::kSuccess) {
|
if (cache_ret != mindspore::kSuccess) {
|
||||||
MS_LOG(ERROR) << "cache_mgr_ init failed.";
|
MS_LOG(ERROR) << "cache_mgr_ init failed.";
|
||||||
return cache_ret;
|
return cache_ret;
|
||||||
|
@ -155,40 +155,37 @@ Status TensorRTDelegate::Init() {
|
||||||
return mindspore::kSuccess;
|
return mindspore::kSuccess;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status TensorRTDelegate::Build(DelegateModel<schema::Primitive> *model) {
|
void TensorRTDelegate::CheckSupportResize(schema::PrimitiveType type) {
|
||||||
int ret = lite::SetCudaDevice(device_info_);
|
if (support_resize_) {
|
||||||
if (ret != RET_OK) {
|
auto opt_iter = unsupport_resize_op_list_.find(type);
|
||||||
return mindspore::kLiteError;
|
if (opt_iter != unsupport_resize_op_list_.end()) {
|
||||||
|
support_resize_ = false;
|
||||||
|
support_hw_resize_ = false;
|
||||||
|
MS_LOG(INFO) << "network has op don't support resize, op: " << type;
|
||||||
|
return;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
if (support_hw_resize_) {
|
||||||
|
auto opt_iter = unsupport_hw_op_lists_.find(type);
|
||||||
|
if (opt_iter != unsupport_hw_op_lists_.end()) {
|
||||||
|
support_hw_resize_ = false;
|
||||||
|
MS_LOG(INFO) << "network has op don't support hw resize, op: " << type;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Status TensorRTDelegate::BuildSubGraph(DelegateModel<schema::Primitive> *model) {
|
||||||
KernelIter from, end;
|
KernelIter from, end;
|
||||||
std::vector<TensorRTOp *> tensorrt_ops;
|
std::vector<TensorRTOp *> tensorrt_ops;
|
||||||
for (KernelIter iter = model->BeginKernelIterator(); iter != model->EndKernelIterator(); iter++) {
|
for (KernelIter iter = model->BeginKernelIterator(); iter != model->EndKernelIterator(); iter++) {
|
||||||
kernel::Kernel *kernel = *iter;
|
kernel::Kernel *kernel = *iter;
|
||||||
if (support_resize_) {
|
CheckSupportResize(model->GetPrimitive(kernel)->value_type());
|
||||||
for (auto no_type : unsupport_resize_op_list_) {
|
|
||||||
if (model->GetPrimitive(kernel)->value_type() == no_type) {
|
|
||||||
support_resize_ = false;
|
|
||||||
support_hw_resize_ = false;
|
|
||||||
MS_LOG(INFO) << "network has op don't support resize.";
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (support_hw_resize_) {
|
|
||||||
for (auto no_type : unsupport_hw_op_lists_) {
|
|
||||||
if (model->GetPrimitive(kernel)->value_type() == no_type) {
|
|
||||||
support_hw_resize_ = false;
|
|
||||||
MS_LOG(INFO) << "network has op don't support hw resize.";
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
auto tensorrt_op = FindTensorRTOp(kernel, model->GetPrimitive(kernel));
|
auto tensorrt_op = FindTensorRTOp(kernel, model->GetPrimitive(kernel));
|
||||||
if (tensorrt_op != nullptr) {
|
if (tensorrt_op != nullptr) {
|
||||||
if (cache_mgr_->CheckIsCacheKernel(kernel)) {
|
if (cache_mgr_->CheckIsCacheKernel(kernel)) {
|
||||||
auto cache_ret = cache_mgr_->InitCacheKernel(kernel, device_info_->GetDeviceID(), &stream_);
|
auto cache_ret = cache_mgr_->InitCacheKernel(kernel, device_info_->GetDeviceID(), &stream_);
|
||||||
if (cache_ret != kSuccess) {
|
if (cache_ret != kSuccess) {
|
||||||
MS_LOG(INFO) << "InitCacheKernel failed " << kernel->name();
|
MS_LOG(ERROR) << "InitCacheKernel failed " << kernel->name();
|
||||||
return cache_ret;
|
return cache_ret;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -215,7 +212,7 @@ Status TensorRTDelegate::Build(DelegateModel<schema::Primitive> *model) {
|
||||||
if (tensorrt_ops.size() > 0) {
|
if (tensorrt_ops.size() > 0) {
|
||||||
auto tensorrt_subgraph = CreateTensorRTGraph(tensorrt_ops, model, from, end);
|
auto tensorrt_subgraph = CreateTensorRTGraph(tensorrt_ops, model, from, end);
|
||||||
if (tensorrt_subgraph == nullptr) {
|
if (tensorrt_subgraph == nullptr) {
|
||||||
MS_LOG(DEBUG) << "Create TensorRT Graph failed.";
|
MS_LOG(ERROR) << "Create TensorRT Graph failed.";
|
||||||
return mindspore::kLiteNullptr;
|
return mindspore::kLiteNullptr;
|
||||||
}
|
}
|
||||||
model->Replace(from, end + 1, tensorrt_subgraph);
|
model->Replace(from, end + 1, tensorrt_subgraph);
|
||||||
|
@ -224,6 +221,28 @@ Status TensorRTDelegate::Build(DelegateModel<schema::Primitive> *model) {
|
||||||
return mindspore::kSuccess;
|
return mindspore::kSuccess;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status TensorRTDelegate::Build(DelegateModel<schema::Primitive> *model) {
|
||||||
|
int ret = lite::SetCudaDevice(device_info_);
|
||||||
|
if (ret != RET_OK) {
|
||||||
|
return mindspore::kLiteError;
|
||||||
|
}
|
||||||
|
if (cache_model_path_.empty() && vocab_size_ > 0) {
|
||||||
|
auto cache_ret = cache_mgr_->Init(model, vocab_size_, device_cache_size_);
|
||||||
|
if (cache_ret != mindspore::kSuccess) {
|
||||||
|
MS_LOG(ERROR) << "cache_mgr_ init failed.";
|
||||||
|
return cache_ret;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto build_ret = BuildSubGraph(model);
|
||||||
|
if (build_ret != kSuccess) {
|
||||||
|
MS_LOG(INFO) << "BuildSubGraph failed";
|
||||||
|
return build_ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
return mindspore::kSuccess;
|
||||||
|
}
|
||||||
|
|
||||||
TensorRTOp *TensorRTDelegate::FindTensorRTOp(kernel::Kernel *kernel, const schema::Primitive *primitive) {
|
TensorRTOp *TensorRTDelegate::FindTensorRTOp(kernel::Kernel *kernel, const schema::Primitive *primitive) {
|
||||||
auto in_tensors = kernel->inputs();
|
auto in_tensors = kernel->inputs();
|
||||||
auto out_tensors = kernel->outputs();
|
auto out_tensors = kernel->outputs();
|
||||||
|
|
|
@ -35,8 +35,12 @@ typedef TensorRTOp *(*TensorRTGetOp)(const schema::Primitive *primitive,
|
||||||
|
|
||||||
class TensorRTDelegate : public Delegate {
|
class TensorRTDelegate : public Delegate {
|
||||||
public:
|
public:
|
||||||
explicit TensorRTDelegate(mindspore::Context *context, const std::string &cache_model_path, size_t vocab_size)
|
explicit TensorRTDelegate(mindspore::Context *context, const std::string &cache_model_path, size_t vocab_size,
|
||||||
: context_(context), cache_model_path_(cache_model_path), vocab_size_(vocab_size) {}
|
size_t device_cache_size)
|
||||||
|
: context_(context),
|
||||||
|
cache_model_path_(cache_model_path),
|
||||||
|
vocab_size_(vocab_size),
|
||||||
|
device_cache_size_(device_cache_size) {}
|
||||||
|
|
||||||
~TensorRTDelegate() override;
|
~TensorRTDelegate() override;
|
||||||
|
|
||||||
|
@ -45,6 +49,8 @@ class TensorRTDelegate : public Delegate {
|
||||||
Status Build(DelegateModel<schema::Primitive> *model) override;
|
Status Build(DelegateModel<schema::Primitive> *model) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
Status BuildSubGraph(DelegateModel<schema::Primitive> *model);
|
||||||
|
void CheckSupportResize(schema::PrimitiveType type);
|
||||||
TensorRTOp *FindTensorRTOp(kernel::Kernel *kernel, const schema::Primitive *primitive);
|
TensorRTOp *FindTensorRTOp(kernel::Kernel *kernel, const schema::Primitive *primitive);
|
||||||
|
|
||||||
TensorRTSubGraph *CreateTensorRTGraph(const std::vector<TensorRTOp *> &ops, DelegateModel<schema::Primitive> *model,
|
TensorRTSubGraph *CreateTensorRTGraph(const std::vector<TensorRTOp *> &ops, DelegateModel<schema::Primitive> *model,
|
||||||
|
@ -67,6 +73,7 @@ class TensorRTDelegate : public Delegate {
|
||||||
bool support_resize_{true};
|
bool support_resize_{true};
|
||||||
const std::string cache_model_path_;
|
const std::string cache_model_path_;
|
||||||
size_t vocab_size_;
|
size_t vocab_size_;
|
||||||
|
size_t device_cache_size_;
|
||||||
std::shared_ptr<cache::EmbeddingCacheManager> cache_mgr_{nullptr};
|
std::shared_ptr<cache::EmbeddingCacheManager> cache_mgr_{nullptr};
|
||||||
|
|
||||||
cudaStream_t stream_;
|
cudaStream_t stream_;
|
||||||
|
|
|
@ -359,8 +359,8 @@ int TensorRTSubGraph::Prepare() {
|
||||||
|
|
||||||
// malloc for cache weight tensor
|
// malloc for cache weight tensor
|
||||||
for (auto cache_tensor : cache_const_inputs_) {
|
for (auto cache_tensor : cache_const_inputs_) {
|
||||||
auto data_size = cache_tensor.DataSize();
|
size_t data_size = cache_mgr_->GetCacheDataSize(cache_tensor);
|
||||||
auto device_ptr = runtime_->GetAllocator()->MallocDeviceMem(cache_tensor, cache_tensor.DataSize());
|
auto device_ptr = runtime_->GetAllocator()->MallocDeviceMem(cache_tensor, data_size);
|
||||||
runtime_->GetAllocator()->SyncMemInHostAndDevice(cache_tensor, cache_tensor.Name().c_str(), true);
|
runtime_->GetAllocator()->SyncMemInHostAndDevice(cache_tensor, cache_tensor.Name().c_str(), true);
|
||||||
runtime_->GetAllocator()->MarkMemValid(cache_tensor.Name().c_str(), true);
|
runtime_->GetAllocator()->MarkMemValid(cache_tensor.Name().c_str(), true);
|
||||||
int index = this->engine_->getBindingIndex(cache_tensor.Name().c_str());
|
int index = this->engine_->getBindingIndex(cache_tensor.Name().c_str());
|
||||||
|
@ -596,9 +596,10 @@ int TensorRTSubGraph::HandleCacheTensor(TensorRTOp *cur_op, const mindspore::MST
|
||||||
FindCacheTensorInfo(cur_op, in_tensor);
|
FindCacheTensorInfo(cur_op, in_tensor);
|
||||||
// cache kernel weight tensor
|
// cache kernel weight tensor
|
||||||
cache_const_inputs_.push_back(in_tensor);
|
cache_const_inputs_.push_back(in_tensor);
|
||||||
|
auto shape = cache_mgr_->GetCacheShape(in_tensor);
|
||||||
MS_LOG(INFO) << "auto add cache constant tensor for: " << in_tensor.Name();
|
MS_LOG(INFO) << "auto add cache constant tensor for: " << in_tensor.Name();
|
||||||
auto cuda_dtype = ConvertDataType(in_tensor.DataType());
|
auto cuda_dtype = ConvertDataType(in_tensor.DataType());
|
||||||
nvinfer1::Dims input_dims = ConvertCudaDims(in_tensor.Shape());
|
nvinfer1::Dims input_dims = ConvertCudaDims(shape);
|
||||||
nvinfer1::ITensor *cache_input = network_->addInput(in_tensor.Name().c_str(), cuda_dtype, input_dims);
|
nvinfer1::ITensor *cache_input = network_->addInput(in_tensor.Name().c_str(), cuda_dtype, input_dims);
|
||||||
if (cache_input == nullptr) {
|
if (cache_input == nullptr) {
|
||||||
MS_LOG(ERROR) << "add cache Weight Tensor data is nullptr.";
|
MS_LOG(ERROR) << "add cache Weight Tensor data is nullptr.";
|
||||||
|
|
|
@ -742,6 +742,7 @@ int LiteSession::CreateTensorRTDelegate() {
|
||||||
#if GPU_TENSORRT
|
#if GPU_TENSORRT
|
||||||
std::string cache_model_path;
|
std::string cache_model_path;
|
||||||
size_t vocab_size = 0;
|
size_t vocab_size = 0;
|
||||||
|
size_t device_cache_size = 0;
|
||||||
if (config_info_ != nullptr) {
|
if (config_info_ != nullptr) {
|
||||||
auto ms_cache_iter = config_info_->find(kMSCache);
|
auto ms_cache_iter = config_info_->find(kMSCache);
|
||||||
if (ms_cache_iter != config_info_->end()) {
|
if (ms_cache_iter != config_info_->end()) {
|
||||||
|
@ -758,10 +759,18 @@ int LiteSession::CreateTensorRTDelegate() {
|
||||||
vocab_size = vocab_size_opt.Get();
|
vocab_size = vocab_size_opt.Get();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto device_cache_size_iter = ms_cache.find(kMSCacheDeviceSize);
|
||||||
|
if (device_cache_size_iter != ms_cache.end()) {
|
||||||
|
auto device_cache_size_opt = GenericParseValue<size_t>(device_cache_size_iter->second);
|
||||||
|
if (!device_cache_size_opt.IsNone()) {
|
||||||
|
device_cache_size = device_cache_size_opt.Get();
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
delegate_ = std::make_shared<TensorRTDelegate>(ms_context_, cache_model_path, vocab_size);
|
delegate_ = std::make_shared<TensorRTDelegate>(ms_context_, cache_model_path, vocab_size, device_cache_size);
|
||||||
if (delegate_ == nullptr) {
|
if (delegate_ == nullptr) {
|
||||||
MS_LOG(ERROR) << "New tensorrt delegate_ failed";
|
MS_LOG(ERROR) << "New tensorrt delegate_ failed";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
|
|
Loading…
Reference in New Issue