Use one struct instead of three vectors for kernel input, output and workspace

This commit is contained in:
tanghuikang 2021-09-26 16:28:54 +08:00
parent 89eb412644
commit 3035f9d367
10 changed files with 69 additions and 84 deletions

View File

@ -204,12 +204,10 @@ std::vector<TaskInfoPtr> AicpuOpKernelMod::GenTask(const std::vector<AddressPtr>
}
device::DynamicKernelPtr AicpuOpKernelMod::GenDynamicKernel(const CNodePtr &cnode_ptr, void *stream_ptr) {
AddressPtrList kernel_inputs;
AddressPtrList kernel_workspaces;
AddressPtrList kernel_outputs;
device::KernelRuntime::GenLaunchArgs(*this, cnode_ptr, &kernel_inputs, &kernel_workspaces, &kernel_outputs);
KernelLaunchInfo kernel_launch_info;
device::KernelRuntime::GenLaunchArgs(*this, cnode_ptr, &kernel_launch_info);
CreateCpuKernelInfo(kernel_inputs, kernel_outputs);
CreateCpuKernelInfo(kernel_launch_info.inputs_, kernel_launch_info.outputs_);
return std::make_shared<AicpuDynamicKernel>(stream_ptr, cnode_ptr, args_, ext_info_, node_so_, node_name_);
}
} // namespace kernel

View File

@ -273,23 +273,21 @@ std::vector<TaskInfoPtr> HcclKernel::GenTask(const std::vector<AddressPtr> &inpu
}
device::DynamicKernelPtr HcclKernel::GenDynamicKernel(const CNodePtr &cnode_ptr, void *stream_ptr) {
AddressPtrList inputs;
AddressPtrList workspaces;
AddressPtrList outputs;
device::KernelRuntime::GenLaunchArgs(*this, cnode_ptr, &inputs, &workspaces, &outputs);
KernelLaunchInfo kernel_launch_info;
device::KernelRuntime::GenLaunchArgs(*this, cnode_ptr, &kernel_launch_info);
std::string hccl_type = MsOpNameToHcomOpType(AnfAlgo::GetCNodeName(anf_node_.lock()));
if (inputs.empty()) {
if (kernel_launch_info.inputs_.empty()) {
MS_LOG(EXCEPTION) << "Hccl kernel input is empty";
}
if (hccl_data_type_list_.empty()) {
MS_LOG(EXCEPTION) << "Hccl data type list is empty";
}
MS_EXCEPTION_IF_NULL(inputs.at(0));
auto input_data_addr = inputs.at(0)->addr;
MS_EXCEPTION_IF_NULL(outputs.at(0));
auto output_data_addr = outputs.at(0)->addr;
MS_EXCEPTION_IF_NULL(kernel_launch_info.inputs_.at(0));
auto input_data_addr = kernel_launch_info.inputs_.at(0)->addr;
MS_EXCEPTION_IF_NULL(kernel_launch_info.outputs_.at(0));
auto output_data_addr = kernel_launch_info.outputs_.at(0)->addr;
HcclDataType data_type = hccl_data_type_list_[0];
auto executor = std::make_shared<device::ascend::HcclDynamicKernel>(

View File

@ -166,12 +166,13 @@ struct Address {
size_t size;
};
using AddressPtr = std::shared_ptr<Address>;
using AddressPtrList = std::vector<AddressPtr>;
// The memory info of kernel launch.
struct KernelLaunchInfo {
std::vector<AddressPtr> inputs_;
std::vector<AddressPtr> outputs_;
std::vector<AddressPtr> workspaces_;
AddressPtrList inputs_;
AddressPtrList outputs_;
AddressPtrList workspaces_;
};
class KernelMod {
@ -179,6 +180,10 @@ class KernelMod {
virtual const std::vector<size_t> &GetInputSizeList() const = 0;
virtual const std::vector<size_t> &GetOutputSizeList() const = 0;
virtual const std::vector<size_t> &GetWorkspaceSizeList() const = 0;
bool Launch(const KernelLaunchInfo &kernel_launch_address, void *stream_ptr) {
return Launch(kernel_launch_address.inputs_, kernel_launch_address.workspaces_, kernel_launch_address.outputs_,
stream_ptr);
}
virtual bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) = 0;
virtual device::DynamicKernelPtr GenDynamicKernel(const CNodePtr &cnode_ptr, void *stream_ptr) { return nullptr; }

View File

@ -133,11 +133,11 @@ std::vector<TaskInfoPtr> MemCpyAsyncKernel::GenTask(const std::vector<AddressPtr
}
device::DynamicKernelPtr MemCpyAsyncKernel::GenDynamicKernel(const CNodePtr &cnode_ptr, void *stream_ptr) {
AddressPtrList kernel_inputs;
AddressPtrList kernel_workspaces;
AddressPtrList kernel_outputs;
device::KernelRuntime::GenLaunchArgs(*this, cnode_ptr, &kernel_inputs, &kernel_workspaces, &kernel_outputs);
KernelLaunchInfo kernel_launch_info;
device::KernelRuntime::GenLaunchArgs(*this, cnode_ptr, &kernel_launch_info);
const auto &kernel_inputs = kernel_launch_info.inputs_;
const auto &kernel_outputs = kernel_launch_info.outputs_;
if (kernel_inputs.size() != 1) {
MS_LOG(EXCEPTION) << "MemCpyAsync op inputs is not one, got " << kernel_inputs.size();
}

View File

@ -112,10 +112,8 @@ std::vector<TaskInfoPtr> TbeKernelMod::GenTask(const std::vector<AddressPtr> &in
}
device::DynamicKernelPtr TbeKernelMod::GenDynamicKernel(const CNodePtr &cnode_ptr, void *stream_ptr) {
AddressPtrList kernel_inputs;
AddressPtrList kernel_workspaces;
AddressPtrList kernel_outputs;
device::KernelRuntime::GenLaunchArgs(*this, cnode_ptr, &kernel_inputs, &kernel_workspaces, &kernel_outputs);
KernelLaunchInfo kernel_launch_info;
device::KernelRuntime::GenLaunchArgs(*this, cnode_ptr, &kernel_launch_info);
auto dynamic_flag = AnfAlgo::IsDynamicShape(cnode_ptr);
// Get para_size from json
@ -124,10 +122,13 @@ device::DynamicKernelPtr TbeKernelMod::GenDynamicKernel(const CNodePtr &cnode_pt
// Generate args
std::vector<void *> runtime_args;
const auto &kernel_inputs = kernel_launch_info.inputs_;
(void)std::transform(std::begin(kernel_inputs), std::end(kernel_inputs), std::back_inserter(runtime_args),
[](const AddressPtr &input) -> void * { return input->addr; });
const auto &kernel_outputs = kernel_launch_info.outputs_;
(void)std::transform(std::begin(kernel_outputs), std::end(kernel_outputs), std::back_inserter(runtime_args),
[](const AddressPtr &output) -> void * { return output->addr; });
const auto &kernel_workspaces = kernel_launch_info.workspaces_;
if (!kernel_workspaces.empty()) {
(void)std::transform(std::begin(kernel_workspaces), std::end(kernel_workspaces), std::back_inserter(runtime_args),
[](const AddressPtr &addr) -> void * { return addr->addr; });

View File

@ -108,15 +108,15 @@ void AiCoreDynamicKernel::UpdateArgs() {
auto kernel_mod = AnfAlgo::GetKernelMod(cnode);
MS_EXCEPTION_IF_NULL(kernel_mod);
AddressPtrList kernel_inputs;
AddressPtrList kernel_workspaces;
AddressPtrList kernel_outputs;
KernelRuntime::GenLaunchArgs(*kernel_mod, cnode, &kernel_inputs, &kernel_workspaces, &kernel_outputs);
KernelLaunchInfo kernel_launch_info;
KernelRuntime::GenLaunchArgs(*kernel_mod, cnode, &kernel_launch_info);
runtime_args_.clear();
const auto &kernel_inputs = kernel_launch_info.inputs_;
(void)std::transform(std::begin(kernel_inputs), std::end(kernel_inputs), std::back_inserter(runtime_args_),
[](const AddressPtr &input) { return input->addr; });
const auto &kernel_outputs = kernel_launch_info.outputs_;
(void)std::transform(std::begin(kernel_outputs), std::end(kernel_outputs), std::back_inserter(runtime_args_),
[](const AddressPtr &output) { return output->addr; });
// Update workspace

View File

@ -44,15 +44,13 @@ void HcclDynamicKernel::UpdateArgs() {
auto kernel_mod = AnfAlgo::GetKernelMod(cnode);
MS_EXCEPTION_IF_NULL(kernel_mod);
// Update input, output, count
AddressPtrList kernel_inputs;
AddressPtrList kernel_workspaces;
AddressPtrList kernel_outputs;
KernelRuntime::GenLaunchArgs(*kernel_mod, cnode, &kernel_inputs, &kernel_workspaces, &kernel_outputs);
if (kernel_inputs.empty() || kernel_outputs.empty()) {
KernelLaunchInfo kernel_launch_info;
KernelRuntime::GenLaunchArgs(*kernel_mod, cnode, &kernel_launch_info);
if (kernel_launch_info.inputs_.empty() || kernel_launch_info.outputs_.empty()) {
MS_LOG(EXCEPTION) << "Inputs or outputs is empty. Node info: " << cnode->DebugString();
}
auto input0 = kernel_inputs.at(0);
auto output0 = kernel_outputs.at(0);
auto input0 = kernel_launch_info.inputs_.at(0);
auto output0 = kernel_launch_info.outputs_.at(0);
MS_EXCEPTION_IF_NULL(input0);
MS_EXCEPTION_IF_NULL(output0);
@ -82,11 +80,9 @@ void HcclDynamicKernel::StaticShapeExecute() {
MS_EXCEPTION_IF_NULL(cnode);
auto kernel_mod = AnfAlgo::GetKernelMod(cnode);
MS_EXCEPTION_IF_NULL(kernel_mod);
AddressPtrList kernel_inputs;
AddressPtrList kernel_workspaces;
AddressPtrList kernel_outputs;
KernelRuntime::GenLaunchArgs(*kernel_mod, cnode, &kernel_inputs, &kernel_workspaces, &kernel_outputs);
kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_);
KernelLaunchInfo kernel_launch_info;
KernelRuntime::GenLaunchArgs(*kernel_mod, cnode, &kernel_launch_info);
kernel_mod->Launch(kernel_launch_info, stream_);
}
void HcclDynamicKernel::Execute() {

View File

@ -891,9 +891,10 @@ bool GPUKernelRuntime::RunOpLaunchKernelDynamic(const session::KernelGraph *grap
AddressPtrList kernel_inputs;
AddressPtrList kernel_workspaces;
AddressPtrList kernel_outputs;
GenLaunchArgs(*kernel_mod, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs);
KernelLaunchInfo kernel_launch_info;
GenLaunchArgs(*kernel_mod, kernel, &kernel_launch_info);
MS_EXCEPTION_IF_NULL(stream_);
auto ret = kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_);
auto ret = kernel_mod->Launch(kernel_launch_info, stream_);
if (!ret) {
MS_LOG(ERROR) << "Launch kernel failed.";
return false;

View File

@ -1104,16 +1104,13 @@ void KernelRuntime::AssignWorkSpaceMem(MemType type, const AnfNodePtr &node) {
}
void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod, const mindspore::AnfNodePtr &kernel,
AddressPtrList *kernel_inputs, AddressPtrList *const kernel_workspaces,
AddressPtrList *kernel_outputs) {
KernelLaunchInfo *kernel_launch_info) {
MS_EXCEPTION_IF_NULL(kernel);
MS_EXCEPTION_IF_NULL(kernel_inputs);
MS_EXCEPTION_IF_NULL(kernel_workspaces);
MS_EXCEPTION_IF_NULL(kernel_outputs);
MS_EXCEPTION_IF_NULL(kernel_launch_info);
auto cnode = kernel->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (AnfAlgo::GetCNodeName(cnode) == kAtomicAddrCleanOpName) {
return GenAddrCleanLaunchArgs(cnode, kernel_inputs);
return GenAddrCleanLaunchArgs(cnode, &(kernel_launch_info->inputs_));
}
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
@ -1140,7 +1137,7 @@ void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod
input->addr = device_address->ptr_;
MS_EXCEPTION_IF_NULL(input->addr);
input->size = device_address->size_;
kernel_inputs->emplace_back(input);
kernel_launch_info->inputs_.emplace_back(input);
}
for (size_t i = 0; i < kernel_mod.GetOutputSizeList().size(); ++i) {
@ -1150,7 +1147,7 @@ void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod
output->addr = device_address->ptr_;
MS_EXCEPTION_IF_NULL(output->addr);
output->size = device_address->size_;
kernel_outputs->emplace_back(output);
kernel_launch_info->outputs_.emplace_back(output);
}
for (size_t i = 0; i < kernel_mod.GetWorkspaceSizeList().size(); ++i) {
@ -1160,7 +1157,7 @@ void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod
workspace->addr = device_address->ptr_;
MS_EXCEPTION_IF_NULL(workspace->addr);
workspace->size = device_address->size_;
kernel_workspaces->emplace_back(workspace);
kernel_launch_info->workspaces_.emplace_back(workspace);
}
}
@ -1230,9 +1227,7 @@ void KernelRuntime::LaunchKernelEvent(const std::vector<std::vector<std::functio
}
bool KernelRuntime::LaunchKernelWithPynativeProfiling(kernel::KernelMod *kernel_mod, const std::string &op_name,
const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream) {
const KernelLaunchInfo &kernel_launch_info, void *stream) {
MS_EXCEPTION_IF_NULL(kernel_mod);
MS_EXCEPTION_IF_NULL(stream);
float cost_time = 0;
@ -1243,7 +1238,7 @@ bool KernelRuntime::LaunchKernelWithPynativeProfiling(kernel::KernelMod *kernel_
start->set_record_stream(stream);
end->set_record_stream(stream);
start->RecordEvent();
bool ret = kernel_mod->Launch(inputs, workspace, outputs, stream);
bool ret = kernel_mod->Launch(kernel_launch_info, stream);
end->RecordEvent();
start->SyncEvent();
end->SyncEvent();
@ -1283,16 +1278,13 @@ void KernelRuntime::GetOrMallocAddress(const std::shared_ptr<MemScheduler> &mem_
}
void KernelRuntime::AssignKernelAddress(const std::shared_ptr<MemScheduler> &mem_scheduler, const AnfNodePtr &kernel,
AddressPtrList *kernel_inputs, AddressPtrList *kernel_workspaces,
AddressPtrList *kernel_outputs) {
KernelLaunchInfo *kernel_launch_info) {
MS_EXCEPTION_IF_NULL(kernel);
MS_EXCEPTION_IF_NULL(kernel_inputs);
MS_EXCEPTION_IF_NULL(kernel_workspaces);
MS_EXCEPTION_IF_NULL(kernel_outputs);
MS_EXCEPTION_IF_NULL(kernel_launch_info);
auto cnode = kernel->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (AnfAlgo::GetCNodeName(cnode) == kAtomicAddrCleanOpName) {
return GenAddrCleanLaunchArgs(cnode, kernel_inputs, mem_scheduler);
return GenAddrCleanLaunchArgs(cnode, &(kernel_launch_info->inputs_), mem_scheduler);
}
auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
MS_EXCEPTION_IF_NULL(kernel_mod);
@ -1307,7 +1299,7 @@ void KernelRuntime::AssignKernelAddress(const std::shared_ptr<MemScheduler> &mem
kernel::AddressPtr input = std::make_shared<kernel::Address>();
GetOrMallocAddress(mem_scheduler, device_address, input);
input->size = device_address->size_;
kernel_inputs->emplace_back(input);
kernel_launch_info->inputs_.emplace_back(input);
}
for (size_t j = 0; j < kernel_mod->GetOutputSizeList().size(); ++j) {
@ -1315,7 +1307,7 @@ void KernelRuntime::AssignKernelAddress(const std::shared_ptr<MemScheduler> &mem
kernel::AddressPtr output = std::make_shared<kernel::Address>();
GetOrMallocAddress(mem_scheduler, device_address, output);
output->size = device_address->size_;
kernel_outputs->emplace_back(output);
kernel_launch_info->outputs_.emplace_back(output);
}
for (size_t i = 0; i < kernel_mod->GetWorkspaceSizeList().size(); ++i) {
@ -1323,7 +1315,7 @@ void KernelRuntime::AssignKernelAddress(const std::shared_ptr<MemScheduler> &mem
kernel::AddressPtr workspace = std::make_shared<kernel::Address>();
GetOrMallocAddress(mem_scheduler, device_address, workspace);
workspace->size = device_address->size_;
kernel_workspaces->emplace_back(workspace);
kernel_launch_info->workspaces_.emplace_back(workspace);
}
}
@ -1400,9 +1392,7 @@ bool KernelRuntime::LaunchKernel(const session::KernelGraph &graph, const AnfNod
MS_EXCEPTION_IF_NULL(kernel);
auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
MS_EXCEPTION_IF_NULL(kernel_mod);
AddressPtrList kernel_inputs;
AddressPtrList kernel_workspaces;
AddressPtrList kernel_outputs;
KernelLaunchInfo kernel_launch_info;
auto stream = kernel_mod->GetStream();
if (stream == nullptr) {
if (AnfAlgo::IsCommunicationOp(kernel)) {
@ -1417,20 +1407,19 @@ bool KernelRuntime::LaunchKernel(const session::KernelGraph &graph, const AnfNod
if (!ret) {
return ret;
}
AssignKernelAddress(mem_scheduler, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs);
AssignKernelAddress(mem_scheduler, kernel, &kernel_launch_info);
} else if (!kernel_mod->GetInputsAddr().empty() || !kernel_mod->GetOutputsAddr().empty()) {
kernel_inputs = kernel_mod->GetInputsAddr();
kernel_outputs = kernel_mod->GetOutputsAddr();
kernel_workspaces = kernel_mod->GetWorkSpacesAddr();
kernel_launch_info.inputs_ = kernel_mod->GetInputsAddr();
kernel_launch_info.outputs_ = kernel_mod->GetOutputsAddr();
kernel_launch_info.workspaces_ = kernel_mod->GetWorkSpacesAddr();
} else {
GenLaunchArgs(*kernel_mod, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs);
GenLaunchArgs(*kernel_mod, kernel, &kernel_launch_info);
}
if (!mock) {
if (pynative_mode_profiling_flag_) {
ret = LaunchKernelWithPynativeProfiling(kernel_mod, kernel->fullname_with_scope(), kernel_inputs,
kernel_workspaces, kernel_outputs, stream);
ret = LaunchKernelWithPynativeProfiling(kernel_mod, kernel->fullname_with_scope(), kernel_launch_info, stream);
} else {
ret = kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream);
ret = kernel_mod->Launch(kernel_launch_info, stream);
}
}
if (mem_scheduler != nullptr) {

View File

@ -41,7 +41,8 @@ using mindspore::tensor::Tensor;
using std::vector;
using TensorPtr = std::shared_ptr<Tensor>;
using mindspore::kernel::AddressPtr;
using AddressPtrList = std::vector<mindspore::kernel::AddressPtr>;
using mindspore::kernel::AddressPtrList;
using mindspore::kernel::KernelLaunchInfo;
namespace mindspore {
#ifndef ENABLE_DEBUGGER
@ -87,8 +88,7 @@ class KernelRuntime {
return mem_manager_->MallocCommunicationMemFromMemPool(size);
}
static void GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod, const AnfNodePtr &kernel,
AddressPtrList *kernel_inputs, AddressPtrList *kernel_workspaces,
AddressPtrList *kernel_outputs);
KernelLaunchInfo *kernel_launch_info);
// for GPU and D to impl
virtual void ReleaseDeviceRes() {}
@ -135,9 +135,7 @@ class KernelRuntime {
void AssignCommunicationNodeInputMem(MemType type, const AnfNodePtr &node);
void AssignCommunicationNodeMem(MemType type, const AnfNodePtr &node);
bool LaunchKernelWithPynativeProfiling(kernel::KernelMod *kernel_mod, const std::string &op_name,
const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream);
const KernelLaunchInfo &kernel_launch_address, void *stream);
virtual void KernelLaunchProfiling(const std::string &kernel_name) {}
@ -147,8 +145,7 @@ class KernelRuntime {
const std::shared_ptr<MemScheduler> &mem_scheduler, bool mock = false);
void ResetNodeAddress(const session::KernelGraph &graph);
void AssignKernelAddress(const std::shared_ptr<MemScheduler> &mem_scheduler, const AnfNodePtr &kernel,
AddressPtrList *kernel_inputs, AddressPtrList *kernel_workspaces,
AddressPtrList *kernel_outputs);
KernelLaunchInfo *kernel_launch_address);
static void GetOrMallocAddress(const std::shared_ptr<MemScheduler> &mem_scheduler,
const DeviceAddress *device_address, const kernel::AddressPtr &kernel_addr);
void InitGraphInputTensors(const std::shared_ptr<MemScheduler> &mem_scheduler, const session::KernelGraph &graph);