forked from mindspore-Ecosystem/mindspore
!1335 check memcpy_async task src size and dst size
Merge pull request !1335 from caifubi/get-device-datatype-in-memcpy
This commit is contained in:
commit
a8b2c5a0bc
|
@ -137,8 +137,9 @@ bool TaskGenerator::LaunchAllKernel(const std::vector<CNodePtr> &anf_node_list,
|
|||
for (const auto &anf_node_ptr : anf_node_list) {
|
||||
size_t old_size = task_info_list->size();
|
||||
uint32_t stream_id = AnfAlgo::GetStreamId(anf_node_ptr);
|
||||
MS_EXCEPTION_IF_NULL(anf_node_ptr);
|
||||
MS_LOG(INFO) << "Task gen launch begin, current_op_idx:" << current_op_index
|
||||
<< " type:" << (AnfAlgo::GetCNodeName(anf_node_ptr)) << ", stream id:" << stream_id;
|
||||
<< " name:" << anf_node_ptr->fullname_with_scope() << ", stream id:" << stream_id;
|
||||
if (!LaunchKernel(anf_node_ptr, stream_id, task_info_list)) {
|
||||
MS_LOG(ERROR) << "LaunchKernel failed.";
|
||||
return false;
|
||||
|
|
|
@ -658,31 +658,10 @@ bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph) {
|
|||
AddressPtrList kernel_workspaces;
|
||||
AddressPtrList kernel_outputs;
|
||||
GenLaunchArgs(*kernel_mod, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs);
|
||||
#if defined(_WIN32) || defined(_WIN64)
|
||||
auto start_time = std::chrono::steady_clock::now();
|
||||
#else
|
||||
struct timeval start_time, end_time;
|
||||
(void)gettimeofday(&start_time, nullptr);
|
||||
#endif
|
||||
auto ret = kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_);
|
||||
if (!ret) {
|
||||
MS_LOG(ERROR) << "Launch kernel failed.";
|
||||
return false;
|
||||
} else {
|
||||
if (AnfAlgo::GetKernelType(kernel) == TBE_KERNEL && !SyncStream()) {
|
||||
MS_LOG(EXCEPTION) << "SyncStream failed.";
|
||||
}
|
||||
#if defined(_WIN32) || defined(_WIN64)
|
||||
auto end_time = std::chrono::steady_clock::now();
|
||||
std::chrono::duration<double, std::ratio<1, 1000000>> cost = end_time - start_time;
|
||||
MS_LOG(DEBUG) << "d " << kernel->fullname_with_scope() << " in " << cost.count() << " us";
|
||||
#else
|
||||
(void)gettimeofday(&end_time, nullptr);
|
||||
const uint64_t kUSecondInSecond = 1000000;
|
||||
uint64_t cost = kUSecondInSecond * static_cast<uint64_t>(end_time.tv_sec - start_time.tv_sec);
|
||||
cost += static_cast<uint64_t>(end_time.tv_usec - start_time.tv_usec);
|
||||
MS_LOG(DEBUG) << "d " << kernel->fullname_with_scope() << " in " << cost << " us";
|
||||
#endif
|
||||
}
|
||||
}
|
||||
return true;
|
||||
|
|
|
@ -48,6 +48,13 @@ bool MemCpyAsyncKernel::Launch(const std::vector<AddressPtr> &inputs, const std:
|
|||
MS_LOG(INFO) << "input addr is same with output addr , no need exe memcpy async";
|
||||
return true;
|
||||
}
|
||||
if (outputs[0]->size < inputs[0]->size) {
|
||||
MS_LOG(EXCEPTION) << "rtMemcpyAsync destMax < src size";
|
||||
}
|
||||
// input x -> memcpy_async -> AllReduce
|
||||
if (outputs[0]->size > inputs[0]->size) {
|
||||
MS_LOG(WARNING) << "rtMemcpyAsync destMax > src size";
|
||||
}
|
||||
rtError_t status = rtMemcpyAsync(outputs[0]->addr, outputs[0]->size, inputs[0]->addr, inputs[0]->size,
|
||||
RT_MEMCPY_DEVICE_TO_DEVICE, stream_ptr);
|
||||
if (status != RT_ERROR_NONE) {
|
||||
|
@ -70,7 +77,7 @@ void MemCpyAsyncKernel::GetInputOutputDataType(const AnfNodePtr &anf_node) {
|
|||
if (input_size != 1) {
|
||||
MS_LOG(EXCEPTION) << "MemCpyAsync input size is not 1";
|
||||
}
|
||||
input_type_id_ = AnfAlgo::GetPrevNodeOutputInferDataType(anf_node, 0);
|
||||
input_type_id_ = AnfAlgo::GetPrevNodeOutputDeviceDataType(anf_node, 0);
|
||||
}
|
||||
|
||||
void MemCpyAsyncKernel::GetInputOutputTotalCount(const AnfNodePtr &anf_node) {
|
||||
|
@ -102,6 +109,14 @@ std::vector<TaskInfoPtr> MemCpyAsyncKernel::GenTask(const std::vector<AddressPtr
|
|||
MS_LOG(EXCEPTION) << "MemCpyAsync op output is not one";
|
||||
}
|
||||
|
||||
if (outputs[0]->size < inputs[0]->size) {
|
||||
MS_LOG(EXCEPTION) << "rtMemcpyAsync destMax < src size";
|
||||
}
|
||||
// input x -> memcpy_async -> AllReduce
|
||||
if (outputs[0]->size > inputs[0]->size) {
|
||||
MS_LOG(WARNING) << "rtMemcpyAsync destMax > src size";
|
||||
}
|
||||
|
||||
stream_id_ = stream_id;
|
||||
std::shared_ptr<MemcpyAsyncTaskInfo> task_info_ptr = std::make_shared<MemcpyAsyncTaskInfo>(
|
||||
stream_id, outputs[0]->addr, outputs[0]->size, inputs[0]->addr, inputs[0]->size, RT_MEMCPY_DEVICE_TO_DEVICE);
|
||||
|
|
Loading…
Reference in New Issue