diff --git a/mindspore/ccsrc/runtime/framework/actor/actor_common.cc b/mindspore/ccsrc/runtime/framework/actor/actor_common.cc index 52c60b527b5..dcb9a8eeaa9 100644 --- a/mindspore/ccsrc/runtime/framework/actor/actor_common.cc +++ b/mindspore/ccsrc/runtime/framework/actor/actor_common.cc @@ -143,15 +143,15 @@ bool Copy(const DeviceTensor *dst_device_tensor, const DeviceTensor *src_device_ << ", output size:" << dst_device_tensor->GetSize(); } - // Exist the size alignment in some device, so get the min device size. - size_t copy_size = std::min(src_device_tensor->GetSize(), dst_device_tensor->GetSize()); - if (src_device_tensor->DeviceType() == device::DeviceAddressType::kCPU) { // CPU device tensor copy to other device tensor. - return dst_device_tensor->SyncHostToDevice(copy_size, src_device_tensor->GetPtr()); + return dst_device_tensor->SyncHostToDevice(src_device_tensor->host_shape(), src_device_tensor->GetSize(), + src_device_tensor->type_id(), src_device_tensor->GetPtr(), + src_device_tensor->format()); } else if (dst_device_tensor->DeviceType() == device::DeviceAddressType::kCPU) { // Other device tensor copy to CPU device tensor. - return src_device_tensor->SyncDeviceToHost(copy_size, dst_device_tensor->GetMutablePtr()); + return src_device_tensor->SyncDeviceToHost(dst_device_tensor->host_shape(), dst_device_tensor->GetSize(), + dst_device_tensor->type_id(), dst_device_tensor->GetMutablePtr()); } else if (dst_device_tensor->DeviceType() == src_device_tensor->DeviceType()) { return dst_device_tensor->SyncDeviceToDevice(src_device_tensor); } else { diff --git a/mindspore/ccsrc/runtime/framework/actor/control_flow/exit_actor.cc b/mindspore/ccsrc/runtime/framework/actor/control_flow/exit_actor.cc index 323dd26c627..b5118969e76 100644 --- a/mindspore/ccsrc/runtime/framework/actor/control_flow/exit_actor.cc +++ b/mindspore/ccsrc/runtime/framework/actor/control_flow/exit_actor.cc @@ -16,6 +16,7 @@ #include "runtime/framework/actor/control_flow/exit_actor.h" #include "runtime/framework/actor/output_actor.h" +#include "runtime/hardware/device_context_manager.h" namespace mindspore { namespace runtime { @@ -152,8 +153,19 @@ void ExitActor::CopyDeviceAddress(OpContext *const context) { } MS_EXCEPTION_IF_NULL(device_contexts_[i]); // Create the new device tensor to take over the input_device_tensors which are the outputs of kernel graphs. - auto new_device_tensor = device_contexts_[i]->CreateDeviceAddress( - nullptr, input_device_tensor->GetSize(), input_device_tensor->format(), input_device_tensor->type_id()); + device::DeviceAddressPtr new_device_tensor; + if (common::GetEnv("ENABLE_HOST_MEM_STACK") == "1") { + // Create new device tensor in host. + const auto &host_device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext( + {"CPU", device_contexts_[i]->device_context_key().device_id_}); + MS_EXCEPTION_IF_NULL(host_device_context); + host_device_context->Initialize(); + new_device_tensor = host_device_context->CreateDeviceAddress( + nullptr, input_device_tensor->GetSize(), input_device_tensor->format(), input_device_tensor->type_id()); + } else { + new_device_tensor = device_contexts_[i]->CreateDeviceAddress( + nullptr, input_device_tensor->GetSize(), input_device_tensor->format(), input_device_tensor->type_id()); + } MS_EXCEPTION_IF_NULL(new_device_tensor); (void)created_device_tensors_.emplace_back(new_device_tensor); (void)new_device_tensors.emplace_back(new_device_tensor.get()); @@ -175,11 +187,25 @@ void ExitActor::CopyDeviceAddress(OpContext *const context) { SET_OPCONTEXT_FAIL_RET_WITH_ERROR(*context, "Sync device to device failed."); } } else { - // Move the device ptr from input_device_tensor to new_device_tensor. - new_device_tensor->set_ptr(input_device_tensor->GetMutablePtr()); - new_device_tensor->set_from_mem_pool(input_device_tensor->from_mem_pool()); - input_device_tensor->set_ptr(nullptr); - input_device_tensor->set_from_mem_pool(false); + if (new_device_tensor->DeviceType() == input_device_tensor->DeviceType()) { + // Move the device ptr from input_device_tensor to new_device_tensor. + new_device_tensor->set_ptr(input_device_tensor->GetMutablePtr()); + new_device_tensor->set_from_mem_pool(input_device_tensor->from_mem_pool()); + input_device_tensor->set_ptr(nullptr); + input_device_tensor->set_from_mem_pool(false); + } else { + const auto &host_device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext( + {"CPU", device_contexts_[i]->device_context_key().device_id_}); + MS_EXCEPTION_IF_NULL(host_device_context); + if (!host_device_context->AllocateMemory(new_device_tensor.get(), new_device_tensor->GetSize())) { + SET_OPCONTEXT_MEMORY_ALLOC_FAIL_BY_STRATEGY(GraphExecutionStrategy::kPipeline, *context, *host_device_context, + GetAID().Name(), new_device_tensor->GetSize()); + } + Copy(new_device_tensor.get(), input_device_tensor); + device_contexts_[i]->FreeMemory(input_device_tensor); + input_device_tensor->set_ptr(nullptr); + input_device_tensor->set_from_mem_pool(false); + } } MS_LOG(DEBUG) << GetAID().Name() << " creates the dynamic ref device address:" << new_device_tensor.get() << ", ptr:" << new_device_tensor->GetPtr()