forked from mindspore-Ecosystem/mindspore
!46094 bugfix for feature permit and evict
Merge pull request !46094 from zyli2020/dynamic_embedding_compile
This commit is contained in:
commit
e71c403bb7
|
@ -58,14 +58,27 @@ std::vector<size_t> GPUHashTable<Key, Value, Allocator>::lookup_counter_initiali
|
|||
|
||||
template <typename Key, typename Value, typename Allocator>
|
||||
GPUHashTable<Key, Value, Allocator>::GPUHashTable(int32_t value_dim, const std::string &initializer,
|
||||
uint64_t permit_threshold, uint64_t evict_threshold,
|
||||
const Allocator &alloc)
|
||||
: value_dim_(value_dim), initializer_(initializer), default_value_(0), char_alloc_(alloc) {
|
||||
: value_dim_(value_dim),
|
||||
initializer_(initializer),
|
||||
default_value_(0),
|
||||
char_alloc_(alloc),
|
||||
permit_threshold_(permit_threshold),
|
||||
evict_threshold_(evict_threshold) {
|
||||
Initialize(alloc);
|
||||
}
|
||||
|
||||
template <typename Key, typename Value, typename Allocator>
|
||||
GPUHashTable<Key, Value, Allocator>::GPUHashTable(int32_t value_dim, const Value &default_value, const Allocator &alloc)
|
||||
: value_dim_(value_dim), initializer_(""), default_value_(default_value), char_alloc_(alloc) {
|
||||
GPUHashTable<Key, Value, Allocator>::GPUHashTable(int32_t value_dim, const Value &default_value,
|
||||
uint64_t permit_threshold, uint64_t evict_threshold,
|
||||
const Allocator &alloc)
|
||||
: value_dim_(value_dim),
|
||||
initializer_(""),
|
||||
default_value_(default_value),
|
||||
char_alloc_(alloc),
|
||||
permit_threshold_(permit_threshold),
|
||||
evict_threshold_(evict_threshold) {
|
||||
Initialize(alloc);
|
||||
}
|
||||
|
||||
|
@ -203,7 +216,7 @@ bool GPUHashTable<Key, Value, Allocator>::Find(const Key *keys, size_t key_num,
|
|||
|
||||
// 2. Insert default value into map by specific value.
|
||||
InsertDefaultValue<<<GET_BLOCKS(key_num), GET_THREADS, 0, cuda_stream>>>(
|
||||
value_dim_, key_num, indices, elements_per_block_, lookup_cnts_ptr_, min_lookup_cnt_before_permit_, default_value,
|
||||
value_dim_, key_num, indices, elements_per_block_, lookup_cnts_ptr_, permit_threshold_, default_value,
|
||||
idle_flags_ptr_, blocks_ptr_);
|
||||
|
||||
// 3. Get all values by indices in blocks.
|
||||
|
@ -238,7 +251,7 @@ bool GPUHashTable<Key, Value, Allocator>::Insert(const Key *keys, size_t key_num
|
|||
// 2. Insert values into map by indices in blocks.
|
||||
size_t total_insert_size = value_dim_ * key_num;
|
||||
InsertValues<<<GET_BLOCKS(total_insert_size), GET_THREADS, 0, cuda_stream>>>(
|
||||
value_dim_, total_insert_size, indices, value, elements_per_block_, lookup_cnts_ptr_, min_lookup_cnt_before_permit_,
|
||||
value_dim_, total_insert_size, indices, value, elements_per_block_, lookup_cnts_ptr_, permit_threshold_,
|
||||
global_timestamp_, update_timestamps_ptr_, statuses_ptr_, idle_flags_ptr_, blocks_ptr_);
|
||||
|
||||
is_dirty_ = true;
|
||||
|
@ -502,7 +515,8 @@ bool GPUHashTable<Key, Value, Allocator>::GetKeysAndValues(Key *keys, Value *val
|
|||
|
||||
template <typename Key, typename Value, typename Allocator>
|
||||
bool GPUHashTable<Key, Value, Allocator>::EvictExpiredElements(cudaStream_t stream) {
|
||||
if (max_time_interval_to_evict_ == SIZE_MAX) {
|
||||
// If evict_threshold_ is greater than or equal to kMaxEvictThreshold, eviction is disable.
|
||||
if (evict_threshold_ >= kMaxEvictThreshold) {
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -558,8 +572,8 @@ bool GPUHashTable<Key, Value, Allocator>::CountExpiredElements(cudaStream_t stre
|
|||
|
||||
// 2. Count the number for expired elements.
|
||||
CountExpiredNum<block_size><<<grid_size, block_size, 0, stream>>>(
|
||||
blocks_.size(), min_lookup_cnt_before_permit_, global_timestamp_, max_time_interval_to_evict_, elements_per_block_,
|
||||
idle_flags_ptr_, lookup_cnts_ptr_, update_timestamps_ptr_, device_expired_counter);
|
||||
blocks_.size(), permit_threshold_, global_timestamp_, evict_threshold_, elements_per_block_, idle_flags_ptr_,
|
||||
lookup_cnts_ptr_, update_timestamps_ptr_, device_expired_counter);
|
||||
|
||||
CHECK_CUDA_RET_WITH_RETURN_FALSE(cudaMemcpyAsync(&host_expired_counter, device_expired_counter,
|
||||
sizeof(CudaAtomicSize), cudaMemcpyDeviceToHost, stream),
|
||||
|
@ -605,9 +619,9 @@ bool GPUHashTable<Key, Value, Allocator>::FindExpiredElements(Key *expired_keys,
|
|||
|
||||
// 3. Find all keys and indices for expired elememts in hash table.
|
||||
FindExpiredKeysAndIndices<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(
|
||||
size, elements_per_block_, min_lookup_cnt_before_permit_, global_timestamp_, max_time_interval_to_evict_,
|
||||
idle_flags_ptr_, lookup_cnts_ptr_, update_timestamps_ptr_, all_keys, all_indices, device_expired_counter,
|
||||
expired_keys, expired_indices);
|
||||
size, elements_per_block_, permit_threshold_, global_timestamp_, evict_threshold_, idle_flags_ptr_,
|
||||
lookup_cnts_ptr_, update_timestamps_ptr_, all_keys, all_indices, device_expired_counter, expired_keys,
|
||||
expired_indices);
|
||||
|
||||
CHECK_CUDA_RET_WITH_RETURN_FALSE(cudaStreamSynchronize(stream), "cudaStreamSynchronize default cuda stream");
|
||||
FreeMemory(all_keys);
|
||||
|
@ -1046,8 +1060,8 @@ template <typename Key, typename Value, typename Allocator>
|
|||
bool GPUHashTable<Key, Value, Allocator>::UpdateSize(size_t key_num, const int *indices, cudaStream_t stream,
|
||||
bool update_lookup_count) {
|
||||
size_t old_size = size_;
|
||||
if (min_lookup_cnt_before_permit_ == 1) {
|
||||
// Elements permission is disable.
|
||||
// If permit_threshold_ is less than or equal to kMinPermitThreshold, permission is disable.
|
||||
if (permit_threshold_ <= kMinPermitThreshold) {
|
||||
MS_EXCEPTION_IF_NULL(cuda_dynamic_map_);
|
||||
auto &dynamic_map = cuda_dynamic_map_->dynamic_map_;
|
||||
size_ = dynamic_map.get_size();
|
||||
|
@ -1058,6 +1072,8 @@ bool GPUHashTable<Key, Value, Allocator>::UpdateSize(size_t key_num, const int *
|
|||
}
|
||||
|
||||
if (!update_lookup_count) {
|
||||
// Note: if need not to update lookup count and permission is enable, the size of gpu hash table could not change.
|
||||
// For example, all keys for `Insert` should be contained in gpu hash table already.
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -1078,7 +1094,7 @@ bool GPUHashTable<Key, Value, Allocator>::UpdateSize(size_t key_num, const int *
|
|||
|
||||
// Launch kernel to count new permitted elements number.
|
||||
CountPermissionNum<block_size><<<grid_size, block_size, 0, stream>>>(
|
||||
elements_per_block_, key_num, indices, lookup_cnts_ptr_, min_lookup_cnt_before_permit_, insert_success_number_);
|
||||
elements_per_block_, key_num, indices, lookup_cnts_ptr_, permit_threshold_, insert_success_number_);
|
||||
CHECK_CUDA_RET_WITH_RETURN_FALSE(cudaStreamSynchronize(stream), "cudaStreamSynchronize default cuda stream");
|
||||
|
||||
// Update hash table size.
|
||||
|
@ -1105,18 +1121,18 @@ bool GPUHashTable<Key, Value, Allocator>::InsertDefaultValueByInitializer(size_t
|
|||
Value stddev = static_cast<Value>(0.01);
|
||||
|
||||
InsertNormalDistRandomValue<<<random_gen_block_count_, random_gen_threads_per_block_, 0, stream>>>(
|
||||
value_dim_, key_num, indices, elements_per_block_, lookup_cnts_ptr_, min_lookup_cnt_before_permit_, mean, stddev,
|
||||
value_dim_, key_num, indices, elements_per_block_, lookup_cnts_ptr_, permit_threshold_, mean, stddev,
|
||||
random_gen_state_, idle_flags_ptr_, blocks_ptr_);
|
||||
} else if (initializer == kOnesDistribution) {
|
||||
// One distribution.
|
||||
InsertDefaultValue<<<GET_BLOCKS(key_num), GET_THREADS, 0, stream>>>(
|
||||
value_dim_, key_num, indices, elements_per_block_, lookup_cnts_ptr_, min_lookup_cnt_before_permit_,
|
||||
static_cast<Value>(1.0), idle_flags_ptr_, blocks_ptr_);
|
||||
value_dim_, key_num, indices, elements_per_block_, lookup_cnts_ptr_, permit_threshold_, static_cast<Value>(1.0),
|
||||
idle_flags_ptr_, blocks_ptr_);
|
||||
} else if (initializer == kZerosDistribution) {
|
||||
// Zero distribution.
|
||||
InsertDefaultValue<<<GET_BLOCKS(key_num), GET_THREADS, 0, stream>>>(
|
||||
value_dim_, key_num, indices, elements_per_block_, lookup_cnts_ptr_, min_lookup_cnt_before_permit_,
|
||||
static_cast<Value>(0), idle_flags_ptr_, blocks_ptr_);
|
||||
value_dim_, key_num, indices, elements_per_block_, lookup_cnts_ptr_, permit_threshold_, static_cast<Value>(0),
|
||||
idle_flags_ptr_, blocks_ptr_);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupported initializer: " << initializer;
|
||||
return false;
|
||||
|
|
|
@ -49,8 +49,10 @@ class CudaDynamicMap;
|
|||
template <typename Key, typename Value, typename Allocator = GPUAllocator<char>>
|
||||
class GPUHashTable : public HashTable<Key, Value> {
|
||||
public:
|
||||
GPUHashTable(int32_t value_dim, const std::string &initializer, const Allocator &alloc = Allocator());
|
||||
GPUHashTable(int32_t value_dim, const Value &default_value, const Allocator &alloc = Allocator());
|
||||
GPUHashTable(int32_t value_dim, const std::string &initializer, uint64_t permit_threshold = kMinPermitThreshold,
|
||||
uint64_t evict_threshold = kMaxEvictThreshold, const Allocator &alloc = Allocator());
|
||||
GPUHashTable(int32_t value_dim, const Value &default_value, uint64_t permit_threshold = kMinPermitThreshold,
|
||||
uint64_t evict_threshold = kMaxEvictThreshold, const Allocator &alloc = Allocator());
|
||||
~GPUHashTable();
|
||||
|
||||
// The Allocator type used allocate gpu memory for 'Key' type.
|
||||
|
@ -68,6 +70,7 @@ class GPUHashTable : public HashTable<Key, Value> {
|
|||
bool Find(const Key *keys, size_t key_num, bool insert_default_value, Value *outputs, void *stream) override;
|
||||
|
||||
// Insert elements with specific keys. If key exists, update the value of the key.
|
||||
// If permission is enable, all keys for `Insert` should be contained in gpu hash table already.
|
||||
bool Insert(const Key *keys, size_t key_num, const Value *value, void *stream) override;
|
||||
|
||||
// Erase elements with specific keys.
|
||||
|
@ -152,12 +155,12 @@ class GPUHashTable : public HashTable<Key, Value> {
|
|||
|
||||
// Count the number for expired elememts in hash table.
|
||||
// If the elements in the hash table are not used or updated within the time interval indicated by the
|
||||
// `max_time_interval_to_evict_`, these elements will be expired.
|
||||
// `evict_threshold_`, these elements will be expired.
|
||||
bool CountExpiredElements(cudaStream_t stream, size_t *expired_num);
|
||||
|
||||
// Find all keys and indices for expired elememts in hash table.
|
||||
// If the elements in the hash table are not used or updated within the time interval indicated by the
|
||||
// `max_time_interval_to_evict_`, these elements will be expired.
|
||||
// `evict_threshold_`, these elements will be expired.
|
||||
bool FindExpiredElements(Key *expired_keys, int *expired_indices, cudaStream_t stream);
|
||||
|
||||
// Erase expired elememts in hash table.
|
||||
|
@ -261,12 +264,14 @@ class GPUHashTable : public HashTable<Key, Value> {
|
|||
|
||||
// Permission threshold: When an element is accessed more than this threshold, it will be actually inserted into
|
||||
// the hash table.
|
||||
size_t min_lookup_cnt_before_permit_{1};
|
||||
// If permit_threshold_ is less than or equal to kMinPermitThreshold, feature permission is disable.
|
||||
uint64_t permit_threshold_;
|
||||
|
||||
// Element eviction time interval threshold: evict elements that are not frequently accessed automatically,
|
||||
// if the elements in the hash table are not used or updated within the time interval indicated by the threshold,
|
||||
// these elements will be removed from the hash table.
|
||||
size_t max_time_interval_to_evict_{SIZE_MAX};
|
||||
// If evict_threshold_ is greater than kMaxEvictThreshold, feature eviction is disable.
|
||||
uint64_t evict_threshold_;
|
||||
|
||||
// Global timestamp, which is currently recorded when the hash table is updated.
|
||||
size_t global_timestamp_{0};
|
||||
|
|
|
@ -60,6 +60,9 @@ using CudaAtomicInt = cuda::atomic<int32_t, cuda::thread_scope_device>;
|
|||
constexpr static int kEmptyKey = -1;
|
||||
constexpr static int kEmptyValue = -1;
|
||||
constexpr static int kErasedKey = -2;
|
||||
|
||||
constexpr static uint64_t kMinPermitThreshold = 1;
|
||||
constexpr static uint64_t kMaxEvictThreshold = INT64_MAX;
|
||||
} // namespace gpu
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -121,7 +121,7 @@ __global__ void InitNormalDisRandomGen(uint32_t seed, curandStatePhilox4_32_10_t
|
|||
template <typename Value>
|
||||
__global__ void InsertNormalDistRandomValue(size_t value_dim, size_t total_insert_elements_num, const int *indices,
|
||||
size_t elements_per_block, size_t *const *lookup_cnts_ptr,
|
||||
size_t min_lookup_cnt, const Value mean, const Value stddev,
|
||||
size_t permit_threshold, const Value mean, const Value stddev,
|
||||
curandStatePhilox4_32_10_t *state, bool **idle_flags_ptr,
|
||||
Value *const *blocks_ptr) {
|
||||
size_t total_thread_num = blockDim.x * gridDim.x;
|
||||
|
@ -129,8 +129,8 @@ __global__ void InsertNormalDistRandomValue(size_t value_dim, size_t total_inser
|
|||
const size_t block_idx = indices[pos] / elements_per_block;
|
||||
const size_t offset_in_block = indices[pos] % elements_per_block;
|
||||
|
||||
bool enable_permission = min_lookup_cnt > 1;
|
||||
bool meet_permit_cond = lookup_cnts_ptr[block_idx][offset_in_block] == min_lookup_cnt;
|
||||
bool enable_permission = permit_threshold > kMinPermitThreshold;
|
||||
bool meet_permit_cond = lookup_cnts_ptr[block_idx][offset_in_block] == permit_threshold;
|
||||
// Need not to initialize if the current slot is not idle and does not meet the permission condition.
|
||||
if (!idle_flags_ptr[block_idx][offset_in_block] && (!(enable_permission && meet_permit_cond))) {
|
||||
continue;
|
||||
|
@ -175,15 +175,15 @@ __global__ void InsertNormalDistRandomValue(size_t value_dim, size_t total_inser
|
|||
// Insert default value into map by specific value.
|
||||
template <typename Value>
|
||||
__global__ void InsertDefaultValue(size_t value_dim, size_t total_insert_elements_num, const int *indices,
|
||||
size_t elements_per_block, size_t *const *lookup_cnts_ptr, size_t min_lookup_cnt,
|
||||
size_t elements_per_block, size_t *const *lookup_cnts_ptr, size_t permit_threshold,
|
||||
const Value default_value, bool **idle_flags_ptr, Value *const *blocks_ptr) {
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < total_insert_elements_num;
|
||||
pos += blockDim.x * gridDim.x) {
|
||||
const size_t block_idx = indices[pos] / elements_per_block;
|
||||
const size_t offset_in_block = indices[pos] % elements_per_block;
|
||||
|
||||
bool enable_permission = min_lookup_cnt > 1;
|
||||
bool meet_permit_cond = lookup_cnts_ptr[block_idx][offset_in_block] == min_lookup_cnt;
|
||||
bool enable_permission = permit_threshold > kMinPermitThreshold;
|
||||
bool meet_permit_cond = lookup_cnts_ptr[block_idx][offset_in_block] == permit_threshold;
|
||||
// Need not to initialize if the current slot is not idle and does not meet the permission condition.
|
||||
if (!idle_flags_ptr[block_idx][offset_in_block] && (!(enable_permission && meet_permit_cond))) {
|
||||
continue;
|
||||
|
@ -231,7 +231,7 @@ __global__ void GetValues(size_t value_dim, size_t total_size, const int *indice
|
|||
template <typename Value>
|
||||
__global__ void InsertValues(size_t value_dim, size_t total_insert_size, const int *indices, const Value *insert_values,
|
||||
const size_t elements_per_block, const size_t *const *lookup_cnts_ptr,
|
||||
size_t min_lookup_cnt, size_t global_timestamp, size_t *const *update_timestamps_ptr,
|
||||
size_t permit_threshold, size_t global_timestamp, size_t *const *update_timestamps_ptr,
|
||||
Status *const *statuses_ptr, bool *const *idle_flags_ptr, Value *const *blocks_ptr) {
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < total_insert_size; pos += blockDim.x * gridDim.x) {
|
||||
const size_t index = pos / value_dim;
|
||||
|
@ -243,7 +243,7 @@ __global__ void InsertValues(size_t value_dim, size_t total_insert_size, const i
|
|||
const size_t block_idx = indices[index] / elements_per_block;
|
||||
const size_t offset_in_block = indices[index] % elements_per_block;
|
||||
|
||||
if (min_lookup_cnt > 1 && lookup_cnts_ptr[block_idx][offset_in_block] < min_lookup_cnt) {
|
||||
if (permit_threshold > kMinPermitThreshold && lookup_cnts_ptr[block_idx][offset_in_block] < permit_threshold) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -269,7 +269,7 @@ __global__ void InsertValues(size_t value_dim, size_t total_insert_size, const i
|
|||
|
||||
template <uint32_t block_size>
|
||||
__global__ void CountPermissionNum(size_t elements_per_block, size_t key_num, const int *indices,
|
||||
size_t *const *lookup_cnts_ptr, size_t min_lookup_cnt,
|
||||
size_t *const *lookup_cnts_ptr, size_t permit_threshold,
|
||||
CudaAtomicSize *insert_counter) {
|
||||
typedef cub::BlockReduce<size_t, block_size> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
|
@ -280,7 +280,7 @@ __global__ void CountPermissionNum(size_t elements_per_block, size_t key_num, co
|
|||
const size_t offset_in_block = indices[pos] % elements_per_block;
|
||||
lookup_cnts_ptr[block_idx][offset_in_block] += 1;
|
||||
|
||||
if (lookup_cnts_ptr[block_idx][offset_in_block] == min_lookup_cnt) {
|
||||
if (lookup_cnts_ptr[block_idx][offset_in_block] == permit_threshold) {
|
||||
// Insert a new element.
|
||||
local_counter++;
|
||||
}
|
||||
|
@ -293,8 +293,8 @@ __global__ void CountPermissionNum(size_t elements_per_block, size_t key_num, co
|
|||
}
|
||||
|
||||
template <uint32_t block_size>
|
||||
__global__ void CountExpiredNum(size_t blocks_num, size_t min_lookup_cnt, size_t global_timestamp,
|
||||
size_t max_time_interval, size_t elements_per_block, bool *const *idle_flags_ptr,
|
||||
__global__ void CountExpiredNum(size_t blocks_num, size_t permit_threshold, size_t global_timestamp,
|
||||
uint64_t evict_threshold, size_t elements_per_block, bool *const *idle_flags_ptr,
|
||||
size_t *const *lookup_cnts_ptr, size_t *const *update_timestamps_ptr,
|
||||
CudaAtomicSize *expired_counter) {
|
||||
typedef cub::BlockReduce<size_t, block_size> BlockReduce;
|
||||
|
@ -304,13 +304,13 @@ __global__ void CountExpiredNum(size_t blocks_num, size_t min_lookup_cnt, size_t
|
|||
for (size_t i = 0; i < blocks_num; i++) {
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < elements_per_block; pos += blockDim.x * gridDim.x) {
|
||||
// Ignore elements at idle slot or ones that have not been really inserted into hash table yet.
|
||||
if (idle_flags_ptr[i][pos] || lookup_cnts_ptr[i][pos] < min_lookup_cnt) {
|
||||
if (idle_flags_ptr[i][pos] || lookup_cnts_ptr[i][pos] < permit_threshold) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Counter expired element.
|
||||
size_t timestamp = update_timestamps_ptr[i][pos];
|
||||
if (global_timestamp > timestamp && (global_timestamp - timestamp) > max_time_interval) {
|
||||
if (global_timestamp > timestamp && (global_timestamp - timestamp) > evict_threshold) {
|
||||
local_expired_counter++;
|
||||
}
|
||||
}
|
||||
|
@ -323,8 +323,8 @@ __global__ void CountExpiredNum(size_t blocks_num, size_t min_lookup_cnt, size_t
|
|||
}
|
||||
|
||||
template <typename Key>
|
||||
__global__ void FindExpiredKeysAndIndices(size_t key_num, size_t elements_per_block, size_t min_lookup_cnt,
|
||||
size_t global_timestamp, size_t max_time_interval,
|
||||
__global__ void FindExpiredKeysAndIndices(size_t key_num, size_t elements_per_block, size_t permit_threshold,
|
||||
size_t global_timestamp, uint64_t evict_threshold,
|
||||
bool *const *idle_flags_ptr, size_t *const *lookup_cnts_ptr,
|
||||
size_t *const *update_timestamps_ptr, const Key *all_keys,
|
||||
const int *all_indices, CudaAtomicSize *expired_counter, Key *expired_keys,
|
||||
|
@ -337,13 +337,13 @@ __global__ void FindExpiredKeysAndIndices(size_t key_num, size_t elements_per_bl
|
|||
const size_t offset_in_block = index % elements_per_block;
|
||||
|
||||
// Ignore elements at idle slot or ones that have not been really inserted into hash table yet.
|
||||
if (idle_flags_ptr[block_idx][offset_in_block] || lookup_cnts_ptr[block_idx][offset_in_block] < min_lookup_cnt) {
|
||||
if (idle_flags_ptr[block_idx][offset_in_block] || lookup_cnts_ptr[block_idx][offset_in_block] < permit_threshold) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Record expired keys and indices.
|
||||
size_t timestamp = update_timestamps_ptr[block_idx][offset_in_block];
|
||||
if (global_timestamp > timestamp && (global_timestamp - timestamp) > max_time_interval) {
|
||||
if (global_timestamp > timestamp && (global_timestamp - timestamp) > evict_threshold) {
|
||||
size_t expired_index = expired_counter->fetch_add(1, cuda::std::memory_order_relaxed);
|
||||
expired_keys[expired_index] = key;
|
||||
expired_indices[expired_index] = index;
|
||||
|
|
|
@ -36,8 +36,22 @@ void SetHashTable(const UserDataPtr &user_data) {
|
|||
MS_EXCEPTION_IF_NULL(user_data);
|
||||
auto shape_vector = user_data->get<ShapeVector>(kHashTableShapeVector);
|
||||
auto default_value = user_data->get<Value>(kHashTableDefaultValue);
|
||||
auto permit_filter_value = user_data->get<Value>(kHashTablePermitFilter);
|
||||
auto evict_filter_value = user_data->get<Value>(kHashTableEvictFilter);
|
||||
MS_EXCEPTION_IF_NULL(shape_vector);
|
||||
MS_EXCEPTION_IF_NULL(default_value);
|
||||
MS_EXCEPTION_IF_NULL(permit_filter_value);
|
||||
MS_EXCEPTION_IF_NULL(evict_filter_value);
|
||||
if (!permit_filter_value->isa<Int64Imm>()) {
|
||||
MS_LOG(EXCEPTION) << "Invalid type for permit filter value: "
|
||||
<< TypeIdLabel(permit_filter_value->type()->type_id());
|
||||
}
|
||||
if (!evict_filter_value->isa<Int64Imm>()) {
|
||||
MS_LOG(EXCEPTION) << "Invalid type for evict filter value: " << TypeIdLabel(evict_filter_value->type()->type_id());
|
||||
}
|
||||
auto permit_threshold = LongToUlong(GetValue<int64_t>(permit_filter_value));
|
||||
auto evict_threshold = LongToUlong(GetValue<int64_t>(evict_filter_value));
|
||||
|
||||
int32_t value_size = 1;
|
||||
for (size_t i = 0; i < (*shape_vector).size(); ++i) {
|
||||
value_size *= (*shape_vector)[i];
|
||||
|
@ -47,11 +61,12 @@ void SetHashTable(const UserDataPtr &user_data) {
|
|||
}
|
||||
if (default_value->isa<StringImm>()) {
|
||||
user_data->set<GPUHashTable<KeyType, ValueType>>(
|
||||
kUserDataData,
|
||||
std::make_shared<GPUHashTable<KeyType, ValueType>>(value_size, GetValue<std::string>(default_value)));
|
||||
kUserDataData, std::make_shared<GPUHashTable<KeyType, ValueType>>(
|
||||
value_size, GetValue<std::string>(default_value), permit_threshold, evict_threshold));
|
||||
} else if (default_value->isa<FloatImm>()) {
|
||||
user_data->set<GPUHashTable<KeyType, ValueType>>(
|
||||
kUserDataData, std::make_shared<GPUHashTable<KeyType, float>>(value_size, GetValue<float>(default_value)));
|
||||
kUserDataData, std::make_shared<GPUHashTable<KeyType, float>>(value_size, GetValue<float>(default_value),
|
||||
permit_threshold, evict_threshold));
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Invalid default value:" << default_value;
|
||||
}
|
||||
|
|
|
@ -69,6 +69,8 @@ void DeviceAddressUtils::CreateDeviceAddressByMapTensorNode(const DeviceContext
|
|||
user_data->set(kHashTableValueType, std::make_shared<TypeId>(map_tensor_type->value_dtype()->type_id()));
|
||||
user_data->set(kHashTableShapeVector, std::make_shared<ShapeVector>(shape_vector));
|
||||
user_data->set(kHashTableDefaultValue, abstract->default_value());
|
||||
user_data->set(kHashTablePermitFilter, abstract->permit_filter_value());
|
||||
user_data->set(kHashTableEvictFilter, abstract->evict_filter_value());
|
||||
// Create device for map tensor node and the ptr size is 1 byte.
|
||||
auto device_address = device_context->device_res_manager_->CreateDeviceAddress(
|
||||
nullptr, 1, kOpFormat_DEFAULT, TypeId::kNumberTypeInt8, ShapeVector(), user_data);
|
||||
|
|
|
@ -82,10 +82,13 @@ class UserData {
|
|||
// User data key name.
|
||||
constexpr auto kUserDataData = "user_data_data";
|
||||
constexpr auto kUserDataType = "user_data_type";
|
||||
// User data key for hash table.
|
||||
constexpr auto kHashTableKeyType = "hash_table_key_type";
|
||||
constexpr auto kHashTableValueType = "hash_table_value_type";
|
||||
constexpr auto kHashTableShapeVector = "hash_table_shape_vector";
|
||||
constexpr auto kHashTableDefaultValue = "hash_table_default_value";
|
||||
constexpr auto kHashTablePermitFilter = "hash_table_permit_filter";
|
||||
constexpr auto kHashTableEvictFilter = "hash_table_evict_filter";
|
||||
|
||||
enum class UserDataType { kUserDataTypeUnknown = 0, kUserTypeHashTable };
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue