forked from mindspore-Ecosystem/mindspore
!48477 fix async reset problem
Merge pull request !48477 from luochao60/fix_async_reset_error_20230206
This commit is contained in:
commit
d9976caa65
|
@ -851,7 +851,7 @@ void GradExecutor::GradNetInner(const prim::GradOperationPtr &grad, const py::ob
|
|||
if (!already_run_top_cell->need_compile_graph()) {
|
||||
MS_LOG(DEBUG) << "No need compile graph";
|
||||
// If no need compile, we can finish construct left bprop
|
||||
async_executor_->Reset();
|
||||
async_executor_->Clear();
|
||||
top_cell_list_.pop_back();
|
||||
set_top_cell(already_run_top_cell);
|
||||
top_cell()->UpdateTopCellInfo(false, false, false);
|
||||
|
|
|
@ -94,16 +94,33 @@ bool AsyncQueue::Empty() {
|
|||
return tasks_.empty();
|
||||
}
|
||||
|
||||
void AsyncQueue::Reset() {
|
||||
void AsyncQueue::Clear() {
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(task_mutex_);
|
||||
if (tasks_.empty()) {
|
||||
return;
|
||||
}
|
||||
std::queue<std::shared_ptr<AsyncTask>> empty;
|
||||
std::swap(tasks_, empty);
|
||||
auto task = std::make_shared<WaitTask>();
|
||||
tasks_.push(task);
|
||||
task_cond_var_.notify_all();
|
||||
}
|
||||
// There is still one task in progress
|
||||
Wait();
|
||||
}
|
||||
|
||||
void AsyncQueue::Reset() {
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(task_mutex_);
|
||||
if (tasks_.empty()) {
|
||||
return;
|
||||
}
|
||||
std::queue<std::shared_ptr<AsyncTask>> empty;
|
||||
std::swap(tasks_, empty);
|
||||
}
|
||||
}
|
||||
|
||||
void AsyncQueue::WorkerJoin() {
|
||||
try {
|
||||
// Avoid worker thread join itself which will cause deadlock
|
||||
|
|
|
@ -43,6 +43,9 @@ class BACKEND_EXPORT AsyncQueue {
|
|||
// Check if the queue is empty.
|
||||
bool Empty();
|
||||
|
||||
// clear tasks of queue, and wait last task.
|
||||
void Clear();
|
||||
|
||||
// When an exception occurs, the state needs to be reset.
|
||||
void Reset();
|
||||
|
||||
|
|
|
@ -19,13 +19,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace pynative {
|
||||
enum TaskType {
|
||||
kUnknownTask = 0,
|
||||
kOpRunTask,
|
||||
kOpBuildTask,
|
||||
kBpropTask,
|
||||
kExitTask,
|
||||
};
|
||||
enum TaskType { kUnknownTask = 0, kOpRunTask, kOpBuildTask, kBpropTask, kExitTask, kWaitTask };
|
||||
class AsyncTask {
|
||||
public:
|
||||
explicit AsyncTask(TaskType task_type) : task_type_(task_type) {}
|
||||
|
@ -44,6 +38,13 @@ class ExitTask : public AsyncTask {
|
|||
~ExitTask() override = default;
|
||||
void Run() override {}
|
||||
};
|
||||
|
||||
class WaitTask : public AsyncTask {
|
||||
public:
|
||||
WaitTask() : AsyncTask(kWaitTask) {}
|
||||
~WaitTask() override = default;
|
||||
void Run() override {}
|
||||
};
|
||||
} // namespace pynative
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
Loading…
Reference in New Issue