!22110 opt ascend single op mode runtime code
Merge pull request !22110 from baihuawei/graph_mode_nonsink_part3-1
This commit is contained in:
commit
02db74ab2c
|
@ -139,7 +139,9 @@ bool AicpuOpKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::
|
|||
MS_LOG(ERROR) << "stream_ptr should not be nullptr.";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (stream_ == nullptr) {
|
||||
stream_ = stream_ptr;
|
||||
}
|
||||
CreateCpuKernelInfo(inputs, outputs);
|
||||
if (node_name_ == kTopK) {
|
||||
node_name_ = kTopKV2;
|
||||
|
@ -152,7 +154,7 @@ bool AicpuOpKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::
|
|||
if (rtCpuKernelLaunch(reinterpret_cast<const void *>(node_so_.c_str()),
|
||||
reinterpret_cast<const void *>(node_name_.c_str()), 1,
|
||||
reinterpret_cast<const void *>(args_.data()), static_cast<uint32_t>(args_.length()), nullptr,
|
||||
stream_ptr) != RT_ERROR_NONE) {
|
||||
stream_) != RT_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Aicpu op launch failed!";
|
||||
|
||||
return false;
|
||||
|
|
|
@ -60,7 +60,9 @@ bool AkgKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vect
|
|||
MS_LOG(ERROR) << "kernel pack should not be nullptr.";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (stream_ == nullptr) {
|
||||
stream_ = stream_ptr;
|
||||
}
|
||||
uint32_t block_dim = DEFAULT_BLOCK_DIM; // default blockdim equal to 1.
|
||||
auto func_stub = KernelManager::GenFuncStub(*kernel_pack_, false, &block_dim);
|
||||
if (func_stub == 0) {
|
||||
|
@ -80,7 +82,7 @@ bool AkgKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vect
|
|||
}
|
||||
|
||||
rtL2Ctrl_t *l2ctrl = nullptr;
|
||||
auto stream = static_cast<rtStream_t *>(stream_ptr);
|
||||
auto stream = static_cast<rtStream_t *>(stream_);
|
||||
if (RT_ERROR_NONE != rtKernelLaunch(reinterpret_cast<void *>(func_stub), block_dim, runtime_args.data(),
|
||||
SizeToUint(sizeof(void *) * runtime_args.size()), l2ctrl, stream)) {
|
||||
MS_LOG(ERROR) << "Call runtime rtKernelLaunch error.";
|
||||
|
|
|
@ -37,8 +37,11 @@ class AscendKernelMod : public KernelMod {
|
|||
return dump_json.NeedDump(fullname_) && dump_json.async_dump_enabled() && dump_json.op_debug_mode() == 0 &&
|
||||
!is_monad_;
|
||||
}
|
||||
void SetStream(void *stream) { stream_ = stream; }
|
||||
void *GetStream() { return stream_; }
|
||||
|
||||
protected:
|
||||
void *stream_{nullptr};
|
||||
uint32_t block_dim_{1};
|
||||
uint32_t stream_id_{0};
|
||||
};
|
||||
|
|
|
@ -31,8 +31,11 @@ bool HcomAllGatherKernel::Launch(const std::vector<AddressPtr> &inputs, const st
|
|||
MS_EXCEPTION_IF_NULL(inputs[0]);
|
||||
MS_EXCEPTION_IF_NULL(outputs[0]);
|
||||
MS_EXCEPTION_IF_NULL(stream_ptr);
|
||||
if (stream_ == nullptr) {
|
||||
stream_ = stream_ptr;
|
||||
}
|
||||
auto hccl_result = hccl::HcclAdapter::GetInstance().HcclAllGather(inputs[0]->addr, outputs[0]->addr, hccl_count_,
|
||||
hccl_data_type_list_[0], stream_ptr, group_);
|
||||
hccl_data_type_list_[0], stream_, group_);
|
||||
if (hccl_result != HCCL_SUCCESS) {
|
||||
MS_LOG(ERROR) << "HcclAllGather faled, ret:" << hccl_result;
|
||||
return false;
|
||||
|
|
|
@ -31,8 +31,11 @@ bool HcomAllReduceKernel::Launch(const std::vector<AddressPtr> &inputs, const st
|
|||
MS_EXCEPTION_IF_NULL(inputs[0]);
|
||||
MS_EXCEPTION_IF_NULL(outputs[0]);
|
||||
MS_EXCEPTION_IF_NULL(stream_ptr);
|
||||
auto hccl_result = hccl::HcclAdapter::GetInstance().HcclAllReduce(
|
||||
inputs[0]->addr, outputs[0]->addr, hccl_count_, hccl_data_type_list_[0], op_type_, stream_ptr, group_);
|
||||
if (stream_ == nullptr) {
|
||||
stream_ = stream_ptr;
|
||||
}
|
||||
auto hccl_result = hccl::HcclAdapter::GetInstance().HcclAllReduce(inputs[0]->addr, outputs[0]->addr, hccl_count_,
|
||||
hccl_data_type_list_[0], op_type_, stream_, group_);
|
||||
if (hccl_result != HCCL_SUCCESS) {
|
||||
MS_LOG(ERROR) << "HcclAllReduce faled, ret:" << hccl_result;
|
||||
return false;
|
||||
|
|
|
@ -31,8 +31,11 @@ bool HcomAllReduceScatterKernel::Launch(const std::vector<AddressPtr> &inputs, c
|
|||
MS_EXCEPTION_IF_NULL(inputs[0]);
|
||||
MS_EXCEPTION_IF_NULL(outputs[0]);
|
||||
MS_EXCEPTION_IF_NULL(stream_ptr);
|
||||
if (stream_ == nullptr) {
|
||||
stream_ = stream_ptr;
|
||||
}
|
||||
auto hccl_result = hccl::HcclAdapter::GetInstance().HcclReduceScatter(
|
||||
inputs[0]->addr, outputs[0]->addr, hccl_count_, hccl_data_type_list_[0], op_type_, stream_ptr, group_);
|
||||
inputs[0]->addr, outputs[0]->addr, hccl_count_, hccl_data_type_list_[0], op_type_, stream_, group_);
|
||||
if (hccl_result != HCCL_SUCCESS) {
|
||||
MS_LOG(ERROR) << "HcclReduceScatter faled, ret:" << hccl_result;
|
||||
return false;
|
||||
|
|
|
@ -30,8 +30,11 @@ bool HcomReceiveKernel::Launch(const std::vector<AddressPtr> &, const std::vecto
|
|||
}
|
||||
MS_EXCEPTION_IF_NULL(outputs[0]);
|
||||
MS_EXCEPTION_IF_NULL(stream_ptr);
|
||||
if (stream_ == nullptr) {
|
||||
stream_ = stream_ptr;
|
||||
}
|
||||
auto hccl_result = hccl::HcclAdapter::GetInstance().HcclRecv(outputs[0]->addr, hccl_count_, hccl_data_type_list_[0],
|
||||
src_rank_, stream_ptr, group_);
|
||||
src_rank_, stream_, group_);
|
||||
if (hccl_result != HCCL_SUCCESS) {
|
||||
MS_LOG(ERROR) << "HcomReceive failed, ret:" << hccl_result;
|
||||
return false;
|
||||
|
|
|
@ -30,8 +30,11 @@ bool HcomSendKernel::Launch(const std::vector<AddressPtr> &inputs, const std::ve
|
|||
}
|
||||
MS_EXCEPTION_IF_NULL(inputs[0]);
|
||||
MS_EXCEPTION_IF_NULL(stream_ptr);
|
||||
if (stream_ == nullptr) {
|
||||
stream_ = stream_ptr;
|
||||
}
|
||||
auto hccl_result = hccl::HcclAdapter::GetInstance().HcclSend(inputs[0]->addr, hccl_count_, hccl_data_type_list_[0],
|
||||
dest_rank_, stream_ptr, group_);
|
||||
dest_rank_, stream_, group_);
|
||||
if (hccl_result != HCCL_SUCCESS) {
|
||||
MS_LOG(ERROR) << "HcomSend faled, ret:" << hccl_result;
|
||||
return false;
|
||||
|
|
|
@ -39,7 +39,9 @@ bool TbeKernelMod::Launch(const std::vector<mindspore::kernel::AddressPtr> &inpu
|
|||
MS_LOG(ERROR) << "kernel pack should not be nullptr.";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (stream_ == nullptr) {
|
||||
stream_ = stream_ptr;
|
||||
}
|
||||
uint32_t blockdim = 1; // default blockdim equal to 1.
|
||||
auto func_stub = KernelManager::GenFuncStub(*kernel_pack_, false, &blockdim);
|
||||
if (func_stub == 0) {
|
||||
|
@ -60,7 +62,8 @@ bool TbeKernelMod::Launch(const std::vector<mindspore::kernel::AddressPtr> &inpu
|
|||
rtL2Ctrl_t *l2ctrl = nullptr;
|
||||
const void *stubFunc = reinterpret_cast<void *>(func_stub);
|
||||
auto argsSize = static_cast<uint32_t>(UlongToUint(sizeof(void *)) * runtimeargs.size());
|
||||
if (RT_ERROR_NONE != rtKernelLaunch(stubFunc, blockdim, runtimeargs.data(), argsSize, l2ctrl, stream_ptr)) {
|
||||
auto ret = rtKernelLaunch(stubFunc, blockdim, runtimeargs.data(), argsSize, l2ctrl, stream_);
|
||||
if (ret != RT_ERROR_NONE) {
|
||||
MS_LOG(ERROR) << "Call runtime rtKernelLaunch error.";
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -2073,6 +2073,24 @@ std::vector<size_t> AnfRuntimeAlgorithm::GetOutputRealDeviceShapeIfExist(const A
|
|||
return device_shape;
|
||||
}
|
||||
|
||||
void AnfRuntimeAlgorithm::GetAllVisitedCNode(const CNodePtr &anf_node, std::vector<AnfNodePtr> *used_kernels) {
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
MS_EXCEPTION_IF_NULL(used_kernels);
|
||||
auto input_size = anf_node->inputs().size() - 1;
|
||||
for (size_t i = 0; i < input_size; ++i) {
|
||||
auto input = AnfAlgo::GetInputNode(anf_node, i);
|
||||
if (!input->isa<CNode>()) {
|
||||
continue;
|
||||
}
|
||||
auto input_cnode = input->cast<CNodePtr>();
|
||||
if (!IsRealKernelCNode(input_cnode) || opt::IsNopNode(input_cnode)) {
|
||||
GetAllVisitedCNode(input_cnode, used_kernels);
|
||||
} else {
|
||||
used_kernels->push_back(input);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void AnfRuntimeAlgorithm::GetAllFatherRealNode(const AnfNodePtr &anf_node, std::vector<AnfNodePtr> *result,
|
||||
std::set<AnfNodePtr> *visited) {
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
|
|
|
@ -286,6 +286,7 @@ class AnfRuntimeAlgorithm {
|
|||
// Find real input nodes.
|
||||
static void GetAllFatherRealNode(const AnfNodePtr &anf_node, std::vector<AnfNodePtr> *result,
|
||||
std::set<AnfNodePtr> *visited);
|
||||
static void GetAllVisitedCNode(const CNodePtr &cnode, std::vector<AnfNodePtr> *used_kernels);
|
||||
static void InsertMakeTupleForOutput(NotNull<KernelGraphPtr> root_graph);
|
||||
static AnfNodeIndexSet GetUpdateStateUsers(const FuncGraphManagerPtr &manager, const AnfNodePtr &node);
|
||||
// Get node real inputs, skip `MakeTuple`, `TupleGetItem`, `Depend`, `Load`, `UpdateState` etc.
|
||||
|
|
|
@ -1136,11 +1136,11 @@ void AscendSession::Execute(const std::shared_ptr<KernelGraph> &kernel_graph, bo
|
|||
}
|
||||
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
|
||||
MS_EXCEPTION_IF_NULL(runtime_instance);
|
||||
if (is_task) {
|
||||
if (is_task && is_task_sink) {
|
||||
DumpSetup(kernel_graph);
|
||||
}
|
||||
bool ret_ok = runtime_instance->Run(kernel_graph.get(), is_task_sink);
|
||||
if (is_task) {
|
||||
if (is_task && is_task_sink) {
|
||||
Dump(kernel_graph);
|
||||
}
|
||||
if (!ret_ok) {
|
||||
|
|
|
@ -385,6 +385,7 @@ DeviceAddressPtr AscendKernelRuntime::CreateDeviceAddress(void *device_ptr, size
|
|||
|
||||
bool AscendKernelRuntime::Load(session::KernelGraph *graph, bool is_task_sink) {
|
||||
if (!is_task_sink) {
|
||||
MS_LOG(INFO) << "Graph mode with not task sink";
|
||||
GenKernelEvents(graph);
|
||||
return true;
|
||||
}
|
||||
|
@ -657,12 +658,138 @@ bool AscendKernelRuntime::Run(session::KernelGraph *const graph, bool is_task_si
|
|||
MS_LOG(INFO) << "Call MS Run Success in " << cost << " us";
|
||||
#endif
|
||||
} else {
|
||||
ret = LaunchKernel(graph);
|
||||
ret = LaunchKernels(graph);
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
bool AscendKernelRuntime::LaunchKernel(const AnfNodePtr &kernel) {
|
||||
auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
|
||||
MS_EXCEPTION_IF_NULL(kernel_mod);
|
||||
AddressPtrList kernel_inputs = kernel_mod->GetInputsAddr();
|
||||
AddressPtrList kernel_workspaces = kernel_mod->GetWorkSpacesAddr();
|
||||
AddressPtrList kernel_outputs = kernel_mod->GetOutputsAddr();
|
||||
bool ret;
|
||||
if (pynative_mode_profiling_flag_) {
|
||||
auto ascend_kernel_mod = dynamic_cast<kernel::AscendKernelMod *>(kernel_mod);
|
||||
MS_EXCEPTION_IF_NULL(ascend_kernel_mod);
|
||||
auto stream = ascend_kernel_mod->GetStream();
|
||||
ret = LaunchKernelWithPynativeProfiling(kernel_mod, kernel->fullname_with_scope(), kernel_inputs, kernel_workspaces,
|
||||
kernel_outputs, stream);
|
||||
} else {
|
||||
ret = kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
void AscendKernelRuntime::SetKernelModStream(const std::vector<CNodePtr> &kernels) {
|
||||
for (size_t i = 0; i < kernels.size(); ++i) {
|
||||
auto &node = kernels[i];
|
||||
auto kernel_mod = AnfAlgo::GetKernelMod(node);
|
||||
auto ascend_kernel_mod = dynamic_cast<kernel::AscendKernelMod *>(kernel_mod);
|
||||
MS_EXCEPTION_IF_NULL(ascend_kernel_mod);
|
||||
if (AnfAlgo::IsCommunicationOp(node)) {
|
||||
auto group = AnfAlgo::GetNodeAttr<std::string>(node, kAttrGroup);
|
||||
auto iter = group_stream_id_map_.find(group);
|
||||
if (iter == group_stream_id_map_.end()) {
|
||||
void *stream = nullptr;
|
||||
auto ret = rtStreamCreate(&stream, 0);
|
||||
if (ret != RT_ERROR_NONE) {
|
||||
MS_LOG(EXCEPTION) << "create communication stream failed, ret:" << ret;
|
||||
}
|
||||
auto id = stream_id_map_.size();
|
||||
group_stream_id_map_[group] = id;
|
||||
stream_id_map_[id] = stream;
|
||||
AnfAlgo::SetStreamId(id, node.get());
|
||||
ascend_kernel_mod->SetStream(stream);
|
||||
} else {
|
||||
auto id = iter->second;
|
||||
AnfAlgo::SetStreamId(id, node.get());
|
||||
ascend_kernel_mod->SetStream(stream_id_map_[id]);
|
||||
}
|
||||
} else if (AnfAlgo::IsIndependentNode(node)) {
|
||||
AnfAlgo::SetStreamId(1, node.get());
|
||||
ascend_kernel_mod->SetStream(independent_stream_);
|
||||
} else {
|
||||
AnfAlgo::SetStreamId(0, node.get());
|
||||
ascend_kernel_mod->SetStream(stream_);
|
||||
}
|
||||
}
|
||||
for (size_t i = 1; i < kernels.size(); ++i) {
|
||||
if (AnfAlgo::GetCNodeName(kernels[i - 1]) == kAtomicAddrCleanOpName) {
|
||||
auto stream_id = AnfAlgo::GetStreamId(kernels[i]);
|
||||
AnfAlgo::SetStreamId(stream_id, kernels[i - 1].get());
|
||||
auto kernel_mod = AnfAlgo::GetKernelMod(kernels[i - 1]);
|
||||
auto ascend_kernel_mod = dynamic_cast<kernel::AscendKernelMod *>(kernel_mod);
|
||||
MS_EXCEPTION_IF_NULL(ascend_kernel_mod);
|
||||
ascend_kernel_mod->SetStream(stream_id_map_[stream_id]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void AscendKernelRuntime::GenKernelEvents(const session::KernelGraph *graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto &kernels = graph->execution_order();
|
||||
if (kernels.empty() || graph_kernel_events_map_.find(graph->graph_id()) != graph_kernel_events_map_.end()) {
|
||||
return;
|
||||
}
|
||||
SetKernelModStream(kernels);
|
||||
auto kernel_events =
|
||||
std::pair<std::vector<std::vector<std::function<void()>>>, std::vector<std::vector<std::function<void()>>>>();
|
||||
auto &kernel_pre_run_events = kernel_events.first;
|
||||
auto &kernel_post_run_events = kernel_events.second;
|
||||
kernel_pre_run_events.resize(kernels.size());
|
||||
kernel_post_run_events.resize(kernels.size());
|
||||
auto kernel_size = kernels.size() - 1;
|
||||
for (auto &iter : stream_id_map_) {
|
||||
auto stream = iter.second;
|
||||
if (stream != stream_) {
|
||||
auto pre_event = CreateDeviceEvent();
|
||||
pre_event->set_wait_stream(stream);
|
||||
pre_event->set_record_stream(stream_);
|
||||
kernel_pre_run_events[0].emplace_back([pre_event]() { pre_event->RecordEvent(); });
|
||||
kernel_pre_run_events[0].emplace_back([pre_event]() { pre_event->WaitEvent(); });
|
||||
auto post_event = CreateDeviceEvent();
|
||||
post_event->set_wait_stream(stream_);
|
||||
post_event->set_record_stream(stream);
|
||||
kernel_post_run_events[kernel_size].emplace_back([post_event]() { post_event->RecordEvent(); });
|
||||
kernel_post_run_events[kernel_size].emplace_back([post_event]() { post_event->WaitEvent(); });
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < kernels.size(); ++i) {
|
||||
auto &kernel = kernels[i];
|
||||
auto curr_stream_id = AnfAlgo::GetStreamId(kernel);
|
||||
if (stream_id_map_.find(curr_stream_id) == stream_id_map_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Stream " << curr_stream_id << "has not been created";
|
||||
}
|
||||
auto wait_stream = stream_id_map_[curr_stream_id];
|
||||
auto stream_num = stream_id_map_.size();
|
||||
std::vector<bool> stream_hit(stream_num, false);
|
||||
std::vector<AnfNodePtr> visited_kernels;
|
||||
AnfAlgo::GetAllVisitedCNode(kernel, &visited_kernels);
|
||||
for (int k = SizeToInt(i) - 1; k >= 0; --k) {
|
||||
auto pre_cnode = kernels[k];
|
||||
auto pre_cnode_stream_id = AnfAlgo::GetStreamId(pre_cnode);
|
||||
if (pre_cnode_stream_id == curr_stream_id) {
|
||||
continue;
|
||||
}
|
||||
for (auto &visited : visited_kernels) {
|
||||
if (visited == pre_cnode && !stream_hit[pre_cnode_stream_id]) {
|
||||
stream_hit[pre_cnode_stream_id] = true;
|
||||
auto record_stream = stream_id_map_[pre_cnode_stream_id];
|
||||
auto event = CreateDeviceEvent();
|
||||
event->set_wait_stream(wait_stream);
|
||||
event->set_record_stream(record_stream);
|
||||
kernel_post_run_events[k].emplace_back([event]() { event->RecordEvent(); });
|
||||
kernel_pre_run_events[i].emplace_back([event]() { event->WaitEvent(); });
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
graph_kernel_events_map_[graph->graph_id()] = std::move(kernel_events);
|
||||
}
|
||||
|
||||
bool AscendKernelRuntime::RunDynamicKernelAsync(const session::KernelGraph *graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_LOG(INFO) << "RunExecutorAsync start. GraphId:" << graph->graph_id();
|
||||
|
@ -747,18 +874,11 @@ bool AscendKernelRuntime::RunTask(const session::KernelGraph *graph) {
|
|||
|
||||
bool AscendKernelRuntime::SyncStream() {
|
||||
SetCurrentContext();
|
||||
if (stream_ == nullptr) {
|
||||
MS_LOG(ERROR) << "SyncStream failed. stream_ is nullptr";
|
||||
return false;
|
||||
}
|
||||
if (RT_ERROR_NONE != rtStreamSynchronize(stream_)) { // o for switch stream
|
||||
MS_LOG(ERROR) << "Call runtime rtStreamSynchronize error.";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (RT_ERROR_NONE != rtStreamSynchronize(communication_stream_)) { // o for switch stream
|
||||
MS_LOG(ERROR) << "Call runtime rtStreamSynchronize error.";
|
||||
return false;
|
||||
for (auto &iter : stream_id_map_) {
|
||||
if (rtStreamSynchronize(iter.second) != RT_ERROR_NONE) { // o for switch stream
|
||||
MS_LOG(ERROR) << "Call runtime rtStreamSynchronize error.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
@ -827,29 +947,30 @@ bool AscendKernelRuntime::InitDevice() {
|
|||
if (ret != RT_ERROR_NONE) {
|
||||
MS_LOG(EXCEPTION) << "Call rtStreamCreate, ret[" << ret << "]";
|
||||
}
|
||||
ret = rtStreamCreateWithFlags(&independent_stream_, 0, RT_STREAM_HUGE);
|
||||
if (ret != RT_ERROR_NONE) {
|
||||
MS_LOG(EXCEPTION) << "Call rtStreamCreate, ret[" << ret << "]";
|
||||
}
|
||||
ret = rtStreamCreate(&communication_stream_, 0);
|
||||
if (ret != RT_ERROR_NONE) {
|
||||
MS_LOG(EXCEPTION) << "create communication stream failed, ret:" << ret;
|
||||
}
|
||||
stream_id_map_[0] = stream_;
|
||||
stream_id_map_[1] = independent_stream_;
|
||||
stream_id_map_[2] = communication_stream_;
|
||||
group_stream_id_map_[kHcclWorldGroup] = 2;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool AscendKernelRuntime::ResetDevice(uint32_t device_id) {
|
||||
SetCurrentContext();
|
||||
int32_t ret;
|
||||
if (stream_ != nullptr) {
|
||||
ret = rtStreamDestroy(stream_);
|
||||
for (auto &iter : stream_id_map_) {
|
||||
ret = rtStreamDestroy(iter.second);
|
||||
if (ret != RT_ERROR_NONE) {
|
||||
MS_LOG(EXCEPTION) << "Call rtStreamDestroy, ret[" << ret << "]";
|
||||
}
|
||||
stream_ = nullptr;
|
||||
}
|
||||
if (communication_stream_ != nullptr) {
|
||||
ret = rtStreamDestroy(communication_stream_);
|
||||
if (ret != RT_ERROR_NONE) {
|
||||
MS_LOG(EXCEPTION) << "Call rtStreamDestroy, ret[" << ret << "]";
|
||||
}
|
||||
communication_stream_ = nullptr;
|
||||
iter.second = nullptr;
|
||||
}
|
||||
ret = rtDeviceReset(device_id);
|
||||
if (ret != RT_ERROR_NONE) {
|
||||
|
|
|
@ -43,6 +43,9 @@ class AscendKernelRuntime : public KernelRuntime {
|
|||
uint32_t GetRankSize() override;
|
||||
bool LoadData(session::KernelGraph *graph) override;
|
||||
bool GenTask(const session::KernelGraph *graph);
|
||||
void GenKernelEvents(const session::KernelGraph *graph) override;
|
||||
void SetKernelModStream(const std::vector<CNodePtr> &kernels);
|
||||
bool LaunchKernel(const AnfNodePtr &kernel) override;
|
||||
bool GenDynamicKernel(const session::KernelGraph *graph) override;
|
||||
bool RunDynamicKernelAsync(const session::KernelGraph *graph) override;
|
||||
bool LoadTask(const session::KernelGraph *graph);
|
||||
|
@ -104,6 +107,8 @@ class AscendKernelRuntime : public KernelRuntime {
|
|||
std::map<std::pair<uint32_t, uint32_t>, std::string> stream_id_task_id_op_name_map_;
|
||||
static std::map<std::string, uint32_t> overflow_tasks_;
|
||||
static std::vector<rtExceptionInfo> task_fail_infoes_;
|
||||
std::map<uint32_t, void *> stream_id_map_;
|
||||
std::map<std::string, uint32_t> group_stream_id_map_;
|
||||
};
|
||||
|
||||
MS_REG_KERNEL_RUNTIME(kAscendDevice, AscendKernelRuntime);
|
||||
|
|
|
@ -472,7 +472,7 @@ bool GPUKernelRuntime::Run(session::KernelGraph *graph, bool is_task_sink) {
|
|||
// run dynamic shape graph in pynative
|
||||
ret = RunOpLaunchKernelDynamic(graph);
|
||||
} else {
|
||||
ret = LaunchKernel(graph);
|
||||
ret = LaunchKernels(graph);
|
||||
}
|
||||
}
|
||||
(void)gettimeofday(&end_time, nullptr);
|
||||
|
|
|
@ -1008,6 +1008,32 @@ void KernelRuntime::DebugStreamSync(const CNodePtr &kernel) {
|
|||
}
|
||||
}
|
||||
|
||||
bool KernelRuntime::LaunchKernel(const AnfNodePtr &kernel) {
|
||||
auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
|
||||
MS_EXCEPTION_IF_NULL(kernel_mod);
|
||||
AddressPtrList kernel_inputs;
|
||||
AddressPtrList kernel_workspaces;
|
||||
AddressPtrList kernel_outputs;
|
||||
GenLaunchArgs(*kernel_mod, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs);
|
||||
bool ret;
|
||||
if (AnfAlgo::IsCommunicationOp(kernel)) {
|
||||
if (pynative_mode_profiling_flag_) {
|
||||
ret = LaunchKernelWithPynativeProfiling(kernel_mod, kernel->fullname_with_scope(), kernel_inputs,
|
||||
kernel_workspaces, kernel_outputs, communication_stream_);
|
||||
} else {
|
||||
ret = kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, communication_stream_);
|
||||
}
|
||||
} else {
|
||||
if (pynative_mode_profiling_flag_) {
|
||||
ret = LaunchKernelWithPynativeProfiling(kernel_mod, kernel->fullname_with_scope(), kernel_inputs,
|
||||
kernel_workspaces, kernel_outputs, stream_);
|
||||
} else {
|
||||
ret = kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_);
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph) {
|
||||
const auto &kernels = graph.execution_order();
|
||||
std::vector<DynamicKernelPtr> dynamic_kernel_list;
|
||||
|
@ -1042,8 +1068,6 @@ bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph) {
|
|||
} else {
|
||||
auto &kernel = kernels[i];
|
||||
MS_EXCEPTION_IF_NULL(kernel);
|
||||
auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
|
||||
MS_EXCEPTION_IF_NULL(kernel_mod);
|
||||
|
||||
// Skip transpose kernel with "nop_op" attr which is not hidden or removed in PyNative infer scenario. Transpose
|
||||
// kernel, which is not supposed to be executed, is generated in TransDataSplit to support specific Transdata.
|
||||
|
@ -1056,34 +1080,7 @@ bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph) {
|
|||
}
|
||||
continue;
|
||||
}
|
||||
AddressPtrList kernel_inputs;
|
||||
AddressPtrList kernel_workspaces;
|
||||
AddressPtrList kernel_outputs;
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
if (ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kAscendDevice) {
|
||||
GenLaunchArgs(*kernel_mod, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs);
|
||||
} else {
|
||||
kernel_inputs = kernel_mod->GetInputsAddr();
|
||||
kernel_workspaces = kernel_mod->GetWorkSpacesAddr();
|
||||
kernel_outputs = kernel_mod->GetOutputsAddr();
|
||||
}
|
||||
bool ret;
|
||||
if (AnfAlgo::IsCommunicationOp(kernel)) {
|
||||
if (pynative_mode_profiling_flag_) {
|
||||
ret = LaunchKernelWithPynativeProfiling(kernel_mod, kernel->fullname_with_scope(), kernel_inputs,
|
||||
kernel_workspaces, kernel_outputs, communication_stream_);
|
||||
} else {
|
||||
ret = kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, communication_stream_);
|
||||
}
|
||||
} else {
|
||||
if (pynative_mode_profiling_flag_) {
|
||||
ret = LaunchKernelWithPynativeProfiling(kernel_mod, kernel->fullname_with_scope(), kernel_inputs,
|
||||
kernel_workspaces, kernel_outputs, stream_);
|
||||
} else {
|
||||
ret = kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_);
|
||||
}
|
||||
}
|
||||
auto ret = LaunchKernel(kernel);
|
||||
if (!ret) {
|
||||
MS_LOG(ERROR) << "Launch kernel failed.";
|
||||
return false;
|
||||
|
@ -1096,7 +1093,7 @@ bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph) {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool KernelRuntime::LaunchKernel(const session::KernelGraph *graph) {
|
||||
bool KernelRuntime::LaunchKernels(const session::KernelGraph *graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
if (!LaunchKernelMod(*graph)) {
|
||||
MS_LOG(ERROR) << "LaunchKernelMod failed!";
|
||||
|
|
|
@ -64,7 +64,8 @@ class KernelRuntime {
|
|||
virtual bool Run(session::KernelGraph *graph, bool is_task_sink) = 0;
|
||||
virtual bool GenDynamicKernel(const session::KernelGraph *graph) = 0;
|
||||
virtual bool RunDynamicKernelAsync(const session::KernelGraph *graph) = 0;
|
||||
bool LaunchKernel(const session::KernelGraph *graph);
|
||||
bool LaunchKernels(const session::KernelGraph *graph);
|
||||
virtual bool LaunchKernel(const AnfNodePtr &kernel);
|
||||
virtual void AssignStaticMemoryInput(const session::KernelGraph *graph);
|
||||
virtual void AssignStaticMemoryValueNode(session::KernelGraph *graph);
|
||||
virtual void ClearGraphRuntimeResource(uint32_t graph_id);
|
||||
|
@ -98,7 +99,7 @@ class KernelRuntime {
|
|||
|
||||
virtual void PreInit() {}
|
||||
virtual uint64_t GetAvailableMemMaxSize() const { return 0; }
|
||||
void GenKernelEvents(const session::KernelGraph *graph);
|
||||
virtual void GenKernelEvents(const session::KernelGraph *graph);
|
||||
virtual std::shared_ptr<DeviceEvent> CreateDeviceEvent() { return nullptr; }
|
||||
virtual std::shared_ptr<DeviceEvent> CreateDeviceTimeEvent() { return nullptr; }
|
||||
virtual DeviceAddressType GetTargetDeviceAddressType() const = 0;
|
||||
|
@ -123,6 +124,10 @@ class KernelRuntime {
|
|||
void AssignCommunicationNodeOutputMem(MemType type, const AnfNodePtr &node);
|
||||
void AssignCommunicationNodeInputMem(MemType type, const AnfNodePtr &node);
|
||||
void AssignCommunicationNodeMem(MemType type, const AnfNodePtr &node);
|
||||
bool LaunchKernelWithPynativeProfiling(kernel::KernelMod *kernel_mod, const std::string &op_name,
|
||||
const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream);
|
||||
|
||||
virtual void KernelLaunchProfiling(const std::string &kernel_name) {}
|
||||
|
||||
|
@ -138,10 +143,6 @@ class KernelRuntime {
|
|||
void RunOpAssignOutputNodeMemory(const ValuePtr &pre_output_value, session::KernelGraph *graph);
|
||||
void AssignValueNodeTensor(const ValueNodePtr &value_node, const ValuePtr &node_value, size_t output_idx);
|
||||
DeviceAddressPtr PreAssignCNodeMemory(const AnfNodePtr &anf_node, size_t index);
|
||||
bool LaunchKernelWithPynativeProfiling(kernel::KernelMod *kernel_mod, const std::string &op_name,
|
||||
const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream);
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
void GetFirstPSEmbeddingCache(const session::KernelGraph *graph, AnfNodePtr *const first_cache_input_index,
|
||||
size_t *const first_cache_size);
|
||||
|
@ -156,6 +157,7 @@ class KernelRuntime {
|
|||
std::shared_ptr<Debugger> debugger_;
|
||||
#endif
|
||||
void *stream_{nullptr};
|
||||
void *independent_stream_{nullptr};
|
||||
void *communication_stream_{nullptr};
|
||||
std::shared_ptr<MemoryManager> mem_manager_{nullptr};
|
||||
std::map<uint32_t, std::vector<DynamicKernelPtr>> graph_dynamic_kernel_map_;
|
||||
|
|
|
@ -28,8 +28,6 @@
|
|||
#include "utils/ms_utils.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "runtime/hccl_adapter/converter.h"
|
||||
#include "runtime/device/ascend/distribute/ascend_collective.h"
|
||||
using HcclCollectiveGroup = mindspore::device::ascend::collective::HcclCollectiveGroup;
|
||||
|
||||
static constexpr const char *kHcclPluginFileName = "libhccl_plugin.so";
|
||||
static constexpr const char *kHcclDeployModeEnv = "DEPLOY_MODE";
|
||||
|
@ -88,6 +86,10 @@ void HcclAdapter::InitPlugin() {
|
|||
single_op_hccl_get_rank_size_ = DlsymFuncObj(HcclGetRankSize, plugin_handle_);
|
||||
launch_hccl_broadcast_ = DlsymFuncObj(HcclBroadcast, plugin_handle_);
|
||||
launch_hccl_all_reduce_ = DlsymFuncObj(HcclAllReduce, plugin_handle_);
|
||||
launch_hccl_reduce_scatter_ = DlsymFuncObj(HcclReduceScatter, plugin_handle_);
|
||||
launch_hccl_all_gather_ = DlsymFuncObj(HcclAllGather, plugin_handle_);
|
||||
launch_hccl_send_ = DlsymFuncObj(HcclSend, plugin_handle_);
|
||||
launch_hccl_recv_ = DlsymFuncObj(HcclRecv, plugin_handle_);
|
||||
hccl_create_group_ = DlsymFuncObj(HcomCreateGroup, plugin_handle_);
|
||||
hccl_destroy_group_ = DlsymFuncObj(HcomDestroyGroup, plugin_handle_);
|
||||
hccl_get_rank_id_ = DlsymFuncObj(HcomGetRankId, plugin_handle_);
|
||||
|
@ -269,65 +271,40 @@ HcclResult HcclAdapter::HcclBroadcast(void *buf, uint64_t count, HcclDataType da
|
|||
HcclResult HcclAdapter::HcclAllReduce(void *send_buf, void *recv_buf, uint64_t count, HcclDataType dataType,
|
||||
HcclReduceOp op, aclrtStream stream, const std::string &group) const {
|
||||
MS_EXCEPTION_IF_NULL(launch_hccl_all_reduce_);
|
||||
HcclComm hccl_comm;
|
||||
if (hccl_comm_ != nullptr) {
|
||||
hccl_comm = hccl_comm_;
|
||||
} else {
|
||||
hccl_comm = HcclCollectiveGroup::instance().GetGroupComm(group);
|
||||
MS_EXCEPTION_IF_NULL(hccl_comm);
|
||||
}
|
||||
auto hccl_comm = GetHcomm(group);
|
||||
MS_EXCEPTION_IF_NULL(hccl_comm);
|
||||
return launch_hccl_all_reduce_(send_buf, recv_buf, count, dataType, op, hccl_comm, stream);
|
||||
}
|
||||
|
||||
HcclResult HcclAdapter::HcclReduceScatter(void *send_buf, void *recv_buf, uint64_t count, HcclDataType dataType,
|
||||
HcclReduceOp op, aclrtStream stream, const std::string &group) const {
|
||||
MS_EXCEPTION_IF_NULL(launch_hccl_reduce_scatter_);
|
||||
HcclComm hccl_comm;
|
||||
if (hccl_comm_ != nullptr) {
|
||||
hccl_comm = hccl_comm_;
|
||||
} else {
|
||||
hccl_comm = HcclCollectiveGroup::instance().GetGroupComm(group);
|
||||
MS_EXCEPTION_IF_NULL(hccl_comm);
|
||||
}
|
||||
auto hccl_comm = GetHcomm(group);
|
||||
MS_EXCEPTION_IF_NULL(hccl_comm);
|
||||
return launch_hccl_reduce_scatter_(send_buf, recv_buf, count, dataType, op, hccl_comm, stream);
|
||||
}
|
||||
|
||||
HcclResult HcclAdapter::HcclAllGather(void *send_buf, void *recv_buf, uint64_t count, HcclDataType dataType,
|
||||
aclrtStream stream, const std::string &group) const {
|
||||
MS_EXCEPTION_IF_NULL(launch_hccl_all_gather_);
|
||||
HcclComm hccl_comm;
|
||||
if (hccl_comm_ != nullptr) {
|
||||
hccl_comm = hccl_comm_;
|
||||
} else {
|
||||
hccl_comm = HcclCollectiveGroup::instance().GetGroupComm(group);
|
||||
MS_EXCEPTION_IF_NULL(hccl_comm);
|
||||
}
|
||||
auto hccl_comm = GetHcomm(group);
|
||||
MS_EXCEPTION_IF_NULL(hccl_comm);
|
||||
return launch_hccl_all_gather_(send_buf, recv_buf, count, dataType, hccl_comm, stream);
|
||||
}
|
||||
|
||||
HcclResult HcclAdapter::HcclSend(void *send_buf, uint64_t count, HcclDataType dataType, uint32_t destRank,
|
||||
aclrtStream stream, const std::string &group) const {
|
||||
MS_EXCEPTION_IF_NULL(launch_hccl_send_);
|
||||
HcclComm hccl_comm;
|
||||
if (hccl_comm_ != nullptr) {
|
||||
hccl_comm = hccl_comm_;
|
||||
} else {
|
||||
hccl_comm = HcclCollectiveGroup::instance().GetGroupComm(group);
|
||||
MS_EXCEPTION_IF_NULL(hccl_comm);
|
||||
}
|
||||
auto hccl_comm = GetHcomm(group);
|
||||
MS_EXCEPTION_IF_NULL(hccl_comm);
|
||||
return launch_hccl_send_(send_buf, count, dataType, destRank, hccl_comm, stream);
|
||||
}
|
||||
|
||||
HcclResult HcclAdapter::HcclRecv(void *recv_buf, uint64_t count, HcclDataType dataType, uint32_t srcRank,
|
||||
aclrtStream stream, const std::string &group) const {
|
||||
MS_EXCEPTION_IF_NULL(launch_hccl_recv_);
|
||||
HcclComm hccl_comm;
|
||||
if (hccl_comm_ != nullptr) {
|
||||
hccl_comm = hccl_comm_;
|
||||
} else {
|
||||
hccl_comm = HcclCollectiveGroup::instance().GetGroupComm(group);
|
||||
MS_EXCEPTION_IF_NULL(hccl_comm);
|
||||
}
|
||||
auto hccl_comm = GetHcomm(group);
|
||||
MS_EXCEPTION_IF_NULL(hccl_comm);
|
||||
return launch_hccl_recv_(recv_buf, count, dataType, srcRank, hccl_comm, stream);
|
||||
}
|
||||
|
||||
|
|
|
@ -24,6 +24,8 @@
|
|||
#include "mindspore/core/ir/anf.h"
|
||||
#include "hccl/hccl_types.h"
|
||||
#include "runtime/hccl_adapter/plugin/hccl_plugin.h"
|
||||
#include "runtime/device/ascend/distribute/ascend_collective.h"
|
||||
using HcclCollectiveGroup = mindspore::device::ascend::collective::HcclCollectiveGroup;
|
||||
|
||||
namespace ge {
|
||||
class OpsKernelInfoStore;
|
||||
|
@ -83,6 +85,14 @@ class HcclAdapter {
|
|||
void InitPlugin();
|
||||
void FinalizePlugin();
|
||||
|
||||
HcclComm GetHcomm(const std::string &group) const {
|
||||
if (hccl_comm_ != nullptr) {
|
||||
return hccl_comm_;
|
||||
} else {
|
||||
return HcclCollectiveGroup::instance().GetGroupComm(group);
|
||||
}
|
||||
}
|
||||
|
||||
bool InitKernelInfoStore(uint32_t device_id, std::string_view rank_id, std::string_view rank_file);
|
||||
bool FinalizeKernelInfoStore();
|
||||
|
||||
|
|
Loading…
Reference in New Issue