!21246 [bugfix]ps cache MallocConstantMemory failed

Merge pull request !21246 from zyli2020/mindrt_debug
This commit is contained in:
i-robot 2021-08-02 15:29:38 +00:00 committed by Gitee
commit 564306c232
1 changed files with 8 additions and 3 deletions

View File

@ -237,9 +237,6 @@ void PsCacheManager::AllocMemForHashTable() {
embedding_device_cache_->hash_swap_value_addr_ = reinterpret_cast<float *>(
embedding_device_cache_->cache_->MallocMemory(max_embedding_size * batch_elements_ * sizeof(float)));
MS_EXCEPTION_IF_NULL(embedding_device_cache_->hash_swap_value_addr_);
if (!(embedding_device_cache_->cache_->MallocConstantMemory(vocab_cache_size_))) {
MS_LOG(EXCEPTION) << "MallocConstantMemory failed.";
}
}
void PsCacheManager::SetLocalIdRank() {
@ -328,6 +325,14 @@ void PsCacheManager::ProcessDataTask(uint32_t device_id, const void *context) {
MS_ERROR_IF_NULL_WO_RET_VAL(embedding_device_cache_);
MS_ERROR_IF_NULL_WO_RET_VAL(embedding_device_cache_->cache_);
embedding_device_cache_->cache_->InitDevice(device_id, context);
// MallocConstantMemory need stream on device Ascend, should be called after InitDevice.
if (!(embedding_device_cache_->cache_->MallocConstantMemory(vocab_cache_size_))) {
MS_LOG(ERROR) << "MallocConstantMemory failed.";
running_ = false;
return;
}
InitParameterServer();
InitDataChannel();
while (running_) {