parent
fe8588f32f
commit
7ba0e2ed8f
|
@ -143,15 +143,15 @@ bool Copy(const DeviceTensor *dst_device_tensor, const DeviceTensor *src_device_
|
|||
<< ", output size:" << dst_device_tensor->GetSize();
|
||||
}
|
||||
|
||||
// Exist the size alignment in some device, so get the min device size.
|
||||
size_t copy_size = std::min(src_device_tensor->GetSize(), dst_device_tensor->GetSize());
|
||||
|
||||
if (src_device_tensor->DeviceType() == device::DeviceAddressType::kCPU) {
|
||||
// CPU device tensor copy to other device tensor.
|
||||
return dst_device_tensor->SyncHostToDevice(copy_size, src_device_tensor->GetPtr());
|
||||
return dst_device_tensor->SyncHostToDevice(src_device_tensor->host_shape(), src_device_tensor->GetSize(),
|
||||
src_device_tensor->type_id(), src_device_tensor->GetPtr(),
|
||||
src_device_tensor->format());
|
||||
} else if (dst_device_tensor->DeviceType() == device::DeviceAddressType::kCPU) {
|
||||
// Other device tensor copy to CPU device tensor.
|
||||
return src_device_tensor->SyncDeviceToHost(copy_size, dst_device_tensor->GetMutablePtr());
|
||||
return src_device_tensor->SyncDeviceToHost(dst_device_tensor->host_shape(), dst_device_tensor->GetSize(),
|
||||
dst_device_tensor->type_id(), dst_device_tensor->GetMutablePtr());
|
||||
} else if (dst_device_tensor->DeviceType() == src_device_tensor->DeviceType()) {
|
||||
return dst_device_tensor->SyncDeviceToDevice(src_device_tensor);
|
||||
} else {
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
#include "runtime/framework/actor/control_flow/exit_actor.h"
|
||||
#include "runtime/framework/actor/output_actor.h"
|
||||
#include "runtime/hardware/device_context_manager.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
|
@ -152,8 +153,19 @@ void ExitActor::CopyDeviceAddress(OpContext<DeviceTensor> *const context) {
|
|||
}
|
||||
MS_EXCEPTION_IF_NULL(device_contexts_[i]);
|
||||
// Create the new device tensor to take over the input_device_tensors which are the outputs of kernel graphs.
|
||||
auto new_device_tensor = device_contexts_[i]->CreateDeviceAddress(
|
||||
nullptr, input_device_tensor->GetSize(), input_device_tensor->format(), input_device_tensor->type_id());
|
||||
device::DeviceAddressPtr new_device_tensor;
|
||||
if (common::GetEnv("ENABLE_HOST_MEM_STACK") == "1") {
|
||||
// Create new device tensor in host.
|
||||
const auto &host_device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
|
||||
{"CPU", device_contexts_[i]->device_context_key().device_id_});
|
||||
MS_EXCEPTION_IF_NULL(host_device_context);
|
||||
host_device_context->Initialize();
|
||||
new_device_tensor = host_device_context->CreateDeviceAddress(
|
||||
nullptr, input_device_tensor->GetSize(), input_device_tensor->format(), input_device_tensor->type_id());
|
||||
} else {
|
||||
new_device_tensor = device_contexts_[i]->CreateDeviceAddress(
|
||||
nullptr, input_device_tensor->GetSize(), input_device_tensor->format(), input_device_tensor->type_id());
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(new_device_tensor);
|
||||
(void)created_device_tensors_.emplace_back(new_device_tensor);
|
||||
(void)new_device_tensors.emplace_back(new_device_tensor.get());
|
||||
|
@ -175,11 +187,25 @@ void ExitActor::CopyDeviceAddress(OpContext<DeviceTensor> *const context) {
|
|||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR(*context, "Sync device to device failed.");
|
||||
}
|
||||
} else {
|
||||
// Move the device ptr from input_device_tensor to new_device_tensor.
|
||||
new_device_tensor->set_ptr(input_device_tensor->GetMutablePtr());
|
||||
new_device_tensor->set_from_mem_pool(input_device_tensor->from_mem_pool());
|
||||
input_device_tensor->set_ptr(nullptr);
|
||||
input_device_tensor->set_from_mem_pool(false);
|
||||
if (new_device_tensor->DeviceType() == input_device_tensor->DeviceType()) {
|
||||
// Move the device ptr from input_device_tensor to new_device_tensor.
|
||||
new_device_tensor->set_ptr(input_device_tensor->GetMutablePtr());
|
||||
new_device_tensor->set_from_mem_pool(input_device_tensor->from_mem_pool());
|
||||
input_device_tensor->set_ptr(nullptr);
|
||||
input_device_tensor->set_from_mem_pool(false);
|
||||
} else {
|
||||
const auto &host_device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
|
||||
{"CPU", device_contexts_[i]->device_context_key().device_id_});
|
||||
MS_EXCEPTION_IF_NULL(host_device_context);
|
||||
if (!host_device_context->AllocateMemory(new_device_tensor.get(), new_device_tensor->GetSize())) {
|
||||
SET_OPCONTEXT_MEMORY_ALLOC_FAIL_BY_STRATEGY(GraphExecutionStrategy::kPipeline, *context, *host_device_context,
|
||||
GetAID().Name(), new_device_tensor->GetSize());
|
||||
}
|
||||
Copy(new_device_tensor.get(), input_device_tensor);
|
||||
device_contexts_[i]->FreeMemory(input_device_tensor);
|
||||
input_device_tensor->set_ptr(nullptr);
|
||||
input_device_tensor->set_from_mem_pool(false);
|
||||
}
|
||||
}
|
||||
MS_LOG(DEBUG) << GetAID().Name() << " creates the dynamic ref device address:" << new_device_tensor.get()
|
||||
<< ", ptr:" << new_device_tensor->GetPtr()
|
||||
|
|
Loading…
Reference in New Issue