forked from mindspore-Ecosystem/mindspore
Use one struct instead of three vectors for kernel input, output and workspace
This commit is contained in:
parent
89eb412644
commit
3035f9d367
|
@ -204,12 +204,10 @@ std::vector<TaskInfoPtr> AicpuOpKernelMod::GenTask(const std::vector<AddressPtr>
|
|||
}
|
||||
|
||||
device::DynamicKernelPtr AicpuOpKernelMod::GenDynamicKernel(const CNodePtr &cnode_ptr, void *stream_ptr) {
|
||||
AddressPtrList kernel_inputs;
|
||||
AddressPtrList kernel_workspaces;
|
||||
AddressPtrList kernel_outputs;
|
||||
device::KernelRuntime::GenLaunchArgs(*this, cnode_ptr, &kernel_inputs, &kernel_workspaces, &kernel_outputs);
|
||||
KernelLaunchInfo kernel_launch_info;
|
||||
device::KernelRuntime::GenLaunchArgs(*this, cnode_ptr, &kernel_launch_info);
|
||||
|
||||
CreateCpuKernelInfo(kernel_inputs, kernel_outputs);
|
||||
CreateCpuKernelInfo(kernel_launch_info.inputs_, kernel_launch_info.outputs_);
|
||||
return std::make_shared<AicpuDynamicKernel>(stream_ptr, cnode_ptr, args_, ext_info_, node_so_, node_name_);
|
||||
}
|
||||
} // namespace kernel
|
||||
|
|
|
@ -273,23 +273,21 @@ std::vector<TaskInfoPtr> HcclKernel::GenTask(const std::vector<AddressPtr> &inpu
|
|||
}
|
||||
|
||||
device::DynamicKernelPtr HcclKernel::GenDynamicKernel(const CNodePtr &cnode_ptr, void *stream_ptr) {
|
||||
AddressPtrList inputs;
|
||||
AddressPtrList workspaces;
|
||||
AddressPtrList outputs;
|
||||
device::KernelRuntime::GenLaunchArgs(*this, cnode_ptr, &inputs, &workspaces, &outputs);
|
||||
KernelLaunchInfo kernel_launch_info;
|
||||
device::KernelRuntime::GenLaunchArgs(*this, cnode_ptr, &kernel_launch_info);
|
||||
|
||||
std::string hccl_type = MsOpNameToHcomOpType(AnfAlgo::GetCNodeName(anf_node_.lock()));
|
||||
|
||||
if (inputs.empty()) {
|
||||
if (kernel_launch_info.inputs_.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Hccl kernel input is empty";
|
||||
}
|
||||
if (hccl_data_type_list_.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Hccl data type list is empty";
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(inputs.at(0));
|
||||
auto input_data_addr = inputs.at(0)->addr;
|
||||
MS_EXCEPTION_IF_NULL(outputs.at(0));
|
||||
auto output_data_addr = outputs.at(0)->addr;
|
||||
MS_EXCEPTION_IF_NULL(kernel_launch_info.inputs_.at(0));
|
||||
auto input_data_addr = kernel_launch_info.inputs_.at(0)->addr;
|
||||
MS_EXCEPTION_IF_NULL(kernel_launch_info.outputs_.at(0));
|
||||
auto output_data_addr = kernel_launch_info.outputs_.at(0)->addr;
|
||||
HcclDataType data_type = hccl_data_type_list_[0];
|
||||
|
||||
auto executor = std::make_shared<device::ascend::HcclDynamicKernel>(
|
||||
|
|
|
@ -166,12 +166,13 @@ struct Address {
|
|||
size_t size;
|
||||
};
|
||||
using AddressPtr = std::shared_ptr<Address>;
|
||||
using AddressPtrList = std::vector<AddressPtr>;
|
||||
|
||||
// The memory info of kernel launch.
|
||||
struct KernelLaunchInfo {
|
||||
std::vector<AddressPtr> inputs_;
|
||||
std::vector<AddressPtr> outputs_;
|
||||
std::vector<AddressPtr> workspaces_;
|
||||
AddressPtrList inputs_;
|
||||
AddressPtrList outputs_;
|
||||
AddressPtrList workspaces_;
|
||||
};
|
||||
|
||||
class KernelMod {
|
||||
|
@ -179,6 +180,10 @@ class KernelMod {
|
|||
virtual const std::vector<size_t> &GetInputSizeList() const = 0;
|
||||
virtual const std::vector<size_t> &GetOutputSizeList() const = 0;
|
||||
virtual const std::vector<size_t> &GetWorkspaceSizeList() const = 0;
|
||||
bool Launch(const KernelLaunchInfo &kernel_launch_address, void *stream_ptr) {
|
||||
return Launch(kernel_launch_address.inputs_, kernel_launch_address.workspaces_, kernel_launch_address.outputs_,
|
||||
stream_ptr);
|
||||
}
|
||||
virtual bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) = 0;
|
||||
virtual device::DynamicKernelPtr GenDynamicKernel(const CNodePtr &cnode_ptr, void *stream_ptr) { return nullptr; }
|
||||
|
|
|
@ -133,11 +133,11 @@ std::vector<TaskInfoPtr> MemCpyAsyncKernel::GenTask(const std::vector<AddressPtr
|
|||
}
|
||||
|
||||
device::DynamicKernelPtr MemCpyAsyncKernel::GenDynamicKernel(const CNodePtr &cnode_ptr, void *stream_ptr) {
|
||||
AddressPtrList kernel_inputs;
|
||||
AddressPtrList kernel_workspaces;
|
||||
AddressPtrList kernel_outputs;
|
||||
device::KernelRuntime::GenLaunchArgs(*this, cnode_ptr, &kernel_inputs, &kernel_workspaces, &kernel_outputs);
|
||||
KernelLaunchInfo kernel_launch_info;
|
||||
device::KernelRuntime::GenLaunchArgs(*this, cnode_ptr, &kernel_launch_info);
|
||||
|
||||
const auto &kernel_inputs = kernel_launch_info.inputs_;
|
||||
const auto &kernel_outputs = kernel_launch_info.outputs_;
|
||||
if (kernel_inputs.size() != 1) {
|
||||
MS_LOG(EXCEPTION) << "MemCpyAsync op inputs is not one, got " << kernel_inputs.size();
|
||||
}
|
||||
|
|
|
@ -112,10 +112,8 @@ std::vector<TaskInfoPtr> TbeKernelMod::GenTask(const std::vector<AddressPtr> &in
|
|||
}
|
||||
|
||||
device::DynamicKernelPtr TbeKernelMod::GenDynamicKernel(const CNodePtr &cnode_ptr, void *stream_ptr) {
|
||||
AddressPtrList kernel_inputs;
|
||||
AddressPtrList kernel_workspaces;
|
||||
AddressPtrList kernel_outputs;
|
||||
device::KernelRuntime::GenLaunchArgs(*this, cnode_ptr, &kernel_inputs, &kernel_workspaces, &kernel_outputs);
|
||||
KernelLaunchInfo kernel_launch_info;
|
||||
device::KernelRuntime::GenLaunchArgs(*this, cnode_ptr, &kernel_launch_info);
|
||||
auto dynamic_flag = AnfAlgo::IsDynamicShape(cnode_ptr);
|
||||
|
||||
// Get para_size from json
|
||||
|
@ -124,10 +122,13 @@ device::DynamicKernelPtr TbeKernelMod::GenDynamicKernel(const CNodePtr &cnode_pt
|
|||
|
||||
// Generate args
|
||||
std::vector<void *> runtime_args;
|
||||
const auto &kernel_inputs = kernel_launch_info.inputs_;
|
||||
(void)std::transform(std::begin(kernel_inputs), std::end(kernel_inputs), std::back_inserter(runtime_args),
|
||||
[](const AddressPtr &input) -> void * { return input->addr; });
|
||||
const auto &kernel_outputs = kernel_launch_info.outputs_;
|
||||
(void)std::transform(std::begin(kernel_outputs), std::end(kernel_outputs), std::back_inserter(runtime_args),
|
||||
[](const AddressPtr &output) -> void * { return output->addr; });
|
||||
const auto &kernel_workspaces = kernel_launch_info.workspaces_;
|
||||
if (!kernel_workspaces.empty()) {
|
||||
(void)std::transform(std::begin(kernel_workspaces), std::end(kernel_workspaces), std::back_inserter(runtime_args),
|
||||
[](const AddressPtr &addr) -> void * { return addr->addr; });
|
||||
|
|
|
@ -108,15 +108,15 @@ void AiCoreDynamicKernel::UpdateArgs() {
|
|||
auto kernel_mod = AnfAlgo::GetKernelMod(cnode);
|
||||
MS_EXCEPTION_IF_NULL(kernel_mod);
|
||||
|
||||
AddressPtrList kernel_inputs;
|
||||
AddressPtrList kernel_workspaces;
|
||||
AddressPtrList kernel_outputs;
|
||||
KernelRuntime::GenLaunchArgs(*kernel_mod, cnode, &kernel_inputs, &kernel_workspaces, &kernel_outputs);
|
||||
KernelLaunchInfo kernel_launch_info;
|
||||
KernelRuntime::GenLaunchArgs(*kernel_mod, cnode, &kernel_launch_info);
|
||||
|
||||
runtime_args_.clear();
|
||||
|
||||
const auto &kernel_inputs = kernel_launch_info.inputs_;
|
||||
(void)std::transform(std::begin(kernel_inputs), std::end(kernel_inputs), std::back_inserter(runtime_args_),
|
||||
[](const AddressPtr &input) { return input->addr; });
|
||||
const auto &kernel_outputs = kernel_launch_info.outputs_;
|
||||
(void)std::transform(std::begin(kernel_outputs), std::end(kernel_outputs), std::back_inserter(runtime_args_),
|
||||
[](const AddressPtr &output) { return output->addr; });
|
||||
// Update workspace
|
||||
|
|
|
@ -44,15 +44,13 @@ void HcclDynamicKernel::UpdateArgs() {
|
|||
auto kernel_mod = AnfAlgo::GetKernelMod(cnode);
|
||||
MS_EXCEPTION_IF_NULL(kernel_mod);
|
||||
// Update input, output, count
|
||||
AddressPtrList kernel_inputs;
|
||||
AddressPtrList kernel_workspaces;
|
||||
AddressPtrList kernel_outputs;
|
||||
KernelRuntime::GenLaunchArgs(*kernel_mod, cnode, &kernel_inputs, &kernel_workspaces, &kernel_outputs);
|
||||
if (kernel_inputs.empty() || kernel_outputs.empty()) {
|
||||
KernelLaunchInfo kernel_launch_info;
|
||||
KernelRuntime::GenLaunchArgs(*kernel_mod, cnode, &kernel_launch_info);
|
||||
if (kernel_launch_info.inputs_.empty() || kernel_launch_info.outputs_.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Inputs or outputs is empty. Node info: " << cnode->DebugString();
|
||||
}
|
||||
auto input0 = kernel_inputs.at(0);
|
||||
auto output0 = kernel_outputs.at(0);
|
||||
auto input0 = kernel_launch_info.inputs_.at(0);
|
||||
auto output0 = kernel_launch_info.outputs_.at(0);
|
||||
MS_EXCEPTION_IF_NULL(input0);
|
||||
MS_EXCEPTION_IF_NULL(output0);
|
||||
|
||||
|
@ -82,11 +80,9 @@ void HcclDynamicKernel::StaticShapeExecute() {
|
|||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto kernel_mod = AnfAlgo::GetKernelMod(cnode);
|
||||
MS_EXCEPTION_IF_NULL(kernel_mod);
|
||||
AddressPtrList kernel_inputs;
|
||||
AddressPtrList kernel_workspaces;
|
||||
AddressPtrList kernel_outputs;
|
||||
KernelRuntime::GenLaunchArgs(*kernel_mod, cnode, &kernel_inputs, &kernel_workspaces, &kernel_outputs);
|
||||
kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_);
|
||||
KernelLaunchInfo kernel_launch_info;
|
||||
KernelRuntime::GenLaunchArgs(*kernel_mod, cnode, &kernel_launch_info);
|
||||
kernel_mod->Launch(kernel_launch_info, stream_);
|
||||
}
|
||||
|
||||
void HcclDynamicKernel::Execute() {
|
||||
|
|
|
@ -891,9 +891,10 @@ bool GPUKernelRuntime::RunOpLaunchKernelDynamic(const session::KernelGraph *grap
|
|||
AddressPtrList kernel_inputs;
|
||||
AddressPtrList kernel_workspaces;
|
||||
AddressPtrList kernel_outputs;
|
||||
GenLaunchArgs(*kernel_mod, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs);
|
||||
KernelLaunchInfo kernel_launch_info;
|
||||
GenLaunchArgs(*kernel_mod, kernel, &kernel_launch_info);
|
||||
MS_EXCEPTION_IF_NULL(stream_);
|
||||
auto ret = kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_);
|
||||
auto ret = kernel_mod->Launch(kernel_launch_info, stream_);
|
||||
if (!ret) {
|
||||
MS_LOG(ERROR) << "Launch kernel failed.";
|
||||
return false;
|
||||
|
|
|
@ -1104,16 +1104,13 @@ void KernelRuntime::AssignWorkSpaceMem(MemType type, const AnfNodePtr &node) {
|
|||
}
|
||||
|
||||
void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod, const mindspore::AnfNodePtr &kernel,
|
||||
AddressPtrList *kernel_inputs, AddressPtrList *const kernel_workspaces,
|
||||
AddressPtrList *kernel_outputs) {
|
||||
KernelLaunchInfo *kernel_launch_info) {
|
||||
MS_EXCEPTION_IF_NULL(kernel);
|
||||
MS_EXCEPTION_IF_NULL(kernel_inputs);
|
||||
MS_EXCEPTION_IF_NULL(kernel_workspaces);
|
||||
MS_EXCEPTION_IF_NULL(kernel_outputs);
|
||||
MS_EXCEPTION_IF_NULL(kernel_launch_info);
|
||||
auto cnode = kernel->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (AnfAlgo::GetCNodeName(cnode) == kAtomicAddrCleanOpName) {
|
||||
return GenAddrCleanLaunchArgs(cnode, kernel_inputs);
|
||||
return GenAddrCleanLaunchArgs(cnode, &(kernel_launch_info->inputs_));
|
||||
}
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
|
@ -1140,7 +1137,7 @@ void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod
|
|||
input->addr = device_address->ptr_;
|
||||
MS_EXCEPTION_IF_NULL(input->addr);
|
||||
input->size = device_address->size_;
|
||||
kernel_inputs->emplace_back(input);
|
||||
kernel_launch_info->inputs_.emplace_back(input);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < kernel_mod.GetOutputSizeList().size(); ++i) {
|
||||
|
@ -1150,7 +1147,7 @@ void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod
|
|||
output->addr = device_address->ptr_;
|
||||
MS_EXCEPTION_IF_NULL(output->addr);
|
||||
output->size = device_address->size_;
|
||||
kernel_outputs->emplace_back(output);
|
||||
kernel_launch_info->outputs_.emplace_back(output);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < kernel_mod.GetWorkspaceSizeList().size(); ++i) {
|
||||
|
@ -1160,7 +1157,7 @@ void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod
|
|||
workspace->addr = device_address->ptr_;
|
||||
MS_EXCEPTION_IF_NULL(workspace->addr);
|
||||
workspace->size = device_address->size_;
|
||||
kernel_workspaces->emplace_back(workspace);
|
||||
kernel_launch_info->workspaces_.emplace_back(workspace);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1230,9 +1227,7 @@ void KernelRuntime::LaunchKernelEvent(const std::vector<std::vector<std::functio
|
|||
}
|
||||
|
||||
bool KernelRuntime::LaunchKernelWithPynativeProfiling(kernel::KernelMod *kernel_mod, const std::string &op_name,
|
||||
const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream) {
|
||||
const KernelLaunchInfo &kernel_launch_info, void *stream) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_mod);
|
||||
MS_EXCEPTION_IF_NULL(stream);
|
||||
float cost_time = 0;
|
||||
|
@ -1243,7 +1238,7 @@ bool KernelRuntime::LaunchKernelWithPynativeProfiling(kernel::KernelMod *kernel_
|
|||
start->set_record_stream(stream);
|
||||
end->set_record_stream(stream);
|
||||
start->RecordEvent();
|
||||
bool ret = kernel_mod->Launch(inputs, workspace, outputs, stream);
|
||||
bool ret = kernel_mod->Launch(kernel_launch_info, stream);
|
||||
end->RecordEvent();
|
||||
start->SyncEvent();
|
||||
end->SyncEvent();
|
||||
|
@ -1283,16 +1278,13 @@ void KernelRuntime::GetOrMallocAddress(const std::shared_ptr<MemScheduler> &mem_
|
|||
}
|
||||
|
||||
void KernelRuntime::AssignKernelAddress(const std::shared_ptr<MemScheduler> &mem_scheduler, const AnfNodePtr &kernel,
|
||||
AddressPtrList *kernel_inputs, AddressPtrList *kernel_workspaces,
|
||||
AddressPtrList *kernel_outputs) {
|
||||
KernelLaunchInfo *kernel_launch_info) {
|
||||
MS_EXCEPTION_IF_NULL(kernel);
|
||||
MS_EXCEPTION_IF_NULL(kernel_inputs);
|
||||
MS_EXCEPTION_IF_NULL(kernel_workspaces);
|
||||
MS_EXCEPTION_IF_NULL(kernel_outputs);
|
||||
MS_EXCEPTION_IF_NULL(kernel_launch_info);
|
||||
auto cnode = kernel->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (AnfAlgo::GetCNodeName(cnode) == kAtomicAddrCleanOpName) {
|
||||
return GenAddrCleanLaunchArgs(cnode, kernel_inputs, mem_scheduler);
|
||||
return GenAddrCleanLaunchArgs(cnode, &(kernel_launch_info->inputs_), mem_scheduler);
|
||||
}
|
||||
auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
|
||||
MS_EXCEPTION_IF_NULL(kernel_mod);
|
||||
|
@ -1307,7 +1299,7 @@ void KernelRuntime::AssignKernelAddress(const std::shared_ptr<MemScheduler> &mem
|
|||
kernel::AddressPtr input = std::make_shared<kernel::Address>();
|
||||
GetOrMallocAddress(mem_scheduler, device_address, input);
|
||||
input->size = device_address->size_;
|
||||
kernel_inputs->emplace_back(input);
|
||||
kernel_launch_info->inputs_.emplace_back(input);
|
||||
}
|
||||
|
||||
for (size_t j = 0; j < kernel_mod->GetOutputSizeList().size(); ++j) {
|
||||
|
@ -1315,7 +1307,7 @@ void KernelRuntime::AssignKernelAddress(const std::shared_ptr<MemScheduler> &mem
|
|||
kernel::AddressPtr output = std::make_shared<kernel::Address>();
|
||||
GetOrMallocAddress(mem_scheduler, device_address, output);
|
||||
output->size = device_address->size_;
|
||||
kernel_outputs->emplace_back(output);
|
||||
kernel_launch_info->outputs_.emplace_back(output);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < kernel_mod->GetWorkspaceSizeList().size(); ++i) {
|
||||
|
@ -1323,7 +1315,7 @@ void KernelRuntime::AssignKernelAddress(const std::shared_ptr<MemScheduler> &mem
|
|||
kernel::AddressPtr workspace = std::make_shared<kernel::Address>();
|
||||
GetOrMallocAddress(mem_scheduler, device_address, workspace);
|
||||
workspace->size = device_address->size_;
|
||||
kernel_workspaces->emplace_back(workspace);
|
||||
kernel_launch_info->workspaces_.emplace_back(workspace);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1400,9 +1392,7 @@ bool KernelRuntime::LaunchKernel(const session::KernelGraph &graph, const AnfNod
|
|||
MS_EXCEPTION_IF_NULL(kernel);
|
||||
auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
|
||||
MS_EXCEPTION_IF_NULL(kernel_mod);
|
||||
AddressPtrList kernel_inputs;
|
||||
AddressPtrList kernel_workspaces;
|
||||
AddressPtrList kernel_outputs;
|
||||
KernelLaunchInfo kernel_launch_info;
|
||||
auto stream = kernel_mod->GetStream();
|
||||
if (stream == nullptr) {
|
||||
if (AnfAlgo::IsCommunicationOp(kernel)) {
|
||||
|
@ -1417,20 +1407,19 @@ bool KernelRuntime::LaunchKernel(const session::KernelGraph &graph, const AnfNod
|
|||
if (!ret) {
|
||||
return ret;
|
||||
}
|
||||
AssignKernelAddress(mem_scheduler, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs);
|
||||
AssignKernelAddress(mem_scheduler, kernel, &kernel_launch_info);
|
||||
} else if (!kernel_mod->GetInputsAddr().empty() || !kernel_mod->GetOutputsAddr().empty()) {
|
||||
kernel_inputs = kernel_mod->GetInputsAddr();
|
||||
kernel_outputs = kernel_mod->GetOutputsAddr();
|
||||
kernel_workspaces = kernel_mod->GetWorkSpacesAddr();
|
||||
kernel_launch_info.inputs_ = kernel_mod->GetInputsAddr();
|
||||
kernel_launch_info.outputs_ = kernel_mod->GetOutputsAddr();
|
||||
kernel_launch_info.workspaces_ = kernel_mod->GetWorkSpacesAddr();
|
||||
} else {
|
||||
GenLaunchArgs(*kernel_mod, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs);
|
||||
GenLaunchArgs(*kernel_mod, kernel, &kernel_launch_info);
|
||||
}
|
||||
if (!mock) {
|
||||
if (pynative_mode_profiling_flag_) {
|
||||
ret = LaunchKernelWithPynativeProfiling(kernel_mod, kernel->fullname_with_scope(), kernel_inputs,
|
||||
kernel_workspaces, kernel_outputs, stream);
|
||||
ret = LaunchKernelWithPynativeProfiling(kernel_mod, kernel->fullname_with_scope(), kernel_launch_info, stream);
|
||||
} else {
|
||||
ret = kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream);
|
||||
ret = kernel_mod->Launch(kernel_launch_info, stream);
|
||||
}
|
||||
}
|
||||
if (mem_scheduler != nullptr) {
|
||||
|
|
|
@ -41,7 +41,8 @@ using mindspore::tensor::Tensor;
|
|||
using std::vector;
|
||||
using TensorPtr = std::shared_ptr<Tensor>;
|
||||
using mindspore::kernel::AddressPtr;
|
||||
using AddressPtrList = std::vector<mindspore::kernel::AddressPtr>;
|
||||
using mindspore::kernel::AddressPtrList;
|
||||
using mindspore::kernel::KernelLaunchInfo;
|
||||
|
||||
namespace mindspore {
|
||||
#ifndef ENABLE_DEBUGGER
|
||||
|
@ -87,8 +88,7 @@ class KernelRuntime {
|
|||
return mem_manager_->MallocCommunicationMemFromMemPool(size);
|
||||
}
|
||||
static void GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod, const AnfNodePtr &kernel,
|
||||
AddressPtrList *kernel_inputs, AddressPtrList *kernel_workspaces,
|
||||
AddressPtrList *kernel_outputs);
|
||||
KernelLaunchInfo *kernel_launch_info);
|
||||
|
||||
// for GPU and D to impl
|
||||
virtual void ReleaseDeviceRes() {}
|
||||
|
@ -135,9 +135,7 @@ class KernelRuntime {
|
|||
void AssignCommunicationNodeInputMem(MemType type, const AnfNodePtr &node);
|
||||
void AssignCommunicationNodeMem(MemType type, const AnfNodePtr &node);
|
||||
bool LaunchKernelWithPynativeProfiling(kernel::KernelMod *kernel_mod, const std::string &op_name,
|
||||
const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream);
|
||||
const KernelLaunchInfo &kernel_launch_address, void *stream);
|
||||
|
||||
virtual void KernelLaunchProfiling(const std::string &kernel_name) {}
|
||||
|
||||
|
@ -147,8 +145,7 @@ class KernelRuntime {
|
|||
const std::shared_ptr<MemScheduler> &mem_scheduler, bool mock = false);
|
||||
void ResetNodeAddress(const session::KernelGraph &graph);
|
||||
void AssignKernelAddress(const std::shared_ptr<MemScheduler> &mem_scheduler, const AnfNodePtr &kernel,
|
||||
AddressPtrList *kernel_inputs, AddressPtrList *kernel_workspaces,
|
||||
AddressPtrList *kernel_outputs);
|
||||
KernelLaunchInfo *kernel_launch_address);
|
||||
static void GetOrMallocAddress(const std::shared_ptr<MemScheduler> &mem_scheduler,
|
||||
const DeviceAddress *device_address, const kernel::AddressPtr &kernel_addr);
|
||||
void InitGraphInputTensors(const std::shared_ptr<MemScheduler> &mem_scheduler, const session::KernelGraph &graph);
|
||||
|
|
Loading…
Reference in New Issue