From a3d04728eec30ce2d8eb3b556bff126265f974f7 Mon Sep 17 00:00:00 2001 From: lizhenyu Date: Tue, 22 Jun 2021 10:50:49 +0800 Subject: [PATCH] Remove useless SyncHostToDevice for PyNative mode --- .../framework/actor/data_source_actor.cc | 7 +++- .../runtime/framework/graph_scheduler.cc | 37 +++++++++++++++---- .../hardware/device_context_manager.cc | 2 +- .../runtime/hardware/device_context_manager.h | 2 +- .../hardware/gpu/gpu_device_context.cc | 2 +- mindspore/ccsrc/vm/backend.cc | 2 +- 6 files changed, 39 insertions(+), 13 deletions(-) diff --git a/mindspore/ccsrc/runtime/framework/actor/data_source_actor.cc b/mindspore/ccsrc/runtime/framework/actor/data_source_actor.cc index ed4c952c828..c044adae4b9 100644 --- a/mindspore/ccsrc/runtime/framework/actor/data_source_actor.cc +++ b/mindspore/ccsrc/runtime/framework/actor/data_source_actor.cc @@ -247,8 +247,11 @@ void HostQueueDataSourceActor::OnMemoryAllocFinish(OpContext *cont auto &device_tensor = device_tensors[i]; MS_EXCEPTION_IF_NULL(host_tensor); MS_EXCEPTION_IF_NULL(device_tensor); - - if (std::dynamic_pointer_cast(host_tensor->device_address()) != nullptr) { + auto tensor_device_address = std::dynamic_pointer_cast(host_tensor->device_address()); + if (tensor_device_address != nullptr) { + if (tensor_device_address.get() != device_tensor) { + MS_LOG(EXCEPTION) << "The device tensor of host queue node should be equal to device address of input tensor"; + } continue; } diff --git a/mindspore/ccsrc/runtime/framework/graph_scheduler.cc b/mindspore/ccsrc/runtime/framework/graph_scheduler.cc index 4cce0967dd5..2a8d4b2adc6 100644 --- a/mindspore/ccsrc/runtime/framework/graph_scheduler.cc +++ b/mindspore/ccsrc/runtime/framework/graph_scheduler.cc @@ -285,16 +285,39 @@ TensorPtr FetchInputTensor(const GraphCompilerInfo &graph_compiler_info, size_t void PrepareDataForHostDataSourceActor(const std::unordered_map &data_node_position_map, const AnfNodePtr &node, const TensorPtr &tensor, - std::vector *host_tensors) { + std::vector *host_tensors, + const DeviceContext *device_context = nullptr, + GraphExecutionStrategy strategy = GraphExecutionStrategy::kPipeline) { MS_EXCEPTION_IF_NULL(tensor); // Fill the host tensors for non weighted parameters. const auto &iter = data_node_position_map.find(node); - if (iter != data_node_position_map.end()) { - (*host_tensors)[iter->second] = tensor; - auto device_address = std::dynamic_pointer_cast(tensor->device_address()); - if (device_address != nullptr) { - AnfAlgo::SetOutputAddr(device_address, 0, node.get()); + if (iter == data_node_position_map.end()) { + return; + } + + (*host_tensors)[iter->second] = tensor; + auto device_address = std::dynamic_pointer_cast(tensor->device_address()); + if (device_address != nullptr) { + AnfAlgo::SetOutputAddr(device_address, 0, node.get()); + return; + } + + if (strategy == GraphExecutionStrategy::kStep) { + auto node_device_address = AnfAlgo::GetMutableOutputAddr(node, 0, false); + MS_EXCEPTION_IF_NULL(node_device_address); + tensor->set_device_address(node_device_address); + UpdateRefCount(node_device_address.get(), true); + + MS_EXCEPTION_IF_NULL(device_context); + if (!device_context->AllocateMemory(node_device_address.get(), node_device_address->GetSize())) { + MS_LOG(EXCEPTION) << "Device memory isn't enough and alloc failed, node name: " << node->fullname_with_scope(); + } + + if (!node_device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(node, 0), + LongToSize(tensor->data().nbytes()), tensor->data_type(), + tensor->data_c(), tensor->device_info().host_format_)) { + MS_LOG(EXCEPTION) << "SyncHostToDevice failed."; } } } @@ -492,7 +515,7 @@ void GraphScheduler::PrepareRun(const ActorSet *actor_set, const GraphCompilerIn strategy)) { MS_EXCEPTION_IF_NULL(host_data_source_actor); PrepareDataForHostDataSourceActor(host_data_source_actor->data_node_position_map_, input_node, input_tensor, - &host_tensors); + &host_tensors, device_context, strategy); } } } diff --git a/mindspore/ccsrc/runtime/hardware/device_context_manager.cc b/mindspore/ccsrc/runtime/hardware/device_context_manager.cc index 6ec49c26a99..a89425b31ab 100644 --- a/mindspore/ccsrc/runtime/hardware/device_context_manager.cc +++ b/mindspore/ccsrc/runtime/hardware/device_context_manager.cc @@ -54,7 +54,7 @@ DeviceContext *DeviceContextManager::GetOrCreateDeviceContext(const DeviceContex return device_context.get(); } -void DeviceContextManager::UpdataDeviceContextKey(const DeviceContextKey &old_key, const DeviceContextKey &new_key) { +void DeviceContextManager::UpdateDeviceContextKey(const DeviceContextKey &old_key, const DeviceContextKey &new_key) { std::string old_key_str = old_key.ToString(); std::string new_key_str = new_key.ToString(); diff --git a/mindspore/ccsrc/runtime/hardware/device_context_manager.h b/mindspore/ccsrc/runtime/hardware/device_context_manager.h index 825446eed36..c3f5af5bcd3 100644 --- a/mindspore/ccsrc/runtime/hardware/device_context_manager.h +++ b/mindspore/ccsrc/runtime/hardware/device_context_manager.h @@ -37,7 +37,7 @@ class DeviceContextManager { } void Register(const std::string &device_name, DeviceContextCreator &&device_context_creator); DeviceContext *GetOrCreateDeviceContext(const DeviceContextKey &device_context_key); - void UpdataDeviceContextKey(const DeviceContextKey &old_key, const DeviceContextKey &new_key); + void UpdateDeviceContextKey(const DeviceContextKey &old_key, const DeviceContextKey &new_key); void ClearDeviceContexts(); private: diff --git a/mindspore/ccsrc/runtime/hardware/gpu/gpu_device_context.cc b/mindspore/ccsrc/runtime/hardware/gpu/gpu_device_context.cc index 45be9ddfecd..05811f2c8a0 100644 --- a/mindspore/ccsrc/runtime/hardware/gpu/gpu_device_context.cc +++ b/mindspore/ccsrc/runtime/hardware/gpu/gpu_device_context.cc @@ -65,7 +65,7 @@ bool GPUDeviceContext::Initialize() { MS_EXCEPTION_IF_NULL(get_local_rank_funcptr); device_context_key_.device_id_ = IntToUint((*get_local_rank_funcptr)()); - DeviceContextManager::GetInstance().UpdataDeviceContextKey(old_key, device_context_key_); + DeviceContextManager::GetInstance().UpdateDeviceContextKey(old_key, device_context_key_); auto ms_context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(ms_context); diff --git a/mindspore/ccsrc/vm/backend.cc b/mindspore/ccsrc/vm/backend.cc index 4748ad86912..2648ef25ec3 100644 --- a/mindspore/ccsrc/vm/backend.cc +++ b/mindspore/ccsrc/vm/backend.cc @@ -253,7 +253,7 @@ void MsBackend::SetDebugger() { target_sess_->SetDebugger(); } #endif MindRTBackend::MindRTBackend(const std::string &backend_name, const std::string &device_name, uint32_t device_id) - : Backend(backend_name), device_name_(device_name), device_id_(device_id) { + : Backend(backend_name), device_name_(device_name) { root_graph_ = nullptr; auto ms_context = MsContext::GetInstance(); const bool pynative_mode = (ms_context->get_param(MS_CTX_EXECUTION_MODE) == kPynativeMode);