parent
fe8588f32f
commit
7ba0e2ed8f
|
@ -143,15 +143,15 @@ bool Copy(const DeviceTensor *dst_device_tensor, const DeviceTensor *src_device_
|
||||||
<< ", output size:" << dst_device_tensor->GetSize();
|
<< ", output size:" << dst_device_tensor->GetSize();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Exist the size alignment in some device, so get the min device size.
|
|
||||||
size_t copy_size = std::min(src_device_tensor->GetSize(), dst_device_tensor->GetSize());
|
|
||||||
|
|
||||||
if (src_device_tensor->DeviceType() == device::DeviceAddressType::kCPU) {
|
if (src_device_tensor->DeviceType() == device::DeviceAddressType::kCPU) {
|
||||||
// CPU device tensor copy to other device tensor.
|
// CPU device tensor copy to other device tensor.
|
||||||
return dst_device_tensor->SyncHostToDevice(copy_size, src_device_tensor->GetPtr());
|
return dst_device_tensor->SyncHostToDevice(src_device_tensor->host_shape(), src_device_tensor->GetSize(),
|
||||||
|
src_device_tensor->type_id(), src_device_tensor->GetPtr(),
|
||||||
|
src_device_tensor->format());
|
||||||
} else if (dst_device_tensor->DeviceType() == device::DeviceAddressType::kCPU) {
|
} else if (dst_device_tensor->DeviceType() == device::DeviceAddressType::kCPU) {
|
||||||
// Other device tensor copy to CPU device tensor.
|
// Other device tensor copy to CPU device tensor.
|
||||||
return src_device_tensor->SyncDeviceToHost(copy_size, dst_device_tensor->GetMutablePtr());
|
return src_device_tensor->SyncDeviceToHost(dst_device_tensor->host_shape(), dst_device_tensor->GetSize(),
|
||||||
|
dst_device_tensor->type_id(), dst_device_tensor->GetMutablePtr());
|
||||||
} else if (dst_device_tensor->DeviceType() == src_device_tensor->DeviceType()) {
|
} else if (dst_device_tensor->DeviceType() == src_device_tensor->DeviceType()) {
|
||||||
return dst_device_tensor->SyncDeviceToDevice(src_device_tensor);
|
return dst_device_tensor->SyncDeviceToDevice(src_device_tensor);
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
#include "runtime/framework/actor/control_flow/exit_actor.h"
|
#include "runtime/framework/actor/control_flow/exit_actor.h"
|
||||||
#include "runtime/framework/actor/output_actor.h"
|
#include "runtime/framework/actor/output_actor.h"
|
||||||
|
#include "runtime/hardware/device_context_manager.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace runtime {
|
namespace runtime {
|
||||||
|
@ -152,8 +153,19 @@ void ExitActor::CopyDeviceAddress(OpContext<DeviceTensor> *const context) {
|
||||||
}
|
}
|
||||||
MS_EXCEPTION_IF_NULL(device_contexts_[i]);
|
MS_EXCEPTION_IF_NULL(device_contexts_[i]);
|
||||||
// Create the new device tensor to take over the input_device_tensors which are the outputs of kernel graphs.
|
// Create the new device tensor to take over the input_device_tensors which are the outputs of kernel graphs.
|
||||||
auto new_device_tensor = device_contexts_[i]->CreateDeviceAddress(
|
device::DeviceAddressPtr new_device_tensor;
|
||||||
nullptr, input_device_tensor->GetSize(), input_device_tensor->format(), input_device_tensor->type_id());
|
if (common::GetEnv("ENABLE_HOST_MEM_STACK") == "1") {
|
||||||
|
// Create new device tensor in host.
|
||||||
|
const auto &host_device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
|
||||||
|
{"CPU", device_contexts_[i]->device_context_key().device_id_});
|
||||||
|
MS_EXCEPTION_IF_NULL(host_device_context);
|
||||||
|
host_device_context->Initialize();
|
||||||
|
new_device_tensor = host_device_context->CreateDeviceAddress(
|
||||||
|
nullptr, input_device_tensor->GetSize(), input_device_tensor->format(), input_device_tensor->type_id());
|
||||||
|
} else {
|
||||||
|
new_device_tensor = device_contexts_[i]->CreateDeviceAddress(
|
||||||
|
nullptr, input_device_tensor->GetSize(), input_device_tensor->format(), input_device_tensor->type_id());
|
||||||
|
}
|
||||||
MS_EXCEPTION_IF_NULL(new_device_tensor);
|
MS_EXCEPTION_IF_NULL(new_device_tensor);
|
||||||
(void)created_device_tensors_.emplace_back(new_device_tensor);
|
(void)created_device_tensors_.emplace_back(new_device_tensor);
|
||||||
(void)new_device_tensors.emplace_back(new_device_tensor.get());
|
(void)new_device_tensors.emplace_back(new_device_tensor.get());
|
||||||
|
@ -175,11 +187,25 @@ void ExitActor::CopyDeviceAddress(OpContext<DeviceTensor> *const context) {
|
||||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR(*context, "Sync device to device failed.");
|
SET_OPCONTEXT_FAIL_RET_WITH_ERROR(*context, "Sync device to device failed.");
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Move the device ptr from input_device_tensor to new_device_tensor.
|
if (new_device_tensor->DeviceType() == input_device_tensor->DeviceType()) {
|
||||||
new_device_tensor->set_ptr(input_device_tensor->GetMutablePtr());
|
// Move the device ptr from input_device_tensor to new_device_tensor.
|
||||||
new_device_tensor->set_from_mem_pool(input_device_tensor->from_mem_pool());
|
new_device_tensor->set_ptr(input_device_tensor->GetMutablePtr());
|
||||||
input_device_tensor->set_ptr(nullptr);
|
new_device_tensor->set_from_mem_pool(input_device_tensor->from_mem_pool());
|
||||||
input_device_tensor->set_from_mem_pool(false);
|
input_device_tensor->set_ptr(nullptr);
|
||||||
|
input_device_tensor->set_from_mem_pool(false);
|
||||||
|
} else {
|
||||||
|
const auto &host_device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
|
||||||
|
{"CPU", device_contexts_[i]->device_context_key().device_id_});
|
||||||
|
MS_EXCEPTION_IF_NULL(host_device_context);
|
||||||
|
if (!host_device_context->AllocateMemory(new_device_tensor.get(), new_device_tensor->GetSize())) {
|
||||||
|
SET_OPCONTEXT_MEMORY_ALLOC_FAIL_BY_STRATEGY(GraphExecutionStrategy::kPipeline, *context, *host_device_context,
|
||||||
|
GetAID().Name(), new_device_tensor->GetSize());
|
||||||
|
}
|
||||||
|
Copy(new_device_tensor.get(), input_device_tensor);
|
||||||
|
device_contexts_[i]->FreeMemory(input_device_tensor);
|
||||||
|
input_device_tensor->set_ptr(nullptr);
|
||||||
|
input_device_tensor->set_from_mem_pool(false);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
MS_LOG(DEBUG) << GetAID().Name() << " creates the dynamic ref device address:" << new_device_tensor.get()
|
MS_LOG(DEBUG) << GetAID().Name() << " creates the dynamic ref device address:" << new_device_tensor.get()
|
||||||
<< ", ptr:" << new_device_tensor->GetPtr()
|
<< ", ptr:" << new_device_tensor->GetPtr()
|
||||||
|
|
Loading…
Reference in New Issue