forked from mindspore-Ecosystem/mindspore
!26736 [MS][LITE]support parameter cache and distribution predict
Merge pull request !26736 from 张学同/distribution_cache_dev
This commit is contained in:
commit
d0792507bd
|
@ -18,7 +18,6 @@
|
|||
#define MINDSPORE_NNACL_ALL_GATHER_PARAMETER_H_
|
||||
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/communication_func.h"
|
||||
|
||||
typedef struct AllGatherParameter {
|
||||
// primitive parameter
|
||||
|
@ -26,6 +25,6 @@ typedef struct AllGatherParameter {
|
|||
char group_[DEFAULT_GROUP_NAME_LEN];
|
||||
|
||||
// other parameter
|
||||
int rank_;
|
||||
int rank_size_;
|
||||
} AllGatherParameter;
|
||||
#endif // MINDSPORE_NNACL_ALL_GATHER_PARAMETER_H_
|
||||
|
|
|
@ -1,33 +0,0 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_NNACL_COMMUNICATION_FUNC_H_
|
||||
#define MINDSPORE_NNACL_COMMUNICATION_FUNC_H_
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
#define DEFAULT_GROUP_NAME_LEN 101
|
||||
#define DEFAULT_GROUP_SIZE 2
|
||||
|
||||
static inline int get_rank(char *group) { return DEFAULT_GROUP_SIZE; }
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // MINDSPORE_NNACL_COMMUNICATION_FUNC_H_
|
|
@ -18,7 +18,6 @@
|
|||
#include "nnacl/infer/infer_register.h"
|
||||
#include "nnacl/infer/common_infer.h"
|
||||
#include "nnacl/all_gather_parameter.h"
|
||||
#include "nnacl/communication_func.h"
|
||||
|
||||
int AllGatherInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
|
||||
OpParameter *parameter) {
|
||||
|
@ -34,8 +33,7 @@ int AllGatherInferShape(const TensorC *const *inputs, size_t inputs_size, Tensor
|
|||
}
|
||||
|
||||
AllGatherParameter *param = (AllGatherParameter *)parameter;
|
||||
param->rank_ = get_rank(param->group_);
|
||||
if (param->rank_ <= 0) {
|
||||
if (param->rank_size_ <= 0) {
|
||||
return NNACL_INFER_INVALID;
|
||||
}
|
||||
|
||||
|
@ -45,7 +43,7 @@ int AllGatherInferShape(const TensorC *const *inputs, size_t inputs_size, Tensor
|
|||
|
||||
int out_shape[MAX_SHAPE_SIZE];
|
||||
size_t out_shape_size = 0;
|
||||
out_shape[0] = in_shape[0] * param->rank_;
|
||||
out_shape[0] = in_shape[0] * param->rank_size_;
|
||||
out_shape_size++;
|
||||
for (int i = 1; i < input_tensor->shape_size_; i++) {
|
||||
out_shape[i] = in_shape[i];
|
||||
|
|
|
@ -13,11 +13,9 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "nnacl/errorcode.h"
|
||||
#include "nnacl/infer/infer_register.h"
|
||||
#include "nnacl/infer/common_infer.h"
|
||||
#include "nnacl/communication_func.h"
|
||||
#include "nnacl/reduce_scatter_parameter.h"
|
||||
|
||||
int ReduceScatterInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
|
||||
|
@ -31,8 +29,7 @@ int ReduceScatterInferShape(const TensorC *const *inputs, size_t inputs_size, Te
|
|||
}
|
||||
|
||||
ReduceScatterParameter *param = (ReduceScatterParameter *)parameter;
|
||||
param->rank_ = get_rank(param->group_);
|
||||
if (param->rank_ <= 0) {
|
||||
if (param->rank_size_ <= 0) {
|
||||
return NNACL_INFER_INVALID;
|
||||
}
|
||||
|
||||
|
@ -40,19 +37,20 @@ int ReduceScatterInferShape(const TensorC *const *inputs, size_t inputs_size, Te
|
|||
const int *in_shape = input_tensor->shape_;
|
||||
TensorC *out_tensor = outputs[0];
|
||||
|
||||
if (in_shape[0] % param->rank_ != 0) {
|
||||
if (in_shape[0] % param->rank_size_ != 0) {
|
||||
return NNACL_INFER_INVALID;
|
||||
}
|
||||
|
||||
int out_shape[MAX_SHAPE_SIZE];
|
||||
size_t out_shape_size = 0;
|
||||
out_shape[0] = in_shape[0] / param->rank_;
|
||||
out_shape[0] = in_shape[0] / param->rank_size_;
|
||||
out_shape_size++;
|
||||
for (int i = 1; i < input_tensor->shape_size_; i++) {
|
||||
out_shape[i] = in_shape[i];
|
||||
out_shape_size++;
|
||||
}
|
||||
SetShapeArray(out_tensor, out_shape, out_shape_size);
|
||||
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -110,6 +110,7 @@
|
|||
#define kDefaulLiteMaxSpinCount 300000
|
||||
#define kDefaulLiteMinSpinCount 1
|
||||
#define kDefaulLiteIosSpinCount 1
|
||||
#define DEFAULT_GROUP_NAME_LEN 101
|
||||
|
||||
#if ENABLE_HIGH_PERFORMANCE
|
||||
#define MS_CHECK_TRUE_RET(value, errcode)
|
||||
|
|
|
@ -18,7 +18,6 @@
|
|||
#define MINDSPORE_NNACL_REDUCE_SCATTER_PARAMETER_H_
|
||||
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/communication_func.h"
|
||||
|
||||
typedef struct ReduceScatterParameter {
|
||||
// primitive parameter
|
||||
|
@ -27,6 +26,6 @@ typedef struct ReduceScatterParameter {
|
|||
int mode_;
|
||||
|
||||
// other parameter
|
||||
int rank_;
|
||||
int rank_size_;
|
||||
} ReduceScatterParameter;
|
||||
#endif // MINDSPORE_NNACL_REDUCE_SCATTER_PARAMETER_H_
|
||||
|
|
|
@ -137,6 +137,10 @@ void *AscendPsCache::MallocMemory(size_t size) {
|
|||
return device_addr;
|
||||
}
|
||||
|
||||
void AscendPsCache::FreeMemory(void *device_addr) {
|
||||
device::ascend::AscendMemoryPool::GetInstance().FreeTensorMem(device_addr);
|
||||
}
|
||||
|
||||
bool AscendPsCache::MallocConstantMemory(size_t cache_vocab_size) {
|
||||
offset_addr_ = reinterpret_cast<int *>(device::ascend::AscendMemoryPool::GetInstance().AllocTensorMem(sizeof(int)));
|
||||
if (offset_addr_ == nullptr) {
|
||||
|
|
|
@ -52,6 +52,7 @@ class AscendPsCache : public PsCacheBasic {
|
|||
~AscendPsCache() override = default;
|
||||
bool InitDevice(uint32_t device_id, const void *context) override;
|
||||
void *MallocMemory(size_t size) override;
|
||||
void FreeMemory(void *device_addr) override;
|
||||
bool MallocConstantMemory(size_t cache_vocab_size) override;
|
||||
bool RecordEvent() override;
|
||||
bool SynchronizeEvent() override;
|
||||
|
|
|
@ -41,6 +41,10 @@ void *GPUPsCache::MallocMemory(size_t size) {
|
|||
return device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(size);
|
||||
}
|
||||
|
||||
void GPUPsCache::FreeMemory(void *device_addr) {
|
||||
device::gpu::GPUMemoryAllocator::GetInstance().FreeTensorMem(device_addr);
|
||||
}
|
||||
|
||||
bool GPUPsCache::RecordEvent() {
|
||||
event_.reset(new cudaEvent_t());
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(event_, false);
|
||||
|
|
|
@ -30,6 +30,7 @@ class GPUPsCache : public PsCacheBasic {
|
|||
~GPUPsCache() override = default;
|
||||
bool InitDevice(uint32_t device_id, const void *context) override;
|
||||
void *MallocMemory(size_t size) override;
|
||||
void FreeMemory(void *device_addr) override;
|
||||
bool RecordEvent() override;
|
||||
bool SynchronizeEvent() override;
|
||||
bool SynchronizeStream() override;
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -19,25 +19,8 @@
|
|||
#include <utility>
|
||||
#include <memory>
|
||||
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
#define RETURN_IF_FALSE(condition) \
|
||||
do { \
|
||||
if (!(condition)) { \
|
||||
return false; \
|
||||
} \
|
||||
} while (false)
|
||||
|
||||
#define RETURN_IF_FALSE_WITH_LOG(condition, message) \
|
||||
do { \
|
||||
if (!(condition)) { \
|
||||
MS_LOG(ERROR) << message; \
|
||||
return false; \
|
||||
} \
|
||||
} while (false)
|
||||
|
||||
class PsCacheBasic {
|
||||
public:
|
||||
PsCacheBasic() = default;
|
||||
|
@ -45,6 +28,7 @@ class PsCacheBasic {
|
|||
virtual bool InitDevice(uint32_t device_id, const void *context) = 0;
|
||||
virtual void *MallocMemory(size_t size) = 0;
|
||||
virtual bool MallocConstantMemory(size_t cache_vocab_size) { return true; }
|
||||
virtual void FreeMemory(void *buf) = 0;
|
||||
virtual bool RecordEvent() = 0;
|
||||
virtual bool SynchronizeEvent() = 0;
|
||||
virtual bool SynchronizeStream() = 0;
|
||||
|
|
|
@ -28,6 +28,15 @@ std::string AllGather::get_group() const {
|
|||
auto value_ptr = GetAttr(kGroup);
|
||||
return GetValue<std::string>(value_ptr);
|
||||
}
|
||||
|
||||
void AllGather::set_rank_size(int rank_size) {
|
||||
(void)this->AddAttr(kRankSize, MakeValue(static_cast<int64_t>(rank_size)));
|
||||
}
|
||||
int AllGather::get_rank_size() const {
|
||||
auto value_ptr = GetAttr(kRankSize);
|
||||
return static_cast<int>(GetValue<int64_t>(value_ptr));
|
||||
}
|
||||
|
||||
REGISTER_PRIMITIVE_C(kNameAllGather, AllGather);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -35,6 +35,8 @@ class MS_CORE_API AllGather : public PrimitiveC {
|
|||
void Init() {}
|
||||
void set_group(const std::string &format);
|
||||
std::string get_group() const;
|
||||
void set_rank_size(int rank_size);
|
||||
int get_rank_size() const;
|
||||
};
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -166,6 +166,7 @@ constexpr auto kDivisorOverride = "divisor_override";
|
|||
constexpr auto kPostNmsTopn = "post_nms_topn";
|
||||
constexpr auto kPower = "power";
|
||||
constexpr auto kPreNmsTopn = "pre_nms_topn";
|
||||
constexpr auto kRankSize = "rank_size";
|
||||
constexpr auto kRatio = "ratio";
|
||||
constexpr auto kReduction = "reduction";
|
||||
constexpr auto kRootRank = "root_rank";
|
||||
|
|
|
@ -39,6 +39,14 @@ ReduceMode ReduceScatter::get_mode() const {
|
|||
return ReduceMode(GetValue<int64_t>(value_ptr));
|
||||
}
|
||||
|
||||
void ReduceScatter::set_rank_size(int rank_size) {
|
||||
(void)this->AddAttr(kRankSize, MakeValue(static_cast<int64_t>(rank_size)));
|
||||
}
|
||||
int ReduceScatter::get_rank_size() const {
|
||||
auto value_ptr = GetAttr(kRankSize);
|
||||
return static_cast<int>(GetValue<int64_t>(value_ptr));
|
||||
}
|
||||
|
||||
REGISTER_PRIMITIVE_C(kNameReduceScatter, ReduceScatter);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -37,6 +37,8 @@ class MS_CORE_API ReduceScatter : public PrimitiveC {
|
|||
std::string get_group() const;
|
||||
void set_mode(const ReduceMode &mode);
|
||||
ReduceMode get_mode() const;
|
||||
void set_rank_size(int rank_size);
|
||||
int get_rank_size() const;
|
||||
};
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -291,6 +291,21 @@ class LogWriter {
|
|||
} \
|
||||
} while (0)
|
||||
|
||||
#define RETURN_IF_FALSE(condition) \
|
||||
do { \
|
||||
if (!(condition)) { \
|
||||
return false; \
|
||||
} \
|
||||
} while (false)
|
||||
|
||||
#define RETURN_IF_FALSE_WITH_LOG(condition, message) \
|
||||
do { \
|
||||
if (!(condition)) { \
|
||||
MS_LOG(ERROR) << message; \
|
||||
return false; \
|
||||
} \
|
||||
} while (false)
|
||||
|
||||
#ifdef DEBUG
|
||||
#include <cassert>
|
||||
#define MS_ASSERT(f) assert(f)
|
||||
|
|
|
@ -1207,9 +1207,11 @@ table ScatterNdUpdate {
|
|||
|
||||
table AllGather {
|
||||
group: string;
|
||||
rank_size: int;
|
||||
}
|
||||
|
||||
table ReduceScatter {
|
||||
group: string;
|
||||
mode: ReduceMode;
|
||||
rank_size: int;
|
||||
}
|
||||
|
|
|
@ -30,7 +30,7 @@ std::vector<int32_t> TruncateShape(const std::vector<int64_t> &shape, enum TypeI
|
|||
size_t element_size = lite::DataTypeSize(type);
|
||||
for (size_t i = 0; i < shape.size(); i++) {
|
||||
auto dim = shape[i];
|
||||
if (dim < 0 || dim > INT_MAX || (dim != 0 && element_size > INT_MAX / static_cast<size_t>(dim))) {
|
||||
if (dim < 0 || dim > INT_MAX) {
|
||||
MS_LOG(ERROR) << "Invalid shape.";
|
||||
return empty;
|
||||
} else {
|
||||
|
|
|
@ -0,0 +1,169 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "src/delegate/parameter_cache/embedding_cache.h"
|
||||
#include <cuda_runtime.h>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <cmath>
|
||||
#include <cstring>
|
||||
#include "src/delegate/parameter_cache/gpu/gpu_cache_mem.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "include/errorcode.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace cache {
|
||||
void LookUpTableTask(size_t indices_lens, size_t first_dim_size, const char *input_addr, const int *indices_addr,
|
||||
char *output_addr, size_t embedding_len) {
|
||||
for (size_t i = 0; i < indices_lens; ++i) {
|
||||
int index = indices_addr[i];
|
||||
if (index >= 0 && index < static_cast<int>(first_dim_size)) {
|
||||
size_t pos = index * embedding_len;
|
||||
std::memcpy(output_addr, input_addr + pos, embedding_len);
|
||||
} else {
|
||||
memset(output_addr, 0, embedding_len);
|
||||
}
|
||||
output_addr += embedding_len;
|
||||
}
|
||||
}
|
||||
|
||||
EmbeddingCache::~EmbeddingCache() {
|
||||
if (hash_swap_value_device_addr_ != nullptr) {
|
||||
device_cache_->FreeMemory(hash_swap_value_device_addr_);
|
||||
hash_swap_value_device_addr_ = nullptr;
|
||||
}
|
||||
if (hash_swap_value_addr_ != nullptr) {
|
||||
free(hash_swap_value_addr_);
|
||||
hash_swap_value_addr_ = nullptr;
|
||||
}
|
||||
if (hash_swap_index_addr_ != nullptr) {
|
||||
device_cache_->FreeMemory(hash_swap_index_addr_);
|
||||
hash_swap_index_addr_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
Status EmbeddingCache::Init() {
|
||||
cache_ = std::make_shared<LFUCacheAlgorithm>(device_cache_size_, mix_host_index_, max_host_index_);
|
||||
if (cache_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc LFUCacheAlgorithm failed";
|
||||
return kLiteMemoryFailed;
|
||||
}
|
||||
device_cache_ = std::make_shared<gpu::GPUCacheMem>();
|
||||
if (device_cache_ == nullptr) {
|
||||
MS_LOG(ERROR) << "get cache failed";
|
||||
return kLiteMemoryFailed;
|
||||
}
|
||||
|
||||
auto hash_swap_value_size = embedding_size_ * batch_elements_ * sizeof_data_type_;
|
||||
hash_swap_value_device_addr_ = device_cache_->MallocMemory(hash_swap_value_size);
|
||||
if (hash_swap_value_device_addr_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc hash_swap_value_device failed, malloc size " << hash_swap_value_size;
|
||||
return kLiteMemoryFailed;
|
||||
}
|
||||
|
||||
hash_swap_value_addr_ = malloc(hash_swap_value_size);
|
||||
if (hash_swap_value_addr_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc hash_swap_value failed, malloc size " << hash_swap_value_size;
|
||||
return kLiteMemoryFailed;
|
||||
}
|
||||
|
||||
// data type of index
|
||||
hash_swap_index_addr_ = static_cast<int *>(device_cache_->MallocMemory(batch_elements_ * sizeof(int)));
|
||||
if (hash_swap_index_addr_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc hash_swap_index failed, malloc size " << batch_elements_ * sizeof(int);
|
||||
return kLiteMemoryFailed;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "init succ, rank_group_size_ num:" << rank_group_size_ << ", rank id:" << rank_id_
|
||||
<< ", index begin:" << mix_host_index_ << ", index end:" << max_host_index_;
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
Status EmbeddingCache::SetHostCacheAddr(void *addr, size_t size) {
|
||||
if (sizeof_data_type_ * host_cache_size_ * embedding_size_ != size) {
|
||||
return kLiteParamInvalid;
|
||||
}
|
||||
host_addr_ = addr;
|
||||
|
||||
// copy part of host mem to device
|
||||
auto ret =
|
||||
device_cache_->CopyHostMemToDevice(device_addr_, addr, sizeof_data_type_ * device_cache_size_ * embedding_size_);
|
||||
if (!ret) {
|
||||
MS_LOG(ERROR) << "CopyHostMemToDevice failed, copy size "
|
||||
<< sizeof_data_type_ * device_cache_size_ * embedding_size_;
|
||||
return kLiteMemoryFailed;
|
||||
}
|
||||
|
||||
// init cache
|
||||
auto index_num = device_cache_size_;
|
||||
for (size_t i = 0; i < index_num; i++) {
|
||||
cache_->Put(mix_host_index_ + i, i);
|
||||
}
|
||||
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
Status EmbeddingCache::SetDeviceCacheAddr(void *device_mem_addr, size_t size) {
|
||||
if (sizeof_data_type_ * device_cache_size_ * embedding_size_ != size) {
|
||||
return kLiteParamInvalid;
|
||||
}
|
||||
|
||||
device_addr_ = device_mem_addr;
|
||||
SetHostCacheAddr(host_addr_, sizeof_data_type_ * host_cache_size_ * embedding_size_);
|
||||
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
Status EmbeddingCache::CheckCacheHit(const int *batch_ids, const size_t batch_ids_len, int *cache_index) {
|
||||
std::vector<int> need_swap_indies;
|
||||
std::vector<int> need_swap_indies_cache_index;
|
||||
auto ret =
|
||||
cache_->CheckCacheHit(batch_ids, batch_ids_len, cache_index, &need_swap_indies, &need_swap_indies_cache_index);
|
||||
if (ret != kSuccess) {
|
||||
MS_LOG(ERROR) << "CheckCacheHit failed";
|
||||
return ret;
|
||||
}
|
||||
auto swap_indices_size = need_swap_indies.size();
|
||||
if (swap_indices_size > 0) {
|
||||
LookUpTableTask(swap_indices_size, host_cache_size_, static_cast<char *>(host_addr_), need_swap_indies.data(),
|
||||
static_cast<char *>(hash_swap_value_addr_), embedding_size_ * sizeof_data_type_);
|
||||
|
||||
// copy data
|
||||
auto device_cache_ret = device_cache_->CopyHostMemToDevice(hash_swap_value_device_addr_, hash_swap_value_addr_,
|
||||
swap_indices_size * embedding_size_ * sizeof_data_type_);
|
||||
if (!device_cache_ret) {
|
||||
MS_LOG(ERROR) << "copy swap value to device failed";
|
||||
return kLiteMemoryFailed;
|
||||
}
|
||||
|
||||
device_cache_ret = device_cache_->CopyHostMemToDevice(hash_swap_index_addr_, need_swap_indies_cache_index.data(),
|
||||
swap_indices_size * sizeof(int));
|
||||
if (!device_cache_ret) {
|
||||
MS_LOG(ERROR) << "copy swap indies to device failed";
|
||||
return kLiteMemoryFailed;
|
||||
}
|
||||
|
||||
device_cache_ret = device_cache_->HashSwapIn(device_addr_, hash_swap_value_device_addr_, hash_swap_index_addr_,
|
||||
device_cache_size_, embedding_size_, swap_indices_size);
|
||||
if (!device_cache_ret) {
|
||||
MS_LOG(ERROR) << "HashSwapIn failed";
|
||||
return kLiteMemoryFailed;
|
||||
}
|
||||
}
|
||||
|
||||
return kSuccess;
|
||||
}
|
||||
} // namespace cache
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,95 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_EMBEDDING_CACHE_H_
|
||||
#define MINDSPORE_LITE_EMBEDDING_CACHE_H_
|
||||
|
||||
#include <cmath>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include "src/delegate/parameter_cache/lfu_cache.h"
|
||||
#include "ps/ps_cache/ps_cache_basic.h"
|
||||
#include "include/api/status.h"
|
||||
#include "include/api/data_type.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace cache {
|
||||
class EmbeddingCache {
|
||||
public:
|
||||
EmbeddingCache(size_t vocab_size, size_t host_cache_size, size_t device_cache_size, size_t embedding_size,
|
||||
size_t batch_elements, DataType data_type, void *host_addr, int rank_id, int rank_group_size)
|
||||
: vocab_size_(vocab_size),
|
||||
host_cache_size_(host_cache_size),
|
||||
device_cache_size_(device_cache_size),
|
||||
embedding_size_(embedding_size),
|
||||
batch_elements_(batch_elements),
|
||||
data_type_(data_type),
|
||||
host_addr_(host_addr),
|
||||
rank_id_(rank_id),
|
||||
rank_group_size_(rank_group_size) {
|
||||
auto local_shard_size = static_cast<int>(std::ceil(static_cast<float>(vocab_size_) / rank_group_size_));
|
||||
mix_host_index_ = local_shard_size * rank_id_;
|
||||
max_host_index_ = std::min(mix_host_index_ + local_shard_size, static_cast<int>(vocab_size_));
|
||||
host_cache_size_ = max_host_index_ - mix_host_index_;
|
||||
switch (data_type_) {
|
||||
case DataType::kNumberTypeFloat16:
|
||||
sizeof_data_type_ = sizeof(int16_t);
|
||||
break;
|
||||
default:
|
||||
sizeof_data_type_ = sizeof(float);
|
||||
}
|
||||
host_addr_ = static_cast<char *>(host_addr) + mix_host_index_ * embedding_size_ * sizeof_data_type_;
|
||||
MS_LOG(INFO) << "rank_group_size_ num:" << rank_group_size_ << ", rank id:" << rank_id_
|
||||
<< ", vocab_size_:" << vocab_size_ << ", host_cache_size_:" << host_cache_size_
|
||||
<< ", device_cache_size_:" << device_cache_size_ << ", embedding_size_:" << embedding_size_
|
||||
<< ", batch_elements_:" << batch_elements_ << ", index begin:" << mix_host_index_
|
||||
<< ", index end:" << max_host_index_;
|
||||
}
|
||||
~EmbeddingCache();
|
||||
Status Init();
|
||||
Status SetHostCacheAddr(void *addr, size_t size);
|
||||
Status SetDeviceCacheAddr(void *host_mem_addr, size_t size);
|
||||
Status CheckCacheHit(const int *batch_ids, const size_t batch_ids_len, int *hash_index);
|
||||
|
||||
private:
|
||||
std::shared_ptr<ps::PsCacheBasic> device_cache_{nullptr};
|
||||
std::shared_ptr<CacheAlgorithm> cache_{nullptr};
|
||||
|
||||
size_t vocab_size_{0}; // total size
|
||||
size_t host_cache_size_{0}; // local host size
|
||||
size_t device_cache_size_{0}; // local device cache size
|
||||
size_t embedding_size_{0};
|
||||
size_t batch_elements_{0};
|
||||
|
||||
DataType data_type_{DataType::kNumberTypeFloat32};
|
||||
size_t sizeof_data_type_{0};
|
||||
|
||||
void *device_addr_{nullptr}; // hash_info.device_address.addr
|
||||
void *host_addr_{nullptr};
|
||||
|
||||
int *hash_swap_index_addr_; // embedding_device_cache_->hash_swap_index_addr_
|
||||
void *hash_swap_value_addr_;
|
||||
void *hash_swap_value_device_addr_;
|
||||
|
||||
int rank_id_;
|
||||
int rank_group_size_;
|
||||
int mix_host_index_{0};
|
||||
int max_host_index_{0};
|
||||
};
|
||||
} // namespace cache
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_EMBEDDING_CACHE_H_
|
|
@ -0,0 +1,146 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "src/delegate/parameter_cache/embedding_cache_manager.h"
|
||||
#include <cuda_runtime.h>
|
||||
#include <cmath>
|
||||
#include <cstring>
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "include/errorcode.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace cache {
|
||||
Status EmbeddingCacheManager::Init(const std::string &cache_model_path) {
|
||||
if (cache_model_path.empty()) {
|
||||
MS_LOG(INFO) << "no cache model ";
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
host_cache_model_ = std::make_shared<HostCacheModel>();
|
||||
if (host_cache_model_ == nullptr) {
|
||||
MS_LOG(ERROR) << "HostCacheModel malloc failed";
|
||||
return kLiteMemoryFailed;
|
||||
}
|
||||
auto ret = host_cache_model_->LoadCache(cache_model_path);
|
||||
MS_LOG(INFO) << "cache manager init end, ret " << ret.ToString();
|
||||
return ret;
|
||||
}
|
||||
|
||||
bool EmbeddingCacheManager::CheckIsCacheKernel(kernel::Kernel *kernel) {
|
||||
if (host_cache_model_ == nullptr) {
|
||||
return false;
|
||||
}
|
||||
return host_cache_model_->CheckIsCacheKernel(kernel);
|
||||
}
|
||||
|
||||
Status EmbeddingCacheManager::InitCacheKernel(kernel::Kernel *kernel) {
|
||||
if (host_cache_model_ == nullptr) {
|
||||
MS_LOG(ERROR) << "cache model is nullptr, kernel " << kernel->name() << " init cache failed";
|
||||
return kLiteError;
|
||||
}
|
||||
auto host_cache_tensor = host_cache_model_->GetHostCacheTensor(kernel);
|
||||
if (host_cache_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << kernel->name() << ": invalid cache kernel";
|
||||
return kLiteError;
|
||||
}
|
||||
|
||||
// only support embedding cache
|
||||
if (kernel->type() != schema::PrimitiveType_Gather) {
|
||||
MS_LOG(ERROR) << kernel->name() << " is not embedding kernel";
|
||||
return kLiteError;
|
||||
}
|
||||
|
||||
auto tensor = kernel->inputs()[0];
|
||||
if (tensor.Shape()[1] != host_cache_tensor.Shape()[1]) {
|
||||
MS_LOG(ERROR) << kernel->name() << " embedding_size is invalid, device size is " << tensor.Shape()[1]
|
||||
<< " host size is " << host_cache_tensor.Shape()[1];
|
||||
return kLiteError;
|
||||
}
|
||||
size_t vocab_size = host_cache_tensor.Shape()[0]; // 需要跟host_cache_size一起处理下
|
||||
size_t host_cache_size = vocab_size / rank_group_size_;
|
||||
size_t device_cache_size = tensor.Shape()[0];
|
||||
size_t embedding_size_ = tensor.Shape()[1];
|
||||
DataType data_type = tensor.DataType();
|
||||
size_t batch_elements = kernel->inputs()[1].ElementNum();
|
||||
auto cache =
|
||||
std::make_shared<EmbeddingCache>(vocab_size, host_cache_size, device_cache_size, embedding_size_, batch_elements,
|
||||
data_type, host_cache_tensor.MutableData(), rank_id_, rank_group_size_);
|
||||
if (cache == nullptr) {
|
||||
MS_LOG(ERROR) << kernel->name() << ": malloc EmbeddingCache failed";
|
||||
return kLiteError;
|
||||
}
|
||||
|
||||
auto ret = cache->Init();
|
||||
if (ret != kSuccess) {
|
||||
MS_LOG(ERROR) << kernel->name() << ": EmbeddingCache init failed";
|
||||
return kLiteError;
|
||||
}
|
||||
caches_[tensor.Name()] = cache;
|
||||
|
||||
MS_LOG(INFO) << kernel->name() << " is cache kernel, input tensor " << kernel->inputs()[1].Name() << ", cache tensor "
|
||||
<< tensor.Name();
|
||||
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
bool EmbeddingCacheManager::IsCacheTensor(mindspore::MSTensor tensor) {
|
||||
if (host_cache_model_ == nullptr) {
|
||||
return false;
|
||||
}
|
||||
auto cache = caches_.find(tensor.Name());
|
||||
if (cache != caches_.end()) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
Status EmbeddingCacheManager::SetDeviceCacheAddr(const std::string &tensor_name, void *device_mem_addr, size_t size) {
|
||||
auto cache_iter = caches_.find(tensor_name);
|
||||
if (cache_iter == caches_.end() || cache_iter->second == nullptr) {
|
||||
MS_LOG(ERROR) << "not find cache, " << tensor_name;
|
||||
return kLiteError;
|
||||
}
|
||||
auto cache = cache_iter->second;
|
||||
return cache->SetDeviceCacheAddr(device_mem_addr, size);
|
||||
}
|
||||
|
||||
// device_addr is model input device addr
|
||||
int EmbeddingCacheManager::CacheHandle(const std::string &tensor_name, mindspore::MSTensor model_input_tensor,
|
||||
void *model_input_device_addr) {
|
||||
auto cache_iter = caches_.find(tensor_name);
|
||||
if (cache_iter == caches_.end()) {
|
||||
MS_LOG(ERROR) << "not find cache, " << tensor_name;
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
auto cache = cache_iter->second;
|
||||
hash_indices_.resize(model_input_tensor.ElementNum());
|
||||
auto ret = cache->CheckCacheHit(static_cast<int *>(model_input_tensor.MutableData()), hash_indices_.size(),
|
||||
hash_indices_.data());
|
||||
if (ret != kSuccess) {
|
||||
MS_LOG(ERROR) << "CheckCacheHit failed, " << model_input_tensor.Name();
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
auto cuda_ret = cudaMemcpy(model_input_device_addr, model_input_tensor.MutableData(),
|
||||
hash_indices_.size() * sizeof(int), cudaMemcpyHostToDevice);
|
||||
if (cuda_ret != cudaSuccess) {
|
||||
MS_LOG(ERROR) << "copy mem failed, " << model_input_tensor.Name();
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
MS_LOG(INFO) << "cache handle succ, " << model_input_tensor.Name() << "," << tensor_name;
|
||||
|
||||
return lite::RET_OK;
|
||||
}
|
||||
} // namespace cache
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,57 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_EMBEDDING_CACHE_MANAGER_H_
|
||||
#define MINDSPORE_LITE_EMBEDDING_CACHE_MANAGER_H_
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "include/api/kernel.h"
|
||||
#include "src/delegate/parameter_cache/embedding_cache.h"
|
||||
#include "src/delegate/parameter_cache/lfu_cache.h"
|
||||
#include "ps/ps_cache/ps_cache_basic.h"
|
||||
#include "src/delegate/parameter_cache/load_host_cache_model.h"
|
||||
#include "include/api/status.h"
|
||||
#include "include/api/data_type.h"
|
||||
#include "src/delegate/tensorrt/distribution/distribution_base.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace cache {
|
||||
class EmbeddingCacheManager {
|
||||
public:
|
||||
EmbeddingCacheManager() {
|
||||
rank_id_ = lite::GetRankID();
|
||||
rank_group_size_ = lite::GetGPUGroupSize();
|
||||
}
|
||||
Status Init(const std::string &cache_model_path);
|
||||
bool CheckIsCacheKernel(kernel::Kernel *kernel);
|
||||
Status InitCacheKernel(kernel::Kernel *kernel);
|
||||
bool IsCacheTensor(mindspore::MSTensor tensor);
|
||||
int CacheHandle(const std::string &tensor_name, mindspore::MSTensor model_input_tensor, void *device_addr);
|
||||
Status SetDeviceCacheAddr(const std::string &tensor_name, void *device_mem_addr, size_t size);
|
||||
|
||||
private:
|
||||
std::map<std::string, std::shared_ptr<EmbeddingCache>> caches_;
|
||||
std::vector<int> hash_indices_;
|
||||
int rank_id_{0};
|
||||
int rank_group_size_{1};
|
||||
|
||||
std::shared_ptr<HostCacheModel> host_cache_model_;
|
||||
};
|
||||
} // namespace cache
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_EMBEDDING_CACHE_MANAGER_H_
|
|
@ -0,0 +1,143 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/delegate/parameter_cache/gpu/gpu_cache_mem.h"
|
||||
#include <cuda_runtime.h>
|
||||
#include "src/delegate/tensorrt/cuda_impl/hash.cuh"
|
||||
#include "runtime/device/gpu/cuda_driver.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace gpu {
|
||||
#define CHECK_CUDA_RET_WITH_RETURN_ERROR_NOTRACE(expression, message) \
|
||||
do { \
|
||||
cudaError_t status = (expression); \
|
||||
if (status != cudaSuccess) { \
|
||||
MS_LOG(ERROR) << "CUDA Error: " << message << " | Error Number: " << status << " " \
|
||||
<< cudaGetErrorString(status); \
|
||||
return false; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define MS_ERROR_IF_NULL_W_RET_VAL(ptr, val) \
|
||||
do { \
|
||||
if ((ptr) == nullptr) { \
|
||||
MS_LOG(ERROR) << ": The pointer[" << #ptr << "] is null."; \
|
||||
return val; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define MS_ERROR_IF_NULL(ptr) \
|
||||
do { \
|
||||
if ((ptr) == nullptr) { \
|
||||
MS_LOG(ERROR) << ": The pointer[" << #ptr << "] is null."; \
|
||||
return false; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
bool GPUCacheMem::InitDevice(uint32_t device_id, const void *) {
|
||||
auto ret = cudaSetDevice(static_cast<int>(device_id));
|
||||
if (ret != cudaSuccess) {
|
||||
MS_LOG(ERROR) << "Failed to set device id:" << device_id;
|
||||
return false;
|
||||
}
|
||||
CHECK_CUDA_RET_WITH_RETURN_ERROR_NOTRACE(cudaStreamCreate(reinterpret_cast<CUstream_st **>(&stream_)),
|
||||
"Cuda create stream failed");
|
||||
return true;
|
||||
}
|
||||
|
||||
void *GPUCacheMem::MallocMemory(size_t size) {
|
||||
void *device_ptr = nullptr;
|
||||
auto cuda_ret = cudaMalloc(&device_ptr, size);
|
||||
if (cuda_ret != cudaSuccess) {
|
||||
MS_LOG(ERROR) << "Cuda Malloc failed for size:" << size;
|
||||
return nullptr;
|
||||
}
|
||||
MS_LOG(INFO) << "cudaMalloc size: " << size;
|
||||
return device_ptr;
|
||||
}
|
||||
|
||||
void GPUCacheMem::FreeMemory(void *device_addr) {
|
||||
auto cuda_ret = cudaFree(device_addr);
|
||||
if (cuda_ret != cudaSuccess && cuda_ret != cudaErrorCudartUnloading) {
|
||||
MS_LOG(WARNING) << "free cuda failed for " << cudaGetErrorName(cuda_ret);
|
||||
}
|
||||
}
|
||||
|
||||
bool GPUCacheMem::RecordEvent() {
|
||||
event_.reset(new cudaEvent_t());
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(event_, false);
|
||||
CHECK_CUDA_RET_WITH_RETURN_ERROR_NOTRACE(cudaEventCreate(&(*event_)), "Cuda create event failed");
|
||||
CHECK_CUDA_RET_WITH_RETURN_ERROR_NOTRACE(cudaEventRecord(*event_, reinterpret_cast<cudaStream_t>(stream_)),
|
||||
"Cuda record event failed");
|
||||
return true;
|
||||
}
|
||||
|
||||
bool GPUCacheMem::SynchronizeEvent() {
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(event_, false);
|
||||
CHECK_CUDA_RET_WITH_RETURN_ERROR_NOTRACE(cudaEventSynchronize(*event_), "Cuda sync event failed");
|
||||
CHECK_CUDA_RET_WITH_RETURN_ERROR_NOTRACE(cudaEventDestroy(*event_), "Cuda destroy event failed");
|
||||
return true;
|
||||
}
|
||||
|
||||
bool GPUCacheMem::SynchronizeStream() {
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(stream_, false);
|
||||
CHECK_CUDA_RET_WITH_RETURN_ERROR_NOTRACE(cudaStreamSynchronize(reinterpret_cast<cudaStream_t>(stream_)),
|
||||
"Cuda sync stream failed");
|
||||
return true;
|
||||
}
|
||||
|
||||
bool GPUCacheMem::CopyHostMemToDevice(void *dst, const void *src, size_t size) {
|
||||
MS_ERROR_IF_NULL(dst);
|
||||
MS_ERROR_IF_NULL(src);
|
||||
CHECK_CUDA_RET_WITH_RETURN_ERROR_NOTRACE(
|
||||
cudaMemcpyAsync(dst, src, size, cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_)),
|
||||
"Cuda memcpy failed");
|
||||
return true;
|
||||
}
|
||||
|
||||
bool GPUCacheMem::CopyDeviceMemToHost(void *dst, const void *src, size_t size) {
|
||||
MS_ERROR_IF_NULL(dst);
|
||||
MS_ERROR_IF_NULL(src);
|
||||
CHECK_CUDA_RET_WITH_RETURN_ERROR_NOTRACE(
|
||||
cudaMemcpyAsync(dst, src, size, cudaMemcpyDeviceToHost, reinterpret_cast<cudaStream_t>(stream_)),
|
||||
"Cuda memcpy failed");
|
||||
return true;
|
||||
}
|
||||
|
||||
bool GPUCacheMem::HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, size_t,
|
||||
size_t embedding_size, size_t swap_out_size) {
|
||||
MS_ERROR_IF_NULL(hash_table_addr);
|
||||
MS_ERROR_IF_NULL(swap_out_value_addr);
|
||||
MS_ERROR_IF_NULL(swap_out_index_addr);
|
||||
DoHashSwapOut(reinterpret_cast<float *>(hash_table_addr), reinterpret_cast<float *>(swap_out_value_addr),
|
||||
reinterpret_cast<int *>(swap_out_index_addr), swap_out_size, embedding_size,
|
||||
reinterpret_cast<cudaStream_t>(stream_));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool GPUCacheMem::HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, size_t,
|
||||
size_t embedding_size, size_t swap_in_size) {
|
||||
MS_ERROR_IF_NULL(hash_table_addr);
|
||||
MS_ERROR_IF_NULL(swap_in_value_addr);
|
||||
MS_ERROR_IF_NULL(swap_in_index_addr);
|
||||
DoHashSwapIn(reinterpret_cast<float *>(hash_table_addr), reinterpret_cast<float *>(swap_in_value_addr),
|
||||
reinterpret_cast<int *>(swap_in_index_addr), swap_in_size, embedding_size,
|
||||
reinterpret_cast<cudaStream_t>(stream_));
|
||||
return true;
|
||||
}
|
||||
} // namespace gpu
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,48 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_PS_CACHE_GPU_GPU_CACHE_MEM_H_
|
||||
#define MINDSPORE_CCSRC_PS_PS_CACHE_GPU_GPU_CACHE_MEM_H_
|
||||
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <memory>
|
||||
#include "ps/ps_cache/ps_cache_basic.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace gpu {
|
||||
class GPUCacheMem : public ps::PsCacheBasic {
|
||||
public:
|
||||
GPUCacheMem() = default;
|
||||
~GPUCacheMem() override = default;
|
||||
bool InitDevice(uint32_t device_id, const void *context) override;
|
||||
void *MallocMemory(size_t size) override;
|
||||
void FreeMemory(void *buf) override;
|
||||
bool RecordEvent() override;
|
||||
bool SynchronizeEvent() override;
|
||||
bool SynchronizeStream() override;
|
||||
bool CopyHostMemToDevice(void *dst, const void *src, size_t size) override;
|
||||
bool CopyDeviceMemToHost(void *dst, const void *src, size_t size) override;
|
||||
bool HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, size_t cache_vocab_size,
|
||||
size_t embedding_size, size_t swap_out_size) override;
|
||||
bool HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, size_t cache_vocab_size,
|
||||
size_t embedding_size, size_t swap_in_size) override;
|
||||
|
||||
private:
|
||||
std::unique_ptr<cudaEvent_t> event_{nullptr};
|
||||
};
|
||||
} // namespace gpu
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_PS_CACHE_GPU_GPU_PS_CACHE_H_
|
|
@ -0,0 +1,230 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <vector>
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "src/delegate/parameter_cache/lfu_cache.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace cache {
|
||||
LFUCacheAlgorithm::~LFUCacheAlgorithm() {
|
||||
for (auto iter : key_table_) {
|
||||
delete *(iter.second);
|
||||
}
|
||||
key_table_.clear();
|
||||
frequency_table_.clear();
|
||||
}
|
||||
|
||||
CacheNoe *LFUCacheAlgorithm::GetNode(int key) {
|
||||
auto key_table_iter = key_table_.find(key);
|
||||
if (key_table_iter == key_table_.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
auto node_iter = key_table_iter->second;
|
||||
auto node = *node_iter;
|
||||
|
||||
auto node_list_iter = frequency_table_.find(key);
|
||||
if (node_list_iter == frequency_table_.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
auto &node_list = node_list_iter->second;
|
||||
node_list.erase(node_iter);
|
||||
|
||||
if (node_list.empty()) {
|
||||
frequency_table_.erase(node_list_iter);
|
||||
}
|
||||
|
||||
node->frequency += 1;
|
||||
frequency_table_[node->frequency].emplace_front(node);
|
||||
key_table_[key] = frequency_table_[node->frequency].begin();
|
||||
return node;
|
||||
}
|
||||
|
||||
int LFUCacheAlgorithm::Get(int key) {
|
||||
auto node = GetNode(key);
|
||||
if (node != nullptr) {
|
||||
return node->value;
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
void LFUCacheAlgorithm::Put(int key, int value) {
|
||||
auto node = GetNode(key);
|
||||
if (node != nullptr) {
|
||||
node->value = value;
|
||||
return;
|
||||
}
|
||||
|
||||
if (cache_size_ == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
CacheNoe *add_node = nullptr;
|
||||
if (key_table_.size() == cache_size_) {
|
||||
add_node = frequency_table_.begin()->second.back();
|
||||
key_table_.erase(add_node->key);
|
||||
frequency_table_.begin()->second.pop_back();
|
||||
if (frequency_table_.begin()->second.size() == 0) {
|
||||
frequency_table_.erase(frequency_table_.begin()->first);
|
||||
}
|
||||
add_node->value = value;
|
||||
add_node->key = key;
|
||||
add_node->frequency = 1;
|
||||
} else {
|
||||
add_node = new CacheNoe(key, 1, value);
|
||||
if (add_node == nullptr) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
frequency_table_[1].emplace_front(add_node);
|
||||
key_table_[key] = frequency_table_[1].begin();
|
||||
}
|
||||
|
||||
void LFUCacheAlgorithm::GetHitNodesAndSwapIndex(const int *batch_ids, const size_t batch_ids_len, int *cache_index,
|
||||
std::unordered_map<int, CacheNoe *> *hit_index_nodes,
|
||||
std::unordered_map<int, std::vector<int>> *need_swap_map) {
|
||||
// 找到没有命中和命中的index
|
||||
for (size_t i = 0; i < batch_ids_len; i++) {
|
||||
auto key = batch_ids[i];
|
||||
if (key < min_host_index_ || key >= max_host_index_) {
|
||||
cache_index[i] = -1;
|
||||
// out range
|
||||
continue;
|
||||
}
|
||||
|
||||
auto hit_iter = hit_index_nodes->find(key);
|
||||
if (hit_iter != hit_index_nodes->end()) {
|
||||
auto node = hit_iter->second;
|
||||
node->frequency += 1;
|
||||
cache_index[i] = node->value;
|
||||
continue;
|
||||
}
|
||||
|
||||
auto swap_iter = need_swap_map->find(key);
|
||||
if (swap_iter != need_swap_map->end()) {
|
||||
swap_iter->second.push_back(i);
|
||||
continue;
|
||||
}
|
||||
|
||||
auto node_iter_iter = key_table_.find(key);
|
||||
if (node_iter_iter == key_table_.end()) {
|
||||
(*need_swap_map)[key].push_back(i);
|
||||
continue;
|
||||
}
|
||||
auto node_iter = node_iter_iter->second;
|
||||
auto node = *node_iter;
|
||||
|
||||
auto node_list_iter = frequency_table_.find(node->frequency);
|
||||
if (node_list_iter == frequency_table_.end()) {
|
||||
continue;
|
||||
}
|
||||
auto &node_list = node_list_iter->second;
|
||||
node_list.erase(node_iter);
|
||||
|
||||
if (node_list.empty()) {
|
||||
frequency_table_.erase(node_list_iter);
|
||||
}
|
||||
// hit
|
||||
node->frequency += 1;
|
||||
cache_index[i] = node->value;
|
||||
(*hit_index_nodes)[key] = node;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
std::list<CacheNoe *> LFUCacheAlgorithm::GetSwapNodes(const std::unordered_map<int, std::vector<int>> &need_swap_map) {
|
||||
std::list<CacheNoe *> need_swap_nodes;
|
||||
auto swap_size = need_swap_map.size();
|
||||
|
||||
while (swap_size > 0 && !frequency_table_.empty()) {
|
||||
auto node_list_iter = frequency_table_.begin();
|
||||
if (node_list_iter->second.size() > swap_size) {
|
||||
auto iter = node_list_iter->second.begin();
|
||||
std::advance(iter, swap_size);
|
||||
need_swap_nodes.splice(need_swap_nodes.end(), node_list_iter->second, node_list_iter->second.begin(), iter);
|
||||
swap_size = 0;
|
||||
} else {
|
||||
swap_size -= node_list_iter->second.size();
|
||||
need_swap_nodes.splice(need_swap_nodes.end(), node_list_iter->second);
|
||||
frequency_table_.erase(node_list_iter);
|
||||
}
|
||||
}
|
||||
return need_swap_nodes;
|
||||
}
|
||||
|
||||
Status LFUCacheAlgorithm::CheckCacheHit(const int *batch_ids, const size_t batch_ids_len, int *cache_index,
|
||||
std::vector<int> *need_swap_indies,
|
||||
std::vector<int> *need_swap_indies_cache_index) {
|
||||
if (batch_ids == nullptr) {
|
||||
MS_LOG(ERROR) << "batch_ids is nullptr";
|
||||
return kLiteNullptr;
|
||||
}
|
||||
if (cache_index == nullptr) {
|
||||
MS_LOG(ERROR) << "cache_index is nullptr";
|
||||
return kLiteNullptr;
|
||||
}
|
||||
std::unordered_map<int, std::vector<int>> need_swap_map;
|
||||
std::unordered_map<int, CacheNoe *> hit_index_nodes;
|
||||
GetHitNodesAndSwapIndex(batch_ids, batch_ids_len, cache_index, &hit_index_nodes, &need_swap_map);
|
||||
|
||||
// get need_swap_indies.size() least recently used node
|
||||
std::list<CacheNoe *> need_swap_nodes = GetSwapNodes(need_swap_map);
|
||||
|
||||
// 更新老节点的值
|
||||
{
|
||||
if (need_swap_map.size() != need_swap_nodes.size()) {
|
||||
MS_LOG(ERROR) << " need_swap_map.size() " << need_swap_map.size() << " != need_swap_nodes.size() "
|
||||
<< need_swap_nodes.size();
|
||||
return kLiteError;
|
||||
}
|
||||
need_swap_indies_cache_index->reserve(need_swap_map.size());
|
||||
auto need_swap_map_iter = need_swap_map.begin();
|
||||
for (auto iter = need_swap_nodes.begin();
|
||||
iter != need_swap_nodes.end() && need_swap_map_iter != need_swap_map.end(); iter++, need_swap_map_iter++) {
|
||||
auto node = *iter;
|
||||
key_table_.erase(node->key);
|
||||
node->key = need_swap_map_iter->first;
|
||||
node->frequency = 1;
|
||||
for (auto index : need_swap_map_iter->second) {
|
||||
cache_index[index] = node->value;
|
||||
}
|
||||
need_swap_indies->push_back(need_swap_map_iter->first);
|
||||
need_swap_indies_cache_index->push_back(node->value);
|
||||
MS_LOG(DEBUG) << "device index " << node->value << ",for host index " << need_swap_map_iter->first;
|
||||
key_table_[(*iter)->key] = iter;
|
||||
}
|
||||
|
||||
auto node_list_iter = frequency_table_.begin();
|
||||
if (node_list_iter->second.size() > 0) {
|
||||
auto iter = node_list_iter->second.begin();
|
||||
if ((*iter)->frequency == 1) {
|
||||
node_list_iter->second.splice(node_list_iter->second.begin(), need_swap_nodes);
|
||||
} else {
|
||||
frequency_table_[1] = need_swap_nodes;
|
||||
}
|
||||
} else {
|
||||
frequency_table_[1] = need_swap_nodes;
|
||||
}
|
||||
}
|
||||
for (auto node_iter : hit_index_nodes) {
|
||||
auto node = node_iter.second;
|
||||
frequency_table_[node->frequency].emplace_front(node);
|
||||
key_table_[node->key] = frequency_table_[node->frequency].begin();
|
||||
}
|
||||
return kSuccess;
|
||||
}
|
||||
} // namespace cache
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,72 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_LRU_CACHE_H_
|
||||
#define MINDSPORE_LITE_LRU_CACHE_H_
|
||||
|
||||
#include <map>
|
||||
#include <unordered_map>
|
||||
#include <list>
|
||||
#include <vector>
|
||||
#include "include/api/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace cache {
|
||||
struct CacheNoe {
|
||||
CacheNoe(int _index, int _frequency, int _value) : key(_index), frequency(_frequency), value(_value) {}
|
||||
int key; // host input index
|
||||
int frequency;
|
||||
int value; // cache index
|
||||
};
|
||||
|
||||
class CacheAlgorithm {
|
||||
public:
|
||||
virtual ~CacheAlgorithm() {}
|
||||
virtual int Get(int key) = 0;
|
||||
virtual void Put(int key, int value) = 0;
|
||||
virtual Status CheckCacheHit(const int *batch_ids, const size_t batch_ids_len, int *cache_index,
|
||||
std::vector<int> *need_swap_indies, std::vector<int> *need_swap_indies_cache_index) = 0;
|
||||
};
|
||||
|
||||
class LFUCacheAlgorithm : public CacheAlgorithm {
|
||||
public:
|
||||
LFUCacheAlgorithm(size_t cache_size, int min_host_index, int max_host_index)
|
||||
: cache_size_(cache_size), min_host_index_(min_host_index), max_host_index_(max_host_index) {}
|
||||
~LFUCacheAlgorithm() override;
|
||||
|
||||
int Get(int key) override;
|
||||
|
||||
void Put(int key, int value) override;
|
||||
Status CheckCacheHit(const int *batch_ids, const size_t batch_ids_len, int *cache_index,
|
||||
std::vector<int> *need_swap_indies, std::vector<int> *need_swap_indies_cache_index) override;
|
||||
|
||||
private:
|
||||
CacheNoe *GetNode(int key);
|
||||
void GetHitNodesAndSwapIndex(const int *batch_ids, const size_t batch_ids_len, int *cache_index,
|
||||
std::unordered_map<int, CacheNoe *> *hit_index_nodes,
|
||||
std::unordered_map<int, std::vector<int>> *need_swap_map);
|
||||
std::list<CacheNoe *> GetSwapNodes(const std::unordered_map<int, std::vector<int>> &need_swap_map);
|
||||
|
||||
std::unordered_map<int, std::list<CacheNoe *>::iterator> key_table_;
|
||||
std::map<int, std::list<CacheNoe *>> frequency_table_;
|
||||
size_t cache_size_;
|
||||
|
||||
int min_host_index_{0};
|
||||
int max_host_index_{1};
|
||||
};
|
||||
} // namespace cache
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_LRU_CACHE_H_
|
|
@ -0,0 +1,99 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <cmath>
|
||||
#include <cstring>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "src/delegate/parameter_cache/load_host_cache_model.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "src/common/file_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace cache {
|
||||
MSTensor *SchemaTensorToMSTensor(lite::SchemaTensorWrapper *schema_tensor_wrapper,
|
||||
mindspore::schema::Tensor *schema_tensor) {
|
||||
std::vector<int64_t> shape;
|
||||
for (size_t j = 0; j < schema_tensor->dims()->size(); j++) {
|
||||
shape.push_back(schema_tensor->dims()->data()[j]);
|
||||
}
|
||||
std::string tensor_name;
|
||||
if (schema_tensor->name() != nullptr) {
|
||||
tensor_name = schema_tensor->name()->str();
|
||||
}
|
||||
return MSTensor::CreateRefTensor(tensor_name, (DataType)schema_tensor->dataType(), shape,
|
||||
schema_tensor_wrapper->data(), schema_tensor_wrapper->length());
|
||||
}
|
||||
|
||||
Status HostCacheModel::LoadCache(const std::string &model_path) {
|
||||
cache_model_ = lite::LiteImportFromPath(model_path.c_str());
|
||||
if (cache_model_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Import model failed";
|
||||
return kLiteGraphFileError;
|
||||
}
|
||||
|
||||
auto allTensors = cache_model_->all_tensors_;
|
||||
for (auto node : cache_model_->all_nodes_) {
|
||||
// only support embedding cache
|
||||
if (node == nullptr || node->node_type_ != schema::PrimitiveType_Gather) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto input_index = node->input_indices_[0];
|
||||
if (input_index > allTensors.size() - 1) {
|
||||
MS_LOG(ERROR) << "invalid kernel input, input_index " << input_index << ",allTensors.size() "
|
||||
<< allTensors.size();
|
||||
return kLiteOutOfTensorRange;
|
||||
}
|
||||
auto schema_tensor_wrapper = cache_model_->GetSchemaTensor(input_index);
|
||||
if (schema_tensor_wrapper == nullptr) {
|
||||
MS_LOG(ERROR) << "invalid kernel input, input_index " << input_index;
|
||||
return kLiteOutOfTensorRange;
|
||||
}
|
||||
|
||||
auto schema_tensor = allTensors[input_index];
|
||||
if (schema_tensor != nullptr && schema_tensor_wrapper->data() != nullptr) {
|
||||
auto tensor = SchemaTensorToMSTensor(schema_tensor_wrapper, schema_tensor);
|
||||
if (tensor == nullptr) {
|
||||
return kLiteMemoryFailed;
|
||||
}
|
||||
cache_tensor_[tensor->Name()] = *tensor;
|
||||
MS_LOG(INFO) << tensor->Name() << " is cache tensor, and the node is [" << node->name_ << "]";
|
||||
delete tensor;
|
||||
}
|
||||
}
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
bool HostCacheModel::CheckIsCacheKernel(kernel::Kernel *kernel) {
|
||||
if (GetHostCacheTensor(kernel) == nullptr) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
MSTensor HostCacheModel::GetHostCacheTensor(kernel::Kernel *kernel) {
|
||||
if (kernel != nullptr && kernel->inputs().size() > 0) {
|
||||
auto iter = cache_tensor_.find(kernel->inputs()[0].Name());
|
||||
if (iter != cache_tensor_.end()) {
|
||||
return iter->second;
|
||||
}
|
||||
}
|
||||
return MSTensor(nullptr);
|
||||
}
|
||||
} // namespace cache
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,47 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_LOAD_HOST_CACHE_DATA_H_
|
||||
#define MINDSPORE_LITE_LOAD_HOST_CACHE_DATA_H_
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include "src/delegate/parameter_cache/lfu_cache.h"
|
||||
#include "ps/ps_cache/ps_cache_basic.h"
|
||||
#include "include/api/status.h"
|
||||
#include "include/api/data_type.h"
|
||||
#include "include/api/types.h"
|
||||
#include "src/lite_model.h"
|
||||
#include "include/api/kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace cache {
|
||||
class HostCacheModel {
|
||||
public:
|
||||
HostCacheModel() {}
|
||||
Status LoadCache(const std::string &model_path);
|
||||
bool CheckIsCacheKernel(kernel::Kernel *kernel);
|
||||
MSTensor GetHostCacheTensor(kernel::Kernel *kernel);
|
||||
|
||||
private:
|
||||
std::map<std::string, MSTensor> cache_tensor_;
|
||||
mindspore::lite::LiteModel *cache_model_{nullptr};
|
||||
char *model_buf_{nullptr};
|
||||
size_t model_size_;
|
||||
};
|
||||
} // namespace cache
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_EMBEDDING_CACHE_H_
|
|
@ -37,6 +37,7 @@ else()
|
|||
endif()
|
||||
add_dependencies(gpu_distribution_collective fbs_src)
|
||||
|
||||
|
||||
file(GLOB TENSORRT_RUNTIME_SRC LIST_DIRECTORIES false
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/*.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/op/*.cc
|
||||
|
@ -44,6 +45,17 @@ file(GLOB TENSORRT_RUNTIME_SRC LIST_DIRECTORIES false
|
|||
${CMAKE_CURRENT_SOURCE_DIR}/../delegate_utils.cc
|
||||
)
|
||||
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../parameter_cache)
|
||||
|
||||
set(TENSORRT_RUNTIME_SRC
|
||||
${TENSORRT_RUNTIME_SRC}
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../parameter_cache/embedding_cache_manager.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../parameter_cache/load_host_cache_model.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../parameter_cache/lfu_cache.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../parameter_cache/embedding_cache.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../parameter_cache/gpu/gpu_cache_mem.cc
|
||||
)
|
||||
|
||||
link_libraries(${CUDA_LIB_PATH}/libcudnn.so)
|
||||
link_libraries(${CUDA_LIB_PATH}/libnvrtc.so)
|
||||
link_libraries(${CUDA_LIB_PATH}/libcublas.so)
|
||||
|
@ -73,3 +85,4 @@ set_source_files_properties(${CUDA_KERNEL_SRC} PROPERTIES CUDA_SOURCE_PROPERTY_F
|
|||
SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGES} -std=c++14 -fPIC")
|
||||
SET(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};-std=c++14;)
|
||||
cuda_add_library(cuda_kernel_mid STATIC ${CUDA_KERNEL_SRC})
|
||||
|
||||
|
|
|
@ -0,0 +1,64 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/delegate/tensorrt/cuda_impl/hash.cuh"
|
||||
#include "src/delegate/tensorrt/cuda_impl/cuda_helper.h"
|
||||
|
||||
template <typename T>
|
||||
__global__ void HashSwapOut(const T *hash_table, T *swap_out_value, const int *swap_out_index, const int index_size,
|
||||
const int hash_dim) {
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < index_size; i += blockDim.x * gridDim.x) {
|
||||
int hash_index = swap_out_index[i];
|
||||
for (int j = 0; j < hash_dim; j++) {
|
||||
swap_out_value[i * hash_dim + j] = hash_table[hash_index * hash_dim + j];
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void HashSwapIn(T *hash_table, const T *swap_in_value, const int *swap_in_index, const int index_size,
|
||||
const int hash_dim) {
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < index_size; i += blockDim.x * gridDim.x) {
|
||||
int hash_index = swap_in_index[i];
|
||||
for (int j = 0; j < hash_dim; j++) {
|
||||
hash_table[hash_index * hash_dim + j] = swap_in_value[i * hash_dim + j];
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void DoHashSwapOut(const T *hash_table, T *swap_out_value, const int *swap_out_index, const int index_size,
|
||||
const int hash_dim, cudaStream_t cuda_stream) {
|
||||
HashSwapOut<<<GET_BLOCKS(index_size), GET_THREADS, 0, cuda_stream>>>(hash_table, swap_out_value, swap_out_index,
|
||||
index_size, hash_dim);
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void DoHashSwapIn(T *hash_table, const T *swap_in_value, const int *swap_in_index, const int index_size,
|
||||
const int hash_dim, cudaStream_t cuda_stream) {
|
||||
HashSwapIn<<<GET_BLOCKS(index_size), GET_THREADS, 0, cuda_stream>>>(hash_table, swap_in_value, swap_in_index,
|
||||
index_size, hash_dim);
|
||||
return;
|
||||
}
|
||||
|
||||
template void DoHashSwapOut<float>(const float *hash_table, float *swap_out_value, const int *swap_out_index,
|
||||
const int index_size, const int hash_dim, cudaStream_t cuda_stream);
|
||||
|
||||
template void DoHashSwapIn<float>(float *hash_table, const float *swap_in_value, const int *swap_in_index,
|
||||
const int index_size, const int hash_dim, cudaStream_t cuda_stream);
|
|
@ -0,0 +1,27 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_CDUA_IMPL_HASH_H_
|
||||
#define MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_CDUA_IMPL_HASH_H_
|
||||
|
||||
template <typename T>
|
||||
void DoHashSwapOut(const T *hash_table, T *swap_out_value, const int *swap_out_index, const int index_size,
|
||||
const int hash_dim, cudaStream_t cuda_stream);
|
||||
|
||||
template <typename T>
|
||||
void DoHashSwapIn(T *hash_table, const T *swap_in_value, const int *swap_in_index, const int index_size,
|
||||
const int hash_dim, cudaStream_t cuda_stream);
|
||||
#endif // MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_CDUA_IMPL_HASH_H_
|
|
@ -132,6 +132,18 @@ Status TensorRTDelegate::Init() {
|
|||
MS_LOG(ERROR) << "TensorRTRuntime init failed.";
|
||||
return mindspore::kLiteError;
|
||||
}
|
||||
|
||||
cache_mgr_ = std::make_shared<cache::EmbeddingCacheManager>();
|
||||
if (cache_mgr_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc EmbeddingCacheManager failed.";
|
||||
return kLiteMemoryFailed;
|
||||
}
|
||||
auto cache_ret = cache_mgr_->Init("");
|
||||
if (cache_ret != mindspore::kSuccess) {
|
||||
MS_LOG(ERROR) << "cache_mgr_ init failed.";
|
||||
return cache_ret;
|
||||
}
|
||||
|
||||
return mindspore::kSuccess;
|
||||
}
|
||||
|
||||
|
@ -165,6 +177,14 @@ Status TensorRTDelegate::Build(DelegateModel *model) {
|
|||
}
|
||||
auto tensorrt_op = FindTensorRTOp(kernel, model->GetPrimitive(kernel));
|
||||
if (tensorrt_op != nullptr) {
|
||||
if (cache_mgr_->CheckIsCacheKernel(kernel)) {
|
||||
auto cache_ret = cache_mgr_->InitCacheKernel(kernel);
|
||||
if (cache_ret != kSuccess) {
|
||||
MS_LOG(INFO) << "InitCacheKernel failed " << kernel->name();
|
||||
return cache_ret;
|
||||
}
|
||||
}
|
||||
|
||||
// If tensorrt_ops does not equal nullptr, this kernel can be supported by delegate
|
||||
if (tensorrt_ops.size() == 0) {
|
||||
from = iter;
|
||||
|
@ -219,6 +239,8 @@ TensorRTSubGraph *TensorRTDelegate::CreateTensorRTGraph(const std::vector<Tensor
|
|||
MS_LOG(ERROR) << "new tensorrt_graph failed.";
|
||||
return nullptr;
|
||||
}
|
||||
tensorrt_graph->SetCacheManager(cache_mgr_);
|
||||
|
||||
// 1. For every op, find pre and next ops
|
||||
FindPreNextOps<TensorRTOp>(ops);
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include <memory>
|
||||
#include "include/api/delegate.h"
|
||||
#include "src/delegate/tensorrt/tensorrt_subgraph.h"
|
||||
#include "src/delegate/parameter_cache/embedding_cache_manager.h"
|
||||
#include "include/api/kernel.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
|
@ -63,6 +64,7 @@ class TensorRTDelegate : public Delegate {
|
|||
bool support_hw_resize_{true};
|
||||
|
||||
bool support_resize_{true};
|
||||
std::shared_ptr<cache::EmbeddingCacheManager> cache_mgr_{nullptr};
|
||||
};
|
||||
} // namespace mindspore::lite
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_DELEGATE_TENSORRT_DELEGATE_
|
||||
|
|
|
@ -352,12 +352,18 @@ int TensorRTSubGraph::Prepare() {
|
|||
}
|
||||
|
||||
// malloc for cache weight tensor
|
||||
for (auto cache_tensor : cache_inputs_) {
|
||||
for (auto cache_tensor : cache_const_inputs_) {
|
||||
auto data_size = cache_tensor.DataSize();
|
||||
auto device_ptr = runtime_->GetAllocator()->MallocDeviceMem(cache_tensor, cache_tensor.DataSize());
|
||||
runtime_->GetAllocator()->SyncMemInHostAndDevice(cache_tensor, cache_tensor.Name().c_str(), true);
|
||||
runtime_->GetAllocator()->MarkMemValid(cache_tensor.Name().c_str(), true);
|
||||
int index = this->engine_->getBindingIndex(cache_tensor.Name().c_str());
|
||||
tensor_bindings_[index] = device_ptr;
|
||||
auto cache_ret = cache_mgr_->SetDeviceCacheAddr(cache_tensor.Name(), device_ptr, data_size);
|
||||
if (cache_ret != kSuccess) {
|
||||
MS_LOG(ERROR) << "SetDeviceCacheAddr failed, cache tensor: " << cache_tensor.Name();
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
|
||||
if (!this->trt_context_->allInputDimensionsSpecified()) {
|
||||
|
@ -466,6 +472,22 @@ int TensorRTSubGraph::Execute() {
|
|||
MS_LOG(INFO) << "no need memcpy to cuda for input tensor: " << trt_in_tensor_name_[i];
|
||||
continue;
|
||||
}
|
||||
|
||||
auto iter = model_input_to_cache_tensors_.find(trt_in_tensor_name_[i]);
|
||||
if (iter != model_input_to_cache_tensors_.end()) {
|
||||
for (auto &cache_tensor : iter->second) {
|
||||
ret = cache_mgr_->CacheHandle(cache_tensor.Name(), inputs_[i],
|
||||
runtime_->GetAllocator()->GetDevicePtr(trt_in_tensor_name_[i]));
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "handle cache failed " << trt_in_tensor_name_[i];
|
||||
return RET_ERROR;
|
||||
}
|
||||
runtime_->GetAllocator()->MarkMemValid(trt_in_tensor_name_[i], true);
|
||||
MS_LOG(DEBUG) << cache_tensor.Name() << " CacheHandle succ " << trt_in_tensor_name_[i];
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
ret = runtime_->GetAllocator()->SyncMemInHostAndDevice(inputs_[i], trt_in_tensor_name_[i], true);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "sync mem from host to device failed for " << trt_in_tensor_name_[i];
|
||||
|
@ -525,9 +547,11 @@ ITensorHelper TensorRTSubGraph::FindTensorRTInputs(TensorRTOp *cur_op, const min
|
|||
}
|
||||
return ITensorHelper{};
|
||||
}
|
||||
bool TensorRTSubGraph::IsCached(TensorRTOp *cur_op, const mindspore::MSTensor &in_tensor) { return false; }
|
||||
bool TensorRTSubGraph::IsCached(TensorRTOp *cur_op, const mindspore::MSTensor &in_tensor) {
|
||||
return cache_mgr_ != nullptr && cache_mgr_->IsCacheTensor(in_tensor);
|
||||
}
|
||||
|
||||
void TensorRTSubGraph::FindCacheTensorInfo(TensorRTOp *cur_op) {
|
||||
void TensorRTSubGraph::FindCacheTensorInfo(TensorRTOp *cur_op, mindspore::MSTensor device_cache_tensor) {
|
||||
auto iter = network_cache_tensor_info_.find(cur_op->GetOpName());
|
||||
if (iter != network_cache_tensor_info_.end()) {
|
||||
return;
|
||||
|
@ -542,6 +566,7 @@ void TensorRTSubGraph::FindCacheTensorInfo(TensorRTOp *cur_op) {
|
|||
for (auto in_tensor : front_op->inputs()) {
|
||||
if (IsSubGraphInputTensor(this->inputs(), in_tensor)) {
|
||||
iter->second.network_input_tensor_.push_back(in_tensor);
|
||||
model_input_to_cache_tensors_[in_tensor.Name()].push_back(device_cache_tensor);
|
||||
MS_LOG(DEBUG) << cur_op->GetOpName() << "'s network input tensor name is " << in_tensor.Name()
|
||||
<< ", can cache: " << iter->second.front_op_can_cache_;
|
||||
}
|
||||
|
@ -556,9 +581,9 @@ void TensorRTSubGraph::FindCacheTensorInfo(TensorRTOp *cur_op) {
|
|||
bool TensorRTSubGraph::CanOpCache(TensorRTOp *cur_op) { return true; }
|
||||
|
||||
int TensorRTSubGraph::HandleCacheTensor(TensorRTOp *cur_op, const mindspore::MSTensor &in_tensor) {
|
||||
FindCacheTensorInfo(cur_op);
|
||||
FindCacheTensorInfo(cur_op, in_tensor);
|
||||
// cache kernel weight tensor
|
||||
cache_inputs_.push_back(in_tensor);
|
||||
cache_const_inputs_.push_back(in_tensor);
|
||||
MS_LOG(INFO) << "auto add cache constant tensor for: " << in_tensor.Name();
|
||||
auto cuda_dtype = ConvertDataType(in_tensor.DataType());
|
||||
nvinfer1::Dims input_dims = ConvertCudaDims(in_tensor.Shape());
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
#include "include/api/kernel.h"
|
||||
#include "src/delegate/tensorrt/tensorrt_runtime.h"
|
||||
#include "src/delegate/tensorrt/tensorrt_utils.h"
|
||||
#include "src/delegate/parameter_cache/embedding_cache_manager.h"
|
||||
#include "include/api/context.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
|
@ -74,6 +75,8 @@ class TensorRTSubGraph : public kernel::Kernel {
|
|||
|
||||
int Init();
|
||||
|
||||
void SetCacheManager(const std::shared_ptr<cache::EmbeddingCacheManager> &cache_mgr) { cache_mgr_ = cache_mgr; }
|
||||
|
||||
private:
|
||||
int BuildEngine();
|
||||
|
||||
|
@ -89,7 +92,7 @@ class TensorRTSubGraph : public kernel::Kernel {
|
|||
|
||||
bool IsCached(TensorRTOp *cur_op, const mindspore::MSTensor &in_tensor);
|
||||
|
||||
void FindCacheTensorInfo(TensorRTOp *cur_op);
|
||||
void FindCacheTensorInfo(TensorRTOp *cur_op, mindspore::MSTensor device_cache_tensor);
|
||||
|
||||
bool CanOpCache(TensorRTOp *cur_op);
|
||||
|
||||
|
@ -113,7 +116,7 @@ class TensorRTSubGraph : public kernel::Kernel {
|
|||
std::vector<std::string> trt_in_tensor_name_;
|
||||
std::vector<std::string> trt_out_tensor_name_;
|
||||
|
||||
std::vector<mindspore::MSTensor> cache_inputs_;
|
||||
std::vector<mindspore::MSTensor> cache_const_inputs_;
|
||||
std::map<std::string, CacheTensorInfo> network_cache_tensor_info_;
|
||||
|
||||
nvinfer1::INetworkDefinition *network_{nullptr};
|
||||
|
@ -126,6 +129,10 @@ class TensorRTSubGraph : public kernel::Kernel {
|
|||
int input_batchsize_index_{0};
|
||||
int output_batchsize_index_{0};
|
||||
int input_hw_index_{0};
|
||||
|
||||
std::map<std::string, std::vector<mindspore::MSTensor>> model_input_to_cache_tensors_;
|
||||
|
||||
std::shared_ptr<cache::EmbeddingCacheManager> cache_mgr_{nullptr};
|
||||
};
|
||||
} // namespace mindspore::lite
|
||||
#endif // MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_TENSORRT_SUBGTAPH_H_
|
||||
|
|
|
@ -483,7 +483,7 @@ SchemaTensorWrapper *LiteModel::GetSchemaTensor(const size_t &tensor_index) cons
|
|||
return this->inner_all_tensors_.at(tensor_index);
|
||||
}
|
||||
|
||||
Model *ImportFromPath(const char *model_path) {
|
||||
LiteModel *LiteImportFromPath(const char *model_path) {
|
||||
if (model_path == nullptr) {
|
||||
MS_LOG(ERROR) << "The model path is nullptr";
|
||||
return nullptr;
|
||||
|
@ -508,6 +508,8 @@ Model *ImportFromPath(const char *model_path) {
|
|||
return model;
|
||||
}
|
||||
|
||||
Model *ImportFromPath(const char *model_path) { return LiteImportFromPath(model_path); }
|
||||
|
||||
Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf) {
|
||||
auto *model = new (std::nothrow) LiteModel();
|
||||
if (model == nullptr) {
|
||||
|
|
|
@ -29,6 +29,7 @@
|
|||
#include "src/common/version_manager.h"
|
||||
#include "src/schema_tensor_wrapper.h"
|
||||
#include "nnacl/op_base.h"
|
||||
#include "src/common/prim_util.h"
|
||||
#ifdef ENABLE_V0
|
||||
#include "schema/model_v0_generated.h"
|
||||
#endif
|
||||
|
@ -79,6 +80,7 @@ class LiteModel : public Model {
|
|||
MS_CHECK_TRUE_MSG(node != nullptr, false, "new node fail!");
|
||||
auto c_node = meta_graph.nodes()->template GetAs<U>(i);
|
||||
MS_CHECK_TRUE_MSG(c_node != nullptr, false, "get as cnode fail!");
|
||||
node->node_type_ = GetPrimitiveType(c_node->primitive(), schema_version_);
|
||||
#ifdef ENABLE_MODEL_OBF
|
||||
auto src_prim = reinterpret_cast<const schema::Primitive *>(c_node->primitive());
|
||||
auto src_prim_type = src_prim->value_type();
|
||||
|
@ -286,6 +288,8 @@ class LiteModel : public Model {
|
|||
};
|
||||
|
||||
Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf);
|
||||
LiteModel *LiteImportFromPath(const char *model_path);
|
||||
Model *ImportFromPath(const char *model_path);
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -242,6 +242,7 @@ int LiteSession::ConvertTensors(const lite::Model *model) {
|
|||
uint32_t tensor_count = model->all_tensors_.size();
|
||||
auto model_input_indices = model->input_indices_;
|
||||
auto model_output_indices = model->output_indices_;
|
||||
|
||||
for (uint32_t i = 0; i < tensor_count; ++i) {
|
||||
auto *src_tensor = model->all_tensors_[i];
|
||||
if (src_tensor == nullptr) {
|
||||
|
@ -279,6 +280,7 @@ int LiteSession::ConvertTensors(const lite::Model *model) {
|
|||
dst_tensor->set_category(Category::GRAPH_OUTPUT);
|
||||
}
|
||||
}
|
||||
|
||||
this->tensors_.emplace_back(dst_tensor);
|
||||
}
|
||||
return RET_OK;
|
||||
|
@ -1483,13 +1485,7 @@ int lite::LiteSession::CreateSessionByBuf(const char *model_buf, mindspore::Mode
|
|||
|
||||
int lite::LiteSession::CreateSessionByPath(const std::string &model_path, mindspore::ModelType model_type,
|
||||
session::LiteSession *session) {
|
||||
size_t model_size;
|
||||
auto model_buf = LoadModelByPath(model_path, model_type, &model_size);
|
||||
if (model_buf == nullptr) {
|
||||
MS_LOG(ERROR) << "Read model file failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto *model = lite::ImportFromBuffer(model_buf, model_size, true);
|
||||
auto *model = lite::ImportFromPath(model_path.c_str());
|
||||
if (model == nullptr) {
|
||||
MS_LOG(ERROR) << "Import model failed";
|
||||
return RET_ERROR;
|
||||
|
|
|
@ -1207,9 +1207,11 @@ OP_SCHEMA_DEF_END(ScatterNdUpdate)
|
|||
|
||||
OP_SCHEMA_DEF(AllGather)
|
||||
OP_ATTR(group, string)
|
||||
OP_ATTR(rank_size, int)
|
||||
OP_SCHEMA_DEF_END(AllGather)
|
||||
|
||||
OP_SCHEMA_DEF(ReduceScatter)
|
||||
OP_ATTR(group, string)
|
||||
OP_ATTR_ENUM(mode, ReduceMode)
|
||||
OP_ATTR(rank_size, int)
|
||||
OP_SCHEMA_DEF_END(ReduceScatter)
|
||||
|
|
|
@ -46,9 +46,10 @@ OpParameter *PopulateAllGatherParameter(const void *prim) {
|
|||
return nullptr;
|
||||
}
|
||||
memset(param, 0, sizeof(AllGatherParameter));
|
||||
|
||||
memcpy(param->group_, group->c_str(), group->size());
|
||||
memcpy(param->group_, value->group()->c_str(), value->group()->size());
|
||||
param->rank_size_ = value->rank_size();
|
||||
param->op_parameter_.type_ = primitive->value_type();
|
||||
|
||||
return reinterpret_cast<OpParameter *>(param);
|
||||
}
|
||||
REG_POPULATE(PrimitiveType_AllGather, PopulateAllGatherParameter, SCHEMA_CUR)
|
||||
|
|
|
@ -47,10 +47,11 @@ OpParameter *PopulateReduceScatterParameter(const void *prim) {
|
|||
return nullptr;
|
||||
}
|
||||
memset(param, 0, sizeof(ReduceScatterParameter));
|
||||
|
||||
memcpy(param->group_, group->c_str(), group->size());
|
||||
memcpy(param->group_, value->group()->c_str(), value->group()->size());
|
||||
param->rank_size_ = value->rank_size();
|
||||
param->mode_ = value->mode();
|
||||
param->op_parameter_.type_ = primitive->value_type();
|
||||
|
||||
return reinterpret_cast<OpParameter *>(param);
|
||||
}
|
||||
REG_POPULATE(PrimitiveType_ReduceScatter, PopulateReduceScatterParameter, SCHEMA_CUR)
|
||||
|
|
|
@ -34,11 +34,7 @@ int AllGatherCPUKernel::Prepare() {
|
|||
int AllGatherCPUKernel::ReSize() { return lite::RET_OK; }
|
||||
|
||||
int AllGatherCPUKernel::Run() {
|
||||
int rank = get_rank(param_->group_);
|
||||
if (param_->rank_ != rank) {
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
|
||||
int rank = param_->rank_size_;
|
||||
size_t data_size = in_tensors().front()->Size();
|
||||
auto out_tensor = out_tensors().front();
|
||||
int8_t *out_data = reinterpret_cast<int8_t *>(out_tensor->data());
|
||||
|
|
|
@ -57,7 +57,7 @@ int ReduceScatterCPUKernel::DoReduceScatter(void *in_data, void *reduce_data, si
|
|||
if (param_->mode_ == schema::ReduceMode_ReduceSum) {
|
||||
for (size_t i = 0; i < data_size; i++) out[i] += in[i];
|
||||
} else if (param_->mode_ == schema::ReduceMode_ReduceMean) {
|
||||
for (size_t i = 0; i < data_size; i++) out[i] += (in[i] / static_cast<float>(param_->rank_));
|
||||
for (size_t i = 0; i < data_size; i++) out[i] += (in[i] / static_cast<float>(param_->rank_size_));
|
||||
} else if (param_->mode_ == schema::ReduceMode_ReduceMax) {
|
||||
for (size_t i = 0; i < data_size; i++) out[i] = in[i] > out[i] ? in[i] : out[i];
|
||||
} else if (param_->mode_ == schema::ReduceMode_ReduceMin) {
|
||||
|
@ -74,11 +74,7 @@ int ReduceScatterCPUKernel::DoReduceScatter(void *in_data, void *reduce_data, si
|
|||
}
|
||||
|
||||
int ReduceScatterCPUKernel::Run() {
|
||||
int rank = get_rank(param_->group_);
|
||||
if (param_->rank_ != rank) {
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
|
||||
int rank = param_->rank_size_;
|
||||
size_t in_data_size = in_tensors().front()->Size();
|
||||
size_t in_ele_size = in_tensors().front()->ElementsNum();
|
||||
size_t out_data_size = out_tensors().front()->Size();
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include "tools/converter/quant_param_holder.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "utils/ms_utils_secure.h"
|
||||
#include "tools/optimizer/common/format_utils.h"
|
||||
#include "nnacl/op_base.h"
|
||||
#include "tools/anf_exporter/anf_exporter.h"
|
||||
|
@ -259,8 +260,9 @@ int FetchFromDefaultParam(const ParameterPtr ¶m_node, const converter::FmkTy
|
|||
if (tensor_info != nullptr && tensor_info->Size() != 0) {
|
||||
if (data_type != kObjectTypeTensorType || tensor_info->Size() >= kTensorListMinSize) {
|
||||
data_info->data_.resize(tensor_info->Size() - offset);
|
||||
if (EOK != memcpy_s(data_info->data_.data(), data_info->data_.size(),
|
||||
static_cast<uint8_t *>(tensor_info->data_c()) + offset, tensor_info->Size() - offset)) {
|
||||
if (EOK != common::huge_memcpy_s(data_info->data_.data(), data_info->data_.size(),
|
||||
static_cast<uint8_t *>(tensor_info->data_c()) + offset,
|
||||
tensor_info->Size() - offset)) {
|
||||
MS_LOG(ERROR) << "memcpy_s failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
|
|
@ -30,7 +30,7 @@
|
|||
|
||||
namespace mindspore::lite {
|
||||
namespace {
|
||||
constexpr size_t kModelSizeLimit = 2 * 1024 * 1024;
|
||||
constexpr size_t kModelSizeLimit = static_cast<size_t>(2) * 1024 * 1024 * 1024;
|
||||
constexpr size_t kExternalDataHeadSize = 4096;
|
||||
constexpr size_t kMagicNumberSize = 4;
|
||||
constexpr size_t kFlatbuffersBuilderInitSize = 1024;
|
||||
|
@ -241,7 +241,7 @@ int MetaGraphSerializer::Save(const schema::MetaGraphT &graph, const std::string
|
|||
builder.Finish(offset);
|
||||
schema::FinishMetaGraphBuffer(builder, offset);
|
||||
size_t size = builder.GetSize();
|
||||
auto save_together = (size < kModelSizeLimit) || true;
|
||||
auto save_together = (size < kModelSizeLimit);
|
||||
MetaGraphSerializer meta_graph_serializer;
|
||||
if (!meta_graph_serializer.Init(graph, output_path, save_together)) {
|
||||
MS_LOG(ERROR) << "Init MetaGraphSerializer failed";
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include <vector>
|
||||
#include "src/tensorlist.h"
|
||||
#include "tools/optimizer/common/format_utils.h"
|
||||
#include "utils/ms_utils_secure.h"
|
||||
#include "nnacl/op_base.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -100,12 +101,13 @@ int ConvertToLiteTensor(const std::vector<lite::DataInfo> &data_infos, std::vect
|
|||
return lite::RET_ERROR;
|
||||
}
|
||||
} else {
|
||||
auto tensor_data = reinterpret_cast<char *>(malloc(tensor_size));
|
||||
auto tensor_data = malloc(tensor_size);
|
||||
if (tensor_data == nullptr) {
|
||||
MS_LOG(ERROR) << "tensor_data is nullptr.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
if (memcpy_s(tensor_data, tensor_size, data_info.data_.data(), tensor_size) != EOK) {
|
||||
if (common::huge_memcpy_s(static_cast<uint8_t *>(tensor_data), tensor_size, data_info.data_.data(),
|
||||
tensor_size) != EOK) {
|
||||
free(tensor_data);
|
||||
MS_LOG(ERROR) << "memcpy data error.";
|
||||
return lite::RET_ERROR;
|
||||
|
|
Loading…
Reference in New Issue