From 47aaf3bad517e22d1d3ce6a4da4a97c4d0052822 Mon Sep 17 00:00:00 2001 From: lizhenyu Date: Thu, 20 Oct 2022 14:42:58 +0800 Subject: [PATCH] Add GPU hash table impl --- .../ccsrc/plugin/device/gpu/CMakeLists.txt | 4 + .../device/gpu/hal/device/CMakeLists.txt | 11 +- .../device/gpu/hal/device/gpu_hash_table.cu | 446 ++++++++++++++++++ .../device/gpu/hal/device/gpu_hash_table.h | 104 +++- .../gpu/hal/device/gpu_hash_table_kernel.cuh | 2 + 5 files changed, 546 insertions(+), 21 deletions(-) create mode 100644 mindspore/ccsrc/plugin/device/gpu/hal/device/gpu_hash_table.cu diff --git a/mindspore/ccsrc/plugin/device/gpu/CMakeLists.txt b/mindspore/ccsrc/plugin/device/gpu/CMakeLists.txt index a77ab936489..b3e0b0ed684 100644 --- a/mindspore/ccsrc/plugin/device/gpu/CMakeLists.txt +++ b/mindspore/ccsrc/plugin/device/gpu/CMakeLists.txt @@ -75,6 +75,10 @@ if(ENABLE_GPU) ${CUDA_PATH}/lib64/libcufft.so ${CUDA_PATH}/lib64/libcusparse.so) endif() + + if(CMAKE_SYSTEM_NAME MATCHES "Linux" AND ${CUDA_VERSION} VERSION_GREATER "11.0") + target_link_libraries(mindspore_gpu PRIVATE gpu_hash_table) + endif() endif() if(ENABLE_DEBUGGER) diff --git a/mindspore/ccsrc/plugin/device/gpu/hal/device/CMakeLists.txt b/mindspore/ccsrc/plugin/device/gpu/hal/device/CMakeLists.txt index e05eed19d63..d4cb8443bf8 100644 --- a/mindspore/ccsrc/plugin/device/gpu/hal/device/CMakeLists.txt +++ b/mindspore/ccsrc/plugin/device/gpu/hal/device/CMakeLists.txt @@ -36,12 +36,19 @@ if(ENABLE_GPU) target_link_libraries(_ms_mpi PRIVATE mindspore::pybind11_module mindspore::ompi) endif() - file(GLOB_RECURSE CUDA_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc" "*.cu") + file(GLOB_RECURSE CUDA_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cu") + set_property(SOURCE ${CUDA_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE) + + if(CMAKE_SYSTEM_NAME MATCHES "Linux" AND ${CUDA_VERSION} VERSION_GREATER "11.0") + string(REPLACE "-arch=sm_53;" "" CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS}") + list(APPEND CUDA_NVCC_FLAGS -std=c++17) + cuda_add_library(gpu_hash_table STATIC ${CUDA_SRC_LIST}) + endif() + set(GPU_COLLECTIVE_SRCS "distribution/collective_wrapper.cc" "distribution/mpi_wrapper.cc" "distribution/nccl_wrapper.cc") - list(REMOVE_ITEM CUDA_SRC_LIST "mpi/mpi_initializer.cc" ${GPU_COLLECTIVE_SRCS}) if(ENABLE_MPI) include(ExternalProject) diff --git a/mindspore/ccsrc/plugin/device/gpu/hal/device/gpu_hash_table.cu b/mindspore/ccsrc/plugin/device/gpu/hal/device/gpu_hash_table.cu new file mode 100644 index 00000000000..5255cd53251 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/gpu/hal/device/gpu_hash_table.cu @@ -0,0 +1,446 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#if CUDA_VERSION > 11000 +#include "plugin/device/gpu/hal/device/gpu_hash_table.h" + +#include +#include +#include + +#include "plugin/device/gpu/hal/device/gpu_hash_table_kernel.cuh" +#include "utils/log_adapter.h" +#include "utils/convert_utils_base.h" +#include "plugin/device/gpu/hal/device/gpu_device_manager.h" + +namespace mindspore { +namespace device { +namespace gpu { +#define CHECK_CUDA_RET(expression, message) \ + { \ + cudaError_t cuda_ret = (expression); \ + if (cuda_ret != cudaSuccess) { \ + MS_LOG(ERROR) << "CUDA Error: " << message << " | Error Number: " << cuda_ret << " " \ + << cudaGetErrorString(cuda_ret); \ + } \ + } + +#define CHECK_CUDA_RET_WITH_RETURN_FALSE(expression, message) \ + { \ + cudaError_t cuda_ret = (expression); \ + if (cuda_ret != cudaSuccess) { \ + MS_LOG(ERROR) << "CUDA Error: " << message << " | Error Number: " << cuda_ret << " " \ + << cudaGetErrorString(cuda_ret); \ + return false; \ + } \ + } + +#define ASSERT_EQUAL(lhs, rhs, message) \ + { \ + if ((lhs) != (rhs)) { \ + MS_LOG(ERROR) << message; \ + return false; \ + } \ + } + +template +using CucoDynamicMap = cuco::dynamic_map; + +// CudaDynamicMap is a wrapper of cuco::dynamic_map, gpu_hash_table.h needs to be used by other cpp source files, in +// order for g++ to compile properly, the declaration of the cuco::dynamic_map type cannot appear in the header file +// gpu_hash_table.h, through the CudaDynamicMap type gpu_hash_ table.h pre-declaration to solve compilation problems. +template +struct CudaDynamicMap { + CucoDynamicMap dynamic_map_; + + CudaDynamicMap(const Key &empty_key, const Value &empty_value, const Key &erased_key, const Allocator &alloc, + cudaStream_t stream = 0) + : dynamic_map_(kInitialCapacity, cuco::sentinel::empty_key{empty_key}, + cuco::sentinel::empty_value{empty_value}, cuco::sentinel::erased_key{erased_key}, + alloc, stream) {} + + ~CudaDynamicMap() = default; +}; + +template +GPUHashTable::GPUHashTable(int32_t value_dim, const std::string &initializer, + const Allocator &alloc) + : value_dim_(value_dim), initializer_(initializer), default_value_(0), index_alloc_(alloc), char_alloc_(alloc) { + cudaStream_t stream = reinterpret_cast(GPUDeviceManager::GetInstance().default_stream()); + cuda_dynamic_map_ = std::make_unique>(static_cast(-1), -1, + static_cast(-2), alloc, stream); + CHECK_CUDA_RET(cudaStreamSynchronize(stream), "cudaStreamSynchronize default cuda stream"); + + CHECK_CUDA_RET( + cudaMallocManaged(&insert_success_number_, sizeof(cuda::atomic)), + "cudaMallocManaged"); +} + +template +GPUHashTable::GPUHashTable(int32_t value_dim, const Value &default_value, const Allocator &alloc) + : value_dim_(value_dim), initializer_(""), default_value_(default_value), index_alloc_(alloc), char_alloc_(alloc) { + cudaStream_t stream = reinterpret_cast(GPUDeviceManager::GetInstance().default_stream()); + cuda_dynamic_map_ = std::make_unique>(static_cast(-1), -1, + static_cast(-2), alloc, stream); + CHECK_CUDA_RET(cudaStreamSynchronize(stream), "cudaStreamSynchronize default cuda stream"); + + CHECK_CUDA_RET( + cudaMallocManaged(&insert_success_number_, sizeof(cuda::atomic)), + "cudaMallocManaged"); +} + +template +GPUHashTable::~GPUHashTable() { + CHECK_CUDA_RET(cudaFree(insert_success_number_), "cudaFree"); +} + +template +bool GPUHashTable::Find(const Key *keys, size_t key_num, Value *outputs, void *stream) { + if (!initializer_.empty()) { + return Find(keys, key_num, initializer_, outputs, stream); + } + return Find(keys, key_num, default_value_, outputs, stream); +} + +template +bool GPUHashTable::Find(const Key *key, size_t key_num, const std::string &initializer, + Value *outputs, void *stream) { + MS_ERROR_IF_NULL(key); + MS_ERROR_IF_NULL(outputs); + MS_ERROR_IF_NULL(stream); + int *indices = std::allocator_traits::allocate(index_alloc_, key_num); + MS_ERROR_IF_NULL(indices); + + // 1. Get all indices in blocks according to the key. + auto cuda_stream = reinterpret_cast(stream); + RETURN_IF_FALSE_WITH_LOG(GetIndicesByKeys(key, key_num, true, indices, cuda_stream), "Get indices by keys failed."); + + Reserve(size_ + key_num); + // 2. Insert default value according to initializer, initializer can be 'normal', 'zeros' or 'ones'. + RETURN_IF_FALSE_WITH_LOG(InsertDefaultValueByInitializer(key_num, initializer, indices, cuda_stream), + "Insert default value for miss keys failed."); + + // 3. Get all values by indices in blocks. + size_t total_size = value_dim_ * key_num; + GetValues<<>>(value_dim_, total_size, indices, + elements_per_block_, blocks_ptr_, outputs); + std::allocator_traits::deallocate(index_alloc_, indices, key_num); + return true; +} + +template +bool GPUHashTable::Find(const Key *key, size_t key_num, const Value &default_value, + Value *outputs, void *stream) { + MS_ERROR_IF_NULL(key); + MS_ERROR_IF_NULL(outputs); + MS_ERROR_IF_NULL(stream); + int *indices = std::allocator_traits::allocate(index_alloc_, key_num); + MS_ERROR_IF_NULL(indices); + + // 1. Get all indices in blocks according to the key. + auto cuda_stream = reinterpret_cast(stream); + RETURN_IF_FALSE_WITH_LOG(GetIndicesByKeys(key, key_num, true, indices, cuda_stream), "Get indices by keys failed."); + + Reserve(size_ + key_num); + // 2. Insert default value into map by specific value. + InsertDefaultValue<<>>( + value_dim_, key_num, indices, elements_per_block_, default_value, idle_flags_ptr_, blocks_ptr_); + + // 3. Get all values by indices in blocks. + size_t total_size = value_dim_ * key_num; + GetValues<<>>(value_dim_, total_size, indices, + elements_per_block_, blocks_ptr_, outputs); + std::allocator_traits::deallocate(index_alloc_, indices, key_num); + return true; +} + +template +bool GPUHashTable::Insert(const Key *key, size_t key_num, const Value *value, void *stream) { + MS_ERROR_IF_NULL(key); + MS_ERROR_IF_NULL(value); + MS_ERROR_IF_NULL(stream); + int *indices = std::allocator_traits::allocate(index_alloc_, key_num); + MS_ERROR_IF_NULL(indices); + + // 1. Get all indices in blocks according to the key. + auto cuda_stream = reinterpret_cast(stream); + RETURN_IF_FALSE_WITH_LOG(GetIndicesByKeys(key, key_num, true, indices, cuda_stream), "Get indices by keys failed."); + + Reserve(size_ + key_num); + // 2. Insert values into map by indices in blocks. + size_t total_insert_size = value_dim_ * key_num; + InsertValues<<>>( + value_dim_, total_insert_size, indices, value, elements_per_block_, idle_flags_ptr_, blocks_ptr_); + + std::allocator_traits::deallocate(index_alloc_, indices, key_num); + return true; +} + +template +bool GPUHashTable::Erase(const Key *keys, size_t key_num, void *stream) { + MS_ERROR_IF_NULL(keys); + MS_ERROR_IF_NULL(stream); + MS_ERROR_IF_NULL(cuda_dynamic_map_); + auto &dynamic_map = cuda_dynamic_map_->dynamic_map_; + dynamic_map.erase(keys, keys + key_num, reinterpret_cast(stream)); + return true; +} + +template +bool GPUHashTable::GetKeysAndValues(Key *keys, Value *values, void *stream) { + MS_ERROR_IF_NULL(keys); + MS_ERROR_IF_NULL(values); + MS_ERROR_IF_NULL(cuda_dynamic_map_); + auto &dynamic_map = cuda_dynamic_map_->dynamic_map_; + int *indices = std::allocator_traits::allocate(index_alloc_, size_); + MS_ERROR_IF_NULL(indices); + + // 1. Export all keys and indices from dynamic map. + auto cuda_stream = reinterpret_cast(stream); + RETURN_IF_FALSE_WITH_LOG(dynamic_map.get_keys_values(keys, indices, cuda_stream), + "Get keys and values from cuda dynamic map failed."); + + // 2. Get all values by indices in blocks. + size_t total_size = value_dim_ * size_; + GetValues<<>>(value_dim_, total_size, indices, + elements_per_block_, blocks_ptr_, values); + std::allocator_traits::deallocate(index_alloc_, indices, size_); + return true; +} + +template +bool GPUHashTable::Import(const DataLenPair &input_data) { + // 1. Store input tensor data until receiving kImportTensorNum(3) input tensor. + // Really import input data to hash table when receive kImportTensorNum(3) input tensor. + static std::vector input_data_list; + if (input_data_list.size() < kImportTensorNum) { + input_data_list.emplace_back(input_data); + } + if (input_data_list.size() != kImportTensorNum) { + return true; + } + + const auto &input_keys = input_data_list[0]; + const auto &input_values = input_data_list[1]; + void *host_keys = input_keys.first; + void *host_values = input_values.first; + MS_ERROR_IF_NULL(host_keys); + MS_ERROR_IF_NULL(host_values); + + size_t keys_len = input_keys.second; + size_t values_len = input_values.second; + + // 2. Allocate temp buffer to keys and values. + char *device_keys = std::allocator_traits::allocate(char_alloc_, keys_len); + char *device_values = std::allocator_traits::allocate(char_alloc_, values_len); + MS_ERROR_IF_NULL(device_keys); + MS_ERROR_IF_NULL(device_values); + + // 3. Copy input keys and values to device. + cudaStream_t stream = reinterpret_cast(GPUDeviceManager::GetInstance().default_stream()); + CHECK_CUDA_RET_WITH_RETURN_FALSE(cudaMemcpyAsync(device_keys, host_keys, keys_len, cudaMemcpyHostToDevice, stream), + "cudaMemcpyAsync"); + CHECK_CUDA_RET_WITH_RETURN_FALSE( + cudaMemcpyAsync(device_values, host_values, values_len, cudaMemcpyHostToDevice, stream), "cudaMemcpyAsync"); + + // 4. Insert input keys and values to hash table. + RETURN_IF_FALSE_WITH_LOG(Insert(reinterpret_cast(device_keys), keys_len / sizeof(Key), + reinterpret_cast(device_values), stream), + "Insert keys and values failed."); + CHECK_CUDA_RET_WITH_RETURN_FALSE(cudaStreamSynchronize(stream), "cudaStreamSynchronize default cuda stream"); + + // 5. Free temp buffer to keys and values. + std::allocator_traits::deallocate(char_alloc_, device_keys, keys_len); + std::allocator_traits::deallocate(char_alloc_, device_values, values_len); + input_data_list.clear(); + + return true; +} + +template +bool GPUHashTable::Export(const DataLenPair &keys, const DataLenPair &values, + const DataLenPair &status) { + MS_ERROR_IF_NULL(keys.first); + MS_ERROR_IF_NULL(values.first); + MS_ERROR_IF_NULL(status.first); + + size_t keys_len = size_ * sizeof(Key); + size_t values_len = size_ * value_dim_ * sizeof(Value); + size_t status_len = size_ * sizeof(Status); + // 1. Check length for output tensor. + ASSERT_EQUAL( + keys_len, keys.second, + std::string("Need keys len[") + std::to_string(keys_len) + "], but got:[" + std::to_string(keys.second) + "]."); + ASSERT_EQUAL(values_len, values.second, + std::string("Need values len[") + std::to_string(values_len) + "], but got:[" + + std::to_string(values.second) + "]."); + ASSERT_EQUAL(status_len, status.second, + std::string("Need status len[") + std::to_string(status_len) + "], but got:[" + + std::to_string(status.second) + "]."); + + // 2. Allocate temp buffer to keys, values and status. + char *device_keys = std::allocator_traits::allocate(char_alloc_, keys_len); + char *device_values = std::allocator_traits::allocate(char_alloc_, values_len); + char *device_status = std::allocator_traits::allocate(char_alloc_, status_len); + MS_ERROR_IF_NULL(device_keys); + MS_ERROR_IF_NULL(device_values); + MS_ERROR_IF_NULL(device_status); + + // 3. Export all keys and indices and store into temp buffer. + cudaStream_t stream = reinterpret_cast(GPUDeviceManager::GetInstance().default_stream()); + RETURN_IF_FALSE_WITH_LOG( + GetKeysAndValues(reinterpret_cast(device_keys), reinterpret_cast(device_values), stream), + "Get keys and values failed."); + + // Note: Get all status. + // 4. Copy keys, values and status from device temp buffer to host. + CHECK_CUDA_RET_WITH_RETURN_FALSE(cudaMemcpyAsync(keys.first, device_keys, keys_len, cudaMemcpyDeviceToHost, stream), + "cudaMemcpyAsync"); + CHECK_CUDA_RET_WITH_RETURN_FALSE( + cudaMemcpyAsync(values.first, device_values, values_len, cudaMemcpyDeviceToHost, stream), "cudaMemcpyAsync"); + CHECK_CUDA_RET_WITH_RETURN_FALSE( + cudaMemcpyAsync(status.first, device_status, status_len, cudaMemcpyDeviceToHost, stream), "cudaMemcpyAsync"); + CHECK_CUDA_RET_WITH_RETURN_FALSE(cudaStreamSynchronize(stream), "cudaStreamSynchronize default cuda stream"); + + // 5. Free temp buffer to keys, values and status. + std::allocator_traits::deallocate(char_alloc_, device_keys, keys_len); + std::allocator_traits::deallocate(char_alloc_, device_values, values_len); + std::allocator_traits::deallocate(char_alloc_, device_status, status_len); + + return true; +} + +template +bool GPUHashTable::GetIndicesByKeys(const Key *key, size_t key_num, bool insert_miss_key, + int32_t *indices, cudaStream_t stream) { + MS_ERROR_IF_NULL(key); + MS_ERROR_IF_NULL(indices); + MS_ERROR_IF_NULL(stream); + MS_ERROR_IF_NULL(cuda_dynamic_map_); + auto &dynamic_map = cuda_dynamic_map_->dynamic_map_; + dynamic_map.reserve(key_num + dynamic_map.get_size()); + size_t submap_idx = 0; + uint32_t device_id = GET_CTX_DEVICE_ID; + size_t remaining_key_num = key_num; + MS_ERROR_IF_NULL(insert_success_number_); + + while (remaining_key_num > 0) { + auto &submap_ptr = dynamic_map.get_submaps()[submap_idx]; + MS_ERROR_IF_NULL(submap_ptr); + // 1. Get reamaining capacity in current submap, max load faltor and min insert size need to be considered. + size_t submap_remaining_capacity = + submap_ptr->get_capacity() * dynamic_map.get_max_load_factor() - submap_ptr->get_size(); + if (submap_remaining_capacity < dynamic_map.get_min_insert_size()) { + submap_idx++; + continue; + } + + *(insert_success_number_) = 0; + CHECK_CUDA_RET_WITH_RETURN_FALSE(cudaMemPrefetchAsync(insert_success_number_, sizeof(CudaAtomicSize), device_id), + "cudaMemPrefetchAsync"); + + // 2. Get the key number could be handled by current submap. + size_t item_num = std::min(submap_remaining_capacity, remaining_key_num); + const uint32_t tile_size = kTileSize; + const uint32_t block_size = kBlockSize; + if (IntToUint(GET_THREADS) < block_size) { + MS_LOG(ERROR) << "The max thread per block is less than: " << block_size << " of this GPU"; + } + const uint32_t grid_size = IntToUint(CUDA_BLOCKS_CAL(device_id, tile_size * item_num, block_size)); + + // 3. Transform all keys into indices in blocks. If the key exist in map already ,just return the index, + // otherwise find a valid position in block. + LookupIndices::mutable_view_type, + typename CucoDynamicMap::view_type><<>>( + key, item_num, insert_miss_key, dynamic_map.get_submaps().size(), submap_idx, idel_slot_, idle_index_, + dynamic_map.get_submap_mutable_views().data().get(), dynamic_map.get_submap_views().data().get(), + insert_success_number_, current_index_, indices); + + // 4. Update size for dynamic map and static submap. + CHECK_CUDA_RET_WITH_RETURN_FALSE(cudaStreamSynchronize(stream), "cudaStreamSynchronize"); + size_t insert_success_num = insert_success_number_->load(cuda::std::memory_order_relaxed); + dynamic_map.update_submap_size(submap_idx, submap_ptr->get_size() + insert_success_num); + dynamic_map.update_size(dynamic_map.get_size() + insert_success_num); + size_ += insert_success_num; + + indices += item_num; + key += item_num; + remaining_key_num -= item_num; + submap_idx++; + } + return true; +} + +template +bool GPUHashTable::InsertDefaultValueByInitializer(size_t key_num, + const std::string &initializer, + const int *indices, cudaStream_t stream) { + MS_ERROR_IF_NULL(indices); + MS_ERROR_IF_NULL(stream); + if (initializer == kNormalDistribution) { + // Normal distribution. + RETURN_IF_FALSE_WITH_LOG(InitNormalDistRandomGenerator(stream), + "Initialize normal distribution random generator failed."); + Value mean = static_cast(0); + Value stddev = static_cast(0.01); + + InsertNormalDistRandomValue<<>>( + value_dim_, key_num, indices, elements_per_block_, mean, stddev, random_gen_state_, idle_flags_ptr_, blocks_ptr_); + } else if (initializer == kOnesDistribution) { + // One distribution. + InsertDefaultValue<<>>( + value_dim_, key_num, indices, elements_per_block_, static_cast(1.0), idle_flags_ptr_, blocks_ptr_); + } else if (initializer == kZerosDistribution) { + // Zero distribution. + InsertDefaultValue<<>>( + value_dim_, key_num, indices, elements_per_block_, static_cast(0), idle_flags_ptr_, blocks_ptr_); + } else { + MS_LOG(ERROR) << "Unsupported initializer: " << initializer; + return false; + } + + return true; +} + +template +bool GPUHashTable::InitNormalDistRandomGenerator(cudaStream_t stream) { + MS_ERROR_IF_NULL(stream); + if (random_gen_init_.load()) { + return true; + } + + // 1. Allocate memory for all random generator states. + auto total_random_state_num = random_gen_threads_per_block_ * random_gen_block_count_; + random_gen_state_ = reinterpret_cast(std::allocator_traits::allocate( + char_alloc_, IntToSize(total_random_state_num) * sizeof(curandStatePhilox4_32_10_t))); + MS_ERROR_IF_NULL(random_gen_state_); + + // 2. Initialize normal distribution random generator states. + std::random_device rd; + uint32_t seed = rd(); + InitNormalDisRandomGen<<>>(seed, + random_gen_state_); + + random_gen_init_ = true; + return true; +} + +template class GPUHashTable; +template class GPUHashTable; +} // namespace gpu +} // namespace device +} // namespace mindspore +#endif diff --git a/mindspore/ccsrc/plugin/device/gpu/hal/device/gpu_hash_table.h b/mindspore/ccsrc/plugin/device/gpu/hal/device/gpu_hash_table.h index 175c4629b6a..3abb6a8301f 100644 --- a/mindspore/ccsrc/plugin/device/gpu/hal/device/gpu_hash_table.h +++ b/mindspore/ccsrc/plugin/device/gpu/hal/device/gpu_hash_table.h @@ -16,36 +16,62 @@ #ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_HAL_DEVICE_GPU_HASH_TABLE_H_ #define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_HAL_DEVICE_GPU_HASH_TABLE_H_ +#if CUDA_VERSION > 11000 +#include +#include + #include #include +#include +#include + #include "runtime/device/hash_table.h" +#include "plugin/device/gpu/hal/device/gpu_allocator.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h" namespace mindspore { namespace device { namespace gpu { -// A hash table base on GPU. +constexpr static int kMaxThreadsPerBlockRandomGen = 2048; +constexpr static size_t kInitialCapacity = 100000; +constexpr static uint32_t kTileSize = 4; +constexpr static uint32_t kBlockSize = 128; +constexpr static size_t kImportTensorNum = 3; +constexpr static char kNormalDistribution[] = "normal"; +constexpr static char kZerosDistribution[] = "zeros"; +constexpr static char kOnesDistribution[] = "ones"; + template +class CudaDynamicMap; +// A hash table base on GPU. +template > class GPUHashTable : public HashTable { public: - explicit GPUHashTable(int32_t value_dim, const Allocator &alloc = Allocator()) - : value_dim_(value_dim), alloc_(alloc) {} - ~GPUHashTable() = default; + using Status = typename HashTable::Status; - using key_allocator_type = typename std::allocator_traits::rebind_alloc; - using value_allocator_type = typename std::allocator_traits::rebind_alloc; - using index_allocator_type = typename std::allocator_traits::rebind_alloc; - using char_allocator_type = typename std::allocator_traits::rebind_alloc; + GPUHashTable(int32_t value_dim, const std::string &initializer, const Allocator &alloc = Allocator()); + GPUHashTable(int32_t value_dim, const Value &default_value, const Allocator &alloc = Allocator()); + ~GPUHashTable(); + + // The Allocator type used allocate gpu memory for 'Key' type. + using KeyAllocatorType = typename std::allocator_traits::template rebind_alloc; + // The Allocator type used allocate gpu memory for 'Value' type. + using ValueAllocatorType = typename std::allocator_traits::template rebind_alloc; + // The Allocator type used allocate gpu memory for 'index' type(int). + using IndexAllocatorType = typename std::allocator_traits::template rebind_alloc; + // The general Allocator type used allocate gpu memory. + using CharAllocatorType = typename std::allocator_traits::template rebind_alloc; // Find elements with specific keys, if a key does not exist, initialize the value for the key based on the // initialzer and insert the key-value pair into map. The initializer can be 'normal', 'zero' or 'one', and also - // could be a specific Value type scalar. + // could be a specific 'Value' type scalar. bool Find(const Key *keys, size_t key_num, Value *outputs, void *stream) override; // Insert elements with specific keys. If key exists, update the value of the key. - bool Insert(const Key *key, size_t key_num, const Value *value) override { return true; } + bool Insert(const Key *keys, size_t key_num, const Value *value, void *stream) override; // Erase elements with specific keys. - bool Erase(const Key *key, size_t key_num) override { return true; } + bool Erase(const Key *keys, size_t key_num, void *stream) override; // Reserves space for at least the specified number of elements. bool Reserve(size_t count) override { return true; } @@ -70,18 +96,22 @@ class GPUHashTable : public HashTable { private: // Find elements with specific keys, if the key does not exist, initialize the value for the key based on the - // initialzer and insert the key-value pair into map.The initializer can be 'normal', 'zero' or 'one'. - bool Find(const Key *key, size_t key_num, const std::string &initializer, Value *outputs) { return true; } + // initialzer and insert the key-value pair into map.The initializer can be 'normal', 'zeros' or 'ones'. + bool Find(const Key *keys, size_t key_num, const std::string &initializer, Value *outputs, void *stream); // Find elements with specific keys, if the key does not exist, initialize the value for the key by 'default_value' // and insert the key-value pair into map. - bool Find(const Key *key, size_t key_num, const Value &default_value, Value *outputs) { return true; } + bool Find(const Key *keys, size_t key_num, const Value &default_value, Value *outputs, void *stream); // Get all indices in blocks according to the key. - bool GetIndicesByKeys(const Key *key, size_t key_num, bool insert_miss_key, int32_t *indices) const; + bool GetIndicesByKeys(const Key *key, size_t key_num, bool insert_miss_key, int32_t *indices, cudaStream_t stream); - // Get value in blocks according to the indices. - bool GetValues(const int32_t *indices, Value *value) const; + // Insert default value according to initializer, initializer can be 'normal', 'zeros' or 'ones'. + bool InsertDefaultValueByInitializer(size_t key_num, const std::string &initializer, const int *indices, + cudaStream_t stream); + + // initialize normal distribution random generator states on GPU. + bool InitNormalDistRandomGenerator(cudaStream_t stream); // Record all block memory that contain all values. std::vector blocks_; @@ -93,19 +123,55 @@ class GPUHashTable : public HashTable { // Record whether every slot is occupied or not for all block, the buffer is on device memory. bool **idle_flags_ptr_{nullptr}; + // The sentinel record the latest used location in blocks. + cuda::atomic *current_index_{nullptr}; + + // The counter record the idle slot number, if the contents of a slot are erased, the slot is marked with the idle + // status. + cuda::atomic *idle_index_{nullptr}; + // The buffer keep all the idle slot position(offset index to the beginning of block). + int32_t *idel_slot_{nullptr}; + // The value dimension for each key. size_t value_dim_; - // The allocator used to alloacte gpu memory. - Allocator alloc_; + // The initializer used to initialize the values for missing keys, the initializer could be 'normal', 'zero' or 'one'. + std::string initializer_; + // The default value used to initialize the values for missing keys. + Value default_value_; + + // The allocator used to alloacte gpu memory for index. + IndexAllocatorType index_alloc_; + // The common allocator used to alloacte gpu memory. + CharAllocatorType char_alloc_; + + // The dynamic hash table on GPU. + std::unique_ptr> cuda_dynamic_map_; // Record the number of elements in the map. size_t size_{0}; // Record the number of elements that can be held in currently allocated storage size_t capacity_{0}; + + // The number of elements of one block. + size_t elements_per_block_{kInitialCapacity}; + + // Record the number of successfully inserted keys. + cuda::atomic *insert_success_number_{nullptr}; + + // The flag record whether normal distribution random generator state is initialized. + std::atomic_bool random_gen_init_{false}; + // The random generator states used to generate normal distribution. + curandStatePhilox4_32_10_t *random_gen_state_{nullptr}; + + // The block size used to launch cuda kernel for inserting normal distribution random values. + int random_gen_threads_per_block_{GET_THREADS}; + // The grid size used to launch cuda kernel for inserting normal distribution random values. + int random_gen_block_count_{(kMaxThreadsPerBlockRandomGen - 1) / random_gen_threads_per_block_ + 1}; }; } // namespace gpu } // namespace device } // namespace mindspore +#endif #endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_HAL_DEVICE_GPU_HASH_TABLE_H_ diff --git a/mindspore/ccsrc/plugin/device/gpu/hal/device/gpu_hash_table_kernel.cuh b/mindspore/ccsrc/plugin/device/gpu/hal/device/gpu_hash_table_kernel.cuh index 0f451633520..75c05b1b781 100644 --- a/mindspore/ccsrc/plugin/device/gpu/hal/device/gpu_hash_table_kernel.cuh +++ b/mindspore/ccsrc/plugin/device/gpu/hal/device/gpu_hash_table_kernel.cuh @@ -16,6 +16,7 @@ #ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_HAL_DEVICE_GPU_HASH_TABLE_KERNEL_CUH_ #define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_HAL_DEVICE_GPU_HASH_TABLE_KERNEL_CUH_ +#if CUDA_VERSION > 11000 #include #include @@ -219,4 +220,5 @@ __global__ void InsertValues(size_t value_dim, size_t total_insert_size, const i } // namespace gpu } // namespace device } // namespace mindspore +#endif #endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_HAL_DEVICE_GPU_HASH_TABLE_KERNEL_CUH_