forked from mindspore-Ecosystem/mindspore
!10649 fix ascend ps cache loss invaild
From: @limingqi107 Reviewed-by: @cristoval,@chujinjin Signed-off-by: @cristoval
This commit is contained in:
commit
6fa83590c1
|
@ -131,15 +131,18 @@ void *AscendPsCache::MallocMemory(size_t size) {
|
|||
return device::ascend::AscendMemoryPool::GetInstance().AllocTensorMem(size);
|
||||
}
|
||||
|
||||
bool AscendPsCache::MallocConstantMemory(size_t constant_value) {
|
||||
bool AscendPsCache::MallocConstantMemory(size_t cache_vocab_size) {
|
||||
offset_addr_ = reinterpret_cast<int *>(device::ascend::AscendMemoryPool::GetInstance().AllocTensorMem(sizeof(int)));
|
||||
MS_ERROR_IF_NULL(offset_addr_);
|
||||
rtMemset(offset_addr_, sizeof(int), 0, sizeof(int));
|
||||
cache_vocab_size_addr_ =
|
||||
reinterpret_cast<int *>(device::ascend::AscendMemoryPool::GetInstance().AllocTensorMem(sizeof(int)));
|
||||
MS_ERROR_IF_NULL(cache_vocab_size_addr_);
|
||||
rtMemset(cache_vocab_size_addr_, sizeof(int), constant_value, sizeof(int));
|
||||
return true;
|
||||
int copy_value = SizeToInt(cache_vocab_size);
|
||||
if (!CopyHostMemToDevice(cache_vocab_size_addr_, ©_value, sizeof(int))) {
|
||||
return false;
|
||||
}
|
||||
return SynchronizeStream();
|
||||
}
|
||||
|
||||
bool AscendPsCache::RecordEvent() {
|
||||
|
|
|
@ -51,7 +51,7 @@ class AscendPsCache : public PsCacheBasic {
|
|||
~AscendPsCache() override = default;
|
||||
bool InitDevice(uint32_t device_id, const void *context) override;
|
||||
void *MallocMemory(size_t size) override;
|
||||
bool MallocConstantMemory(size_t constant_value) override;
|
||||
bool MallocConstantMemory(size_t cache_vocab_size) override;
|
||||
bool RecordEvent() override;
|
||||
bool SynchronizeEvent() override;
|
||||
bool SynchronizeStream() override;
|
||||
|
|
|
@ -34,7 +34,7 @@ class PsCacheBasic {
|
|||
virtual ~PsCacheBasic() = default;
|
||||
virtual bool InitDevice(uint32_t device_id, const void *context) = 0;
|
||||
virtual void *MallocMemory(size_t size) = 0;
|
||||
virtual bool MallocConstantMemory(size_t constant_value) { return true; }
|
||||
virtual bool MallocConstantMemory(size_t cache_vocab_size) { return true; }
|
||||
virtual bool RecordEvent() = 0;
|
||||
virtual bool SynchronizeEvent() = 0;
|
||||
virtual bool SynchronizeStream() = 0;
|
||||
|
|
|
@ -674,6 +674,7 @@ bool PsCacheManager::HashSwapHostToDevice(const HashTableInfo &hash_info) {
|
|||
RETURN_IF_FALSE(embedding_device_cache_->cache_->HashSwapIn(
|
||||
hash_table_addr, embedding_device_cache_->hash_swap_value_addr_, embedding_device_cache_->hash_swap_index_addr_,
|
||||
hash_table_size, embedding_size, swap_indices_size));
|
||||
RETURN_IF_FALSE(embedding_device_cache_->cache_->SynchronizeStream());
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -171,7 +171,10 @@ class EmbeddingLookup(Cell):
|
|||
max_norm (Union[float, None]): A maximum clipping value. The data type must be float16, float32
|
||||
or None. Default: None
|
||||
sparse (bool): Using sparse mode. When 'target' is set to 'CPU', 'sparse' has to be true. Default: True.
|
||||
vocab_cache_size (int): Cache size of the dictionary of embeddings.
|
||||
vocab_cache_size (int): Cache size of the dictionary of embeddings. Default: 0. It is valid only in
|
||||
parameter server trainning mode and 'DEVICE' target. And the moment parameter of corresponding
|
||||
optimizer will also be set to the cache size. In addition, it should be noted that it will cost the 'DEVICE'
|
||||
memory, so suggests setting a reasonable value to avoid insufficient memory.
|
||||
|
||||
Inputs:
|
||||
- **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
|
||||
|
|
Loading…
Reference in New Issue