!22110 opt ascend single op mode runtime code

Merge pull request !22110 from baihuawei/graph_mode_nonsink_part3-1
This commit is contained in:
i-robot 2021-08-23 06:21:26 +00:00 committed by Gitee
commit 02db74ab2c
19 changed files with 268 additions and 112 deletions

View File

@ -139,7 +139,9 @@ bool AicpuOpKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::
MS_LOG(ERROR) << "stream_ptr should not be nullptr.";
return false;
}
if (stream_ == nullptr) {
stream_ = stream_ptr;
}
CreateCpuKernelInfo(inputs, outputs);
if (node_name_ == kTopK) {
node_name_ = kTopKV2;
@ -152,7 +154,7 @@ bool AicpuOpKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::
if (rtCpuKernelLaunch(reinterpret_cast<const void *>(node_so_.c_str()),
reinterpret_cast<const void *>(node_name_.c_str()), 1,
reinterpret_cast<const void *>(args_.data()), static_cast<uint32_t>(args_.length()), nullptr,
stream_ptr) != RT_ERROR_NONE) {
stream_) != RT_ERROR_NONE) {
MS_LOG(ERROR) << "Aicpu op launch failed!";
return false;

View File

@ -60,7 +60,9 @@ bool AkgKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vect
MS_LOG(ERROR) << "kernel pack should not be nullptr.";
return false;
}
if (stream_ == nullptr) {
stream_ = stream_ptr;
}
uint32_t block_dim = DEFAULT_BLOCK_DIM; // default blockdim equal to 1.
auto func_stub = KernelManager::GenFuncStub(*kernel_pack_, false, &block_dim);
if (func_stub == 0) {
@ -80,7 +82,7 @@ bool AkgKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vect
}
rtL2Ctrl_t *l2ctrl = nullptr;
auto stream = static_cast<rtStream_t *>(stream_ptr);
auto stream = static_cast<rtStream_t *>(stream_);
if (RT_ERROR_NONE != rtKernelLaunch(reinterpret_cast<void *>(func_stub), block_dim, runtime_args.data(),
SizeToUint(sizeof(void *) * runtime_args.size()), l2ctrl, stream)) {
MS_LOG(ERROR) << "Call runtime rtKernelLaunch error.";

View File

@ -37,8 +37,11 @@ class AscendKernelMod : public KernelMod {
return dump_json.NeedDump(fullname_) && dump_json.async_dump_enabled() && dump_json.op_debug_mode() == 0 &&
!is_monad_;
}
void SetStream(void *stream) { stream_ = stream; }
void *GetStream() { return stream_; }
protected:
void *stream_{nullptr};
uint32_t block_dim_{1};
uint32_t stream_id_{0};
};

View File

@ -31,8 +31,11 @@ bool HcomAllGatherKernel::Launch(const std::vector<AddressPtr> &inputs, const st
MS_EXCEPTION_IF_NULL(inputs[0]);
MS_EXCEPTION_IF_NULL(outputs[0]);
MS_EXCEPTION_IF_NULL(stream_ptr);
if (stream_ == nullptr) {
stream_ = stream_ptr;
}
auto hccl_result = hccl::HcclAdapter::GetInstance().HcclAllGather(inputs[0]->addr, outputs[0]->addr, hccl_count_,
hccl_data_type_list_[0], stream_ptr, group_);
hccl_data_type_list_[0], stream_, group_);
if (hccl_result != HCCL_SUCCESS) {
MS_LOG(ERROR) << "HcclAllGather faled, ret:" << hccl_result;
return false;

View File

@ -31,8 +31,11 @@ bool HcomAllReduceKernel::Launch(const std::vector<AddressPtr> &inputs, const st
MS_EXCEPTION_IF_NULL(inputs[0]);
MS_EXCEPTION_IF_NULL(outputs[0]);
MS_EXCEPTION_IF_NULL(stream_ptr);
auto hccl_result = hccl::HcclAdapter::GetInstance().HcclAllReduce(
inputs[0]->addr, outputs[0]->addr, hccl_count_, hccl_data_type_list_[0], op_type_, stream_ptr, group_);
if (stream_ == nullptr) {
stream_ = stream_ptr;
}
auto hccl_result = hccl::HcclAdapter::GetInstance().HcclAllReduce(inputs[0]->addr, outputs[0]->addr, hccl_count_,
hccl_data_type_list_[0], op_type_, stream_, group_);
if (hccl_result != HCCL_SUCCESS) {
MS_LOG(ERROR) << "HcclAllReduce faled, ret:" << hccl_result;
return false;

View File

@ -31,8 +31,11 @@ bool HcomAllReduceScatterKernel::Launch(const std::vector<AddressPtr> &inputs, c
MS_EXCEPTION_IF_NULL(inputs[0]);
MS_EXCEPTION_IF_NULL(outputs[0]);
MS_EXCEPTION_IF_NULL(stream_ptr);
if (stream_ == nullptr) {
stream_ = stream_ptr;
}
auto hccl_result = hccl::HcclAdapter::GetInstance().HcclReduceScatter(
inputs[0]->addr, outputs[0]->addr, hccl_count_, hccl_data_type_list_[0], op_type_, stream_ptr, group_);
inputs[0]->addr, outputs[0]->addr, hccl_count_, hccl_data_type_list_[0], op_type_, stream_, group_);
if (hccl_result != HCCL_SUCCESS) {
MS_LOG(ERROR) << "HcclReduceScatter faled, ret:" << hccl_result;
return false;

View File

@ -30,8 +30,11 @@ bool HcomReceiveKernel::Launch(const std::vector<AddressPtr> &, const std::vecto
}
MS_EXCEPTION_IF_NULL(outputs[0]);
MS_EXCEPTION_IF_NULL(stream_ptr);
if (stream_ == nullptr) {
stream_ = stream_ptr;
}
auto hccl_result = hccl::HcclAdapter::GetInstance().HcclRecv(outputs[0]->addr, hccl_count_, hccl_data_type_list_[0],
src_rank_, stream_ptr, group_);
src_rank_, stream_, group_);
if (hccl_result != HCCL_SUCCESS) {
MS_LOG(ERROR) << "HcomReceive failed, ret:" << hccl_result;
return false;

View File

@ -30,8 +30,11 @@ bool HcomSendKernel::Launch(const std::vector<AddressPtr> &inputs, const std::ve
}
MS_EXCEPTION_IF_NULL(inputs[0]);
MS_EXCEPTION_IF_NULL(stream_ptr);
if (stream_ == nullptr) {
stream_ = stream_ptr;
}
auto hccl_result = hccl::HcclAdapter::GetInstance().HcclSend(inputs[0]->addr, hccl_count_, hccl_data_type_list_[0],
dest_rank_, stream_ptr, group_);
dest_rank_, stream_, group_);
if (hccl_result != HCCL_SUCCESS) {
MS_LOG(ERROR) << "HcomSend faled, ret:" << hccl_result;
return false;

View File

@ -39,7 +39,9 @@ bool TbeKernelMod::Launch(const std::vector<mindspore::kernel::AddressPtr> &inpu
MS_LOG(ERROR) << "kernel pack should not be nullptr.";
return false;
}
if (stream_ == nullptr) {
stream_ = stream_ptr;
}
uint32_t blockdim = 1; // default blockdim equal to 1.
auto func_stub = KernelManager::GenFuncStub(*kernel_pack_, false, &blockdim);
if (func_stub == 0) {
@ -60,7 +62,8 @@ bool TbeKernelMod::Launch(const std::vector<mindspore::kernel::AddressPtr> &inpu
rtL2Ctrl_t *l2ctrl = nullptr;
const void *stubFunc = reinterpret_cast<void *>(func_stub);
auto argsSize = static_cast<uint32_t>(UlongToUint(sizeof(void *)) * runtimeargs.size());
if (RT_ERROR_NONE != rtKernelLaunch(stubFunc, blockdim, runtimeargs.data(), argsSize, l2ctrl, stream_ptr)) {
auto ret = rtKernelLaunch(stubFunc, blockdim, runtimeargs.data(), argsSize, l2ctrl, stream_);
if (ret != RT_ERROR_NONE) {
MS_LOG(ERROR) << "Call runtime rtKernelLaunch error.";
return false;
}

View File

@ -2073,6 +2073,24 @@ std::vector<size_t> AnfRuntimeAlgorithm::GetOutputRealDeviceShapeIfExist(const A
return device_shape;
}
void AnfRuntimeAlgorithm::GetAllVisitedCNode(const CNodePtr &anf_node, std::vector<AnfNodePtr> *used_kernels) {
MS_EXCEPTION_IF_NULL(anf_node);
MS_EXCEPTION_IF_NULL(used_kernels);
auto input_size = anf_node->inputs().size() - 1;
for (size_t i = 0; i < input_size; ++i) {
auto input = AnfAlgo::GetInputNode(anf_node, i);
if (!input->isa<CNode>()) {
continue;
}
auto input_cnode = input->cast<CNodePtr>();
if (!IsRealKernelCNode(input_cnode) || opt::IsNopNode(input_cnode)) {
GetAllVisitedCNode(input_cnode, used_kernels);
} else {
used_kernels->push_back(input);
}
}
}
void AnfRuntimeAlgorithm::GetAllFatherRealNode(const AnfNodePtr &anf_node, std::vector<AnfNodePtr> *result,
std::set<AnfNodePtr> *visited) {
MS_EXCEPTION_IF_NULL(anf_node);

View File

@ -286,6 +286,7 @@ class AnfRuntimeAlgorithm {
// Find real input nodes.
static void GetAllFatherRealNode(const AnfNodePtr &anf_node, std::vector<AnfNodePtr> *result,
std::set<AnfNodePtr> *visited);
static void GetAllVisitedCNode(const CNodePtr &cnode, std::vector<AnfNodePtr> *used_kernels);
static void InsertMakeTupleForOutput(NotNull<KernelGraphPtr> root_graph);
static AnfNodeIndexSet GetUpdateStateUsers(const FuncGraphManagerPtr &manager, const AnfNodePtr &node);
// Get node real inputs, skip `MakeTuple`, `TupleGetItem`, `Depend`, `Load`, `UpdateState` etc.

View File

@ -1136,11 +1136,11 @@ void AscendSession::Execute(const std::shared_ptr<KernelGraph> &kernel_graph, bo
}
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
MS_EXCEPTION_IF_NULL(runtime_instance);
if (is_task) {
if (is_task && is_task_sink) {
DumpSetup(kernel_graph);
}
bool ret_ok = runtime_instance->Run(kernel_graph.get(), is_task_sink);
if (is_task) {
if (is_task && is_task_sink) {
Dump(kernel_graph);
}
if (!ret_ok) {

View File

@ -385,6 +385,7 @@ DeviceAddressPtr AscendKernelRuntime::CreateDeviceAddress(void *device_ptr, size
bool AscendKernelRuntime::Load(session::KernelGraph *graph, bool is_task_sink) {
if (!is_task_sink) {
MS_LOG(INFO) << "Graph mode with not task sink";
GenKernelEvents(graph);
return true;
}
@ -657,12 +658,138 @@ bool AscendKernelRuntime::Run(session::KernelGraph *const graph, bool is_task_si
MS_LOG(INFO) << "Call MS Run Success in " << cost << " us";
#endif
} else {
ret = LaunchKernel(graph);
ret = LaunchKernels(graph);
}
return ret;
}
bool AscendKernelRuntime::LaunchKernel(const AnfNodePtr &kernel) {
auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
MS_EXCEPTION_IF_NULL(kernel_mod);
AddressPtrList kernel_inputs = kernel_mod->GetInputsAddr();
AddressPtrList kernel_workspaces = kernel_mod->GetWorkSpacesAddr();
AddressPtrList kernel_outputs = kernel_mod->GetOutputsAddr();
bool ret;
if (pynative_mode_profiling_flag_) {
auto ascend_kernel_mod = dynamic_cast<kernel::AscendKernelMod *>(kernel_mod);
MS_EXCEPTION_IF_NULL(ascend_kernel_mod);
auto stream = ascend_kernel_mod->GetStream();
ret = LaunchKernelWithPynativeProfiling(kernel_mod, kernel->fullname_with_scope(), kernel_inputs, kernel_workspaces,
kernel_outputs, stream);
} else {
ret = kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_);
}
return ret;
}
void AscendKernelRuntime::SetKernelModStream(const std::vector<CNodePtr> &kernels) {
for (size_t i = 0; i < kernels.size(); ++i) {
auto &node = kernels[i];
auto kernel_mod = AnfAlgo::GetKernelMod(node);
auto ascend_kernel_mod = dynamic_cast<kernel::AscendKernelMod *>(kernel_mod);
MS_EXCEPTION_IF_NULL(ascend_kernel_mod);
if (AnfAlgo::IsCommunicationOp(node)) {
auto group = AnfAlgo::GetNodeAttr<std::string>(node, kAttrGroup);
auto iter = group_stream_id_map_.find(group);
if (iter == group_stream_id_map_.end()) {
void *stream = nullptr;
auto ret = rtStreamCreate(&stream, 0);
if (ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "create communication stream failed, ret:" << ret;
}
auto id = stream_id_map_.size();
group_stream_id_map_[group] = id;
stream_id_map_[id] = stream;
AnfAlgo::SetStreamId(id, node.get());
ascend_kernel_mod->SetStream(stream);
} else {
auto id = iter->second;
AnfAlgo::SetStreamId(id, node.get());
ascend_kernel_mod->SetStream(stream_id_map_[id]);
}
} else if (AnfAlgo::IsIndependentNode(node)) {
AnfAlgo::SetStreamId(1, node.get());
ascend_kernel_mod->SetStream(independent_stream_);
} else {
AnfAlgo::SetStreamId(0, node.get());
ascend_kernel_mod->SetStream(stream_);
}
}
for (size_t i = 1; i < kernels.size(); ++i) {
if (AnfAlgo::GetCNodeName(kernels[i - 1]) == kAtomicAddrCleanOpName) {
auto stream_id = AnfAlgo::GetStreamId(kernels[i]);
AnfAlgo::SetStreamId(stream_id, kernels[i - 1].get());
auto kernel_mod = AnfAlgo::GetKernelMod(kernels[i - 1]);
auto ascend_kernel_mod = dynamic_cast<kernel::AscendKernelMod *>(kernel_mod);
MS_EXCEPTION_IF_NULL(ascend_kernel_mod);
ascend_kernel_mod->SetStream(stream_id_map_[stream_id]);
}
}
}
void AscendKernelRuntime::GenKernelEvents(const session::KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(graph);
auto &kernels = graph->execution_order();
if (kernels.empty() || graph_kernel_events_map_.find(graph->graph_id()) != graph_kernel_events_map_.end()) {
return;
}
SetKernelModStream(kernels);
auto kernel_events =
std::pair<std::vector<std::vector<std::function<void()>>>, std::vector<std::vector<std::function<void()>>>>();
auto &kernel_pre_run_events = kernel_events.first;
auto &kernel_post_run_events = kernel_events.second;
kernel_pre_run_events.resize(kernels.size());
kernel_post_run_events.resize(kernels.size());
auto kernel_size = kernels.size() - 1;
for (auto &iter : stream_id_map_) {
auto stream = iter.second;
if (stream != stream_) {
auto pre_event = CreateDeviceEvent();
pre_event->set_wait_stream(stream);
pre_event->set_record_stream(stream_);
kernel_pre_run_events[0].emplace_back([pre_event]() { pre_event->RecordEvent(); });
kernel_pre_run_events[0].emplace_back([pre_event]() { pre_event->WaitEvent(); });
auto post_event = CreateDeviceEvent();
post_event->set_wait_stream(stream_);
post_event->set_record_stream(stream);
kernel_post_run_events[kernel_size].emplace_back([post_event]() { post_event->RecordEvent(); });
kernel_post_run_events[kernel_size].emplace_back([post_event]() { post_event->WaitEvent(); });
}
}
for (size_t i = 0; i < kernels.size(); ++i) {
auto &kernel = kernels[i];
auto curr_stream_id = AnfAlgo::GetStreamId(kernel);
if (stream_id_map_.find(curr_stream_id) == stream_id_map_.end()) {
MS_LOG(EXCEPTION) << "Stream " << curr_stream_id << "has not been created";
}
auto wait_stream = stream_id_map_[curr_stream_id];
auto stream_num = stream_id_map_.size();
std::vector<bool> stream_hit(stream_num, false);
std::vector<AnfNodePtr> visited_kernels;
AnfAlgo::GetAllVisitedCNode(kernel, &visited_kernels);
for (int k = SizeToInt(i) - 1; k >= 0; --k) {
auto pre_cnode = kernels[k];
auto pre_cnode_stream_id = AnfAlgo::GetStreamId(pre_cnode);
if (pre_cnode_stream_id == curr_stream_id) {
continue;
}
for (auto &visited : visited_kernels) {
if (visited == pre_cnode && !stream_hit[pre_cnode_stream_id]) {
stream_hit[pre_cnode_stream_id] = true;
auto record_stream = stream_id_map_[pre_cnode_stream_id];
auto event = CreateDeviceEvent();
event->set_wait_stream(wait_stream);
event->set_record_stream(record_stream);
kernel_post_run_events[k].emplace_back([event]() { event->RecordEvent(); });
kernel_pre_run_events[i].emplace_back([event]() { event->WaitEvent(); });
}
}
}
}
graph_kernel_events_map_[graph->graph_id()] = std::move(kernel_events);
}
bool AscendKernelRuntime::RunDynamicKernelAsync(const session::KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(graph);
MS_LOG(INFO) << "RunExecutorAsync start. GraphId:" << graph->graph_id();
@ -747,18 +874,11 @@ bool AscendKernelRuntime::RunTask(const session::KernelGraph *graph) {
bool AscendKernelRuntime::SyncStream() {
SetCurrentContext();
if (stream_ == nullptr) {
MS_LOG(ERROR) << "SyncStream failed. stream_ is nullptr";
return false;
}
if (RT_ERROR_NONE != rtStreamSynchronize(stream_)) { // o for switch stream
MS_LOG(ERROR) << "Call runtime rtStreamSynchronize error.";
return false;
}
if (RT_ERROR_NONE != rtStreamSynchronize(communication_stream_)) { // o for switch stream
MS_LOG(ERROR) << "Call runtime rtStreamSynchronize error.";
return false;
for (auto &iter : stream_id_map_) {
if (rtStreamSynchronize(iter.second) != RT_ERROR_NONE) { // o for switch stream
MS_LOG(ERROR) << "Call runtime rtStreamSynchronize error.";
return false;
}
}
return true;
}
@ -827,29 +947,30 @@ bool AscendKernelRuntime::InitDevice() {
if (ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "Call rtStreamCreate, ret[" << ret << "]";
}
ret = rtStreamCreateWithFlags(&independent_stream_, 0, RT_STREAM_HUGE);
if (ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "Call rtStreamCreate, ret[" << ret << "]";
}
ret = rtStreamCreate(&communication_stream_, 0);
if (ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "create communication stream failed, ret:" << ret;
}
stream_id_map_[0] = stream_;
stream_id_map_[1] = independent_stream_;
stream_id_map_[2] = communication_stream_;
group_stream_id_map_[kHcclWorldGroup] = 2;
return true;
}
bool AscendKernelRuntime::ResetDevice(uint32_t device_id) {
SetCurrentContext();
int32_t ret;
if (stream_ != nullptr) {
ret = rtStreamDestroy(stream_);
for (auto &iter : stream_id_map_) {
ret = rtStreamDestroy(iter.second);
if (ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "Call rtStreamDestroy, ret[" << ret << "]";
}
stream_ = nullptr;
}
if (communication_stream_ != nullptr) {
ret = rtStreamDestroy(communication_stream_);
if (ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "Call rtStreamDestroy, ret[" << ret << "]";
}
communication_stream_ = nullptr;
iter.second = nullptr;
}
ret = rtDeviceReset(device_id);
if (ret != RT_ERROR_NONE) {

View File

@ -43,6 +43,9 @@ class AscendKernelRuntime : public KernelRuntime {
uint32_t GetRankSize() override;
bool LoadData(session::KernelGraph *graph) override;
bool GenTask(const session::KernelGraph *graph);
void GenKernelEvents(const session::KernelGraph *graph) override;
void SetKernelModStream(const std::vector<CNodePtr> &kernels);
bool LaunchKernel(const AnfNodePtr &kernel) override;
bool GenDynamicKernel(const session::KernelGraph *graph) override;
bool RunDynamicKernelAsync(const session::KernelGraph *graph) override;
bool LoadTask(const session::KernelGraph *graph);
@ -104,6 +107,8 @@ class AscendKernelRuntime : public KernelRuntime {
std::map<std::pair<uint32_t, uint32_t>, std::string> stream_id_task_id_op_name_map_;
static std::map<std::string, uint32_t> overflow_tasks_;
static std::vector<rtExceptionInfo> task_fail_infoes_;
std::map<uint32_t, void *> stream_id_map_;
std::map<std::string, uint32_t> group_stream_id_map_;
};
MS_REG_KERNEL_RUNTIME(kAscendDevice, AscendKernelRuntime);

View File

@ -472,7 +472,7 @@ bool GPUKernelRuntime::Run(session::KernelGraph *graph, bool is_task_sink) {
// run dynamic shape graph in pynative
ret = RunOpLaunchKernelDynamic(graph);
} else {
ret = LaunchKernel(graph);
ret = LaunchKernels(graph);
}
}
(void)gettimeofday(&end_time, nullptr);

View File

@ -1008,6 +1008,32 @@ void KernelRuntime::DebugStreamSync(const CNodePtr &kernel) {
}
}
bool KernelRuntime::LaunchKernel(const AnfNodePtr &kernel) {
auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
MS_EXCEPTION_IF_NULL(kernel_mod);
AddressPtrList kernel_inputs;
AddressPtrList kernel_workspaces;
AddressPtrList kernel_outputs;
GenLaunchArgs(*kernel_mod, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs);
bool ret;
if (AnfAlgo::IsCommunicationOp(kernel)) {
if (pynative_mode_profiling_flag_) {
ret = LaunchKernelWithPynativeProfiling(kernel_mod, kernel->fullname_with_scope(), kernel_inputs,
kernel_workspaces, kernel_outputs, communication_stream_);
} else {
ret = kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, communication_stream_);
}
} else {
if (pynative_mode_profiling_flag_) {
ret = LaunchKernelWithPynativeProfiling(kernel_mod, kernel->fullname_with_scope(), kernel_inputs,
kernel_workspaces, kernel_outputs, stream_);
} else {
ret = kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_);
}
}
return ret;
}
bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph) {
const auto &kernels = graph.execution_order();
std::vector<DynamicKernelPtr> dynamic_kernel_list;
@ -1042,8 +1068,6 @@ bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph) {
} else {
auto &kernel = kernels[i];
MS_EXCEPTION_IF_NULL(kernel);
auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
MS_EXCEPTION_IF_NULL(kernel_mod);
// Skip transpose kernel with "nop_op" attr which is not hidden or removed in PyNative infer scenario. Transpose
// kernel, which is not supposed to be executed, is generated in TransDataSplit to support specific Transdata.
@ -1056,34 +1080,7 @@ bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph) {
}
continue;
}
AddressPtrList kernel_inputs;
AddressPtrList kernel_workspaces;
AddressPtrList kernel_outputs;
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kAscendDevice) {
GenLaunchArgs(*kernel_mod, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs);
} else {
kernel_inputs = kernel_mod->GetInputsAddr();
kernel_workspaces = kernel_mod->GetWorkSpacesAddr();
kernel_outputs = kernel_mod->GetOutputsAddr();
}
bool ret;
if (AnfAlgo::IsCommunicationOp(kernel)) {
if (pynative_mode_profiling_flag_) {
ret = LaunchKernelWithPynativeProfiling(kernel_mod, kernel->fullname_with_scope(), kernel_inputs,
kernel_workspaces, kernel_outputs, communication_stream_);
} else {
ret = kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, communication_stream_);
}
} else {
if (pynative_mode_profiling_flag_) {
ret = LaunchKernelWithPynativeProfiling(kernel_mod, kernel->fullname_with_scope(), kernel_inputs,
kernel_workspaces, kernel_outputs, stream_);
} else {
ret = kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_);
}
}
auto ret = LaunchKernel(kernel);
if (!ret) {
MS_LOG(ERROR) << "Launch kernel failed.";
return false;
@ -1096,7 +1093,7 @@ bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph) {
return true;
}
bool KernelRuntime::LaunchKernel(const session::KernelGraph *graph) {
bool KernelRuntime::LaunchKernels(const session::KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(graph);
if (!LaunchKernelMod(*graph)) {
MS_LOG(ERROR) << "LaunchKernelMod failed!";

View File

@ -64,7 +64,8 @@ class KernelRuntime {
virtual bool Run(session::KernelGraph *graph, bool is_task_sink) = 0;
virtual bool GenDynamicKernel(const session::KernelGraph *graph) = 0;
virtual bool RunDynamicKernelAsync(const session::KernelGraph *graph) = 0;
bool LaunchKernel(const session::KernelGraph *graph);
bool LaunchKernels(const session::KernelGraph *graph);
virtual bool LaunchKernel(const AnfNodePtr &kernel);
virtual void AssignStaticMemoryInput(const session::KernelGraph *graph);
virtual void AssignStaticMemoryValueNode(session::KernelGraph *graph);
virtual void ClearGraphRuntimeResource(uint32_t graph_id);
@ -98,7 +99,7 @@ class KernelRuntime {
virtual void PreInit() {}
virtual uint64_t GetAvailableMemMaxSize() const { return 0; }
void GenKernelEvents(const session::KernelGraph *graph);
virtual void GenKernelEvents(const session::KernelGraph *graph);
virtual std::shared_ptr<DeviceEvent> CreateDeviceEvent() { return nullptr; }
virtual std::shared_ptr<DeviceEvent> CreateDeviceTimeEvent() { return nullptr; }
virtual DeviceAddressType GetTargetDeviceAddressType() const = 0;
@ -123,6 +124,10 @@ class KernelRuntime {
void AssignCommunicationNodeOutputMem(MemType type, const AnfNodePtr &node);
void AssignCommunicationNodeInputMem(MemType type, const AnfNodePtr &node);
void AssignCommunicationNodeMem(MemType type, const AnfNodePtr &node);
bool LaunchKernelWithPynativeProfiling(kernel::KernelMod *kernel_mod, const std::string &op_name,
const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream);
virtual void KernelLaunchProfiling(const std::string &kernel_name) {}
@ -138,10 +143,6 @@ class KernelRuntime {
void RunOpAssignOutputNodeMemory(const ValuePtr &pre_output_value, session::KernelGraph *graph);
void AssignValueNodeTensor(const ValueNodePtr &value_node, const ValuePtr &node_value, size_t output_idx);
DeviceAddressPtr PreAssignCNodeMemory(const AnfNodePtr &anf_node, size_t index);
bool LaunchKernelWithPynativeProfiling(kernel::KernelMod *kernel_mod, const std::string &op_name,
const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream);
#if (ENABLE_CPU && !_WIN32)
void GetFirstPSEmbeddingCache(const session::KernelGraph *graph, AnfNodePtr *const first_cache_input_index,
size_t *const first_cache_size);
@ -156,6 +157,7 @@ class KernelRuntime {
std::shared_ptr<Debugger> debugger_;
#endif
void *stream_{nullptr};
void *independent_stream_{nullptr};
void *communication_stream_{nullptr};
std::shared_ptr<MemoryManager> mem_manager_{nullptr};
std::map<uint32_t, std::vector<DynamicKernelPtr>> graph_dynamic_kernel_map_;

View File

@ -28,8 +28,6 @@
#include "utils/ms_utils.h"
#include "utils/ms_context.h"
#include "runtime/hccl_adapter/converter.h"
#include "runtime/device/ascend/distribute/ascend_collective.h"
using HcclCollectiveGroup = mindspore::device::ascend::collective::HcclCollectiveGroup;
static constexpr const char *kHcclPluginFileName = "libhccl_plugin.so";
static constexpr const char *kHcclDeployModeEnv = "DEPLOY_MODE";
@ -88,6 +86,10 @@ void HcclAdapter::InitPlugin() {
single_op_hccl_get_rank_size_ = DlsymFuncObj(HcclGetRankSize, plugin_handle_);
launch_hccl_broadcast_ = DlsymFuncObj(HcclBroadcast, plugin_handle_);
launch_hccl_all_reduce_ = DlsymFuncObj(HcclAllReduce, plugin_handle_);
launch_hccl_reduce_scatter_ = DlsymFuncObj(HcclReduceScatter, plugin_handle_);
launch_hccl_all_gather_ = DlsymFuncObj(HcclAllGather, plugin_handle_);
launch_hccl_send_ = DlsymFuncObj(HcclSend, plugin_handle_);
launch_hccl_recv_ = DlsymFuncObj(HcclRecv, plugin_handle_);
hccl_create_group_ = DlsymFuncObj(HcomCreateGroup, plugin_handle_);
hccl_destroy_group_ = DlsymFuncObj(HcomDestroyGroup, plugin_handle_);
hccl_get_rank_id_ = DlsymFuncObj(HcomGetRankId, plugin_handle_);
@ -269,65 +271,40 @@ HcclResult HcclAdapter::HcclBroadcast(void *buf, uint64_t count, HcclDataType da
HcclResult HcclAdapter::HcclAllReduce(void *send_buf, void *recv_buf, uint64_t count, HcclDataType dataType,
HcclReduceOp op, aclrtStream stream, const std::string &group) const {
MS_EXCEPTION_IF_NULL(launch_hccl_all_reduce_);
HcclComm hccl_comm;
if (hccl_comm_ != nullptr) {
hccl_comm = hccl_comm_;
} else {
hccl_comm = HcclCollectiveGroup::instance().GetGroupComm(group);
MS_EXCEPTION_IF_NULL(hccl_comm);
}
auto hccl_comm = GetHcomm(group);
MS_EXCEPTION_IF_NULL(hccl_comm);
return launch_hccl_all_reduce_(send_buf, recv_buf, count, dataType, op, hccl_comm, stream);
}
HcclResult HcclAdapter::HcclReduceScatter(void *send_buf, void *recv_buf, uint64_t count, HcclDataType dataType,
HcclReduceOp op, aclrtStream stream, const std::string &group) const {
MS_EXCEPTION_IF_NULL(launch_hccl_reduce_scatter_);
HcclComm hccl_comm;
if (hccl_comm_ != nullptr) {
hccl_comm = hccl_comm_;
} else {
hccl_comm = HcclCollectiveGroup::instance().GetGroupComm(group);
MS_EXCEPTION_IF_NULL(hccl_comm);
}
auto hccl_comm = GetHcomm(group);
MS_EXCEPTION_IF_NULL(hccl_comm);
return launch_hccl_reduce_scatter_(send_buf, recv_buf, count, dataType, op, hccl_comm, stream);
}
HcclResult HcclAdapter::HcclAllGather(void *send_buf, void *recv_buf, uint64_t count, HcclDataType dataType,
aclrtStream stream, const std::string &group) const {
MS_EXCEPTION_IF_NULL(launch_hccl_all_gather_);
HcclComm hccl_comm;
if (hccl_comm_ != nullptr) {
hccl_comm = hccl_comm_;
} else {
hccl_comm = HcclCollectiveGroup::instance().GetGroupComm(group);
MS_EXCEPTION_IF_NULL(hccl_comm);
}
auto hccl_comm = GetHcomm(group);
MS_EXCEPTION_IF_NULL(hccl_comm);
return launch_hccl_all_gather_(send_buf, recv_buf, count, dataType, hccl_comm, stream);
}
HcclResult HcclAdapter::HcclSend(void *send_buf, uint64_t count, HcclDataType dataType, uint32_t destRank,
aclrtStream stream, const std::string &group) const {
MS_EXCEPTION_IF_NULL(launch_hccl_send_);
HcclComm hccl_comm;
if (hccl_comm_ != nullptr) {
hccl_comm = hccl_comm_;
} else {
hccl_comm = HcclCollectiveGroup::instance().GetGroupComm(group);
MS_EXCEPTION_IF_NULL(hccl_comm);
}
auto hccl_comm = GetHcomm(group);
MS_EXCEPTION_IF_NULL(hccl_comm);
return launch_hccl_send_(send_buf, count, dataType, destRank, hccl_comm, stream);
}
HcclResult HcclAdapter::HcclRecv(void *recv_buf, uint64_t count, HcclDataType dataType, uint32_t srcRank,
aclrtStream stream, const std::string &group) const {
MS_EXCEPTION_IF_NULL(launch_hccl_recv_);
HcclComm hccl_comm;
if (hccl_comm_ != nullptr) {
hccl_comm = hccl_comm_;
} else {
hccl_comm = HcclCollectiveGroup::instance().GetGroupComm(group);
MS_EXCEPTION_IF_NULL(hccl_comm);
}
auto hccl_comm = GetHcomm(group);
MS_EXCEPTION_IF_NULL(hccl_comm);
return launch_hccl_recv_(recv_buf, count, dataType, srcRank, hccl_comm, stream);
}

View File

@ -24,6 +24,8 @@
#include "mindspore/core/ir/anf.h"
#include "hccl/hccl_types.h"
#include "runtime/hccl_adapter/plugin/hccl_plugin.h"
#include "runtime/device/ascend/distribute/ascend_collective.h"
using HcclCollectiveGroup = mindspore::device::ascend::collective::HcclCollectiveGroup;
namespace ge {
class OpsKernelInfoStore;
@ -83,6 +85,14 @@ class HcclAdapter {
void InitPlugin();
void FinalizePlugin();
HcclComm GetHcomm(const std::string &group) const {
if (hccl_comm_ != nullptr) {
return hccl_comm_;
} else {
return HcclCollectiveGroup::instance().GetGroupComm(group);
}
}
bool InitKernelInfoStore(uint32_t device_id, std::string_view rank_id, std::string_view rank_file);
bool FinalizeKernelInfoStore();