ps cache code review

This commit is contained in:
limingqi107 2020-12-30 20:56:51 +08:00
parent f4ee467d77
commit 798cca42e5
5 changed files with 25 additions and 20 deletions

View File

@ -206,7 +206,7 @@ bool AscendPsCache::CopyDeviceMemToHost(void *dst, void *src, size_t size) {
}
bool AscendPsCache::HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr,
size_t hash_table_size, size_t embedding_size, size_t swap_out_size) {
size_t cache_vocab_size, size_t embedding_size, size_t swap_out_size) {
MS_ERROR_IF_NULL(hash_table_addr);
MS_ERROR_IF_NULL(swap_out_value_addr);
MS_ERROR_IF_NULL(swap_out_index_addr);
@ -217,7 +217,7 @@ bool AscendPsCache::HashSwapOut(void *hash_table_addr, void *swap_out_value_addr
std::vector<std::vector<size_t>> output_shape;
std::vector<TypeId> input_type = {TypeId::kNumberTypeFloat32, TypeId::kNumberTypeInt32, TypeId::kNumberTypeInt32};
std::vector<TypeId> output_type = {TypeId::kNumberTypeFloat32};
input_shape.push_back({hash_table_size, embedding_size});
input_shape.push_back({cache_vocab_size, embedding_size});
input_shape.push_back({swap_out_size});
input_shape.push_back({1});
output_shape.push_back({swap_out_size, embedding_size});
@ -229,7 +229,8 @@ bool AscendPsCache::HashSwapOut(void *hash_table_addr, void *swap_out_value_addr
AddressPtrList kernel_outputs = {
std::make_shared<Address>(swap_out_value_addr, swap_out_size * embedding_size * sizeof(float))};
AddressPtrList kernel_workspaces;
kernel_inputs.push_back(std::make_shared<Address>(hash_table_addr, hash_table_size * embedding_size * sizeof(float)));
kernel_inputs.push_back(
std::make_shared<Address>(hash_table_addr, cache_vocab_size * embedding_size * sizeof(float)));
kernel_inputs.push_back(std::make_shared<Address>(swap_out_index_addr, swap_out_size * sizeof(int)));
kernel_inputs.push_back(std::make_shared<Address>(offset_addr_, sizeof(int)));
auto ret = hash_swap_out_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_);
@ -241,7 +242,7 @@ bool AscendPsCache::HashSwapOut(void *hash_table_addr, void *swap_out_value_addr
}
bool AscendPsCache::HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr,
size_t hash_table_size, size_t embedding_size, size_t swap_in_size) {
size_t cache_vocab_size, size_t embedding_size, size_t swap_in_size) {
MS_ERROR_IF_NULL(hash_table_addr);
MS_ERROR_IF_NULL(swap_in_value_addr);
MS_ERROR_IF_NULL(swap_in_index_addr);
@ -253,7 +254,7 @@ bool AscendPsCache::HashSwapIn(void *hash_table_addr, void *swap_in_value_addr,
std::vector<TypeId> input_type = {TypeId::kNumberTypeFloat32, TypeId::kNumberTypeInt32, TypeId::kNumberTypeFloat32,
TypeId::kNumberTypeInt32};
std::vector<TypeId> output_type = {TypeId::kNumberTypeInt32};
input_shape.push_back({hash_table_size, embedding_size});
input_shape.push_back({cache_vocab_size, embedding_size});
input_shape.push_back({swap_in_size});
input_shape.push_back({swap_in_size, embedding_size});
input_shape.push_back({1});
@ -265,7 +266,8 @@ bool AscendPsCache::HashSwapIn(void *hash_table_addr, void *swap_in_value_addr,
AddressPtrList kernel_inputs;
AddressPtrList kernel_outputs;
AddressPtrList kernel_workspaces;
kernel_inputs.push_back(std::make_shared<Address>(hash_table_addr, hash_table_size * embedding_size * sizeof(float)));
kernel_inputs.push_back(
std::make_shared<Address>(hash_table_addr, cache_vocab_size * embedding_size * sizeof(float)));
kernel_inputs.push_back(std::make_shared<Address>(swap_in_index_addr, swap_in_size * sizeof(int)));
kernel_inputs.push_back(std::make_shared<Address>(swap_in_value_addr, swap_in_size * embedding_size * sizeof(float)));
kernel_inputs.push_back(std::make_shared<Address>(cache_vocab_size_addr_, sizeof(int)));

View File

@ -57,9 +57,9 @@ class AscendPsCache : public PsCacheBasic {
bool SynchronizeStream() override;
bool CopyHostMemToDevice(void *dst, void *src, size_t size) override;
bool CopyDeviceMemToHost(void *dst, void *src, size_t size) override;
bool HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, size_t hash_table_size,
bool HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, size_t cache_vocab_size,
size_t embedding_size, size_t swap_out_size) override;
bool HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, size_t hash_table_size,
bool HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, size_t cache_vocab_size,
size_t embedding_size, size_t swap_in_size) override;
private:

View File

@ -35,9 +35,9 @@ class GPUPsCache : public PsCacheBasic {
bool SynchronizeStream() override;
bool CopyHostMemToDevice(void *dst, void *src, size_t size) override;
bool CopyDeviceMemToHost(void *dst, void *src, size_t size) override;
bool HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, size_t hash_table_size,
bool HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, size_t cache_vocab_size,
size_t embedding_size, size_t swap_out_size) override;
bool HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, size_t hash_table_size,
bool HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, size_t cache_vocab_size,
size_t embedding_size, size_t swap_in_size) override;
private:

View File

@ -41,9 +41,9 @@ class PsCacheBasic {
virtual bool CopyHostMemToDevice(void *dst, void *src, size_t size) = 0;
virtual bool CopyDeviceMemToHost(void *dst, void *src, size_t size) = 0;
virtual bool HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr,
size_t hash_table_size, size_t embedding_size, size_t swap_out_size) = 0;
size_t cache_vocab_size, size_t embedding_size, size_t swap_out_size) = 0;
virtual bool HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr,
size_t hash_table_size, size_t embedding_size, size_t swap_in_size) = 0;
size_t cache_vocab_size, size_t embedding_size, size_t swap_in_size) = 0;
protected:
void *stream_;

View File

@ -660,7 +660,7 @@ bool PsCacheManager::HashSwapHostToDevice(const HashTableInfo &hash_info) {
}
auto embedding_size = hash_info.embedding_size;
auto hash_table_addr = reinterpret_cast<float *>(hash_info.device_address.addr);
auto hash_table_size = hash_info.device_address.size;
auto cache_vocab_size = hash_info.cache_vocab_size;
auto host_hash_table_addr = reinterpret_cast<float *>(hash_info.host_address.get());
auto swap_out_data = std::make_unique<float[]>(swap_indices_size * embedding_size);
RETURN_IF_FALSE(LookUpHostHashTable(embedding_size, swap_indices_size, host_hash_table_addr,
@ -673,7 +673,7 @@ bool PsCacheManager::HashSwapHostToDevice(const HashTableInfo &hash_info) {
swap_indices_size * sizeof(int)));
RETURN_IF_FALSE(embedding_device_cache_->cache_->HashSwapIn(
hash_table_addr, embedding_device_cache_->hash_swap_value_addr_, embedding_device_cache_->hash_swap_index_addr_,
hash_table_size, embedding_size, swap_indices_size));
cache_vocab_size, embedding_size, swap_indices_size));
RETURN_IF_FALSE(embedding_device_cache_->cache_->SynchronizeStream());
return true;
}
@ -689,7 +689,7 @@ bool PsCacheManager::HashSwapDeviceToHost(const HashTableInfo &hash_info) {
return true;
}
auto hash_table_addr = reinterpret_cast<float *>(hash_info.device_address.addr);
auto hash_table_size = hash_info.device_address.size;
auto cache_vocab_size = hash_info.cache_vocab_size;
auto host_hash_table_addr = reinterpret_cast<float *>(hash_info.host_address.get());
auto embedding_size = hash_info.embedding_size;
auto swap_out_data = std::make_unique<float[]>(swap_indices_size * embedding_size);
@ -698,7 +698,7 @@ bool PsCacheManager::HashSwapDeviceToHost(const HashTableInfo &hash_info) {
swap_indices_size * sizeof(int)));
RETURN_IF_FALSE(embedding_device_cache_->cache_->HashSwapOut(
hash_table_addr, embedding_device_cache_->hash_swap_value_addr_, embedding_device_cache_->hash_swap_index_addr_,
hash_table_size, embedding_size, swap_indices_size));
cache_vocab_size, embedding_size, swap_indices_size));
RETURN_IF_FALSE(embedding_device_cache_->cache_->CopyDeviceMemToHost(
swap_out_data.get(), embedding_device_cache_->hash_swap_value_addr_,
swap_indices_size * embedding_size * sizeof(float)));
@ -770,14 +770,14 @@ bool PsCacheManager::HashSwapDeviceOut(int *swap_out_index, ::ps::SArray<float>
return true;
}
auto hash_table_addr = reinterpret_cast<float *>(hash_info.device_address.addr);
auto hash_table_size = hash_info.device_address.size;
auto cache_vocab_size = hash_info.cache_vocab_size;
auto embedding_size = hash_info.embedding_size;
swap_out_data->resize(swap_out_index_size * embedding_size);
RETURN_IF_FALSE(embedding_device_cache_->cache_->CopyHostMemToDevice(
embedding_device_cache_->hash_swap_index_addr_, swap_out_index, swap_out_index_size * sizeof(int)));
RETURN_IF_FALSE(embedding_device_cache_->cache_->HashSwapOut(
hash_table_addr, embedding_device_cache_->hash_swap_value_addr_, embedding_device_cache_->hash_swap_index_addr_,
hash_table_size, embedding_size, swap_out_index_size));
cache_vocab_size, embedding_size, swap_out_index_size));
RETURN_IF_FALSE(embedding_device_cache_->cache_->CopyDeviceMemToHost(
swap_out_data->data(), embedding_device_cache_->hash_swap_value_addr_,
swap_out_index_size * embedding_size * sizeof(float)));
@ -796,7 +796,7 @@ bool PsCacheManager::HashSwapDeviceIn(int *swap_in_ids, int *swap_in_index, cons
return true;
}
auto hash_table_addr = reinterpret_cast<float *>(hash_info.device_address.addr);
auto hash_table_size = hash_info.device_address.size;
auto cache_vocab_size = hash_info.cache_vocab_size;
auto embedding_size = hash_info.embedding_size;
// Get id embs by swap_in_ids in host(Pipeline with hash swap-out in device).
::ps::SArray<int> lengths{swap_in_ids_size};
@ -817,7 +817,7 @@ bool PsCacheManager::HashSwapDeviceIn(int *swap_in_ids, int *swap_in_index, cons
swap_in_index, swap_in_ids_size * sizeof(int)));
RETURN_IF_FALSE(embedding_device_cache_->cache_->HashSwapIn(
hash_table_addr, embedding_device_cache_->hash_swap_value_addr_, embedding_device_cache_->hash_swap_index_addr_,
hash_table_size, embedding_size, swap_in_ids_size));
cache_vocab_size, embedding_size, swap_in_ids_size));
return true;
}
@ -846,6 +846,9 @@ void PsCacheManager::SyncEmbeddingTable() {
if (finish_embedding_table_sync_) {
return;
}
if (!initialized_ps_cache_) {
return;
}
if (!SyncHostEmbeddingTable()) {
MS_LOG(ERROR) << "SyncHostEmbeddingTable failed.";
}