add host mem stackt

Use correct syn api in Ascend.
This commit is contained in:
liangzelang 2022-02-10 14:22:58 +08:00
parent fe8588f32f
commit 7ba0e2ed8f
2 changed files with 38 additions and 12 deletions

View File

@ -143,15 +143,15 @@ bool Copy(const DeviceTensor *dst_device_tensor, const DeviceTensor *src_device_
<< ", output size:" << dst_device_tensor->GetSize();
}
// Exist the size alignment in some device, so get the min device size.
size_t copy_size = std::min(src_device_tensor->GetSize(), dst_device_tensor->GetSize());
if (src_device_tensor->DeviceType() == device::DeviceAddressType::kCPU) {
// CPU device tensor copy to other device tensor.
return dst_device_tensor->SyncHostToDevice(copy_size, src_device_tensor->GetPtr());
return dst_device_tensor->SyncHostToDevice(src_device_tensor->host_shape(), src_device_tensor->GetSize(),
src_device_tensor->type_id(), src_device_tensor->GetPtr(),
src_device_tensor->format());
} else if (dst_device_tensor->DeviceType() == device::DeviceAddressType::kCPU) {
// Other device tensor copy to CPU device tensor.
return src_device_tensor->SyncDeviceToHost(copy_size, dst_device_tensor->GetMutablePtr());
return src_device_tensor->SyncDeviceToHost(dst_device_tensor->host_shape(), dst_device_tensor->GetSize(),
dst_device_tensor->type_id(), dst_device_tensor->GetMutablePtr());
} else if (dst_device_tensor->DeviceType() == src_device_tensor->DeviceType()) {
return dst_device_tensor->SyncDeviceToDevice(src_device_tensor);
} else {

View File

@ -16,6 +16,7 @@
#include "runtime/framework/actor/control_flow/exit_actor.h"
#include "runtime/framework/actor/output_actor.h"
#include "runtime/hardware/device_context_manager.h"
namespace mindspore {
namespace runtime {
@ -152,8 +153,19 @@ void ExitActor::CopyDeviceAddress(OpContext<DeviceTensor> *const context) {
}
MS_EXCEPTION_IF_NULL(device_contexts_[i]);
// Create the new device tensor to take over the input_device_tensors which are the outputs of kernel graphs.
auto new_device_tensor = device_contexts_[i]->CreateDeviceAddress(
nullptr, input_device_tensor->GetSize(), input_device_tensor->format(), input_device_tensor->type_id());
device::DeviceAddressPtr new_device_tensor;
if (common::GetEnv("ENABLE_HOST_MEM_STACK") == "1") {
// Create new device tensor in host.
const auto &host_device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
{"CPU", device_contexts_[i]->device_context_key().device_id_});
MS_EXCEPTION_IF_NULL(host_device_context);
host_device_context->Initialize();
new_device_tensor = host_device_context->CreateDeviceAddress(
nullptr, input_device_tensor->GetSize(), input_device_tensor->format(), input_device_tensor->type_id());
} else {
new_device_tensor = device_contexts_[i]->CreateDeviceAddress(
nullptr, input_device_tensor->GetSize(), input_device_tensor->format(), input_device_tensor->type_id());
}
MS_EXCEPTION_IF_NULL(new_device_tensor);
(void)created_device_tensors_.emplace_back(new_device_tensor);
(void)new_device_tensors.emplace_back(new_device_tensor.get());
@ -175,11 +187,25 @@ void ExitActor::CopyDeviceAddress(OpContext<DeviceTensor> *const context) {
SET_OPCONTEXT_FAIL_RET_WITH_ERROR(*context, "Sync device to device failed.");
}
} else {
// Move the device ptr from input_device_tensor to new_device_tensor.
new_device_tensor->set_ptr(input_device_tensor->GetMutablePtr());
new_device_tensor->set_from_mem_pool(input_device_tensor->from_mem_pool());
input_device_tensor->set_ptr(nullptr);
input_device_tensor->set_from_mem_pool(false);
if (new_device_tensor->DeviceType() == input_device_tensor->DeviceType()) {
// Move the device ptr from input_device_tensor to new_device_tensor.
new_device_tensor->set_ptr(input_device_tensor->GetMutablePtr());
new_device_tensor->set_from_mem_pool(input_device_tensor->from_mem_pool());
input_device_tensor->set_ptr(nullptr);
input_device_tensor->set_from_mem_pool(false);
} else {
const auto &host_device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
{"CPU", device_contexts_[i]->device_context_key().device_id_});
MS_EXCEPTION_IF_NULL(host_device_context);
if (!host_device_context->AllocateMemory(new_device_tensor.get(), new_device_tensor->GetSize())) {
SET_OPCONTEXT_MEMORY_ALLOC_FAIL_BY_STRATEGY(GraphExecutionStrategy::kPipeline, *context, *host_device_context,
GetAID().Name(), new_device_tensor->GetSize());
}
Copy(new_device_tensor.get(), input_device_tensor);
device_contexts_[i]->FreeMemory(input_device_tensor);
input_device_tensor->set_ptr(nullptr);
input_device_tensor->set_from_mem_pool(false);
}
}
MS_LOG(DEBUG) << GetAID().Name() << " creates the dynamic ref device address:" << new_device_tensor.get()
<< ", ptr:" << new_device_tensor->GetPtr()