!48477 fix async reset problem

Merge pull request !48477 from luochao60/fix_async_reset_error_20230206
This commit is contained in:
i-robot 2023-02-07 12:38:33 +00:00 committed by Gitee
commit d9976caa65
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 30 additions and 9 deletions

View File

@ -851,7 +851,7 @@ void GradExecutor::GradNetInner(const prim::GradOperationPtr &grad, const py::ob
if (!already_run_top_cell->need_compile_graph()) {
MS_LOG(DEBUG) << "No need compile graph";
// If no need compile, we can finish construct left bprop
async_executor_->Reset();
async_executor_->Clear();
top_cell_list_.pop_back();
set_top_cell(already_run_top_cell);
top_cell()->UpdateTopCellInfo(false, false, false);

View File

@ -94,16 +94,33 @@ bool AsyncQueue::Empty() {
return tasks_.empty();
}
void AsyncQueue::Reset() {
void AsyncQueue::Clear() {
{
std::lock_guard<std::mutex> lock(task_mutex_);
if (tasks_.empty()) {
return;
}
std::queue<std::shared_ptr<AsyncTask>> empty;
std::swap(tasks_, empty);
auto task = std::make_shared<WaitTask>();
tasks_.push(task);
task_cond_var_.notify_all();
}
// There is still one task in progress
Wait();
}
void AsyncQueue::Reset() {
{
std::lock_guard<std::mutex> lock(task_mutex_);
if (tasks_.empty()) {
return;
}
std::queue<std::shared_ptr<AsyncTask>> empty;
std::swap(tasks_, empty);
}
}
void AsyncQueue::WorkerJoin() {
try {
// Avoid worker thread join itself which will cause deadlock

View File

@ -43,6 +43,9 @@ class BACKEND_EXPORT AsyncQueue {
// Check if the queue is empty.
bool Empty();
// clear tasks of queue, and wait last task.
void Clear();
// When an exception occurs, the state needs to be reset.
void Reset();

View File

@ -19,13 +19,7 @@
namespace mindspore {
namespace pynative {
enum TaskType {
kUnknownTask = 0,
kOpRunTask,
kOpBuildTask,
kBpropTask,
kExitTask,
};
enum TaskType { kUnknownTask = 0, kOpRunTask, kOpBuildTask, kBpropTask, kExitTask, kWaitTask };
class AsyncTask {
public:
explicit AsyncTask(TaskType task_type) : task_type_(task_type) {}
@ -44,6 +38,13 @@ class ExitTask : public AsyncTask {
~ExitTask() override = default;
void Run() override {}
};
class WaitTask : public AsyncTask {
public:
WaitTask() : AsyncTask(kWaitTask) {}
~WaitTask() override = default;
void Run() override {}
};
} // namespace pynative
} // namespace mindspore