forked from mindspore-Ecosystem/mindspore
ps cache code review
This commit is contained in:
parent
f4ee467d77
commit
798cca42e5
|
@ -206,7 +206,7 @@ bool AscendPsCache::CopyDeviceMemToHost(void *dst, void *src, size_t size) {
|
||||||
}
|
}
|
||||||
|
|
||||||
bool AscendPsCache::HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr,
|
bool AscendPsCache::HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr,
|
||||||
size_t hash_table_size, size_t embedding_size, size_t swap_out_size) {
|
size_t cache_vocab_size, size_t embedding_size, size_t swap_out_size) {
|
||||||
MS_ERROR_IF_NULL(hash_table_addr);
|
MS_ERROR_IF_NULL(hash_table_addr);
|
||||||
MS_ERROR_IF_NULL(swap_out_value_addr);
|
MS_ERROR_IF_NULL(swap_out_value_addr);
|
||||||
MS_ERROR_IF_NULL(swap_out_index_addr);
|
MS_ERROR_IF_NULL(swap_out_index_addr);
|
||||||
|
@ -217,7 +217,7 @@ bool AscendPsCache::HashSwapOut(void *hash_table_addr, void *swap_out_value_addr
|
||||||
std::vector<std::vector<size_t>> output_shape;
|
std::vector<std::vector<size_t>> output_shape;
|
||||||
std::vector<TypeId> input_type = {TypeId::kNumberTypeFloat32, TypeId::kNumberTypeInt32, TypeId::kNumberTypeInt32};
|
std::vector<TypeId> input_type = {TypeId::kNumberTypeFloat32, TypeId::kNumberTypeInt32, TypeId::kNumberTypeInt32};
|
||||||
std::vector<TypeId> output_type = {TypeId::kNumberTypeFloat32};
|
std::vector<TypeId> output_type = {TypeId::kNumberTypeFloat32};
|
||||||
input_shape.push_back({hash_table_size, embedding_size});
|
input_shape.push_back({cache_vocab_size, embedding_size});
|
||||||
input_shape.push_back({swap_out_size});
|
input_shape.push_back({swap_out_size});
|
||||||
input_shape.push_back({1});
|
input_shape.push_back({1});
|
||||||
output_shape.push_back({swap_out_size, embedding_size});
|
output_shape.push_back({swap_out_size, embedding_size});
|
||||||
|
@ -229,7 +229,8 @@ bool AscendPsCache::HashSwapOut(void *hash_table_addr, void *swap_out_value_addr
|
||||||
AddressPtrList kernel_outputs = {
|
AddressPtrList kernel_outputs = {
|
||||||
std::make_shared<Address>(swap_out_value_addr, swap_out_size * embedding_size * sizeof(float))};
|
std::make_shared<Address>(swap_out_value_addr, swap_out_size * embedding_size * sizeof(float))};
|
||||||
AddressPtrList kernel_workspaces;
|
AddressPtrList kernel_workspaces;
|
||||||
kernel_inputs.push_back(std::make_shared<Address>(hash_table_addr, hash_table_size * embedding_size * sizeof(float)));
|
kernel_inputs.push_back(
|
||||||
|
std::make_shared<Address>(hash_table_addr, cache_vocab_size * embedding_size * sizeof(float)));
|
||||||
kernel_inputs.push_back(std::make_shared<Address>(swap_out_index_addr, swap_out_size * sizeof(int)));
|
kernel_inputs.push_back(std::make_shared<Address>(swap_out_index_addr, swap_out_size * sizeof(int)));
|
||||||
kernel_inputs.push_back(std::make_shared<Address>(offset_addr_, sizeof(int)));
|
kernel_inputs.push_back(std::make_shared<Address>(offset_addr_, sizeof(int)));
|
||||||
auto ret = hash_swap_out_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_);
|
auto ret = hash_swap_out_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_);
|
||||||
|
@ -241,7 +242,7 @@ bool AscendPsCache::HashSwapOut(void *hash_table_addr, void *swap_out_value_addr
|
||||||
}
|
}
|
||||||
|
|
||||||
bool AscendPsCache::HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr,
|
bool AscendPsCache::HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr,
|
||||||
size_t hash_table_size, size_t embedding_size, size_t swap_in_size) {
|
size_t cache_vocab_size, size_t embedding_size, size_t swap_in_size) {
|
||||||
MS_ERROR_IF_NULL(hash_table_addr);
|
MS_ERROR_IF_NULL(hash_table_addr);
|
||||||
MS_ERROR_IF_NULL(swap_in_value_addr);
|
MS_ERROR_IF_NULL(swap_in_value_addr);
|
||||||
MS_ERROR_IF_NULL(swap_in_index_addr);
|
MS_ERROR_IF_NULL(swap_in_index_addr);
|
||||||
|
@ -253,7 +254,7 @@ bool AscendPsCache::HashSwapIn(void *hash_table_addr, void *swap_in_value_addr,
|
||||||
std::vector<TypeId> input_type = {TypeId::kNumberTypeFloat32, TypeId::kNumberTypeInt32, TypeId::kNumberTypeFloat32,
|
std::vector<TypeId> input_type = {TypeId::kNumberTypeFloat32, TypeId::kNumberTypeInt32, TypeId::kNumberTypeFloat32,
|
||||||
TypeId::kNumberTypeInt32};
|
TypeId::kNumberTypeInt32};
|
||||||
std::vector<TypeId> output_type = {TypeId::kNumberTypeInt32};
|
std::vector<TypeId> output_type = {TypeId::kNumberTypeInt32};
|
||||||
input_shape.push_back({hash_table_size, embedding_size});
|
input_shape.push_back({cache_vocab_size, embedding_size});
|
||||||
input_shape.push_back({swap_in_size});
|
input_shape.push_back({swap_in_size});
|
||||||
input_shape.push_back({swap_in_size, embedding_size});
|
input_shape.push_back({swap_in_size, embedding_size});
|
||||||
input_shape.push_back({1});
|
input_shape.push_back({1});
|
||||||
|
@ -265,7 +266,8 @@ bool AscendPsCache::HashSwapIn(void *hash_table_addr, void *swap_in_value_addr,
|
||||||
AddressPtrList kernel_inputs;
|
AddressPtrList kernel_inputs;
|
||||||
AddressPtrList kernel_outputs;
|
AddressPtrList kernel_outputs;
|
||||||
AddressPtrList kernel_workspaces;
|
AddressPtrList kernel_workspaces;
|
||||||
kernel_inputs.push_back(std::make_shared<Address>(hash_table_addr, hash_table_size * embedding_size * sizeof(float)));
|
kernel_inputs.push_back(
|
||||||
|
std::make_shared<Address>(hash_table_addr, cache_vocab_size * embedding_size * sizeof(float)));
|
||||||
kernel_inputs.push_back(std::make_shared<Address>(swap_in_index_addr, swap_in_size * sizeof(int)));
|
kernel_inputs.push_back(std::make_shared<Address>(swap_in_index_addr, swap_in_size * sizeof(int)));
|
||||||
kernel_inputs.push_back(std::make_shared<Address>(swap_in_value_addr, swap_in_size * embedding_size * sizeof(float)));
|
kernel_inputs.push_back(std::make_shared<Address>(swap_in_value_addr, swap_in_size * embedding_size * sizeof(float)));
|
||||||
kernel_inputs.push_back(std::make_shared<Address>(cache_vocab_size_addr_, sizeof(int)));
|
kernel_inputs.push_back(std::make_shared<Address>(cache_vocab_size_addr_, sizeof(int)));
|
||||||
|
|
|
@ -57,9 +57,9 @@ class AscendPsCache : public PsCacheBasic {
|
||||||
bool SynchronizeStream() override;
|
bool SynchronizeStream() override;
|
||||||
bool CopyHostMemToDevice(void *dst, void *src, size_t size) override;
|
bool CopyHostMemToDevice(void *dst, void *src, size_t size) override;
|
||||||
bool CopyDeviceMemToHost(void *dst, void *src, size_t size) override;
|
bool CopyDeviceMemToHost(void *dst, void *src, size_t size) override;
|
||||||
bool HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, size_t hash_table_size,
|
bool HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, size_t cache_vocab_size,
|
||||||
size_t embedding_size, size_t swap_out_size) override;
|
size_t embedding_size, size_t swap_out_size) override;
|
||||||
bool HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, size_t hash_table_size,
|
bool HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, size_t cache_vocab_size,
|
||||||
size_t embedding_size, size_t swap_in_size) override;
|
size_t embedding_size, size_t swap_in_size) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
|
@ -35,9 +35,9 @@ class GPUPsCache : public PsCacheBasic {
|
||||||
bool SynchronizeStream() override;
|
bool SynchronizeStream() override;
|
||||||
bool CopyHostMemToDevice(void *dst, void *src, size_t size) override;
|
bool CopyHostMemToDevice(void *dst, void *src, size_t size) override;
|
||||||
bool CopyDeviceMemToHost(void *dst, void *src, size_t size) override;
|
bool CopyDeviceMemToHost(void *dst, void *src, size_t size) override;
|
||||||
bool HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, size_t hash_table_size,
|
bool HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, size_t cache_vocab_size,
|
||||||
size_t embedding_size, size_t swap_out_size) override;
|
size_t embedding_size, size_t swap_out_size) override;
|
||||||
bool HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, size_t hash_table_size,
|
bool HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, size_t cache_vocab_size,
|
||||||
size_t embedding_size, size_t swap_in_size) override;
|
size_t embedding_size, size_t swap_in_size) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
|
@ -41,9 +41,9 @@ class PsCacheBasic {
|
||||||
virtual bool CopyHostMemToDevice(void *dst, void *src, size_t size) = 0;
|
virtual bool CopyHostMemToDevice(void *dst, void *src, size_t size) = 0;
|
||||||
virtual bool CopyDeviceMemToHost(void *dst, void *src, size_t size) = 0;
|
virtual bool CopyDeviceMemToHost(void *dst, void *src, size_t size) = 0;
|
||||||
virtual bool HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr,
|
virtual bool HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr,
|
||||||
size_t hash_table_size, size_t embedding_size, size_t swap_out_size) = 0;
|
size_t cache_vocab_size, size_t embedding_size, size_t swap_out_size) = 0;
|
||||||
virtual bool HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr,
|
virtual bool HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr,
|
||||||
size_t hash_table_size, size_t embedding_size, size_t swap_in_size) = 0;
|
size_t cache_vocab_size, size_t embedding_size, size_t swap_in_size) = 0;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
void *stream_;
|
void *stream_;
|
||||||
|
|
|
@ -660,7 +660,7 @@ bool PsCacheManager::HashSwapHostToDevice(const HashTableInfo &hash_info) {
|
||||||
}
|
}
|
||||||
auto embedding_size = hash_info.embedding_size;
|
auto embedding_size = hash_info.embedding_size;
|
||||||
auto hash_table_addr = reinterpret_cast<float *>(hash_info.device_address.addr);
|
auto hash_table_addr = reinterpret_cast<float *>(hash_info.device_address.addr);
|
||||||
auto hash_table_size = hash_info.device_address.size;
|
auto cache_vocab_size = hash_info.cache_vocab_size;
|
||||||
auto host_hash_table_addr = reinterpret_cast<float *>(hash_info.host_address.get());
|
auto host_hash_table_addr = reinterpret_cast<float *>(hash_info.host_address.get());
|
||||||
auto swap_out_data = std::make_unique<float[]>(swap_indices_size * embedding_size);
|
auto swap_out_data = std::make_unique<float[]>(swap_indices_size * embedding_size);
|
||||||
RETURN_IF_FALSE(LookUpHostHashTable(embedding_size, swap_indices_size, host_hash_table_addr,
|
RETURN_IF_FALSE(LookUpHostHashTable(embedding_size, swap_indices_size, host_hash_table_addr,
|
||||||
|
@ -673,7 +673,7 @@ bool PsCacheManager::HashSwapHostToDevice(const HashTableInfo &hash_info) {
|
||||||
swap_indices_size * sizeof(int)));
|
swap_indices_size * sizeof(int)));
|
||||||
RETURN_IF_FALSE(embedding_device_cache_->cache_->HashSwapIn(
|
RETURN_IF_FALSE(embedding_device_cache_->cache_->HashSwapIn(
|
||||||
hash_table_addr, embedding_device_cache_->hash_swap_value_addr_, embedding_device_cache_->hash_swap_index_addr_,
|
hash_table_addr, embedding_device_cache_->hash_swap_value_addr_, embedding_device_cache_->hash_swap_index_addr_,
|
||||||
hash_table_size, embedding_size, swap_indices_size));
|
cache_vocab_size, embedding_size, swap_indices_size));
|
||||||
RETURN_IF_FALSE(embedding_device_cache_->cache_->SynchronizeStream());
|
RETURN_IF_FALSE(embedding_device_cache_->cache_->SynchronizeStream());
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -689,7 +689,7 @@ bool PsCacheManager::HashSwapDeviceToHost(const HashTableInfo &hash_info) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
auto hash_table_addr = reinterpret_cast<float *>(hash_info.device_address.addr);
|
auto hash_table_addr = reinterpret_cast<float *>(hash_info.device_address.addr);
|
||||||
auto hash_table_size = hash_info.device_address.size;
|
auto cache_vocab_size = hash_info.cache_vocab_size;
|
||||||
auto host_hash_table_addr = reinterpret_cast<float *>(hash_info.host_address.get());
|
auto host_hash_table_addr = reinterpret_cast<float *>(hash_info.host_address.get());
|
||||||
auto embedding_size = hash_info.embedding_size;
|
auto embedding_size = hash_info.embedding_size;
|
||||||
auto swap_out_data = std::make_unique<float[]>(swap_indices_size * embedding_size);
|
auto swap_out_data = std::make_unique<float[]>(swap_indices_size * embedding_size);
|
||||||
|
@ -698,7 +698,7 @@ bool PsCacheManager::HashSwapDeviceToHost(const HashTableInfo &hash_info) {
|
||||||
swap_indices_size * sizeof(int)));
|
swap_indices_size * sizeof(int)));
|
||||||
RETURN_IF_FALSE(embedding_device_cache_->cache_->HashSwapOut(
|
RETURN_IF_FALSE(embedding_device_cache_->cache_->HashSwapOut(
|
||||||
hash_table_addr, embedding_device_cache_->hash_swap_value_addr_, embedding_device_cache_->hash_swap_index_addr_,
|
hash_table_addr, embedding_device_cache_->hash_swap_value_addr_, embedding_device_cache_->hash_swap_index_addr_,
|
||||||
hash_table_size, embedding_size, swap_indices_size));
|
cache_vocab_size, embedding_size, swap_indices_size));
|
||||||
RETURN_IF_FALSE(embedding_device_cache_->cache_->CopyDeviceMemToHost(
|
RETURN_IF_FALSE(embedding_device_cache_->cache_->CopyDeviceMemToHost(
|
||||||
swap_out_data.get(), embedding_device_cache_->hash_swap_value_addr_,
|
swap_out_data.get(), embedding_device_cache_->hash_swap_value_addr_,
|
||||||
swap_indices_size * embedding_size * sizeof(float)));
|
swap_indices_size * embedding_size * sizeof(float)));
|
||||||
|
@ -770,14 +770,14 @@ bool PsCacheManager::HashSwapDeviceOut(int *swap_out_index, ::ps::SArray<float>
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
auto hash_table_addr = reinterpret_cast<float *>(hash_info.device_address.addr);
|
auto hash_table_addr = reinterpret_cast<float *>(hash_info.device_address.addr);
|
||||||
auto hash_table_size = hash_info.device_address.size;
|
auto cache_vocab_size = hash_info.cache_vocab_size;
|
||||||
auto embedding_size = hash_info.embedding_size;
|
auto embedding_size = hash_info.embedding_size;
|
||||||
swap_out_data->resize(swap_out_index_size * embedding_size);
|
swap_out_data->resize(swap_out_index_size * embedding_size);
|
||||||
RETURN_IF_FALSE(embedding_device_cache_->cache_->CopyHostMemToDevice(
|
RETURN_IF_FALSE(embedding_device_cache_->cache_->CopyHostMemToDevice(
|
||||||
embedding_device_cache_->hash_swap_index_addr_, swap_out_index, swap_out_index_size * sizeof(int)));
|
embedding_device_cache_->hash_swap_index_addr_, swap_out_index, swap_out_index_size * sizeof(int)));
|
||||||
RETURN_IF_FALSE(embedding_device_cache_->cache_->HashSwapOut(
|
RETURN_IF_FALSE(embedding_device_cache_->cache_->HashSwapOut(
|
||||||
hash_table_addr, embedding_device_cache_->hash_swap_value_addr_, embedding_device_cache_->hash_swap_index_addr_,
|
hash_table_addr, embedding_device_cache_->hash_swap_value_addr_, embedding_device_cache_->hash_swap_index_addr_,
|
||||||
hash_table_size, embedding_size, swap_out_index_size));
|
cache_vocab_size, embedding_size, swap_out_index_size));
|
||||||
RETURN_IF_FALSE(embedding_device_cache_->cache_->CopyDeviceMemToHost(
|
RETURN_IF_FALSE(embedding_device_cache_->cache_->CopyDeviceMemToHost(
|
||||||
swap_out_data->data(), embedding_device_cache_->hash_swap_value_addr_,
|
swap_out_data->data(), embedding_device_cache_->hash_swap_value_addr_,
|
||||||
swap_out_index_size * embedding_size * sizeof(float)));
|
swap_out_index_size * embedding_size * sizeof(float)));
|
||||||
|
@ -796,7 +796,7 @@ bool PsCacheManager::HashSwapDeviceIn(int *swap_in_ids, int *swap_in_index, cons
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
auto hash_table_addr = reinterpret_cast<float *>(hash_info.device_address.addr);
|
auto hash_table_addr = reinterpret_cast<float *>(hash_info.device_address.addr);
|
||||||
auto hash_table_size = hash_info.device_address.size;
|
auto cache_vocab_size = hash_info.cache_vocab_size;
|
||||||
auto embedding_size = hash_info.embedding_size;
|
auto embedding_size = hash_info.embedding_size;
|
||||||
// Get id embs by swap_in_ids in host(Pipeline with hash swap-out in device).
|
// Get id embs by swap_in_ids in host(Pipeline with hash swap-out in device).
|
||||||
::ps::SArray<int> lengths{swap_in_ids_size};
|
::ps::SArray<int> lengths{swap_in_ids_size};
|
||||||
|
@ -817,7 +817,7 @@ bool PsCacheManager::HashSwapDeviceIn(int *swap_in_ids, int *swap_in_index, cons
|
||||||
swap_in_index, swap_in_ids_size * sizeof(int)));
|
swap_in_index, swap_in_ids_size * sizeof(int)));
|
||||||
RETURN_IF_FALSE(embedding_device_cache_->cache_->HashSwapIn(
|
RETURN_IF_FALSE(embedding_device_cache_->cache_->HashSwapIn(
|
||||||
hash_table_addr, embedding_device_cache_->hash_swap_value_addr_, embedding_device_cache_->hash_swap_index_addr_,
|
hash_table_addr, embedding_device_cache_->hash_swap_value_addr_, embedding_device_cache_->hash_swap_index_addr_,
|
||||||
hash_table_size, embedding_size, swap_in_ids_size));
|
cache_vocab_size, embedding_size, swap_in_ids_size));
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -846,6 +846,9 @@ void PsCacheManager::SyncEmbeddingTable() {
|
||||||
if (finish_embedding_table_sync_) {
|
if (finish_embedding_table_sync_) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
if (!initialized_ps_cache_) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
if (!SyncHostEmbeddingTable()) {
|
if (!SyncHostEmbeddingTable()) {
|
||||||
MS_LOG(ERROR) << "SyncHostEmbeddingTable failed.";
|
MS_LOG(ERROR) << "SyncHostEmbeddingTable failed.";
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue