!10649 fix ascend ps cache loss invaild

From: @limingqi107
Reviewed-by: @cristoval,@chujinjin
Signed-off-by: @cristoval
This commit is contained in:
mindspore-ci-bot 2020-12-27 18:14:16 +08:00 committed by Gitee
commit 6fa83590c1
5 changed files with 13 additions and 6 deletions

View File

@ -131,15 +131,18 @@ void *AscendPsCache::MallocMemory(size_t size) {
return device::ascend::AscendMemoryPool::GetInstance().AllocTensorMem(size);
}
bool AscendPsCache::MallocConstantMemory(size_t constant_value) {
bool AscendPsCache::MallocConstantMemory(size_t cache_vocab_size) {
offset_addr_ = reinterpret_cast<int *>(device::ascend::AscendMemoryPool::GetInstance().AllocTensorMem(sizeof(int)));
MS_ERROR_IF_NULL(offset_addr_);
rtMemset(offset_addr_, sizeof(int), 0, sizeof(int));
cache_vocab_size_addr_ =
reinterpret_cast<int *>(device::ascend::AscendMemoryPool::GetInstance().AllocTensorMem(sizeof(int)));
MS_ERROR_IF_NULL(cache_vocab_size_addr_);
rtMemset(cache_vocab_size_addr_, sizeof(int), constant_value, sizeof(int));
return true;
int copy_value = SizeToInt(cache_vocab_size);
if (!CopyHostMemToDevice(cache_vocab_size_addr_, &copy_value, sizeof(int))) {
return false;
}
return SynchronizeStream();
}
bool AscendPsCache::RecordEvent() {

View File

@ -51,7 +51,7 @@ class AscendPsCache : public PsCacheBasic {
~AscendPsCache() override = default;
bool InitDevice(uint32_t device_id, const void *context) override;
void *MallocMemory(size_t size) override;
bool MallocConstantMemory(size_t constant_value) override;
bool MallocConstantMemory(size_t cache_vocab_size) override;
bool RecordEvent() override;
bool SynchronizeEvent() override;
bool SynchronizeStream() override;

View File

@ -34,7 +34,7 @@ class PsCacheBasic {
virtual ~PsCacheBasic() = default;
virtual bool InitDevice(uint32_t device_id, const void *context) = 0;
virtual void *MallocMemory(size_t size) = 0;
virtual bool MallocConstantMemory(size_t constant_value) { return true; }
virtual bool MallocConstantMemory(size_t cache_vocab_size) { return true; }
virtual bool RecordEvent() = 0;
virtual bool SynchronizeEvent() = 0;
virtual bool SynchronizeStream() = 0;

View File

@ -674,6 +674,7 @@ bool PsCacheManager::HashSwapHostToDevice(const HashTableInfo &hash_info) {
RETURN_IF_FALSE(embedding_device_cache_->cache_->HashSwapIn(
hash_table_addr, embedding_device_cache_->hash_swap_value_addr_, embedding_device_cache_->hash_swap_index_addr_,
hash_table_size, embedding_size, swap_indices_size));
RETURN_IF_FALSE(embedding_device_cache_->cache_->SynchronizeStream());
return true;
}

View File

@ -171,7 +171,10 @@ class EmbeddingLookup(Cell):
max_norm (Union[float, None]): A maximum clipping value. The data type must be float16, float32
or None. Default: None
sparse (bool): Using sparse mode. When 'target' is set to 'CPU', 'sparse' has to be true. Default: True.
vocab_cache_size (int): Cache size of the dictionary of embeddings.
vocab_cache_size (int): Cache size of the dictionary of embeddings. Default: 0. It is valid only in
parameter server trainning mode and 'DEVICE' target. And the moment parameter of corresponding
optimizer will also be set to the cache size. In addition, it should be noted that it will cost the 'DEVICE'
memory, so suggests setting a reasonable value to avoid insufficient memory.
Inputs:
- **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`.