diff --git a/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_mod.cc b/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_mod.cc index 4bc9c4cc08b..37b6095e84b 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_mod.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_mod.cc @@ -204,12 +204,10 @@ std::vector AicpuOpKernelMod::GenTask(const std::vector } device::DynamicKernelPtr AicpuOpKernelMod::GenDynamicKernel(const CNodePtr &cnode_ptr, void *stream_ptr) { - AddressPtrList kernel_inputs; - AddressPtrList kernel_workspaces; - AddressPtrList kernel_outputs; - device::KernelRuntime::GenLaunchArgs(*this, cnode_ptr, &kernel_inputs, &kernel_workspaces, &kernel_outputs); + KernelLaunchInfo kernel_launch_info; + device::KernelRuntime::GenLaunchArgs(*this, cnode_ptr, &kernel_launch_info); - CreateCpuKernelInfo(kernel_inputs, kernel_outputs); + CreateCpuKernelInfo(kernel_launch_info.inputs_, kernel_launch_info.outputs_); return std::make_shared(stream_ptr, cnode_ptr, args_, ext_info_, node_so_, node_name_); } } // namespace kernel diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.cc index e71eab88cdf..8f65b9c449d 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.cc @@ -273,23 +273,21 @@ std::vector HcclKernel::GenTask(const std::vector &inpu } device::DynamicKernelPtr HcclKernel::GenDynamicKernel(const CNodePtr &cnode_ptr, void *stream_ptr) { - AddressPtrList inputs; - AddressPtrList workspaces; - AddressPtrList outputs; - device::KernelRuntime::GenLaunchArgs(*this, cnode_ptr, &inputs, &workspaces, &outputs); + KernelLaunchInfo kernel_launch_info; + device::KernelRuntime::GenLaunchArgs(*this, cnode_ptr, &kernel_launch_info); std::string hccl_type = MsOpNameToHcomOpType(AnfAlgo::GetCNodeName(anf_node_.lock())); - if (inputs.empty()) { + if (kernel_launch_info.inputs_.empty()) { MS_LOG(EXCEPTION) << "Hccl kernel input is empty"; } if (hccl_data_type_list_.empty()) { MS_LOG(EXCEPTION) << "Hccl data type list is empty"; } - MS_EXCEPTION_IF_NULL(inputs.at(0)); - auto input_data_addr = inputs.at(0)->addr; - MS_EXCEPTION_IF_NULL(outputs.at(0)); - auto output_data_addr = outputs.at(0)->addr; + MS_EXCEPTION_IF_NULL(kernel_launch_info.inputs_.at(0)); + auto input_data_addr = kernel_launch_info.inputs_.at(0)->addr; + MS_EXCEPTION_IF_NULL(kernel_launch_info.outputs_.at(0)); + auto output_data_addr = kernel_launch_info.outputs_.at(0)->addr; HcclDataType data_type = hccl_data_type_list_[0]; auto executor = std::make_shared( diff --git a/mindspore/ccsrc/backend/kernel_compiler/kernel.h b/mindspore/ccsrc/backend/kernel_compiler/kernel.h index 35fab774d67..e727a4884e4 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/kernel.h @@ -166,12 +166,13 @@ struct Address { size_t size; }; using AddressPtr = std::shared_ptr
; +using AddressPtrList = std::vector; // The memory info of kernel launch. struct KernelLaunchInfo { - std::vector inputs_; - std::vector outputs_; - std::vector workspaces_; + AddressPtrList inputs_; + AddressPtrList outputs_; + AddressPtrList workspaces_; }; class KernelMod { @@ -179,6 +180,10 @@ class KernelMod { virtual const std::vector &GetInputSizeList() const = 0; virtual const std::vector &GetOutputSizeList() const = 0; virtual const std::vector &GetWorkspaceSizeList() const = 0; + bool Launch(const KernelLaunchInfo &kernel_launch_address, void *stream_ptr) { + return Launch(kernel_launch_address.inputs_, kernel_launch_address.workspaces_, kernel_launch_address.outputs_, + stream_ptr); + } virtual bool Launch(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs, void *stream_ptr) = 0; virtual device::DynamicKernelPtr GenDynamicKernel(const CNodePtr &cnode_ptr, void *stream_ptr) { return nullptr; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/memcpy_async.cc b/mindspore/ccsrc/backend/kernel_compiler/rts/memcpy_async.cc index debad5d08b8..8dccb8aa3e8 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/rts/memcpy_async.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/memcpy_async.cc @@ -133,11 +133,11 @@ std::vector MemCpyAsyncKernel::GenTask(const std::vector TbeKernelMod::GenTask(const std::vector &in } device::DynamicKernelPtr TbeKernelMod::GenDynamicKernel(const CNodePtr &cnode_ptr, void *stream_ptr) { - AddressPtrList kernel_inputs; - AddressPtrList kernel_workspaces; - AddressPtrList kernel_outputs; - device::KernelRuntime::GenLaunchArgs(*this, cnode_ptr, &kernel_inputs, &kernel_workspaces, &kernel_outputs); + KernelLaunchInfo kernel_launch_info; + device::KernelRuntime::GenLaunchArgs(*this, cnode_ptr, &kernel_launch_info); auto dynamic_flag = AnfAlgo::IsDynamicShape(cnode_ptr); // Get para_size from json @@ -124,10 +122,13 @@ device::DynamicKernelPtr TbeKernelMod::GenDynamicKernel(const CNodePtr &cnode_pt // Generate args std::vector runtime_args; + const auto &kernel_inputs = kernel_launch_info.inputs_; (void)std::transform(std::begin(kernel_inputs), std::end(kernel_inputs), std::back_inserter(runtime_args), [](const AddressPtr &input) -> void * { return input->addr; }); + const auto &kernel_outputs = kernel_launch_info.outputs_; (void)std::transform(std::begin(kernel_outputs), std::end(kernel_outputs), std::back_inserter(runtime_args), [](const AddressPtr &output) -> void * { return output->addr; }); + const auto &kernel_workspaces = kernel_launch_info.workspaces_; if (!kernel_workspaces.empty()) { (void)std::transform(std::begin(kernel_workspaces), std::end(kernel_workspaces), std::back_inserter(runtime_args), [](const AddressPtr &addr) -> void * { return addr->addr; }); diff --git a/mindspore/ccsrc/runtime/device/ascend/executor/ai_core_dynamic_kernel.cc b/mindspore/ccsrc/runtime/device/ascend/executor/ai_core_dynamic_kernel.cc index d786836d60d..51d7513b3b8 100644 --- a/mindspore/ccsrc/runtime/device/ascend/executor/ai_core_dynamic_kernel.cc +++ b/mindspore/ccsrc/runtime/device/ascend/executor/ai_core_dynamic_kernel.cc @@ -108,15 +108,15 @@ void AiCoreDynamicKernel::UpdateArgs() { auto kernel_mod = AnfAlgo::GetKernelMod(cnode); MS_EXCEPTION_IF_NULL(kernel_mod); - AddressPtrList kernel_inputs; - AddressPtrList kernel_workspaces; - AddressPtrList kernel_outputs; - KernelRuntime::GenLaunchArgs(*kernel_mod, cnode, &kernel_inputs, &kernel_workspaces, &kernel_outputs); + KernelLaunchInfo kernel_launch_info; + KernelRuntime::GenLaunchArgs(*kernel_mod, cnode, &kernel_launch_info); runtime_args_.clear(); + const auto &kernel_inputs = kernel_launch_info.inputs_; (void)std::transform(std::begin(kernel_inputs), std::end(kernel_inputs), std::back_inserter(runtime_args_), [](const AddressPtr &input) { return input->addr; }); + const auto &kernel_outputs = kernel_launch_info.outputs_; (void)std::transform(std::begin(kernel_outputs), std::end(kernel_outputs), std::back_inserter(runtime_args_), [](const AddressPtr &output) { return output->addr; }); // Update workspace diff --git a/mindspore/ccsrc/runtime/device/ascend/executor/hccl_dynamic_kernel.cc b/mindspore/ccsrc/runtime/device/ascend/executor/hccl_dynamic_kernel.cc index c9aadf14c1e..60a75259694 100644 --- a/mindspore/ccsrc/runtime/device/ascend/executor/hccl_dynamic_kernel.cc +++ b/mindspore/ccsrc/runtime/device/ascend/executor/hccl_dynamic_kernel.cc @@ -44,15 +44,13 @@ void HcclDynamicKernel::UpdateArgs() { auto kernel_mod = AnfAlgo::GetKernelMod(cnode); MS_EXCEPTION_IF_NULL(kernel_mod); // Update input, output, count - AddressPtrList kernel_inputs; - AddressPtrList kernel_workspaces; - AddressPtrList kernel_outputs; - KernelRuntime::GenLaunchArgs(*kernel_mod, cnode, &kernel_inputs, &kernel_workspaces, &kernel_outputs); - if (kernel_inputs.empty() || kernel_outputs.empty()) { + KernelLaunchInfo kernel_launch_info; + KernelRuntime::GenLaunchArgs(*kernel_mod, cnode, &kernel_launch_info); + if (kernel_launch_info.inputs_.empty() || kernel_launch_info.outputs_.empty()) { MS_LOG(EXCEPTION) << "Inputs or outputs is empty. Node info: " << cnode->DebugString(); } - auto input0 = kernel_inputs.at(0); - auto output0 = kernel_outputs.at(0); + auto input0 = kernel_launch_info.inputs_.at(0); + auto output0 = kernel_launch_info.outputs_.at(0); MS_EXCEPTION_IF_NULL(input0); MS_EXCEPTION_IF_NULL(output0); @@ -82,11 +80,9 @@ void HcclDynamicKernel::StaticShapeExecute() { MS_EXCEPTION_IF_NULL(cnode); auto kernel_mod = AnfAlgo::GetKernelMod(cnode); MS_EXCEPTION_IF_NULL(kernel_mod); - AddressPtrList kernel_inputs; - AddressPtrList kernel_workspaces; - AddressPtrList kernel_outputs; - KernelRuntime::GenLaunchArgs(*kernel_mod, cnode, &kernel_inputs, &kernel_workspaces, &kernel_outputs); - kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_); + KernelLaunchInfo kernel_launch_info; + KernelRuntime::GenLaunchArgs(*kernel_mod, cnode, &kernel_launch_info); + kernel_mod->Launch(kernel_launch_info, stream_); } void HcclDynamicKernel::Execute() { diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc index f1fd46ec2d3..55602f28883 100644 --- a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc @@ -891,9 +891,10 @@ bool GPUKernelRuntime::RunOpLaunchKernelDynamic(const session::KernelGraph *grap AddressPtrList kernel_inputs; AddressPtrList kernel_workspaces; AddressPtrList kernel_outputs; - GenLaunchArgs(*kernel_mod, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs); + KernelLaunchInfo kernel_launch_info; + GenLaunchArgs(*kernel_mod, kernel, &kernel_launch_info); MS_EXCEPTION_IF_NULL(stream_); - auto ret = kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_); + auto ret = kernel_mod->Launch(kernel_launch_info, stream_); if (!ret) { MS_LOG(ERROR) << "Launch kernel failed."; return false; diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime.cc b/mindspore/ccsrc/runtime/device/kernel_runtime.cc index 9dd1eea1fd9..fa852d70a5e 100644 --- a/mindspore/ccsrc/runtime/device/kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/kernel_runtime.cc @@ -1104,16 +1104,13 @@ void KernelRuntime::AssignWorkSpaceMem(MemType type, const AnfNodePtr &node) { } void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod, const mindspore::AnfNodePtr &kernel, - AddressPtrList *kernel_inputs, AddressPtrList *const kernel_workspaces, - AddressPtrList *kernel_outputs) { + KernelLaunchInfo *kernel_launch_info) { MS_EXCEPTION_IF_NULL(kernel); - MS_EXCEPTION_IF_NULL(kernel_inputs); - MS_EXCEPTION_IF_NULL(kernel_workspaces); - MS_EXCEPTION_IF_NULL(kernel_outputs); + MS_EXCEPTION_IF_NULL(kernel_launch_info); auto cnode = kernel->cast(); MS_EXCEPTION_IF_NULL(cnode); if (AnfAlgo::GetCNodeName(cnode) == kAtomicAddrCleanOpName) { - return GenAddrCleanLaunchArgs(cnode, kernel_inputs); + return GenAddrCleanLaunchArgs(cnode, &(kernel_launch_info->inputs_)); } auto ms_context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(ms_context); @@ -1140,7 +1137,7 @@ void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod input->addr = device_address->ptr_; MS_EXCEPTION_IF_NULL(input->addr); input->size = device_address->size_; - kernel_inputs->emplace_back(input); + kernel_launch_info->inputs_.emplace_back(input); } for (size_t i = 0; i < kernel_mod.GetOutputSizeList().size(); ++i) { @@ -1150,7 +1147,7 @@ void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod output->addr = device_address->ptr_; MS_EXCEPTION_IF_NULL(output->addr); output->size = device_address->size_; - kernel_outputs->emplace_back(output); + kernel_launch_info->outputs_.emplace_back(output); } for (size_t i = 0; i < kernel_mod.GetWorkspaceSizeList().size(); ++i) { @@ -1160,7 +1157,7 @@ void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod workspace->addr = device_address->ptr_; MS_EXCEPTION_IF_NULL(workspace->addr); workspace->size = device_address->size_; - kernel_workspaces->emplace_back(workspace); + kernel_launch_info->workspaces_.emplace_back(workspace); } } @@ -1230,9 +1227,7 @@ void KernelRuntime::LaunchKernelEvent(const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs, void *stream) { + const KernelLaunchInfo &kernel_launch_info, void *stream) { MS_EXCEPTION_IF_NULL(kernel_mod); MS_EXCEPTION_IF_NULL(stream); float cost_time = 0; @@ -1243,7 +1238,7 @@ bool KernelRuntime::LaunchKernelWithPynativeProfiling(kernel::KernelMod *kernel_ start->set_record_stream(stream); end->set_record_stream(stream); start->RecordEvent(); - bool ret = kernel_mod->Launch(inputs, workspace, outputs, stream); + bool ret = kernel_mod->Launch(kernel_launch_info, stream); end->RecordEvent(); start->SyncEvent(); end->SyncEvent(); @@ -1283,16 +1278,13 @@ void KernelRuntime::GetOrMallocAddress(const std::shared_ptr &mem_ } void KernelRuntime::AssignKernelAddress(const std::shared_ptr &mem_scheduler, const AnfNodePtr &kernel, - AddressPtrList *kernel_inputs, AddressPtrList *kernel_workspaces, - AddressPtrList *kernel_outputs) { + KernelLaunchInfo *kernel_launch_info) { MS_EXCEPTION_IF_NULL(kernel); - MS_EXCEPTION_IF_NULL(kernel_inputs); - MS_EXCEPTION_IF_NULL(kernel_workspaces); - MS_EXCEPTION_IF_NULL(kernel_outputs); + MS_EXCEPTION_IF_NULL(kernel_launch_info); auto cnode = kernel->cast(); MS_EXCEPTION_IF_NULL(cnode); if (AnfAlgo::GetCNodeName(cnode) == kAtomicAddrCleanOpName) { - return GenAddrCleanLaunchArgs(cnode, kernel_inputs, mem_scheduler); + return GenAddrCleanLaunchArgs(cnode, &(kernel_launch_info->inputs_), mem_scheduler); } auto kernel_mod = AnfAlgo::GetKernelMod(kernel); MS_EXCEPTION_IF_NULL(kernel_mod); @@ -1307,7 +1299,7 @@ void KernelRuntime::AssignKernelAddress(const std::shared_ptr &mem kernel::AddressPtr input = std::make_shared(); GetOrMallocAddress(mem_scheduler, device_address, input); input->size = device_address->size_; - kernel_inputs->emplace_back(input); + kernel_launch_info->inputs_.emplace_back(input); } for (size_t j = 0; j < kernel_mod->GetOutputSizeList().size(); ++j) { @@ -1315,7 +1307,7 @@ void KernelRuntime::AssignKernelAddress(const std::shared_ptr &mem kernel::AddressPtr output = std::make_shared(); GetOrMallocAddress(mem_scheduler, device_address, output); output->size = device_address->size_; - kernel_outputs->emplace_back(output); + kernel_launch_info->outputs_.emplace_back(output); } for (size_t i = 0; i < kernel_mod->GetWorkspaceSizeList().size(); ++i) { @@ -1323,7 +1315,7 @@ void KernelRuntime::AssignKernelAddress(const std::shared_ptr &mem kernel::AddressPtr workspace = std::make_shared(); GetOrMallocAddress(mem_scheduler, device_address, workspace); workspace->size = device_address->size_; - kernel_workspaces->emplace_back(workspace); + kernel_launch_info->workspaces_.emplace_back(workspace); } } @@ -1400,9 +1392,7 @@ bool KernelRuntime::LaunchKernel(const session::KernelGraph &graph, const AnfNod MS_EXCEPTION_IF_NULL(kernel); auto kernel_mod = AnfAlgo::GetKernelMod(kernel); MS_EXCEPTION_IF_NULL(kernel_mod); - AddressPtrList kernel_inputs; - AddressPtrList kernel_workspaces; - AddressPtrList kernel_outputs; + KernelLaunchInfo kernel_launch_info; auto stream = kernel_mod->GetStream(); if (stream == nullptr) { if (AnfAlgo::IsCommunicationOp(kernel)) { @@ -1417,20 +1407,19 @@ bool KernelRuntime::LaunchKernel(const session::KernelGraph &graph, const AnfNod if (!ret) { return ret; } - AssignKernelAddress(mem_scheduler, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs); + AssignKernelAddress(mem_scheduler, kernel, &kernel_launch_info); } else if (!kernel_mod->GetInputsAddr().empty() || !kernel_mod->GetOutputsAddr().empty()) { - kernel_inputs = kernel_mod->GetInputsAddr(); - kernel_outputs = kernel_mod->GetOutputsAddr(); - kernel_workspaces = kernel_mod->GetWorkSpacesAddr(); + kernel_launch_info.inputs_ = kernel_mod->GetInputsAddr(); + kernel_launch_info.outputs_ = kernel_mod->GetOutputsAddr(); + kernel_launch_info.workspaces_ = kernel_mod->GetWorkSpacesAddr(); } else { - GenLaunchArgs(*kernel_mod, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs); + GenLaunchArgs(*kernel_mod, kernel, &kernel_launch_info); } if (!mock) { if (pynative_mode_profiling_flag_) { - ret = LaunchKernelWithPynativeProfiling(kernel_mod, kernel->fullname_with_scope(), kernel_inputs, - kernel_workspaces, kernel_outputs, stream); + ret = LaunchKernelWithPynativeProfiling(kernel_mod, kernel->fullname_with_scope(), kernel_launch_info, stream); } else { - ret = kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream); + ret = kernel_mod->Launch(kernel_launch_info, stream); } } if (mem_scheduler != nullptr) { diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime.h b/mindspore/ccsrc/runtime/device/kernel_runtime.h index bdef5a0252d..a61e09815aa 100644 --- a/mindspore/ccsrc/runtime/device/kernel_runtime.h +++ b/mindspore/ccsrc/runtime/device/kernel_runtime.h @@ -41,7 +41,8 @@ using mindspore::tensor::Tensor; using std::vector; using TensorPtr = std::shared_ptr; using mindspore::kernel::AddressPtr; -using AddressPtrList = std::vector; +using mindspore::kernel::AddressPtrList; +using mindspore::kernel::KernelLaunchInfo; namespace mindspore { #ifndef ENABLE_DEBUGGER @@ -87,8 +88,7 @@ class KernelRuntime { return mem_manager_->MallocCommunicationMemFromMemPool(size); } static void GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod, const AnfNodePtr &kernel, - AddressPtrList *kernel_inputs, AddressPtrList *kernel_workspaces, - AddressPtrList *kernel_outputs); + KernelLaunchInfo *kernel_launch_info); // for GPU and D to impl virtual void ReleaseDeviceRes() {} @@ -135,9 +135,7 @@ class KernelRuntime { void AssignCommunicationNodeInputMem(MemType type, const AnfNodePtr &node); void AssignCommunicationNodeMem(MemType type, const AnfNodePtr &node); bool LaunchKernelWithPynativeProfiling(kernel::KernelMod *kernel_mod, const std::string &op_name, - const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs, void *stream); + const KernelLaunchInfo &kernel_launch_address, void *stream); virtual void KernelLaunchProfiling(const std::string &kernel_name) {} @@ -147,8 +145,7 @@ class KernelRuntime { const std::shared_ptr &mem_scheduler, bool mock = false); void ResetNodeAddress(const session::KernelGraph &graph); void AssignKernelAddress(const std::shared_ptr &mem_scheduler, const AnfNodePtr &kernel, - AddressPtrList *kernel_inputs, AddressPtrList *kernel_workspaces, - AddressPtrList *kernel_outputs); + KernelLaunchInfo *kernel_launch_address); static void GetOrMallocAddress(const std::shared_ptr &mem_scheduler, const DeviceAddress *device_address, const kernel::AddressPtr &kernel_addr); void InitGraphInputTensors(const std::shared_ptr &mem_scheduler, const session::KernelGraph &graph);