forked from mindspore-Ecosystem/mindspore
fix tensoradd grad op run fail
This commit is contained in:
parent
6721541c9d
commit
762bf9ac25
|
@ -260,6 +260,19 @@ bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) {
|
|||
auto anf_node_list = graph->execution_order();
|
||||
TaskGenerator::GenTasks(anf_node_list, &task_info_list, graph->graph_id());
|
||||
|
||||
// Store the task_info_list
|
||||
auto iter = task_map_.find(graph);
|
||||
if (iter != task_map_.end()) {
|
||||
MS_LOG(EXCEPTION) << "graph TaskInfo list already exist";
|
||||
}
|
||||
task_map_[graph] = task_info_list;
|
||||
|
||||
// Graph may have no compute node, such TensorAddGrad.
|
||||
if (task_info_list.empty()) {
|
||||
MS_LOG(WARNING) << "graph " << graph->graph_id() << " have no compute node";
|
||||
return true;
|
||||
}
|
||||
|
||||
AscendStreamAssign &assign_instance = AscendStreamAssign::GetInstance();
|
||||
// the streams' flag not HEAD_STREAM
|
||||
std::vector<uint32_t> wait_active_stream_list = assign_instance.GetWaitStreams();
|
||||
|
@ -278,10 +291,6 @@ bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) {
|
|||
graph_model_map_[graph] = model;
|
||||
graph_model_id_map_[graph] = graph->graph_id();
|
||||
MS_LOG(INFO) << "TaskGenerator GetTaskInfo end...";
|
||||
|
||||
// Store the task_info_list
|
||||
task_map_.insert(std::make_pair(graph, task_info_list));
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -305,6 +314,11 @@ bool AscendKernelRuntime::LoadTask(const session::KernelGraph *graph) {
|
|||
return true;
|
||||
}
|
||||
|
||||
if (GraphWithEmptyTaskList(graph)) {
|
||||
MS_LOG(WARNING) << "LoadTask end, task list is empty";
|
||||
return true;
|
||||
}
|
||||
|
||||
auto task_iter = graph_model_map_.find(graph);
|
||||
if (task_iter == graph_model_map_.end()) {
|
||||
MS_LOG(ERROR) << "task not exist";
|
||||
|
@ -333,6 +347,11 @@ bool AscendKernelRuntime::RunTask(const session::KernelGraph *graph) {
|
|||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
ge::InputData input_tensors = ge::InputData();
|
||||
ge::OutputData *output_tensors = nullptr;
|
||||
if (GraphWithEmptyTaskList(graph)) {
|
||||
MS_LOG(WARNING) << "RunTask end, no task info found";
|
||||
return true;
|
||||
}
|
||||
|
||||
auto model_id = GetGraphModelId(graph);
|
||||
bool status = ge::model_runner::ModelRunner::Instance().RunModel(model_id, input_tensors, output_tensors);
|
||||
if (!status) {
|
||||
|
@ -468,6 +487,14 @@ bool AscendKernelRuntime::DestroyHccl() {
|
|||
context_ptr->set_enable_hccl(false);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool AscendKernelRuntime::GraphWithEmptyTaskList(const session::KernelGraph *graph) const {
|
||||
auto iter = task_map_.find(graph);
|
||||
if (iter == task_map_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Unknown graph ptr";
|
||||
}
|
||||
return iter->second.empty();
|
||||
}
|
||||
} // namespace ascend
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -55,6 +55,8 @@ class AscendKernelRuntime : public KernelRuntime {
|
|||
void ClearGraphModelMap();
|
||||
void ReleaseDeviceRes() override;
|
||||
uint32_t GetGraphModelId(const session::KernelGraph *kernel_graph);
|
||||
bool GraphWithEmptyTaskList(const session::KernelGraph *graph) const;
|
||||
|
||||
rtContext_t rt_context_{nullptr};
|
||||
bool initialized_{false};
|
||||
unordered_map<const session::KernelGraph *, vector<std::shared_ptr<TaskInfo>>> task_map_;
|
||||
|
|
Loading…
Reference in New Issue