From 5a447c9adcc76ea78c4d87efe22f65199e349658 Mon Sep 17 00:00:00 2001 From: caifubi Date: Thu, 21 May 2020 19:06:35 +0800 Subject: [PATCH] fix memcpy_async dst size less than src size --- .../device/ascend/tasksink/task_generator.cc | 3 ++- mindspore/ccsrc/device/kernel_runtime.cc | 21 ------------------- mindspore/ccsrc/kernel/rts/memcpy_async.cc | 17 ++++++++++++++- 3 files changed, 18 insertions(+), 23 deletions(-) diff --git a/mindspore/ccsrc/device/ascend/tasksink/task_generator.cc b/mindspore/ccsrc/device/ascend/tasksink/task_generator.cc index e7b3298b91e..18da9665750 100644 --- a/mindspore/ccsrc/device/ascend/tasksink/task_generator.cc +++ b/mindspore/ccsrc/device/ascend/tasksink/task_generator.cc @@ -137,8 +137,9 @@ bool TaskGenerator::LaunchAllKernel(const std::vector &anf_node_list, for (const auto &anf_node_ptr : anf_node_list) { size_t old_size = task_info_list->size(); uint32_t stream_id = AnfAlgo::GetStreamId(anf_node_ptr); + MS_EXCEPTION_IF_NULL(anf_node_ptr); MS_LOG(INFO) << "Task gen launch begin, current_op_idx:" << current_op_index - << " type:" << (AnfAlgo::GetCNodeName(anf_node_ptr)) << ", stream id:" << stream_id; + << " name:" << anf_node_ptr->fullname_with_scope() << ", stream id:" << stream_id; if (!LaunchKernel(anf_node_ptr, stream_id, task_info_list)) { MS_LOG(ERROR) << "LaunchKernel failed."; return false; diff --git a/mindspore/ccsrc/device/kernel_runtime.cc b/mindspore/ccsrc/device/kernel_runtime.cc index 0f5f282fd1b..283d3c2f42f 100644 --- a/mindspore/ccsrc/device/kernel_runtime.cc +++ b/mindspore/ccsrc/device/kernel_runtime.cc @@ -658,31 +658,10 @@ bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph) { AddressPtrList kernel_workspaces; AddressPtrList kernel_outputs; GenLaunchArgs(*kernel_mod, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs); -#if defined(_WIN32) || defined(_WIN64) - auto start_time = std::chrono::steady_clock::now(); -#else - struct timeval start_time, end_time; - (void)gettimeofday(&start_time, nullptr); -#endif auto ret = kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_); if (!ret) { MS_LOG(ERROR) << "Launch kernel failed."; return false; - } else { - if (AnfAlgo::GetKernelType(kernel) == TBE_KERNEL && !SyncStream()) { - MS_LOG(EXCEPTION) << "SyncStream failed."; - } -#if defined(_WIN32) || defined(_WIN64) - auto end_time = std::chrono::steady_clock::now(); - std::chrono::duration> cost = end_time - start_time; - MS_LOG(DEBUG) << "d " << kernel->fullname_with_scope() << " in " << cost.count() << " us"; -#else - (void)gettimeofday(&end_time, nullptr); - const uint64_t kUSecondInSecond = 1000000; - uint64_t cost = kUSecondInSecond * static_cast(end_time.tv_sec - start_time.tv_sec); - cost += static_cast(end_time.tv_usec - start_time.tv_usec); - MS_LOG(DEBUG) << "d " << kernel->fullname_with_scope() << " in " << cost << " us"; -#endif } } return true; diff --git a/mindspore/ccsrc/kernel/rts/memcpy_async.cc b/mindspore/ccsrc/kernel/rts/memcpy_async.cc index 3d5a7c88abc..1079a493401 100644 --- a/mindspore/ccsrc/kernel/rts/memcpy_async.cc +++ b/mindspore/ccsrc/kernel/rts/memcpy_async.cc @@ -48,6 +48,13 @@ bool MemCpyAsyncKernel::Launch(const std::vector &inputs, const std: MS_LOG(INFO) << "input addr is same with output addr , no need exe memcpy async"; return true; } + if (outputs[0]->size < inputs[0]->size) { + MS_LOG(EXCEPTION) << "rtMemcpyAsync destMax < src size"; + } + // input x -> memcpy_async -> AllReduce + if (outputs[0]->size > inputs[0]->size) { + MS_LOG(WARNING) << "rtMemcpyAsync destMax > src size"; + } rtError_t status = rtMemcpyAsync(outputs[0]->addr, outputs[0]->size, inputs[0]->addr, inputs[0]->size, RT_MEMCPY_DEVICE_TO_DEVICE, stream_ptr); if (status != RT_ERROR_NONE) { @@ -70,7 +77,7 @@ void MemCpyAsyncKernel::GetInputOutputDataType(const AnfNodePtr &anf_node) { if (input_size != 1) { MS_LOG(EXCEPTION) << "MemCpyAsync input size is not 1"; } - input_type_id_ = AnfAlgo::GetPrevNodeOutputInferDataType(anf_node, 0); + input_type_id_ = AnfAlgo::GetPrevNodeOutputDeviceDataType(anf_node, 0); } void MemCpyAsyncKernel::GetInputOutputTotalCount(const AnfNodePtr &anf_node) { @@ -102,6 +109,14 @@ std::vector MemCpyAsyncKernel::GenTask(const std::vectorsize < inputs[0]->size) { + MS_LOG(EXCEPTION) << "rtMemcpyAsync destMax < src size"; + } + // input x -> memcpy_async -> AllReduce + if (outputs[0]->size > inputs[0]->size) { + MS_LOG(WARNING) << "rtMemcpyAsync destMax > src size"; + } + stream_id_ = stream_id; std::shared_ptr task_info_ptr = std::make_shared( stream_id, outputs[0]->addr, outputs[0]->size, inputs[0]->addr, inputs[0]->size, RT_MEMCPY_DEVICE_TO_DEVICE);