!44275 Add GPU hash table impl

Merge pull request !44275 from zyli2020/dynamic_embedding_dev
This commit is contained in:
i-robot 2022-10-26 09:11:59 +00:00 committed by Gitee
commit 7193e1016d
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
5 changed files with 546 additions and 21 deletions

View File

@ -75,6 +75,10 @@ if(ENABLE_GPU)
${CUDA_PATH}/lib64/libcufft.so
${CUDA_PATH}/lib64/libcusparse.so)
endif()
if(CMAKE_SYSTEM_NAME MATCHES "Linux" AND ${CUDA_VERSION} VERSION_GREATER "11.0")
target_link_libraries(mindspore_gpu PRIVATE gpu_hash_table)
endif()
endif()
if(ENABLE_DEBUGGER)

View File

@ -36,12 +36,19 @@ if(ENABLE_GPU)
target_link_libraries(_ms_mpi PRIVATE mindspore::pybind11_module mindspore::ompi)
endif()
file(GLOB_RECURSE CUDA_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc" "*.cu")
file(GLOB_RECURSE CUDA_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cu")
set_property(SOURCE ${CUDA_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE)
if(CMAKE_SYSTEM_NAME MATCHES "Linux" AND ${CUDA_VERSION} VERSION_GREATER "11.0")
string(REPLACE "-arch=sm_53;" "" CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS}")
list(APPEND CUDA_NVCC_FLAGS -std=c++17)
cuda_add_library(gpu_hash_table STATIC ${CUDA_SRC_LIST})
endif()
set(GPU_COLLECTIVE_SRCS "distribution/collective_wrapper.cc"
"distribution/mpi_wrapper.cc"
"distribution/nccl_wrapper.cc")
list(REMOVE_ITEM CUDA_SRC_LIST "mpi/mpi_initializer.cc" ${GPU_COLLECTIVE_SRCS})
if(ENABLE_MPI)
include(ExternalProject)

View File

@ -0,0 +1,446 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#if CUDA_VERSION > 11000
#include "plugin/device/gpu/hal/device/gpu_hash_table.h"
#include <cuco/dynamic_map.cuh>
#include <random>
#include <algorithm>
#include "plugin/device/gpu/hal/device/gpu_hash_table_kernel.cuh"
#include "utils/log_adapter.h"
#include "utils/convert_utils_base.h"
#include "plugin/device/gpu/hal/device/gpu_device_manager.h"
namespace mindspore {
namespace device {
namespace gpu {
#define CHECK_CUDA_RET(expression, message) \
{ \
cudaError_t cuda_ret = (expression); \
if (cuda_ret != cudaSuccess) { \
MS_LOG(ERROR) << "CUDA Error: " << message << " | Error Number: " << cuda_ret << " " \
<< cudaGetErrorString(cuda_ret); \
} \
}
#define CHECK_CUDA_RET_WITH_RETURN_FALSE(expression, message) \
{ \
cudaError_t cuda_ret = (expression); \
if (cuda_ret != cudaSuccess) { \
MS_LOG(ERROR) << "CUDA Error: " << message << " | Error Number: " << cuda_ret << " " \
<< cudaGetErrorString(cuda_ret); \
return false; \
} \
}
#define ASSERT_EQUAL(lhs, rhs, message) \
{ \
if ((lhs) != (rhs)) { \
MS_LOG(ERROR) << message; \
return false; \
} \
}
template <typename Key, typename Value, typename Allocator>
using CucoDynamicMap = cuco::dynamic_map<Key, Value, cuda::thread_scope_device, Allocator>;
// CudaDynamicMap is a wrapper of cuco::dynamic_map, gpu_hash_table.h needs to be used by other cpp source files, in
// order for g++ to compile properly, the declaration of the cuco::dynamic_map type cannot appear in the header file
// gpu_hash_table.h, through the CudaDynamicMap type gpu_hash_ table.h pre-declaration to solve compilation problems.
template <typename Key, typename Value, typename Allocator>
struct CudaDynamicMap {
CucoDynamicMap<Key, Value, Allocator> dynamic_map_;
CudaDynamicMap(const Key &empty_key, const Value &empty_value, const Key &erased_key, const Allocator &alloc,
cudaStream_t stream = 0)
: dynamic_map_(kInitialCapacity, cuco::sentinel::empty_key<Key>{empty_key},
cuco::sentinel::empty_value<Value>{empty_value}, cuco::sentinel::erased_key<Key>{erased_key},
alloc, stream) {}
~CudaDynamicMap() = default;
};
template <typename Key, typename Value, typename Allocator>
GPUHashTable<Key, Value, Allocator>::GPUHashTable(int32_t value_dim, const std::string &initializer,
const Allocator &alloc)
: value_dim_(value_dim), initializer_(initializer), default_value_(0), index_alloc_(alloc), char_alloc_(alloc) {
cudaStream_t stream = reinterpret_cast<cudaStream_t>(GPUDeviceManager::GetInstance().default_stream());
cuda_dynamic_map_ = std::make_unique<CudaDynamicMap<Key, int32_t, Allocator>>(static_cast<Key>(-1), -1,
static_cast<Key>(-2), alloc, stream);
CHECK_CUDA_RET(cudaStreamSynchronize(stream), "cudaStreamSynchronize default cuda stream");
CHECK_CUDA_RET(
cudaMallocManaged(&insert_success_number_, sizeof(cuda::atomic<std::size_t, cuda::thread_scope_device>)),
"cudaMallocManaged");
}
template <typename Key, typename Value, typename Allocator>
GPUHashTable<Key, Value, Allocator>::GPUHashTable(int32_t value_dim, const Value &default_value, const Allocator &alloc)
: value_dim_(value_dim), initializer_(""), default_value_(default_value), index_alloc_(alloc), char_alloc_(alloc) {
cudaStream_t stream = reinterpret_cast<cudaStream_t>(GPUDeviceManager::GetInstance().default_stream());
cuda_dynamic_map_ = std::make_unique<CudaDynamicMap<Key, int32_t, Allocator>>(static_cast<Key>(-1), -1,
static_cast<Key>(-2), alloc, stream);
CHECK_CUDA_RET(cudaStreamSynchronize(stream), "cudaStreamSynchronize default cuda stream");
CHECK_CUDA_RET(
cudaMallocManaged(&insert_success_number_, sizeof(cuda::atomic<std::size_t, cuda::thread_scope_device>)),
"cudaMallocManaged");
}
template <typename Key, typename Value, typename Allocator>
GPUHashTable<Key, Value, Allocator>::~GPUHashTable() {
CHECK_CUDA_RET(cudaFree(insert_success_number_), "cudaFree");
}
template <typename Key, typename Value, typename Allocator>
bool GPUHashTable<Key, Value, Allocator>::Find(const Key *keys, size_t key_num, Value *outputs, void *stream) {
if (!initializer_.empty()) {
return Find(keys, key_num, initializer_, outputs, stream);
}
return Find(keys, key_num, default_value_, outputs, stream);
}
template <typename Key, typename Value, typename Allocator>
bool GPUHashTable<Key, Value, Allocator>::Find(const Key *key, size_t key_num, const std::string &initializer,
Value *outputs, void *stream) {
MS_ERROR_IF_NULL(key);
MS_ERROR_IF_NULL(outputs);
MS_ERROR_IF_NULL(stream);
int *indices = std::allocator_traits<IndexAllocatorType>::allocate(index_alloc_, key_num);
MS_ERROR_IF_NULL(indices);
// 1. Get all indices in blocks according to the key.
auto cuda_stream = reinterpret_cast<cudaStream_t>(stream);
RETURN_IF_FALSE_WITH_LOG(GetIndicesByKeys(key, key_num, true, indices, cuda_stream), "Get indices by keys failed.");
Reserve(size_ + key_num);
// 2. Insert default value according to initializer, initializer can be 'normal', 'zeros' or 'ones'.
RETURN_IF_FALSE_WITH_LOG(InsertDefaultValueByInitializer(key_num, initializer, indices, cuda_stream),
"Insert default value for miss keys failed.");
// 3. Get all values by indices in blocks.
size_t total_size = value_dim_ * key_num;
GetValues<<<GET_BLOCKS(total_size), GET_THREADS, 0, cuda_stream>>>(value_dim_, total_size, indices,
elements_per_block_, blocks_ptr_, outputs);
std::allocator_traits<IndexAllocatorType>::deallocate(index_alloc_, indices, key_num);
return true;
}
template <typename Key, typename Value, typename Allocator>
bool GPUHashTable<Key, Value, Allocator>::Find(const Key *key, size_t key_num, const Value &default_value,
Value *outputs, void *stream) {
MS_ERROR_IF_NULL(key);
MS_ERROR_IF_NULL(outputs);
MS_ERROR_IF_NULL(stream);
int *indices = std::allocator_traits<IndexAllocatorType>::allocate(index_alloc_, key_num);
MS_ERROR_IF_NULL(indices);
// 1. Get all indices in blocks according to the key.
auto cuda_stream = reinterpret_cast<cudaStream_t>(stream);
RETURN_IF_FALSE_WITH_LOG(GetIndicesByKeys(key, key_num, true, indices, cuda_stream), "Get indices by keys failed.");
Reserve(size_ + key_num);
// 2. Insert default value into map by specific value.
InsertDefaultValue<<<GET_BLOCKS(key_num), GET_THREADS, 0, cuda_stream>>>(
value_dim_, key_num, indices, elements_per_block_, default_value, idle_flags_ptr_, blocks_ptr_);
// 3. Get all values by indices in blocks.
size_t total_size = value_dim_ * key_num;
GetValues<<<GET_BLOCKS(total_size), GET_THREADS, 0, cuda_stream>>>(value_dim_, total_size, indices,
elements_per_block_, blocks_ptr_, outputs);
std::allocator_traits<IndexAllocatorType>::deallocate(index_alloc_, indices, key_num);
return true;
}
template <typename Key, typename Value, typename Allocator>
bool GPUHashTable<Key, Value, Allocator>::Insert(const Key *key, size_t key_num, const Value *value, void *stream) {
MS_ERROR_IF_NULL(key);
MS_ERROR_IF_NULL(value);
MS_ERROR_IF_NULL(stream);
int *indices = std::allocator_traits<IndexAllocatorType>::allocate(index_alloc_, key_num);
MS_ERROR_IF_NULL(indices);
// 1. Get all indices in blocks according to the key.
auto cuda_stream = reinterpret_cast<cudaStream_t>(stream);
RETURN_IF_FALSE_WITH_LOG(GetIndicesByKeys(key, key_num, true, indices, cuda_stream), "Get indices by keys failed.");
Reserve(size_ + key_num);
// 2. Insert values into map by indices in blocks.
size_t total_insert_size = value_dim_ * key_num;
InsertValues<<<GET_BLOCKS(total_insert_size), GET_THREADS, 0, cuda_stream>>>(
value_dim_, total_insert_size, indices, value, elements_per_block_, idle_flags_ptr_, blocks_ptr_);
std::allocator_traits<IndexAllocatorType>::deallocate(index_alloc_, indices, key_num);
return true;
}
template <typename Key, typename Value, typename Allocator>
bool GPUHashTable<Key, Value, Allocator>::Erase(const Key *keys, size_t key_num, void *stream) {
MS_ERROR_IF_NULL(keys);
MS_ERROR_IF_NULL(stream);
MS_ERROR_IF_NULL(cuda_dynamic_map_);
auto &dynamic_map = cuda_dynamic_map_->dynamic_map_;
dynamic_map.erase(keys, keys + key_num, reinterpret_cast<cudaStream_t>(stream));
return true;
}
template <typename Key, typename Value, typename Allocator>
bool GPUHashTable<Key, Value, Allocator>::GetKeysAndValues(Key *keys, Value *values, void *stream) {
MS_ERROR_IF_NULL(keys);
MS_ERROR_IF_NULL(values);
MS_ERROR_IF_NULL(cuda_dynamic_map_);
auto &dynamic_map = cuda_dynamic_map_->dynamic_map_;
int *indices = std::allocator_traits<IndexAllocatorType>::allocate(index_alloc_, size_);
MS_ERROR_IF_NULL(indices);
// 1. Export all keys and indices from dynamic map.
auto cuda_stream = reinterpret_cast<cudaStream_t>(stream);
RETURN_IF_FALSE_WITH_LOG(dynamic_map.get_keys_values(keys, indices, cuda_stream),
"Get keys and values from cuda dynamic map failed.");
// 2. Get all values by indices in blocks.
size_t total_size = value_dim_ * size_;
GetValues<<<GET_BLOCKS(total_size), GET_THREADS, 0, cuda_stream>>>(value_dim_, total_size, indices,
elements_per_block_, blocks_ptr_, values);
std::allocator_traits<IndexAllocatorType>::deallocate(index_alloc_, indices, size_);
return true;
}
template <typename Key, typename Value, typename Allocator>
bool GPUHashTable<Key, Value, Allocator>::Import(const DataLenPair &input_data) {
// 1. Store input tensor data until receiving kImportTensorNum(3) input tensor.
// Really import input data to hash table when receive kImportTensorNum(3) input tensor.
static std::vector<DataLenPair> input_data_list;
if (input_data_list.size() < kImportTensorNum) {
input_data_list.emplace_back(input_data);
}
if (input_data_list.size() != kImportTensorNum) {
return true;
}
const auto &input_keys = input_data_list[0];
const auto &input_values = input_data_list[1];
void *host_keys = input_keys.first;
void *host_values = input_values.first;
MS_ERROR_IF_NULL(host_keys);
MS_ERROR_IF_NULL(host_values);
size_t keys_len = input_keys.second;
size_t values_len = input_values.second;
// 2. Allocate temp buffer to keys and values.
char *device_keys = std::allocator_traits<CharAllocatorType>::allocate(char_alloc_, keys_len);
char *device_values = std::allocator_traits<CharAllocatorType>::allocate(char_alloc_, values_len);
MS_ERROR_IF_NULL(device_keys);
MS_ERROR_IF_NULL(device_values);
// 3. Copy input keys and values to device.
cudaStream_t stream = reinterpret_cast<cudaStream_t>(GPUDeviceManager::GetInstance().default_stream());
CHECK_CUDA_RET_WITH_RETURN_FALSE(cudaMemcpyAsync(device_keys, host_keys, keys_len, cudaMemcpyHostToDevice, stream),
"cudaMemcpyAsync");
CHECK_CUDA_RET_WITH_RETURN_FALSE(
cudaMemcpyAsync(device_values, host_values, values_len, cudaMemcpyHostToDevice, stream), "cudaMemcpyAsync");
// 4. Insert input keys and values to hash table.
RETURN_IF_FALSE_WITH_LOG(Insert(reinterpret_cast<Key *>(device_keys), keys_len / sizeof(Key),
reinterpret_cast<Value *>(device_values), stream),
"Insert keys and values failed.");
CHECK_CUDA_RET_WITH_RETURN_FALSE(cudaStreamSynchronize(stream), "cudaStreamSynchronize default cuda stream");
// 5. Free temp buffer to keys and values.
std::allocator_traits<CharAllocatorType>::deallocate(char_alloc_, device_keys, keys_len);
std::allocator_traits<CharAllocatorType>::deallocate(char_alloc_, device_values, values_len);
input_data_list.clear();
return true;
}
template <typename Key, typename Value, typename Allocator>
bool GPUHashTable<Key, Value, Allocator>::Export(const DataLenPair &keys, const DataLenPair &values,
const DataLenPair &status) {
MS_ERROR_IF_NULL(keys.first);
MS_ERROR_IF_NULL(values.first);
MS_ERROR_IF_NULL(status.first);
size_t keys_len = size_ * sizeof(Key);
size_t values_len = size_ * value_dim_ * sizeof(Value);
size_t status_len = size_ * sizeof(Status);
// 1. Check length for output tensor.
ASSERT_EQUAL(
keys_len, keys.second,
std::string("Need keys len[") + std::to_string(keys_len) + "], but got:[" + std::to_string(keys.second) + "].");
ASSERT_EQUAL(values_len, values.second,
std::string("Need values len[") + std::to_string(values_len) + "], but got:[" +
std::to_string(values.second) + "].");
ASSERT_EQUAL(status_len, status.second,
std::string("Need status len[") + std::to_string(status_len) + "], but got:[" +
std::to_string(status.second) + "].");
// 2. Allocate temp buffer to keys, values and status.
char *device_keys = std::allocator_traits<CharAllocatorType>::allocate(char_alloc_, keys_len);
char *device_values = std::allocator_traits<CharAllocatorType>::allocate(char_alloc_, values_len);
char *device_status = std::allocator_traits<CharAllocatorType>::allocate(char_alloc_, status_len);
MS_ERROR_IF_NULL(device_keys);
MS_ERROR_IF_NULL(device_values);
MS_ERROR_IF_NULL(device_status);
// 3. Export all keys and indices and store into temp buffer.
cudaStream_t stream = reinterpret_cast<cudaStream_t>(GPUDeviceManager::GetInstance().default_stream());
RETURN_IF_FALSE_WITH_LOG(
GetKeysAndValues(reinterpret_cast<Key *>(device_keys), reinterpret_cast<Value *>(device_values), stream),
"Get keys and values failed.");
// Note: Get all status.
// 4. Copy keys, values and status from device temp buffer to host.
CHECK_CUDA_RET_WITH_RETURN_FALSE(cudaMemcpyAsync(keys.first, device_keys, keys_len, cudaMemcpyDeviceToHost, stream),
"cudaMemcpyAsync");
CHECK_CUDA_RET_WITH_RETURN_FALSE(
cudaMemcpyAsync(values.first, device_values, values_len, cudaMemcpyDeviceToHost, stream), "cudaMemcpyAsync");
CHECK_CUDA_RET_WITH_RETURN_FALSE(
cudaMemcpyAsync(status.first, device_status, status_len, cudaMemcpyDeviceToHost, stream), "cudaMemcpyAsync");
CHECK_CUDA_RET_WITH_RETURN_FALSE(cudaStreamSynchronize(stream), "cudaStreamSynchronize default cuda stream");
// 5. Free temp buffer to keys, values and status.
std::allocator_traits<CharAllocatorType>::deallocate(char_alloc_, device_keys, keys_len);
std::allocator_traits<CharAllocatorType>::deallocate(char_alloc_, device_values, values_len);
std::allocator_traits<CharAllocatorType>::deallocate(char_alloc_, device_status, status_len);
return true;
}
template <typename Key, typename Value, typename Allocator>
bool GPUHashTable<Key, Value, Allocator>::GetIndicesByKeys(const Key *key, size_t key_num, bool insert_miss_key,
int32_t *indices, cudaStream_t stream) {
MS_ERROR_IF_NULL(key);
MS_ERROR_IF_NULL(indices);
MS_ERROR_IF_NULL(stream);
MS_ERROR_IF_NULL(cuda_dynamic_map_);
auto &dynamic_map = cuda_dynamic_map_->dynamic_map_;
dynamic_map.reserve(key_num + dynamic_map.get_size());
size_t submap_idx = 0;
uint32_t device_id = GET_CTX_DEVICE_ID;
size_t remaining_key_num = key_num;
MS_ERROR_IF_NULL(insert_success_number_);
while (remaining_key_num > 0) {
auto &submap_ptr = dynamic_map.get_submaps()[submap_idx];
MS_ERROR_IF_NULL(submap_ptr);
// 1. Get reamaining capacity in current submap, max load faltor and min insert size need to be considered.
size_t submap_remaining_capacity =
submap_ptr->get_capacity() * dynamic_map.get_max_load_factor() - submap_ptr->get_size();
if (submap_remaining_capacity < dynamic_map.get_min_insert_size()) {
submap_idx++;
continue;
}
*(insert_success_number_) = 0;
CHECK_CUDA_RET_WITH_RETURN_FALSE(cudaMemPrefetchAsync(insert_success_number_, sizeof(CudaAtomicSize), device_id),
"cudaMemPrefetchAsync");
// 2. Get the key number could be handled by current submap.
size_t item_num = std::min(submap_remaining_capacity, remaining_key_num);
const uint32_t tile_size = kTileSize;
const uint32_t block_size = kBlockSize;
if (IntToUint(GET_THREADS) < block_size) {
MS_LOG(ERROR) << "The max thread per block is less than: " << block_size << " of this GPU";
}
const uint32_t grid_size = IntToUint(CUDA_BLOCKS_CAL(device_id, tile_size * item_num, block_size));
// 3. Transform all keys into indices in blocks. If the key exist in map already ,just return the index,
// otherwise find a valid position in block.
LookupIndices<block_size, tile_size, Key, typename CucoDynamicMap<Key, int32_t, Allocator>::mutable_view_type,
typename CucoDynamicMap<Key, int32_t, Allocator>::view_type><<<grid_size, block_size, 0, stream>>>(
key, item_num, insert_miss_key, dynamic_map.get_submaps().size(), submap_idx, idel_slot_, idle_index_,
dynamic_map.get_submap_mutable_views().data().get(), dynamic_map.get_submap_views().data().get(),
insert_success_number_, current_index_, indices);
// 4. Update size for dynamic map and static submap.
CHECK_CUDA_RET_WITH_RETURN_FALSE(cudaStreamSynchronize(stream), "cudaStreamSynchronize");
size_t insert_success_num = insert_success_number_->load(cuda::std::memory_order_relaxed);
dynamic_map.update_submap_size(submap_idx, submap_ptr->get_size() + insert_success_num);
dynamic_map.update_size(dynamic_map.get_size() + insert_success_num);
size_ += insert_success_num;
indices += item_num;
key += item_num;
remaining_key_num -= item_num;
submap_idx++;
}
return true;
}
template <typename Key, typename Value, typename Allocator>
bool GPUHashTable<Key, Value, Allocator>::InsertDefaultValueByInitializer(size_t key_num,
const std::string &initializer,
const int *indices, cudaStream_t stream) {
MS_ERROR_IF_NULL(indices);
MS_ERROR_IF_NULL(stream);
if (initializer == kNormalDistribution) {
// Normal distribution.
RETURN_IF_FALSE_WITH_LOG(InitNormalDistRandomGenerator(stream),
"Initialize normal distribution random generator failed.");
Value mean = static_cast<Value>(0);
Value stddev = static_cast<Value>(0.01);
InsertNormalDistRandomValue<<<random_gen_block_count_, random_gen_threads_per_block_, 0, stream>>>(
value_dim_, key_num, indices, elements_per_block_, mean, stddev, random_gen_state_, idle_flags_ptr_, blocks_ptr_);
} else if (initializer == kOnesDistribution) {
// One distribution.
InsertDefaultValue<<<GET_BLOCKS(key_num), GET_THREADS, 0, stream>>>(
value_dim_, key_num, indices, elements_per_block_, static_cast<Value>(1.0), idle_flags_ptr_, blocks_ptr_);
} else if (initializer == kZerosDistribution) {
// Zero distribution.
InsertDefaultValue<<<GET_BLOCKS(key_num), GET_THREADS, 0, stream>>>(
value_dim_, key_num, indices, elements_per_block_, static_cast<Value>(0), idle_flags_ptr_, blocks_ptr_);
} else {
MS_LOG(ERROR) << "Unsupported initializer: " << initializer;
return false;
}
return true;
}
template <typename Key, typename Value, typename Allocator>
bool GPUHashTable<Key, Value, Allocator>::InitNormalDistRandomGenerator(cudaStream_t stream) {
MS_ERROR_IF_NULL(stream);
if (random_gen_init_.load()) {
return true;
}
// 1. Allocate memory for all random generator states.
auto total_random_state_num = random_gen_threads_per_block_ * random_gen_block_count_;
random_gen_state_ = reinterpret_cast<curandStatePhilox4_32_10_t *>(std::allocator_traits<CharAllocatorType>::allocate(
char_alloc_, IntToSize(total_random_state_num) * sizeof(curandStatePhilox4_32_10_t)));
MS_ERROR_IF_NULL(random_gen_state_);
// 2. Initialize normal distribution random generator states.
std::random_device rd;
uint32_t seed = rd();
InitNormalDisRandomGen<<<random_gen_block_count_, random_gen_threads_per_block_, 0, stream>>>(seed,
random_gen_state_);
random_gen_init_ = true;
return true;
}
template class GPUHashTable<int32_t, float>;
template class GPUHashTable<int64_t, float>;
} // namespace gpu
} // namespace device
} // namespace mindspore
#endif

View File

@ -16,36 +16,62 @@
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_HAL_DEVICE_GPU_HASH_TABLE_H_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_HAL_DEVICE_GPU_HASH_TABLE_H_
#if CUDA_VERSION > 11000
#include <curand_kernel.h>
#include <cuda/std/atomic>
#include <string>
#include <vector>
#include <memory>
#include <atomic>
#include "runtime/device/hash_table.h"
#include "plugin/device/gpu/hal/device/gpu_allocator.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h"
namespace mindspore {
namespace device {
namespace gpu {
// A hash table base on GPU.
constexpr static int kMaxThreadsPerBlockRandomGen = 2048;
constexpr static size_t kInitialCapacity = 100000;
constexpr static uint32_t kTileSize = 4;
constexpr static uint32_t kBlockSize = 128;
constexpr static size_t kImportTensorNum = 3;
constexpr static char kNormalDistribution[] = "normal";
constexpr static char kZerosDistribution[] = "zeros";
constexpr static char kOnesDistribution[] = "ones";
template <typename Key, typename Value, typename Allocator>
class CudaDynamicMap;
// A hash table base on GPU.
template <typename Key, typename Value, typename Allocator = GPUAllocator<char>>
class GPUHashTable : public HashTable<Key, Value> {
public:
explicit GPUHashTable(int32_t value_dim, const Allocator &alloc = Allocator())
: value_dim_(value_dim), alloc_(alloc) {}
~GPUHashTable() = default;
using Status = typename HashTable<Key, Value>::Status;
using key_allocator_type = typename std::allocator_traits<Allocator>::rebind_alloc<key>;
using value_allocator_type = typename std::allocator_traits<Allocator>::rebind_alloc<Value>;
using index_allocator_type = typename std::allocator_traits<Allocator>::rebind_alloc<int>;
using char_allocator_type = typename std::allocator_traits<Allocator>::rebind_alloc<char>;
GPUHashTable(int32_t value_dim, const std::string &initializer, const Allocator &alloc = Allocator());
GPUHashTable(int32_t value_dim, const Value &default_value, const Allocator &alloc = Allocator());
~GPUHashTable();
// The Allocator type used allocate gpu memory for 'Key' type.
using KeyAllocatorType = typename std::allocator_traits<Allocator>::template rebind_alloc<Key>;
// The Allocator type used allocate gpu memory for 'Value' type.
using ValueAllocatorType = typename std::allocator_traits<Allocator>::template rebind_alloc<Value>;
// The Allocator type used allocate gpu memory for 'index' type(int).
using IndexAllocatorType = typename std::allocator_traits<Allocator>::template rebind_alloc<int>;
// The general Allocator type used allocate gpu memory.
using CharAllocatorType = typename std::allocator_traits<Allocator>::template rebind_alloc<char>;
// Find elements with specific keys, if a key does not exist, initialize the value for the key based on the
// initialzer and insert the key-value pair into map. The initializer can be 'normal', 'zero' or 'one', and also
// could be a specific Value type scalar.
// could be a specific 'Value' type scalar.
bool Find(const Key *keys, size_t key_num, Value *outputs, void *stream) override;
// Insert elements with specific keys. If key exists, update the value of the key.
bool Insert(const Key *key, size_t key_num, const Value *value) override { return true; }
bool Insert(const Key *keys, size_t key_num, const Value *value, void *stream) override;
// Erase elements with specific keys.
bool Erase(const Key *key, size_t key_num) override { return true; }
bool Erase(const Key *keys, size_t key_num, void *stream) override;
// Reserves space for at least the specified number of elements.
bool Reserve(size_t count) override { return true; }
@ -70,18 +96,22 @@ class GPUHashTable : public HashTable<Key, Value> {
private:
// Find elements with specific keys, if the key does not exist, initialize the value for the key based on the
// initialzer and insert the key-value pair into map.The initializer can be 'normal', 'zero' or 'one'.
bool Find(const Key *key, size_t key_num, const std::string &initializer, Value *outputs) { return true; }
// initialzer and insert the key-value pair into map.The initializer can be 'normal', 'zeros' or 'ones'.
bool Find(const Key *keys, size_t key_num, const std::string &initializer, Value *outputs, void *stream);
// Find elements with specific keys, if the key does not exist, initialize the value for the key by 'default_value'
// and insert the key-value pair into map.
bool Find(const Key *key, size_t key_num, const Value &default_value, Value *outputs) { return true; }
bool Find(const Key *keys, size_t key_num, const Value &default_value, Value *outputs, void *stream);
// Get all indices in blocks according to the key.
bool GetIndicesByKeys(const Key *key, size_t key_num, bool insert_miss_key, int32_t *indices) const;
bool GetIndicesByKeys(const Key *key, size_t key_num, bool insert_miss_key, int32_t *indices, cudaStream_t stream);
// Get value in blocks according to the indices.
bool GetValues(const int32_t *indices, Value *value) const;
// Insert default value according to initializer, initializer can be 'normal', 'zeros' or 'ones'.
bool InsertDefaultValueByInitializer(size_t key_num, const std::string &initializer, const int *indices,
cudaStream_t stream);
// initialize normal distribution random generator states on GPU.
bool InitNormalDistRandomGenerator(cudaStream_t stream);
// Record all block memory that contain all values.
std::vector<Value *> blocks_;
@ -93,19 +123,55 @@ class GPUHashTable : public HashTable<Key, Value> {
// Record whether every slot is occupied or not for all block, the buffer is on device memory.
bool **idle_flags_ptr_{nullptr};
// The sentinel record the latest used location in blocks.
cuda::atomic<std::size_t, cuda::thread_scope_device> *current_index_{nullptr};
// The counter record the idle slot number, if the contents of a slot are erased, the slot is marked with the idle
// status.
cuda::atomic<std::size_t, cuda::thread_scope_device> *idle_index_{nullptr};
// The buffer keep all the idle slot position(offset index to the beginning of block).
int32_t *idel_slot_{nullptr};
// The value dimension for each key.
size_t value_dim_;
// The allocator used to alloacte gpu memory.
Allocator alloc_;
// The initializer used to initialize the values for missing keys, the initializer could be 'normal', 'zero' or 'one'.
std::string initializer_;
// The default value used to initialize the values for missing keys.
Value default_value_;
// The allocator used to alloacte gpu memory for index.
IndexAllocatorType index_alloc_;
// The common allocator used to alloacte gpu memory.
CharAllocatorType char_alloc_;
// The dynamic hash table on GPU.
std::unique_ptr<CudaDynamicMap<Key, int32_t, Allocator>> cuda_dynamic_map_;
// Record the number of elements in the map.
size_t size_{0};
// Record the number of elements that can be held in currently allocated storage
size_t capacity_{0};
// The number of elements of one block.
size_t elements_per_block_{kInitialCapacity};
// Record the number of successfully inserted keys.
cuda::atomic<std::size_t, cuda::thread_scope_device> *insert_success_number_{nullptr};
// The flag record whether normal distribution random generator state is initialized.
std::atomic_bool random_gen_init_{false};
// The random generator states used to generate normal distribution.
curandStatePhilox4_32_10_t *random_gen_state_{nullptr};
// The block size used to launch cuda kernel for inserting normal distribution random values.
int random_gen_threads_per_block_{GET_THREADS};
// The grid size used to launch cuda kernel for inserting normal distribution random values.
int random_gen_block_count_{(kMaxThreadsPerBlockRandomGen - 1) / random_gen_threads_per_block_ + 1};
};
} // namespace gpu
} // namespace device
} // namespace mindspore
#endif
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_HAL_DEVICE_GPU_HASH_TABLE_H_

View File

@ -16,6 +16,7 @@
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_HAL_DEVICE_GPU_HASH_TABLE_KERNEL_CUH_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_HAL_DEVICE_GPU_HASH_TABLE_KERNEL_CUH_
#if CUDA_VERSION > 11000
#include <cuco/dynamic_map.cuh>
#include <curand_kernel.h>
@ -219,4 +220,5 @@ __global__ void InsertValues(size_t value_dim, size_t total_insert_size, const i
} // namespace gpu
} // namespace device
} // namespace mindspore
#endif
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_HAL_DEVICE_GPU_HASH_TABLE_KERNEL_CUH_