!26736 [MS][LITE]support parameter cache and distribution predict

Merge pull request !26736 from 张学同/distribution_cache_dev
This commit is contained in:
i-robot 2021-11-25 03:35:31 +00:00 committed by Gitee
commit d0792507bd
47 changed files with 1365 additions and 102 deletions

View File

@ -18,7 +18,6 @@
#define MINDSPORE_NNACL_ALL_GATHER_PARAMETER_H_
#include "nnacl/op_base.h"
#include "nnacl/communication_func.h"
typedef struct AllGatherParameter {
// primitive parameter
@ -26,6 +25,6 @@ typedef struct AllGatherParameter {
char group_[DEFAULT_GROUP_NAME_LEN];
// other parameter
int rank_;
int rank_size_;
} AllGatherParameter;
#endif // MINDSPORE_NNACL_ALL_GATHER_PARAMETER_H_

View File

@ -1,33 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_NNACL_COMMUNICATION_FUNC_H_
#define MINDSPORE_NNACL_COMMUNICATION_FUNC_H_
#ifdef __cplusplus
extern "C" {
#endif
#define DEFAULT_GROUP_NAME_LEN 101
#define DEFAULT_GROUP_SIZE 2
static inline int get_rank(char *group) { return DEFAULT_GROUP_SIZE; }
#ifdef __cplusplus
}
#endif
#endif // MINDSPORE_NNACL_COMMUNICATION_FUNC_H_

View File

@ -18,7 +18,6 @@
#include "nnacl/infer/infer_register.h"
#include "nnacl/infer/common_infer.h"
#include "nnacl/all_gather_parameter.h"
#include "nnacl/communication_func.h"
int AllGatherInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
OpParameter *parameter) {
@ -34,8 +33,7 @@ int AllGatherInferShape(const TensorC *const *inputs, size_t inputs_size, Tensor
}
AllGatherParameter *param = (AllGatherParameter *)parameter;
param->rank_ = get_rank(param->group_);
if (param->rank_ <= 0) {
if (param->rank_size_ <= 0) {
return NNACL_INFER_INVALID;
}
@ -45,7 +43,7 @@ int AllGatherInferShape(const TensorC *const *inputs, size_t inputs_size, Tensor
int out_shape[MAX_SHAPE_SIZE];
size_t out_shape_size = 0;
out_shape[0] = in_shape[0] * param->rank_;
out_shape[0] = in_shape[0] * param->rank_size_;
out_shape_size++;
for (int i = 1; i < input_tensor->shape_size_; i++) {
out_shape[i] = in_shape[i];

View File

@ -13,11 +13,9 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "nnacl/errorcode.h"
#include "nnacl/infer/infer_register.h"
#include "nnacl/infer/common_infer.h"
#include "nnacl/communication_func.h"
#include "nnacl/reduce_scatter_parameter.h"
int ReduceScatterInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
@ -31,8 +29,7 @@ int ReduceScatterInferShape(const TensorC *const *inputs, size_t inputs_size, Te
}
ReduceScatterParameter *param = (ReduceScatterParameter *)parameter;
param->rank_ = get_rank(param->group_);
if (param->rank_ <= 0) {
if (param->rank_size_ <= 0) {
return NNACL_INFER_INVALID;
}
@ -40,19 +37,20 @@ int ReduceScatterInferShape(const TensorC *const *inputs, size_t inputs_size, Te
const int *in_shape = input_tensor->shape_;
TensorC *out_tensor = outputs[0];
if (in_shape[0] % param->rank_ != 0) {
if (in_shape[0] % param->rank_size_ != 0) {
return NNACL_INFER_INVALID;
}
int out_shape[MAX_SHAPE_SIZE];
size_t out_shape_size = 0;
out_shape[0] = in_shape[0] / param->rank_;
out_shape[0] = in_shape[0] / param->rank_size_;
out_shape_size++;
for (int i = 1; i < input_tensor->shape_size_; i++) {
out_shape[i] = in_shape[i];
out_shape_size++;
}
SetShapeArray(out_tensor, out_shape, out_shape_size);
return NNACL_OK;
}

View File

@ -110,6 +110,7 @@
#define kDefaulLiteMaxSpinCount 300000
#define kDefaulLiteMinSpinCount 1
#define kDefaulLiteIosSpinCount 1
#define DEFAULT_GROUP_NAME_LEN 101
#if ENABLE_HIGH_PERFORMANCE
#define MS_CHECK_TRUE_RET(value, errcode)

View File

@ -18,7 +18,6 @@
#define MINDSPORE_NNACL_REDUCE_SCATTER_PARAMETER_H_
#include "nnacl/op_base.h"
#include "nnacl/communication_func.h"
typedef struct ReduceScatterParameter {
// primitive parameter
@ -27,6 +26,6 @@ typedef struct ReduceScatterParameter {
int mode_;
// other parameter
int rank_;
int rank_size_;
} ReduceScatterParameter;
#endif // MINDSPORE_NNACL_REDUCE_SCATTER_PARAMETER_H_

View File

@ -137,6 +137,10 @@ void *AscendPsCache::MallocMemory(size_t size) {
return device_addr;
}
void AscendPsCache::FreeMemory(void *device_addr) {
device::ascend::AscendMemoryPool::GetInstance().FreeTensorMem(device_addr);
}
bool AscendPsCache::MallocConstantMemory(size_t cache_vocab_size) {
offset_addr_ = reinterpret_cast<int *>(device::ascend::AscendMemoryPool::GetInstance().AllocTensorMem(sizeof(int)));
if (offset_addr_ == nullptr) {

View File

@ -52,6 +52,7 @@ class AscendPsCache : public PsCacheBasic {
~AscendPsCache() override = default;
bool InitDevice(uint32_t device_id, const void *context) override;
void *MallocMemory(size_t size) override;
void FreeMemory(void *device_addr) override;
bool MallocConstantMemory(size_t cache_vocab_size) override;
bool RecordEvent() override;
bool SynchronizeEvent() override;

View File

@ -41,6 +41,10 @@ void *GPUPsCache::MallocMemory(size_t size) {
return device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(size);
}
void GPUPsCache::FreeMemory(void *device_addr) {
device::gpu::GPUMemoryAllocator::GetInstance().FreeTensorMem(device_addr);
}
bool GPUPsCache::RecordEvent() {
event_.reset(new cudaEvent_t());
MS_ERROR_IF_NULL_W_RET_VAL(event_, false);

View File

@ -30,6 +30,7 @@ class GPUPsCache : public PsCacheBasic {
~GPUPsCache() override = default;
bool InitDevice(uint32_t device_id, const void *context) override;
void *MallocMemory(size_t size) override;
void FreeMemory(void *device_addr) override;
bool RecordEvent() override;
bool SynchronizeEvent() override;
bool SynchronizeStream() override;

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -19,25 +19,8 @@
#include <utility>
#include <memory>
#include "utils/log_adapter.h"
namespace mindspore {
namespace ps {
#define RETURN_IF_FALSE(condition) \
do { \
if (!(condition)) { \
return false; \
} \
} while (false)
#define RETURN_IF_FALSE_WITH_LOG(condition, message) \
do { \
if (!(condition)) { \
MS_LOG(ERROR) << message; \
return false; \
} \
} while (false)
class PsCacheBasic {
public:
PsCacheBasic() = default;
@ -45,6 +28,7 @@ class PsCacheBasic {
virtual bool InitDevice(uint32_t device_id, const void *context) = 0;
virtual void *MallocMemory(size_t size) = 0;
virtual bool MallocConstantMemory(size_t cache_vocab_size) { return true; }
virtual void FreeMemory(void *buf) = 0;
virtual bool RecordEvent() = 0;
virtual bool SynchronizeEvent() = 0;
virtual bool SynchronizeStream() = 0;

View File

@ -28,6 +28,15 @@ std::string AllGather::get_group() const {
auto value_ptr = GetAttr(kGroup);
return GetValue<std::string>(value_ptr);
}
void AllGather::set_rank_size(int rank_size) {
(void)this->AddAttr(kRankSize, MakeValue(static_cast<int64_t>(rank_size)));
}
int AllGather::get_rank_size() const {
auto value_ptr = GetAttr(kRankSize);
return static_cast<int>(GetValue<int64_t>(value_ptr));
}
REGISTER_PRIMITIVE_C(kNameAllGather, AllGather);
} // namespace ops
} // namespace mindspore

View File

@ -35,6 +35,8 @@ class MS_CORE_API AllGather : public PrimitiveC {
void Init() {}
void set_group(const std::string &format);
std::string get_group() const;
void set_rank_size(int rank_size);
int get_rank_size() const;
};
} // namespace ops
} // namespace mindspore

View File

@ -166,6 +166,7 @@ constexpr auto kDivisorOverride = "divisor_override";
constexpr auto kPostNmsTopn = "post_nms_topn";
constexpr auto kPower = "power";
constexpr auto kPreNmsTopn = "pre_nms_topn";
constexpr auto kRankSize = "rank_size";
constexpr auto kRatio = "ratio";
constexpr auto kReduction = "reduction";
constexpr auto kRootRank = "root_rank";

View File

@ -39,6 +39,14 @@ ReduceMode ReduceScatter::get_mode() const {
return ReduceMode(GetValue<int64_t>(value_ptr));
}
void ReduceScatter::set_rank_size(int rank_size) {
(void)this->AddAttr(kRankSize, MakeValue(static_cast<int64_t>(rank_size)));
}
int ReduceScatter::get_rank_size() const {
auto value_ptr = GetAttr(kRankSize);
return static_cast<int>(GetValue<int64_t>(value_ptr));
}
REGISTER_PRIMITIVE_C(kNameReduceScatter, ReduceScatter);
} // namespace ops
} // namespace mindspore

View File

@ -37,6 +37,8 @@ class MS_CORE_API ReduceScatter : public PrimitiveC {
std::string get_group() const;
void set_mode(const ReduceMode &mode);
ReduceMode get_mode() const;
void set_rank_size(int rank_size);
int get_rank_size() const;
};
} // namespace ops
} // namespace mindspore

View File

@ -291,6 +291,21 @@ class LogWriter {
} \
} while (0)
#define RETURN_IF_FALSE(condition) \
do { \
if (!(condition)) { \
return false; \
} \
} while (false)
#define RETURN_IF_FALSE_WITH_LOG(condition, message) \
do { \
if (!(condition)) { \
MS_LOG(ERROR) << message; \
return false; \
} \
} while (false)
#ifdef DEBUG
#include <cassert>
#define MS_ASSERT(f) assert(f)

View File

@ -1207,9 +1207,11 @@ table ScatterNdUpdate {
table AllGather {
group: string;
rank_size: int;
}
table ReduceScatter {
group: string;
mode: ReduceMode;
rank_size: int;
}

View File

@ -30,7 +30,7 @@ std::vector<int32_t> TruncateShape(const std::vector<int64_t> &shape, enum TypeI
size_t element_size = lite::DataTypeSize(type);
for (size_t i = 0; i < shape.size(); i++) {
auto dim = shape[i];
if (dim < 0 || dim > INT_MAX || (dim != 0 && element_size > INT_MAX / static_cast<size_t>(dim))) {
if (dim < 0 || dim > INT_MAX) {
MS_LOG(ERROR) << "Invalid shape.";
return empty;
} else {

View File

@ -0,0 +1,169 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/delegate/parameter_cache/embedding_cache.h"
#include <cuda_runtime.h>
#include <memory>
#include <vector>
#include <cmath>
#include <cstring>
#include "src/delegate/parameter_cache/gpu/gpu_cache_mem.h"
#include "src/common/log_adapter.h"
#include "include/errorcode.h"
namespace mindspore {
namespace cache {
void LookUpTableTask(size_t indices_lens, size_t first_dim_size, const char *input_addr, const int *indices_addr,
char *output_addr, size_t embedding_len) {
for (size_t i = 0; i < indices_lens; ++i) {
int index = indices_addr[i];
if (index >= 0 && index < static_cast<int>(first_dim_size)) {
size_t pos = index * embedding_len;
std::memcpy(output_addr, input_addr + pos, embedding_len);
} else {
memset(output_addr, 0, embedding_len);
}
output_addr += embedding_len;
}
}
EmbeddingCache::~EmbeddingCache() {
if (hash_swap_value_device_addr_ != nullptr) {
device_cache_->FreeMemory(hash_swap_value_device_addr_);
hash_swap_value_device_addr_ = nullptr;
}
if (hash_swap_value_addr_ != nullptr) {
free(hash_swap_value_addr_);
hash_swap_value_addr_ = nullptr;
}
if (hash_swap_index_addr_ != nullptr) {
device_cache_->FreeMemory(hash_swap_index_addr_);
hash_swap_index_addr_ = nullptr;
}
}
Status EmbeddingCache::Init() {
cache_ = std::make_shared<LFUCacheAlgorithm>(device_cache_size_, mix_host_index_, max_host_index_);
if (cache_ == nullptr) {
MS_LOG(ERROR) << "malloc LFUCacheAlgorithm failed";
return kLiteMemoryFailed;
}
device_cache_ = std::make_shared<gpu::GPUCacheMem>();
if (device_cache_ == nullptr) {
MS_LOG(ERROR) << "get cache failed";
return kLiteMemoryFailed;
}
auto hash_swap_value_size = embedding_size_ * batch_elements_ * sizeof_data_type_;
hash_swap_value_device_addr_ = device_cache_->MallocMemory(hash_swap_value_size);
if (hash_swap_value_device_addr_ == nullptr) {
MS_LOG(ERROR) << "malloc hash_swap_value_device failed, malloc size " << hash_swap_value_size;
return kLiteMemoryFailed;
}
hash_swap_value_addr_ = malloc(hash_swap_value_size);
if (hash_swap_value_addr_ == nullptr) {
MS_LOG(ERROR) << "malloc hash_swap_value failed, malloc size " << hash_swap_value_size;
return kLiteMemoryFailed;
}
// data type of index
hash_swap_index_addr_ = static_cast<int *>(device_cache_->MallocMemory(batch_elements_ * sizeof(int)));
if (hash_swap_index_addr_ == nullptr) {
MS_LOG(ERROR) << "malloc hash_swap_index failed, malloc size " << batch_elements_ * sizeof(int);
return kLiteMemoryFailed;
}
MS_LOG(INFO) << "init succ, rank_group_size_ num:" << rank_group_size_ << ", rank id:" << rank_id_
<< ", index begin:" << mix_host_index_ << ", index end:" << max_host_index_;
return kSuccess;
}
Status EmbeddingCache::SetHostCacheAddr(void *addr, size_t size) {
if (sizeof_data_type_ * host_cache_size_ * embedding_size_ != size) {
return kLiteParamInvalid;
}
host_addr_ = addr;
// copy part of host mem to device
auto ret =
device_cache_->CopyHostMemToDevice(device_addr_, addr, sizeof_data_type_ * device_cache_size_ * embedding_size_);
if (!ret) {
MS_LOG(ERROR) << "CopyHostMemToDevice failed, copy size "
<< sizeof_data_type_ * device_cache_size_ * embedding_size_;
return kLiteMemoryFailed;
}
// init cache
auto index_num = device_cache_size_;
for (size_t i = 0; i < index_num; i++) {
cache_->Put(mix_host_index_ + i, i);
}
return kSuccess;
}
Status EmbeddingCache::SetDeviceCacheAddr(void *device_mem_addr, size_t size) {
if (sizeof_data_type_ * device_cache_size_ * embedding_size_ != size) {
return kLiteParamInvalid;
}
device_addr_ = device_mem_addr;
SetHostCacheAddr(host_addr_, sizeof_data_type_ * host_cache_size_ * embedding_size_);
return kSuccess;
}
Status EmbeddingCache::CheckCacheHit(const int *batch_ids, const size_t batch_ids_len, int *cache_index) {
std::vector<int> need_swap_indies;
std::vector<int> need_swap_indies_cache_index;
auto ret =
cache_->CheckCacheHit(batch_ids, batch_ids_len, cache_index, &need_swap_indies, &need_swap_indies_cache_index);
if (ret != kSuccess) {
MS_LOG(ERROR) << "CheckCacheHit failed";
return ret;
}
auto swap_indices_size = need_swap_indies.size();
if (swap_indices_size > 0) {
LookUpTableTask(swap_indices_size, host_cache_size_, static_cast<char *>(host_addr_), need_swap_indies.data(),
static_cast<char *>(hash_swap_value_addr_), embedding_size_ * sizeof_data_type_);
// copy data
auto device_cache_ret = device_cache_->CopyHostMemToDevice(hash_swap_value_device_addr_, hash_swap_value_addr_,
swap_indices_size * embedding_size_ * sizeof_data_type_);
if (!device_cache_ret) {
MS_LOG(ERROR) << "copy swap value to device failed";
return kLiteMemoryFailed;
}
device_cache_ret = device_cache_->CopyHostMemToDevice(hash_swap_index_addr_, need_swap_indies_cache_index.data(),
swap_indices_size * sizeof(int));
if (!device_cache_ret) {
MS_LOG(ERROR) << "copy swap indies to device failed";
return kLiteMemoryFailed;
}
device_cache_ret = device_cache_->HashSwapIn(device_addr_, hash_swap_value_device_addr_, hash_swap_index_addr_,
device_cache_size_, embedding_size_, swap_indices_size);
if (!device_cache_ret) {
MS_LOG(ERROR) << "HashSwapIn failed";
return kLiteMemoryFailed;
}
}
return kSuccess;
}
} // namespace cache
} // namespace mindspore

View File

@ -0,0 +1,95 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_EMBEDDING_CACHE_H_
#define MINDSPORE_LITE_EMBEDDING_CACHE_H_
#include <cmath>
#include <algorithm>
#include <memory>
#include "src/delegate/parameter_cache/lfu_cache.h"
#include "ps/ps_cache/ps_cache_basic.h"
#include "include/api/status.h"
#include "include/api/data_type.h"
#include "src/common/log_adapter.h"
namespace mindspore {
namespace cache {
class EmbeddingCache {
public:
EmbeddingCache(size_t vocab_size, size_t host_cache_size, size_t device_cache_size, size_t embedding_size,
size_t batch_elements, DataType data_type, void *host_addr, int rank_id, int rank_group_size)
: vocab_size_(vocab_size),
host_cache_size_(host_cache_size),
device_cache_size_(device_cache_size),
embedding_size_(embedding_size),
batch_elements_(batch_elements),
data_type_(data_type),
host_addr_(host_addr),
rank_id_(rank_id),
rank_group_size_(rank_group_size) {
auto local_shard_size = static_cast<int>(std::ceil(static_cast<float>(vocab_size_) / rank_group_size_));
mix_host_index_ = local_shard_size * rank_id_;
max_host_index_ = std::min(mix_host_index_ + local_shard_size, static_cast<int>(vocab_size_));
host_cache_size_ = max_host_index_ - mix_host_index_;
switch (data_type_) {
case DataType::kNumberTypeFloat16:
sizeof_data_type_ = sizeof(int16_t);
break;
default:
sizeof_data_type_ = sizeof(float);
}
host_addr_ = static_cast<char *>(host_addr) + mix_host_index_ * embedding_size_ * sizeof_data_type_;
MS_LOG(INFO) << "rank_group_size_ num:" << rank_group_size_ << ", rank id:" << rank_id_
<< ", vocab_size_:" << vocab_size_ << ", host_cache_size_:" << host_cache_size_
<< ", device_cache_size_:" << device_cache_size_ << ", embedding_size_:" << embedding_size_
<< ", batch_elements_:" << batch_elements_ << ", index begin:" << mix_host_index_
<< ", index end:" << max_host_index_;
}
~EmbeddingCache();
Status Init();
Status SetHostCacheAddr(void *addr, size_t size);
Status SetDeviceCacheAddr(void *host_mem_addr, size_t size);
Status CheckCacheHit(const int *batch_ids, const size_t batch_ids_len, int *hash_index);
private:
std::shared_ptr<ps::PsCacheBasic> device_cache_{nullptr};
std::shared_ptr<CacheAlgorithm> cache_{nullptr};
size_t vocab_size_{0}; // total size
size_t host_cache_size_{0}; // local host size
size_t device_cache_size_{0}; // local device cache size
size_t embedding_size_{0};
size_t batch_elements_{0};
DataType data_type_{DataType::kNumberTypeFloat32};
size_t sizeof_data_type_{0};
void *device_addr_{nullptr}; // hash_info.device_address.addr
void *host_addr_{nullptr};
int *hash_swap_index_addr_; // embedding_device_cache_->hash_swap_index_addr_
void *hash_swap_value_addr_;
void *hash_swap_value_device_addr_;
int rank_id_;
int rank_group_size_;
int mix_host_index_{0};
int max_host_index_{0};
};
} // namespace cache
} // namespace mindspore
#endif // MINDSPORE_LITE_EMBEDDING_CACHE_H_

View File

@ -0,0 +1,146 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/delegate/parameter_cache/embedding_cache_manager.h"
#include <cuda_runtime.h>
#include <cmath>
#include <cstring>
#include "src/common/log_adapter.h"
#include "include/errorcode.h"
namespace mindspore {
namespace cache {
Status EmbeddingCacheManager::Init(const std::string &cache_model_path) {
if (cache_model_path.empty()) {
MS_LOG(INFO) << "no cache model ";
return kSuccess;
}
host_cache_model_ = std::make_shared<HostCacheModel>();
if (host_cache_model_ == nullptr) {
MS_LOG(ERROR) << "HostCacheModel malloc failed";
return kLiteMemoryFailed;
}
auto ret = host_cache_model_->LoadCache(cache_model_path);
MS_LOG(INFO) << "cache manager init end, ret " << ret.ToString();
return ret;
}
bool EmbeddingCacheManager::CheckIsCacheKernel(kernel::Kernel *kernel) {
if (host_cache_model_ == nullptr) {
return false;
}
return host_cache_model_->CheckIsCacheKernel(kernel);
}
Status EmbeddingCacheManager::InitCacheKernel(kernel::Kernel *kernel) {
if (host_cache_model_ == nullptr) {
MS_LOG(ERROR) << "cache model is nullptr, kernel " << kernel->name() << " init cache failed";
return kLiteError;
}
auto host_cache_tensor = host_cache_model_->GetHostCacheTensor(kernel);
if (host_cache_tensor == nullptr) {
MS_LOG(ERROR) << kernel->name() << ": invalid cache kernel";
return kLiteError;
}
// only support embedding cache
if (kernel->type() != schema::PrimitiveType_Gather) {
MS_LOG(ERROR) << kernel->name() << " is not embedding kernel";
return kLiteError;
}
auto tensor = kernel->inputs()[0];
if (tensor.Shape()[1] != host_cache_tensor.Shape()[1]) {
MS_LOG(ERROR) << kernel->name() << " embedding_size is invalid, device size is " << tensor.Shape()[1]
<< " host size is " << host_cache_tensor.Shape()[1];
return kLiteError;
}
size_t vocab_size = host_cache_tensor.Shape()[0]; // 需要跟host_cache_size一起处理下
size_t host_cache_size = vocab_size / rank_group_size_;
size_t device_cache_size = tensor.Shape()[0];
size_t embedding_size_ = tensor.Shape()[1];
DataType data_type = tensor.DataType();
size_t batch_elements = kernel->inputs()[1].ElementNum();
auto cache =
std::make_shared<EmbeddingCache>(vocab_size, host_cache_size, device_cache_size, embedding_size_, batch_elements,
data_type, host_cache_tensor.MutableData(), rank_id_, rank_group_size_);
if (cache == nullptr) {
MS_LOG(ERROR) << kernel->name() << ": malloc EmbeddingCache failed";
return kLiteError;
}
auto ret = cache->Init();
if (ret != kSuccess) {
MS_LOG(ERROR) << kernel->name() << ": EmbeddingCache init failed";
return kLiteError;
}
caches_[tensor.Name()] = cache;
MS_LOG(INFO) << kernel->name() << " is cache kernel, input tensor " << kernel->inputs()[1].Name() << ", cache tensor "
<< tensor.Name();
return kSuccess;
}
bool EmbeddingCacheManager::IsCacheTensor(mindspore::MSTensor tensor) {
if (host_cache_model_ == nullptr) {
return false;
}
auto cache = caches_.find(tensor.Name());
if (cache != caches_.end()) {
return true;
}
return false;
}
Status EmbeddingCacheManager::SetDeviceCacheAddr(const std::string &tensor_name, void *device_mem_addr, size_t size) {
auto cache_iter = caches_.find(tensor_name);
if (cache_iter == caches_.end() || cache_iter->second == nullptr) {
MS_LOG(ERROR) << "not find cache, " << tensor_name;
return kLiteError;
}
auto cache = cache_iter->second;
return cache->SetDeviceCacheAddr(device_mem_addr, size);
}
// device_addr is model input device addr
int EmbeddingCacheManager::CacheHandle(const std::string &tensor_name, mindspore::MSTensor model_input_tensor,
void *model_input_device_addr) {
auto cache_iter = caches_.find(tensor_name);
if (cache_iter == caches_.end()) {
MS_LOG(ERROR) << "not find cache, " << tensor_name;
return lite::RET_ERROR;
}
auto cache = cache_iter->second;
hash_indices_.resize(model_input_tensor.ElementNum());
auto ret = cache->CheckCacheHit(static_cast<int *>(model_input_tensor.MutableData()), hash_indices_.size(),
hash_indices_.data());
if (ret != kSuccess) {
MS_LOG(ERROR) << "CheckCacheHit failed, " << model_input_tensor.Name();
return lite::RET_ERROR;
}
auto cuda_ret = cudaMemcpy(model_input_device_addr, model_input_tensor.MutableData(),
hash_indices_.size() * sizeof(int), cudaMemcpyHostToDevice);
if (cuda_ret != cudaSuccess) {
MS_LOG(ERROR) << "copy mem failed, " << model_input_tensor.Name();
return lite::RET_ERROR;
}
MS_LOG(INFO) << "cache handle succ, " << model_input_tensor.Name() << "," << tensor_name;
return lite::RET_OK;
}
} // namespace cache
} // namespace mindspore

View File

@ -0,0 +1,57 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_EMBEDDING_CACHE_MANAGER_H_
#define MINDSPORE_LITE_EMBEDDING_CACHE_MANAGER_H_
#include <memory>
#include <map>
#include <string>
#include <vector>
#include "include/api/kernel.h"
#include "src/delegate/parameter_cache/embedding_cache.h"
#include "src/delegate/parameter_cache/lfu_cache.h"
#include "ps/ps_cache/ps_cache_basic.h"
#include "src/delegate/parameter_cache/load_host_cache_model.h"
#include "include/api/status.h"
#include "include/api/data_type.h"
#include "src/delegate/tensorrt/distribution/distribution_base.h"
namespace mindspore {
namespace cache {
class EmbeddingCacheManager {
public:
EmbeddingCacheManager() {
rank_id_ = lite::GetRankID();
rank_group_size_ = lite::GetGPUGroupSize();
}
Status Init(const std::string &cache_model_path);
bool CheckIsCacheKernel(kernel::Kernel *kernel);
Status InitCacheKernel(kernel::Kernel *kernel);
bool IsCacheTensor(mindspore::MSTensor tensor);
int CacheHandle(const std::string &tensor_name, mindspore::MSTensor model_input_tensor, void *device_addr);
Status SetDeviceCacheAddr(const std::string &tensor_name, void *device_mem_addr, size_t size);
private:
std::map<std::string, std::shared_ptr<EmbeddingCache>> caches_;
std::vector<int> hash_indices_;
int rank_id_{0};
int rank_group_size_{1};
std::shared_ptr<HostCacheModel> host_cache_model_;
};
} // namespace cache
} // namespace mindspore
#endif // MINDSPORE_LITE_EMBEDDING_CACHE_MANAGER_H_

View File

@ -0,0 +1,143 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/delegate/parameter_cache/gpu/gpu_cache_mem.h"
#include <cuda_runtime.h>
#include "src/delegate/tensorrt/cuda_impl/hash.cuh"
#include "runtime/device/gpu/cuda_driver.h"
#include "src/common/log_adapter.h"
namespace mindspore {
namespace gpu {
#define CHECK_CUDA_RET_WITH_RETURN_ERROR_NOTRACE(expression, message) \
do { \
cudaError_t status = (expression); \
if (status != cudaSuccess) { \
MS_LOG(ERROR) << "CUDA Error: " << message << " | Error Number: " << status << " " \
<< cudaGetErrorString(status); \
return false; \
} \
} while (0)
#define MS_ERROR_IF_NULL_W_RET_VAL(ptr, val) \
do { \
if ((ptr) == nullptr) { \
MS_LOG(ERROR) << ": The pointer[" << #ptr << "] is null."; \
return val; \
} \
} while (0)
#define MS_ERROR_IF_NULL(ptr) \
do { \
if ((ptr) == nullptr) { \
MS_LOG(ERROR) << ": The pointer[" << #ptr << "] is null."; \
return false; \
} \
} while (0)
bool GPUCacheMem::InitDevice(uint32_t device_id, const void *) {
auto ret = cudaSetDevice(static_cast<int>(device_id));
if (ret != cudaSuccess) {
MS_LOG(ERROR) << "Failed to set device id:" << device_id;
return false;
}
CHECK_CUDA_RET_WITH_RETURN_ERROR_NOTRACE(cudaStreamCreate(reinterpret_cast<CUstream_st **>(&stream_)),
"Cuda create stream failed");
return true;
}
void *GPUCacheMem::MallocMemory(size_t size) {
void *device_ptr = nullptr;
auto cuda_ret = cudaMalloc(&device_ptr, size);
if (cuda_ret != cudaSuccess) {
MS_LOG(ERROR) << "Cuda Malloc failed for size:" << size;
return nullptr;
}
MS_LOG(INFO) << "cudaMalloc size: " << size;
return device_ptr;
}
void GPUCacheMem::FreeMemory(void *device_addr) {
auto cuda_ret = cudaFree(device_addr);
if (cuda_ret != cudaSuccess && cuda_ret != cudaErrorCudartUnloading) {
MS_LOG(WARNING) << "free cuda failed for " << cudaGetErrorName(cuda_ret);
}
}
bool GPUCacheMem::RecordEvent() {
event_.reset(new cudaEvent_t());
MS_ERROR_IF_NULL_W_RET_VAL(event_, false);
CHECK_CUDA_RET_WITH_RETURN_ERROR_NOTRACE(cudaEventCreate(&(*event_)), "Cuda create event failed");
CHECK_CUDA_RET_WITH_RETURN_ERROR_NOTRACE(cudaEventRecord(*event_, reinterpret_cast<cudaStream_t>(stream_)),
"Cuda record event failed");
return true;
}
bool GPUCacheMem::SynchronizeEvent() {
MS_ERROR_IF_NULL_W_RET_VAL(event_, false);
CHECK_CUDA_RET_WITH_RETURN_ERROR_NOTRACE(cudaEventSynchronize(*event_), "Cuda sync event failed");
CHECK_CUDA_RET_WITH_RETURN_ERROR_NOTRACE(cudaEventDestroy(*event_), "Cuda destroy event failed");
return true;
}
bool GPUCacheMem::SynchronizeStream() {
MS_ERROR_IF_NULL_W_RET_VAL(stream_, false);
CHECK_CUDA_RET_WITH_RETURN_ERROR_NOTRACE(cudaStreamSynchronize(reinterpret_cast<cudaStream_t>(stream_)),
"Cuda sync stream failed");
return true;
}
bool GPUCacheMem::CopyHostMemToDevice(void *dst, const void *src, size_t size) {
MS_ERROR_IF_NULL(dst);
MS_ERROR_IF_NULL(src);
CHECK_CUDA_RET_WITH_RETURN_ERROR_NOTRACE(
cudaMemcpyAsync(dst, src, size, cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_)),
"Cuda memcpy failed");
return true;
}
bool GPUCacheMem::CopyDeviceMemToHost(void *dst, const void *src, size_t size) {
MS_ERROR_IF_NULL(dst);
MS_ERROR_IF_NULL(src);
CHECK_CUDA_RET_WITH_RETURN_ERROR_NOTRACE(
cudaMemcpyAsync(dst, src, size, cudaMemcpyDeviceToHost, reinterpret_cast<cudaStream_t>(stream_)),
"Cuda memcpy failed");
return true;
}
bool GPUCacheMem::HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, size_t,
size_t embedding_size, size_t swap_out_size) {
MS_ERROR_IF_NULL(hash_table_addr);
MS_ERROR_IF_NULL(swap_out_value_addr);
MS_ERROR_IF_NULL(swap_out_index_addr);
DoHashSwapOut(reinterpret_cast<float *>(hash_table_addr), reinterpret_cast<float *>(swap_out_value_addr),
reinterpret_cast<int *>(swap_out_index_addr), swap_out_size, embedding_size,
reinterpret_cast<cudaStream_t>(stream_));
return true;
}
bool GPUCacheMem::HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, size_t,
size_t embedding_size, size_t swap_in_size) {
MS_ERROR_IF_NULL(hash_table_addr);
MS_ERROR_IF_NULL(swap_in_value_addr);
MS_ERROR_IF_NULL(swap_in_index_addr);
DoHashSwapIn(reinterpret_cast<float *>(hash_table_addr), reinterpret_cast<float *>(swap_in_value_addr),
reinterpret_cast<int *>(swap_in_index_addr), swap_in_size, embedding_size,
reinterpret_cast<cudaStream_t>(stream_));
return true;
}
} // namespace gpu
} // namespace mindspore

View File

@ -0,0 +1,48 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_PS_CACHE_GPU_GPU_CACHE_MEM_H_
#define MINDSPORE_CCSRC_PS_PS_CACHE_GPU_GPU_CACHE_MEM_H_
#include <cuda_runtime_api.h>
#include <memory>
#include "ps/ps_cache/ps_cache_basic.h"
namespace mindspore {
namespace gpu {
class GPUCacheMem : public ps::PsCacheBasic {
public:
GPUCacheMem() = default;
~GPUCacheMem() override = default;
bool InitDevice(uint32_t device_id, const void *context) override;
void *MallocMemory(size_t size) override;
void FreeMemory(void *buf) override;
bool RecordEvent() override;
bool SynchronizeEvent() override;
bool SynchronizeStream() override;
bool CopyHostMemToDevice(void *dst, const void *src, size_t size) override;
bool CopyDeviceMemToHost(void *dst, const void *src, size_t size) override;
bool HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, size_t cache_vocab_size,
size_t embedding_size, size_t swap_out_size) override;
bool HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, size_t cache_vocab_size,
size_t embedding_size, size_t swap_in_size) override;
private:
std::unique_ptr<cudaEvent_t> event_{nullptr};
};
} // namespace gpu
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_PS_CACHE_GPU_GPU_PS_CACHE_H_

View File

@ -0,0 +1,230 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include "src/common/log_adapter.h"
#include "src/delegate/parameter_cache/lfu_cache.h"
namespace mindspore {
namespace cache {
LFUCacheAlgorithm::~LFUCacheAlgorithm() {
for (auto iter : key_table_) {
delete *(iter.second);
}
key_table_.clear();
frequency_table_.clear();
}
CacheNoe *LFUCacheAlgorithm::GetNode(int key) {
auto key_table_iter = key_table_.find(key);
if (key_table_iter == key_table_.end()) {
return nullptr;
}
auto node_iter = key_table_iter->second;
auto node = *node_iter;
auto node_list_iter = frequency_table_.find(key);
if (node_list_iter == frequency_table_.end()) {
return nullptr;
}
auto &node_list = node_list_iter->second;
node_list.erase(node_iter);
if (node_list.empty()) {
frequency_table_.erase(node_list_iter);
}
node->frequency += 1;
frequency_table_[node->frequency].emplace_front(node);
key_table_[key] = frequency_table_[node->frequency].begin();
return node;
}
int LFUCacheAlgorithm::Get(int key) {
auto node = GetNode(key);
if (node != nullptr) {
return node->value;
}
return -1;
}
void LFUCacheAlgorithm::Put(int key, int value) {
auto node = GetNode(key);
if (node != nullptr) {
node->value = value;
return;
}
if (cache_size_ == 0) {
return;
}
CacheNoe *add_node = nullptr;
if (key_table_.size() == cache_size_) {
add_node = frequency_table_.begin()->second.back();
key_table_.erase(add_node->key);
frequency_table_.begin()->second.pop_back();
if (frequency_table_.begin()->second.size() == 0) {
frequency_table_.erase(frequency_table_.begin()->first);
}
add_node->value = value;
add_node->key = key;
add_node->frequency = 1;
} else {
add_node = new CacheNoe(key, 1, value);
if (add_node == nullptr) {
return;
}
}
frequency_table_[1].emplace_front(add_node);
key_table_[key] = frequency_table_[1].begin();
}
void LFUCacheAlgorithm::GetHitNodesAndSwapIndex(const int *batch_ids, const size_t batch_ids_len, int *cache_index,
std::unordered_map<int, CacheNoe *> *hit_index_nodes,
std::unordered_map<int, std::vector<int>> *need_swap_map) {
// 找到没有命中和命中的index
for (size_t i = 0; i < batch_ids_len; i++) {
auto key = batch_ids[i];
if (key < min_host_index_ || key >= max_host_index_) {
cache_index[i] = -1;
// out range
continue;
}
auto hit_iter = hit_index_nodes->find(key);
if (hit_iter != hit_index_nodes->end()) {
auto node = hit_iter->second;
node->frequency += 1;
cache_index[i] = node->value;
continue;
}
auto swap_iter = need_swap_map->find(key);
if (swap_iter != need_swap_map->end()) {
swap_iter->second.push_back(i);
continue;
}
auto node_iter_iter = key_table_.find(key);
if (node_iter_iter == key_table_.end()) {
(*need_swap_map)[key].push_back(i);
continue;
}
auto node_iter = node_iter_iter->second;
auto node = *node_iter;
auto node_list_iter = frequency_table_.find(node->frequency);
if (node_list_iter == frequency_table_.end()) {
continue;
}
auto &node_list = node_list_iter->second;
node_list.erase(node_iter);
if (node_list.empty()) {
frequency_table_.erase(node_list_iter);
}
// hit
node->frequency += 1;
cache_index[i] = node->value;
(*hit_index_nodes)[key] = node;
}
return;
}
std::list<CacheNoe *> LFUCacheAlgorithm::GetSwapNodes(const std::unordered_map<int, std::vector<int>> &need_swap_map) {
std::list<CacheNoe *> need_swap_nodes;
auto swap_size = need_swap_map.size();
while (swap_size > 0 && !frequency_table_.empty()) {
auto node_list_iter = frequency_table_.begin();
if (node_list_iter->second.size() > swap_size) {
auto iter = node_list_iter->second.begin();
std::advance(iter, swap_size);
need_swap_nodes.splice(need_swap_nodes.end(), node_list_iter->second, node_list_iter->second.begin(), iter);
swap_size = 0;
} else {
swap_size -= node_list_iter->second.size();
need_swap_nodes.splice(need_swap_nodes.end(), node_list_iter->second);
frequency_table_.erase(node_list_iter);
}
}
return need_swap_nodes;
}
Status LFUCacheAlgorithm::CheckCacheHit(const int *batch_ids, const size_t batch_ids_len, int *cache_index,
std::vector<int> *need_swap_indies,
std::vector<int> *need_swap_indies_cache_index) {
if (batch_ids == nullptr) {
MS_LOG(ERROR) << "batch_ids is nullptr";
return kLiteNullptr;
}
if (cache_index == nullptr) {
MS_LOG(ERROR) << "cache_index is nullptr";
return kLiteNullptr;
}
std::unordered_map<int, std::vector<int>> need_swap_map;
std::unordered_map<int, CacheNoe *> hit_index_nodes;
GetHitNodesAndSwapIndex(batch_ids, batch_ids_len, cache_index, &hit_index_nodes, &need_swap_map);
// get need_swap_indies.size() least recently used node
std::list<CacheNoe *> need_swap_nodes = GetSwapNodes(need_swap_map);
// 更新老节点的值
{
if (need_swap_map.size() != need_swap_nodes.size()) {
MS_LOG(ERROR) << " need_swap_map.size() " << need_swap_map.size() << " != need_swap_nodes.size() "
<< need_swap_nodes.size();
return kLiteError;
}
need_swap_indies_cache_index->reserve(need_swap_map.size());
auto need_swap_map_iter = need_swap_map.begin();
for (auto iter = need_swap_nodes.begin();
iter != need_swap_nodes.end() && need_swap_map_iter != need_swap_map.end(); iter++, need_swap_map_iter++) {
auto node = *iter;
key_table_.erase(node->key);
node->key = need_swap_map_iter->first;
node->frequency = 1;
for (auto index : need_swap_map_iter->second) {
cache_index[index] = node->value;
}
need_swap_indies->push_back(need_swap_map_iter->first);
need_swap_indies_cache_index->push_back(node->value);
MS_LOG(DEBUG) << "device index " << node->value << ",for host index " << need_swap_map_iter->first;
key_table_[(*iter)->key] = iter;
}
auto node_list_iter = frequency_table_.begin();
if (node_list_iter->second.size() > 0) {
auto iter = node_list_iter->second.begin();
if ((*iter)->frequency == 1) {
node_list_iter->second.splice(node_list_iter->second.begin(), need_swap_nodes);
} else {
frequency_table_[1] = need_swap_nodes;
}
} else {
frequency_table_[1] = need_swap_nodes;
}
}
for (auto node_iter : hit_index_nodes) {
auto node = node_iter.second;
frequency_table_[node->frequency].emplace_front(node);
key_table_[node->key] = frequency_table_[node->frequency].begin();
}
return kSuccess;
}
} // namespace cache
} // namespace mindspore

View File

@ -0,0 +1,72 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_LRU_CACHE_H_
#define MINDSPORE_LITE_LRU_CACHE_H_
#include <map>
#include <unordered_map>
#include <list>
#include <vector>
#include "include/api/status.h"
namespace mindspore {
namespace cache {
struct CacheNoe {
CacheNoe(int _index, int _frequency, int _value) : key(_index), frequency(_frequency), value(_value) {}
int key; // host input index
int frequency;
int value; // cache index
};
class CacheAlgorithm {
public:
virtual ~CacheAlgorithm() {}
virtual int Get(int key) = 0;
virtual void Put(int key, int value) = 0;
virtual Status CheckCacheHit(const int *batch_ids, const size_t batch_ids_len, int *cache_index,
std::vector<int> *need_swap_indies, std::vector<int> *need_swap_indies_cache_index) = 0;
};
class LFUCacheAlgorithm : public CacheAlgorithm {
public:
LFUCacheAlgorithm(size_t cache_size, int min_host_index, int max_host_index)
: cache_size_(cache_size), min_host_index_(min_host_index), max_host_index_(max_host_index) {}
~LFUCacheAlgorithm() override;
int Get(int key) override;
void Put(int key, int value) override;
Status CheckCacheHit(const int *batch_ids, const size_t batch_ids_len, int *cache_index,
std::vector<int> *need_swap_indies, std::vector<int> *need_swap_indies_cache_index) override;
private:
CacheNoe *GetNode(int key);
void GetHitNodesAndSwapIndex(const int *batch_ids, const size_t batch_ids_len, int *cache_index,
std::unordered_map<int, CacheNoe *> *hit_index_nodes,
std::unordered_map<int, std::vector<int>> *need_swap_map);
std::list<CacheNoe *> GetSwapNodes(const std::unordered_map<int, std::vector<int>> &need_swap_map);
std::unordered_map<int, std::list<CacheNoe *>::iterator> key_table_;
std::map<int, std::list<CacheNoe *>> frequency_table_;
size_t cache_size_;
int min_host_index_{0};
int max_host_index_{1};
};
} // namespace cache
} // namespace mindspore
#endif // MINDSPORE_LITE_LRU_CACHE_H_

View File

@ -0,0 +1,99 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cmath>
#include <cstring>
#include <string>
#include <vector>
#include "src/delegate/parameter_cache/load_host_cache_model.h"
#include "src/common/log_adapter.h"
#include "include/errorcode.h"
#include "src/common/file_utils.h"
namespace mindspore {
namespace cache {
MSTensor *SchemaTensorToMSTensor(lite::SchemaTensorWrapper *schema_tensor_wrapper,
mindspore::schema::Tensor *schema_tensor) {
std::vector<int64_t> shape;
for (size_t j = 0; j < schema_tensor->dims()->size(); j++) {
shape.push_back(schema_tensor->dims()->data()[j]);
}
std::string tensor_name;
if (schema_tensor->name() != nullptr) {
tensor_name = schema_tensor->name()->str();
}
return MSTensor::CreateRefTensor(tensor_name, (DataType)schema_tensor->dataType(), shape,
schema_tensor_wrapper->data(), schema_tensor_wrapper->length());
}
Status HostCacheModel::LoadCache(const std::string &model_path) {
cache_model_ = lite::LiteImportFromPath(model_path.c_str());
if (cache_model_ == nullptr) {
MS_LOG(ERROR) << "Import model failed";
return kLiteGraphFileError;
}
auto allTensors = cache_model_->all_tensors_;
for (auto node : cache_model_->all_nodes_) {
// only support embedding cache
if (node == nullptr || node->node_type_ != schema::PrimitiveType_Gather) {
continue;
}
auto input_index = node->input_indices_[0];
if (input_index > allTensors.size() - 1) {
MS_LOG(ERROR) << "invalid kernel input, input_index " << input_index << ",allTensors.size() "
<< allTensors.size();
return kLiteOutOfTensorRange;
}
auto schema_tensor_wrapper = cache_model_->GetSchemaTensor(input_index);
if (schema_tensor_wrapper == nullptr) {
MS_LOG(ERROR) << "invalid kernel input, input_index " << input_index;
return kLiteOutOfTensorRange;
}
auto schema_tensor = allTensors[input_index];
if (schema_tensor != nullptr && schema_tensor_wrapper->data() != nullptr) {
auto tensor = SchemaTensorToMSTensor(schema_tensor_wrapper, schema_tensor);
if (tensor == nullptr) {
return kLiteMemoryFailed;
}
cache_tensor_[tensor->Name()] = *tensor;
MS_LOG(INFO) << tensor->Name() << " is cache tensor, and the node is [" << node->name_ << "]";
delete tensor;
}
}
return kSuccess;
}
bool HostCacheModel::CheckIsCacheKernel(kernel::Kernel *kernel) {
if (GetHostCacheTensor(kernel) == nullptr) {
return false;
}
return true;
}
MSTensor HostCacheModel::GetHostCacheTensor(kernel::Kernel *kernel) {
if (kernel != nullptr && kernel->inputs().size() > 0) {
auto iter = cache_tensor_.find(kernel->inputs()[0].Name());
if (iter != cache_tensor_.end()) {
return iter->second;
}
}
return MSTensor(nullptr);
}
} // namespace cache
} // namespace mindspore

View File

@ -0,0 +1,47 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_LOAD_HOST_CACHE_DATA_H_
#define MINDSPORE_LITE_LOAD_HOST_CACHE_DATA_H_
#include <map>
#include <string>
#include "src/delegate/parameter_cache/lfu_cache.h"
#include "ps/ps_cache/ps_cache_basic.h"
#include "include/api/status.h"
#include "include/api/data_type.h"
#include "include/api/types.h"
#include "src/lite_model.h"
#include "include/api/kernel.h"
namespace mindspore {
namespace cache {
class HostCacheModel {
public:
HostCacheModel() {}
Status LoadCache(const std::string &model_path);
bool CheckIsCacheKernel(kernel::Kernel *kernel);
MSTensor GetHostCacheTensor(kernel::Kernel *kernel);
private:
std::map<std::string, MSTensor> cache_tensor_;
mindspore::lite::LiteModel *cache_model_{nullptr};
char *model_buf_{nullptr};
size_t model_size_;
};
} // namespace cache
} // namespace mindspore
#endif // MINDSPORE_LITE_EMBEDDING_CACHE_H_

View File

@ -37,6 +37,7 @@ else()
endif()
add_dependencies(gpu_distribution_collective fbs_src)
file(GLOB TENSORRT_RUNTIME_SRC LIST_DIRECTORIES false
${CMAKE_CURRENT_SOURCE_DIR}/*.cc
${CMAKE_CURRENT_SOURCE_DIR}/op/*.cc
@ -44,6 +45,17 @@ file(GLOB TENSORRT_RUNTIME_SRC LIST_DIRECTORIES false
${CMAKE_CURRENT_SOURCE_DIR}/../delegate_utils.cc
)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../parameter_cache)
set(TENSORRT_RUNTIME_SRC
${TENSORRT_RUNTIME_SRC}
${CMAKE_CURRENT_SOURCE_DIR}/../parameter_cache/embedding_cache_manager.cc
${CMAKE_CURRENT_SOURCE_DIR}/../parameter_cache/load_host_cache_model.cc
${CMAKE_CURRENT_SOURCE_DIR}/../parameter_cache/lfu_cache.cc
${CMAKE_CURRENT_SOURCE_DIR}/../parameter_cache/embedding_cache.cc
${CMAKE_CURRENT_SOURCE_DIR}/../parameter_cache/gpu/gpu_cache_mem.cc
)
link_libraries(${CUDA_LIB_PATH}/libcudnn.so)
link_libraries(${CUDA_LIB_PATH}/libnvrtc.so)
link_libraries(${CUDA_LIB_PATH}/libcublas.so)
@ -73,3 +85,4 @@ set_source_files_properties(${CUDA_KERNEL_SRC} PROPERTIES CUDA_SOURCE_PROPERTY_F
SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGES} -std=c++14 -fPIC")
SET(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};-std=c++14;)
cuda_add_library(cuda_kernel_mid STATIC ${CUDA_KERNEL_SRC})

View File

@ -0,0 +1,64 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/delegate/tensorrt/cuda_impl/hash.cuh"
#include "src/delegate/tensorrt/cuda_impl/cuda_helper.h"
template <typename T>
__global__ void HashSwapOut(const T *hash_table, T *swap_out_value, const int *swap_out_index, const int index_size,
const int hash_dim) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < index_size; i += blockDim.x * gridDim.x) {
int hash_index = swap_out_index[i];
for (int j = 0; j < hash_dim; j++) {
swap_out_value[i * hash_dim + j] = hash_table[hash_index * hash_dim + j];
}
}
return;
}
template <typename T>
__global__ void HashSwapIn(T *hash_table, const T *swap_in_value, const int *swap_in_index, const int index_size,
const int hash_dim) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < index_size; i += blockDim.x * gridDim.x) {
int hash_index = swap_in_index[i];
for (int j = 0; j < hash_dim; j++) {
hash_table[hash_index * hash_dim + j] = swap_in_value[i * hash_dim + j];
}
}
return;
}
template <typename T>
void DoHashSwapOut(const T *hash_table, T *swap_out_value, const int *swap_out_index, const int index_size,
const int hash_dim, cudaStream_t cuda_stream) {
HashSwapOut<<<GET_BLOCKS(index_size), GET_THREADS, 0, cuda_stream>>>(hash_table, swap_out_value, swap_out_index,
index_size, hash_dim);
return;
}
template <typename T>
void DoHashSwapIn(T *hash_table, const T *swap_in_value, const int *swap_in_index, const int index_size,
const int hash_dim, cudaStream_t cuda_stream) {
HashSwapIn<<<GET_BLOCKS(index_size), GET_THREADS, 0, cuda_stream>>>(hash_table, swap_in_value, swap_in_index,
index_size, hash_dim);
return;
}
template void DoHashSwapOut<float>(const float *hash_table, float *swap_out_value, const int *swap_out_index,
const int index_size, const int hash_dim, cudaStream_t cuda_stream);
template void DoHashSwapIn<float>(float *hash_table, const float *swap_in_value, const int *swap_in_index,
const int index_size, const int hash_dim, cudaStream_t cuda_stream);

View File

@ -0,0 +1,27 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_CDUA_IMPL_HASH_H_
#define MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_CDUA_IMPL_HASH_H_
template <typename T>
void DoHashSwapOut(const T *hash_table, T *swap_out_value, const int *swap_out_index, const int index_size,
const int hash_dim, cudaStream_t cuda_stream);
template <typename T>
void DoHashSwapIn(T *hash_table, const T *swap_in_value, const int *swap_in_index, const int index_size,
const int hash_dim, cudaStream_t cuda_stream);
#endif // MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_CDUA_IMPL_HASH_H_

View File

@ -132,6 +132,18 @@ Status TensorRTDelegate::Init() {
MS_LOG(ERROR) << "TensorRTRuntime init failed.";
return mindspore::kLiteError;
}
cache_mgr_ = std::make_shared<cache::EmbeddingCacheManager>();
if (cache_mgr_ == nullptr) {
MS_LOG(ERROR) << "malloc EmbeddingCacheManager failed.";
return kLiteMemoryFailed;
}
auto cache_ret = cache_mgr_->Init("");
if (cache_ret != mindspore::kSuccess) {
MS_LOG(ERROR) << "cache_mgr_ init failed.";
return cache_ret;
}
return mindspore::kSuccess;
}
@ -165,6 +177,14 @@ Status TensorRTDelegate::Build(DelegateModel *model) {
}
auto tensorrt_op = FindTensorRTOp(kernel, model->GetPrimitive(kernel));
if (tensorrt_op != nullptr) {
if (cache_mgr_->CheckIsCacheKernel(kernel)) {
auto cache_ret = cache_mgr_->InitCacheKernel(kernel);
if (cache_ret != kSuccess) {
MS_LOG(INFO) << "InitCacheKernel failed " << kernel->name();
return cache_ret;
}
}
// If tensorrt_ops does not equal nullptr, this kernel can be supported by delegate
if (tensorrt_ops.size() == 0) {
from = iter;
@ -219,6 +239,8 @@ TensorRTSubGraph *TensorRTDelegate::CreateTensorRTGraph(const std::vector<Tensor
MS_LOG(ERROR) << "new tensorrt_graph failed.";
return nullptr;
}
tensorrt_graph->SetCacheManager(cache_mgr_);
// 1. For every op, find pre and next ops
FindPreNextOps<TensorRTOp>(ops);

View File

@ -22,6 +22,7 @@
#include <memory>
#include "include/api/delegate.h"
#include "src/delegate/tensorrt/tensorrt_subgraph.h"
#include "src/delegate/parameter_cache/embedding_cache_manager.h"
#include "include/api/kernel.h"
#include "include/errorcode.h"
#include "src/common/log_adapter.h"
@ -63,6 +64,7 @@ class TensorRTDelegate : public Delegate {
bool support_hw_resize_{true};
bool support_resize_{true};
std::shared_ptr<cache::EmbeddingCacheManager> cache_mgr_{nullptr};
};
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_SRC_RUNTIME_DELEGATE_TENSORRT_DELEGATE_

View File

@ -352,12 +352,18 @@ int TensorRTSubGraph::Prepare() {
}
// malloc for cache weight tensor
for (auto cache_tensor : cache_inputs_) {
for (auto cache_tensor : cache_const_inputs_) {
auto data_size = cache_tensor.DataSize();
auto device_ptr = runtime_->GetAllocator()->MallocDeviceMem(cache_tensor, cache_tensor.DataSize());
runtime_->GetAllocator()->SyncMemInHostAndDevice(cache_tensor, cache_tensor.Name().c_str(), true);
runtime_->GetAllocator()->MarkMemValid(cache_tensor.Name().c_str(), true);
int index = this->engine_->getBindingIndex(cache_tensor.Name().c_str());
tensor_bindings_[index] = device_ptr;
auto cache_ret = cache_mgr_->SetDeviceCacheAddr(cache_tensor.Name(), device_ptr, data_size);
if (cache_ret != kSuccess) {
MS_LOG(ERROR) << "SetDeviceCacheAddr failed, cache tensor: " << cache_tensor.Name();
return RET_ERROR;
}
}
if (!this->trt_context_->allInputDimensionsSpecified()) {
@ -466,6 +472,22 @@ int TensorRTSubGraph::Execute() {
MS_LOG(INFO) << "no need memcpy to cuda for input tensor: " << trt_in_tensor_name_[i];
continue;
}
auto iter = model_input_to_cache_tensors_.find(trt_in_tensor_name_[i]);
if (iter != model_input_to_cache_tensors_.end()) {
for (auto &cache_tensor : iter->second) {
ret = cache_mgr_->CacheHandle(cache_tensor.Name(), inputs_[i],
runtime_->GetAllocator()->GetDevicePtr(trt_in_tensor_name_[i]));
if (ret != RET_OK) {
MS_LOG(ERROR) << "handle cache failed " << trt_in_tensor_name_[i];
return RET_ERROR;
}
runtime_->GetAllocator()->MarkMemValid(trt_in_tensor_name_[i], true);
MS_LOG(DEBUG) << cache_tensor.Name() << " CacheHandle succ " << trt_in_tensor_name_[i];
}
continue;
}
ret = runtime_->GetAllocator()->SyncMemInHostAndDevice(inputs_[i], trt_in_tensor_name_[i], true);
if (ret != RET_OK) {
MS_LOG(ERROR) << "sync mem from host to device failed for " << trt_in_tensor_name_[i];
@ -525,9 +547,11 @@ ITensorHelper TensorRTSubGraph::FindTensorRTInputs(TensorRTOp *cur_op, const min
}
return ITensorHelper{};
}
bool TensorRTSubGraph::IsCached(TensorRTOp *cur_op, const mindspore::MSTensor &in_tensor) { return false; }
bool TensorRTSubGraph::IsCached(TensorRTOp *cur_op, const mindspore::MSTensor &in_tensor) {
return cache_mgr_ != nullptr && cache_mgr_->IsCacheTensor(in_tensor);
}
void TensorRTSubGraph::FindCacheTensorInfo(TensorRTOp *cur_op) {
void TensorRTSubGraph::FindCacheTensorInfo(TensorRTOp *cur_op, mindspore::MSTensor device_cache_tensor) {
auto iter = network_cache_tensor_info_.find(cur_op->GetOpName());
if (iter != network_cache_tensor_info_.end()) {
return;
@ -542,6 +566,7 @@ void TensorRTSubGraph::FindCacheTensorInfo(TensorRTOp *cur_op) {
for (auto in_tensor : front_op->inputs()) {
if (IsSubGraphInputTensor(this->inputs(), in_tensor)) {
iter->second.network_input_tensor_.push_back(in_tensor);
model_input_to_cache_tensors_[in_tensor.Name()].push_back(device_cache_tensor);
MS_LOG(DEBUG) << cur_op->GetOpName() << "'s network input tensor name is " << in_tensor.Name()
<< ", can cache: " << iter->second.front_op_can_cache_;
}
@ -556,9 +581,9 @@ void TensorRTSubGraph::FindCacheTensorInfo(TensorRTOp *cur_op) {
bool TensorRTSubGraph::CanOpCache(TensorRTOp *cur_op) { return true; }
int TensorRTSubGraph::HandleCacheTensor(TensorRTOp *cur_op, const mindspore::MSTensor &in_tensor) {
FindCacheTensorInfo(cur_op);
FindCacheTensorInfo(cur_op, in_tensor);
// cache kernel weight tensor
cache_inputs_.push_back(in_tensor);
cache_const_inputs_.push_back(in_tensor);
MS_LOG(INFO) << "auto add cache constant tensor for: " << in_tensor.Name();
auto cuda_dtype = ConvertDataType(in_tensor.DataType());
nvinfer1::Dims input_dims = ConvertCudaDims(in_tensor.Shape());

View File

@ -24,6 +24,7 @@
#include "include/api/kernel.h"
#include "src/delegate/tensorrt/tensorrt_runtime.h"
#include "src/delegate/tensorrt/tensorrt_utils.h"
#include "src/delegate/parameter_cache/embedding_cache_manager.h"
#include "include/api/context.h"
namespace mindspore::lite {
@ -74,6 +75,8 @@ class TensorRTSubGraph : public kernel::Kernel {
int Init();
void SetCacheManager(const std::shared_ptr<cache::EmbeddingCacheManager> &cache_mgr) { cache_mgr_ = cache_mgr; }
private:
int BuildEngine();
@ -89,7 +92,7 @@ class TensorRTSubGraph : public kernel::Kernel {
bool IsCached(TensorRTOp *cur_op, const mindspore::MSTensor &in_tensor);
void FindCacheTensorInfo(TensorRTOp *cur_op);
void FindCacheTensorInfo(TensorRTOp *cur_op, mindspore::MSTensor device_cache_tensor);
bool CanOpCache(TensorRTOp *cur_op);
@ -113,7 +116,7 @@ class TensorRTSubGraph : public kernel::Kernel {
std::vector<std::string> trt_in_tensor_name_;
std::vector<std::string> trt_out_tensor_name_;
std::vector<mindspore::MSTensor> cache_inputs_;
std::vector<mindspore::MSTensor> cache_const_inputs_;
std::map<std::string, CacheTensorInfo> network_cache_tensor_info_;
nvinfer1::INetworkDefinition *network_{nullptr};
@ -126,6 +129,10 @@ class TensorRTSubGraph : public kernel::Kernel {
int input_batchsize_index_{0};
int output_batchsize_index_{0};
int input_hw_index_{0};
std::map<std::string, std::vector<mindspore::MSTensor>> model_input_to_cache_tensors_;
std::shared_ptr<cache::EmbeddingCacheManager> cache_mgr_{nullptr};
};
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_TENSORRT_SUBGTAPH_H_

View File

@ -483,7 +483,7 @@ SchemaTensorWrapper *LiteModel::GetSchemaTensor(const size_t &tensor_index) cons
return this->inner_all_tensors_.at(tensor_index);
}
Model *ImportFromPath(const char *model_path) {
LiteModel *LiteImportFromPath(const char *model_path) {
if (model_path == nullptr) {
MS_LOG(ERROR) << "The model path is nullptr";
return nullptr;
@ -508,6 +508,8 @@ Model *ImportFromPath(const char *model_path) {
return model;
}
Model *ImportFromPath(const char *model_path) { return LiteImportFromPath(model_path); }
Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf) {
auto *model = new (std::nothrow) LiteModel();
if (model == nullptr) {

View File

@ -29,6 +29,7 @@
#include "src/common/version_manager.h"
#include "src/schema_tensor_wrapper.h"
#include "nnacl/op_base.h"
#include "src/common/prim_util.h"
#ifdef ENABLE_V0
#include "schema/model_v0_generated.h"
#endif
@ -79,6 +80,7 @@ class LiteModel : public Model {
MS_CHECK_TRUE_MSG(node != nullptr, false, "new node fail!");
auto c_node = meta_graph.nodes()->template GetAs<U>(i);
MS_CHECK_TRUE_MSG(c_node != nullptr, false, "get as cnode fail!");
node->node_type_ = GetPrimitiveType(c_node->primitive(), schema_version_);
#ifdef ENABLE_MODEL_OBF
auto src_prim = reinterpret_cast<const schema::Primitive *>(c_node->primitive());
auto src_prim_type = src_prim->value_type();
@ -286,6 +288,8 @@ class LiteModel : public Model {
};
Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf);
LiteModel *LiteImportFromPath(const char *model_path);
Model *ImportFromPath(const char *model_path);
} // namespace lite
} // namespace mindspore

View File

@ -242,6 +242,7 @@ int LiteSession::ConvertTensors(const lite::Model *model) {
uint32_t tensor_count = model->all_tensors_.size();
auto model_input_indices = model->input_indices_;
auto model_output_indices = model->output_indices_;
for (uint32_t i = 0; i < tensor_count; ++i) {
auto *src_tensor = model->all_tensors_[i];
if (src_tensor == nullptr) {
@ -279,6 +280,7 @@ int LiteSession::ConvertTensors(const lite::Model *model) {
dst_tensor->set_category(Category::GRAPH_OUTPUT);
}
}
this->tensors_.emplace_back(dst_tensor);
}
return RET_OK;
@ -1483,13 +1485,7 @@ int lite::LiteSession::CreateSessionByBuf(const char *model_buf, mindspore::Mode
int lite::LiteSession::CreateSessionByPath(const std::string &model_path, mindspore::ModelType model_type,
session::LiteSession *session) {
size_t model_size;
auto model_buf = LoadModelByPath(model_path, model_type, &model_size);
if (model_buf == nullptr) {
MS_LOG(ERROR) << "Read model file failed";
return RET_ERROR;
}
auto *model = lite::ImportFromBuffer(model_buf, model_size, true);
auto *model = lite::ImportFromPath(model_path.c_str());
if (model == nullptr) {
MS_LOG(ERROR) << "Import model failed";
return RET_ERROR;

View File

@ -1207,9 +1207,11 @@ OP_SCHEMA_DEF_END(ScatterNdUpdate)
OP_SCHEMA_DEF(AllGather)
OP_ATTR(group, string)
OP_ATTR(rank_size, int)
OP_SCHEMA_DEF_END(AllGather)
OP_SCHEMA_DEF(ReduceScatter)
OP_ATTR(group, string)
OP_ATTR_ENUM(mode, ReduceMode)
OP_ATTR(rank_size, int)
OP_SCHEMA_DEF_END(ReduceScatter)

View File

@ -46,9 +46,10 @@ OpParameter *PopulateAllGatherParameter(const void *prim) {
return nullptr;
}
memset(param, 0, sizeof(AllGatherParameter));
memcpy(param->group_, group->c_str(), group->size());
memcpy(param->group_, value->group()->c_str(), value->group()->size());
param->rank_size_ = value->rank_size();
param->op_parameter_.type_ = primitive->value_type();
return reinterpret_cast<OpParameter *>(param);
}
REG_POPULATE(PrimitiveType_AllGather, PopulateAllGatherParameter, SCHEMA_CUR)

View File

@ -47,10 +47,11 @@ OpParameter *PopulateReduceScatterParameter(const void *prim) {
return nullptr;
}
memset(param, 0, sizeof(ReduceScatterParameter));
memcpy(param->group_, group->c_str(), group->size());
memcpy(param->group_, value->group()->c_str(), value->group()->size());
param->rank_size_ = value->rank_size();
param->mode_ = value->mode();
param->op_parameter_.type_ = primitive->value_type();
return reinterpret_cast<OpParameter *>(param);
}
REG_POPULATE(PrimitiveType_ReduceScatter, PopulateReduceScatterParameter, SCHEMA_CUR)

View File

@ -34,11 +34,7 @@ int AllGatherCPUKernel::Prepare() {
int AllGatherCPUKernel::ReSize() { return lite::RET_OK; }
int AllGatherCPUKernel::Run() {
int rank = get_rank(param_->group_);
if (param_->rank_ != rank) {
return lite::RET_ERROR;
}
int rank = param_->rank_size_;
size_t data_size = in_tensors().front()->Size();
auto out_tensor = out_tensors().front();
int8_t *out_data = reinterpret_cast<int8_t *>(out_tensor->data());

View File

@ -57,7 +57,7 @@ int ReduceScatterCPUKernel::DoReduceScatter(void *in_data, void *reduce_data, si
if (param_->mode_ == schema::ReduceMode_ReduceSum) {
for (size_t i = 0; i < data_size; i++) out[i] += in[i];
} else if (param_->mode_ == schema::ReduceMode_ReduceMean) {
for (size_t i = 0; i < data_size; i++) out[i] += (in[i] / static_cast<float>(param_->rank_));
for (size_t i = 0; i < data_size; i++) out[i] += (in[i] / static_cast<float>(param_->rank_size_));
} else if (param_->mode_ == schema::ReduceMode_ReduceMax) {
for (size_t i = 0; i < data_size; i++) out[i] = in[i] > out[i] ? in[i] : out[i];
} else if (param_->mode_ == schema::ReduceMode_ReduceMin) {
@ -74,11 +74,7 @@ int ReduceScatterCPUKernel::DoReduceScatter(void *in_data, void *reduce_data, si
}
int ReduceScatterCPUKernel::Run() {
int rank = get_rank(param_->group_);
if (param_->rank_ != rank) {
return lite::RET_ERROR;
}
int rank = param_->rank_size_;
size_t in_data_size = in_tensors().front()->Size();
size_t in_ele_size = in_tensors().front()->ElementsNum();
size_t out_data_size = out_tensors().front()->Size();

View File

@ -23,6 +23,7 @@
#include "tools/converter/quant_param_holder.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "utils/check_convert_utils.h"
#include "utils/ms_utils_secure.h"
#include "tools/optimizer/common/format_utils.h"
#include "nnacl/op_base.h"
#include "tools/anf_exporter/anf_exporter.h"
@ -259,8 +260,9 @@ int FetchFromDefaultParam(const ParameterPtr &param_node, const converter::FmkTy
if (tensor_info != nullptr && tensor_info->Size() != 0) {
if (data_type != kObjectTypeTensorType || tensor_info->Size() >= kTensorListMinSize) {
data_info->data_.resize(tensor_info->Size() - offset);
if (EOK != memcpy_s(data_info->data_.data(), data_info->data_.size(),
static_cast<uint8_t *>(tensor_info->data_c()) + offset, tensor_info->Size() - offset)) {
if (EOK != common::huge_memcpy_s(data_info->data_.data(), data_info->data_.size(),
static_cast<uint8_t *>(tensor_info->data_c()) + offset,
tensor_info->Size() - offset)) {
MS_LOG(ERROR) << "memcpy_s failed.";
return RET_ERROR;
}

View File

@ -30,7 +30,7 @@
namespace mindspore::lite {
namespace {
constexpr size_t kModelSizeLimit = 2 * 1024 * 1024;
constexpr size_t kModelSizeLimit = static_cast<size_t>(2) * 1024 * 1024 * 1024;
constexpr size_t kExternalDataHeadSize = 4096;
constexpr size_t kMagicNumberSize = 4;
constexpr size_t kFlatbuffersBuilderInitSize = 1024;
@ -241,7 +241,7 @@ int MetaGraphSerializer::Save(const schema::MetaGraphT &graph, const std::string
builder.Finish(offset);
schema::FinishMetaGraphBuffer(builder, offset);
size_t size = builder.GetSize();
auto save_together = (size < kModelSizeLimit) || true;
auto save_together = (size < kModelSizeLimit);
MetaGraphSerializer meta_graph_serializer;
if (!meta_graph_serializer.Init(graph, output_path, save_together)) {
MS_LOG(ERROR) << "Init MetaGraphSerializer failed";

View File

@ -19,6 +19,7 @@
#include <vector>
#include "src/tensorlist.h"
#include "tools/optimizer/common/format_utils.h"
#include "utils/ms_utils_secure.h"
#include "nnacl/op_base.h"
namespace mindspore {
@ -100,12 +101,13 @@ int ConvertToLiteTensor(const std::vector<lite::DataInfo> &data_infos, std::vect
return lite::RET_ERROR;
}
} else {
auto tensor_data = reinterpret_cast<char *>(malloc(tensor_size));
auto tensor_data = malloc(tensor_size);
if (tensor_data == nullptr) {
MS_LOG(ERROR) << "tensor_data is nullptr.";
return lite::RET_ERROR;
}
if (memcpy_s(tensor_data, tensor_size, data_info.data_.data(), tensor_size) != EOK) {
if (common::huge_memcpy_s(static_cast<uint8_t *>(tensor_data), tensor_size, data_info.data_.data(),
tensor_size) != EOK) {
free(tensor_data);
MS_LOG(ERROR) << "memcpy data error.";
return lite::RET_ERROR;