From 9fcfd6166e75cd27c31a03fd64bf32b749ad4076 Mon Sep 17 00:00:00 2001 From: caifubi Date: Sun, 5 Mar 2023 17:42:37 +0800 Subject: [PATCH] Bugfix for PyTrace 1. Handle stub_node exception. 2. Use python infer for StridedSlice. --- .../ccsrc/include/common/utils/stub_tensor.h | 2 + .../pipeline/pynative/forward/do_infer.cc | 9 +-- .../pipeline/pynative/forward/forward.cc | 6 +- .../ccsrc/pipeline/pynative/forward/forward.h | 3 +- .../pipeline/pynative/forward/forward_task.cc | 2 + .../pipeline/pynative/forward/forward_task.h | 1 + .../pipeline/pynative/pynative_execute.cc | 9 ++- .../runtime/pynative/async/async_queue.cc | 21 +++++-- mindspore/ccsrc/runtime/pynative/async/task.h | 1 + mindspore/ccsrc/utils/stub_tensor.cc | 56 +++++++------------ .../dynamic_shape/test_concat_offset_dyn.py | 7 ++- 11 files changed, 62 insertions(+), 55 deletions(-) diff --git a/mindspore/ccsrc/include/common/utils/stub_tensor.h b/mindspore/ccsrc/include/common/utils/stub_tensor.h index eb6f567b83b..463571b6871 100644 --- a/mindspore/ccsrc/include/common/utils/stub_tensor.h +++ b/mindspore/ccsrc/include/common/utils/stub_tensor.h @@ -46,6 +46,7 @@ class COMMON_EXPORT StubNode : public Value { virtual bool SetAbstract(const AbstractBasePtr &abs); virtual void SetValue(const ValuePtr &val); + void SetException(const std::exception_ptr &e_ptr); AbstractBasePtr WaitAbstract(); ValuePtr WaitValue(); @@ -60,6 +61,7 @@ class COMMON_EXPORT StubNode : public Value { ValuePtr value_; std::atomic wait_flag_{false}; StubNodePtr top_node_; + std::exception_ptr e_ptr_{}; }; class TensorNode : public StubNode { diff --git a/mindspore/ccsrc/pipeline/pynative/forward/do_infer.cc b/mindspore/ccsrc/pipeline/pynative/forward/do_infer.cc index bf5ddc4acbd..802f47ee757 100644 --- a/mindspore/ccsrc/pipeline/pynative/forward/do_infer.cc +++ b/mindspore/ccsrc/pipeline/pynative/forward/do_infer.cc @@ -87,12 +87,8 @@ void InferOperation::PynativeInfer(const FrontendOpRunInfoPtr &op_run_info) cons auto eval_impl = abstract::GetFrontendPrimitiveInferImpl(prim); bool need_call_python_code = false; // Charge if the primitive should call the python code, when infer abstract. - if (!eval_impl.has_value()) { - eval_impl = abstract::GetBackendPrimitiveInferImpl(prim); - if (!eval_impl.has_value()) { - MS_LOG(DEBUG) << "Can't found infer function from Frontend and Backend, try to infer with python"; - need_call_python_code = true; - } + if (prim->prim_type() == kPrimTypePyCheck || !eval_impl.has_value()) { + need_call_python_code = true; } // Only cache the abstract when the primitive should call the python code. if (need_call_python_code && GetOutputAbstractByCache(op_run_info)) { @@ -103,6 +99,7 @@ void InferOperation::PynativeInfer(const FrontendOpRunInfoPtr &op_run_info) cons // Call Python func if (need_call_python_code) { + py::gil_scoped_acquire acquire; CallPyInferFunc(prim, op_run_info); if (op_run_info->base_op_run_info.abstract != nullptr) { return; diff --git a/mindspore/ccsrc/pipeline/pynative/forward/forward.cc b/mindspore/ccsrc/pipeline/pynative/forward/forward.cc index 8e2415479d0..5175c317eeb 100644 --- a/mindspore/ccsrc/pipeline/pynative/forward/forward.cc +++ b/mindspore/ccsrc/pipeline/pynative/forward/forward.cc @@ -171,7 +171,6 @@ void ForwardExecutor::Init() { compile::SetMindRTEnable(); python_adapter::set_python_env_flag(true); init_ = true; - forward_queue_ = std::make_shared(); runtime::OpExecutor::GetInstance().RegisterForwardCallback([this]() { forward_queue_->Wait(); }); } @@ -548,7 +547,10 @@ ValuePtr ForwardExecutor::RunOpInMsInner(const FrontendOpRunInfoPtr &op_run_info void ForwardExecutor::ClearRes() { MS_LOG(DEBUG) << "Clear forward res"; - forward_queue_->Reset(); + { + GilReleaseWithCheck gil_release; + forward_queue_->Clear(); + } for (const auto &item : mindrt_backends_) { MS_EXCEPTION_IF_NULL(item.second); item.second->ClearOpExecutorResource(); diff --git a/mindspore/ccsrc/pipeline/pynative/forward/forward.h b/mindspore/ccsrc/pipeline/pynative/forward/forward.h index 941cbd9e45b..68634ac3aa9 100644 --- a/mindspore/ccsrc/pipeline/pynative/forward/forward.h +++ b/mindspore/ccsrc/pipeline/pynative/forward/forward.h @@ -42,7 +42,8 @@ class ForwardExecutor { ForwardExecutor() : cast_operation_(std::make_shared()), infer_operation_(std::make_shared()), - enable_async_(std::getenv("ENABLE_ASYNC")) {} + enable_async_(std::getenv("ENABLE_ASYNC")), + forward_queue_(std::make_shared()) {} ~ForwardExecutor() = default; void Init(); diff --git a/mindspore/ccsrc/pipeline/pynative/forward/forward_task.cc b/mindspore/ccsrc/pipeline/pynative/forward/forward_task.cc index 171899078ae..8f3c8b38730 100644 --- a/mindspore/ccsrc/pipeline/pynative/forward/forward_task.cc +++ b/mindspore/ccsrc/pipeline/pynative/forward/forward_task.cc @@ -19,5 +19,7 @@ namespace mindspore { namespace pynative { void ForwardTask::Run() { run_func_(op_run_info_); } + +void ForwardTask::SetException(const std::exception_ptr &e) { op_run_info_->stub_output->SetException(e); } } // namespace pynative } // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/pynative/forward/forward_task.h b/mindspore/ccsrc/pipeline/pynative/forward/forward_task.h index 3c3c8ae095d..d98402b8820 100644 --- a/mindspore/ccsrc/pipeline/pynative/forward/forward_task.h +++ b/mindspore/ccsrc/pipeline/pynative/forward/forward_task.h @@ -30,6 +30,7 @@ class ForwardTask : public AsyncTask { : AsyncTask(kForwardTask), run_func_(std::move(run_func)), op_run_info_(std::move(op_run_info)) {} ~ForwardTask() = default; void Run() override; + void SetException(const std::exception_ptr &e) override; private: std::function run_func_; diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index 72ac3fc37d8..9d53a3ecbef 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -91,11 +91,10 @@ bool PyNativeExecutor::DisablePyTraceAsync(const FrontendOpRunInfoPtr &op_run_in #ifdef ENABLE_TEST return true; #else - return forward_executor()->IsVmOp(op_run_info->base_op_run_info.op_name) || - op_run_info->op_prim->name() == "Custom" || ScopedFallbackRunning::on() || - op_run_info->op_prim->HasAttr("side_effect_mem") || - (!abstract::GetFrontendPrimitiveInferImpl(op_run_info->op_prim).has_value() && - !abstract::GetBackendPrimitiveInferImpl(op_run_info->op_prim).has_value()) || + const auto &op_prim = op_run_info->op_prim; + return forward_executor()->IsVmOp(op_run_info->base_op_run_info.op_name) || op_prim->name() == "Custom" || + ScopedFallbackRunning::on() || op_prim->HasAttr("side_effect_mem") || + (op_prim->prim_type() == kPrimTypePyCheck || !abstract::GetFrontendPrimitiveInferImpl(op_prim).has_value()) || MsContext::GetInstance()->get_param(MS_CTX_ENABLE_PYNATIVE_SYNCHRONIZE); #endif } diff --git a/mindspore/ccsrc/runtime/pynative/async/async_queue.cc b/mindspore/ccsrc/runtime/pynative/async/async_queue.cc index 7e84676be49..9ed3f66165c 100644 --- a/mindspore/ccsrc/runtime/pynative/async/async_queue.cc +++ b/mindspore/ccsrc/runtime/pynative/async/async_queue.cc @@ -66,9 +66,17 @@ void AsyncQueue::WorkerLoop() { MS_LOG(ERROR) << "Run task failed, error msg:" << e.what(); { std::unique_lock lock(task_mutex_); - std::queue> empty; - std::swap(tasks_, empty); + MsException::Instance().SetException(); + // MsException is unreliable because it gets modified everywhere. + auto e_ptr = std::current_exception(); + task->SetException(e_ptr); + while (!tasks_.empty()) { + auto &t = tasks_.front(); + t->SetException(e_ptr); + tasks_.pop(); + } + task_cond_var_.notify_all(); } } @@ -110,8 +118,13 @@ void AsyncQueue::Clear() { } std::queue> empty; std::swap(tasks_, empty); - auto task = std::make_shared(); - tasks_.push(task); + + // Avoid to push task after WorkerJoin. + if (worker_ != nullptr && worker_->joinable()) { + auto task = std::make_shared(); + tasks_.push(task); + } + task_cond_var_.notify_all(); } // There is still one task in progress diff --git a/mindspore/ccsrc/runtime/pynative/async/task.h b/mindspore/ccsrc/runtime/pynative/async/task.h index c69bbcc2499..79066960bc5 100644 --- a/mindspore/ccsrc/runtime/pynative/async/task.h +++ b/mindspore/ccsrc/runtime/pynative/async/task.h @@ -26,6 +26,7 @@ class AsyncTask { explicit AsyncTask(TaskType task_type) : task_type_(task_type) {} virtual ~AsyncTask() = default; virtual void Run() = 0; + virtual void SetException(const std::exception_ptr &e) {} TaskType task_type() const { return task_type_; } diff --git a/mindspore/ccsrc/utils/stub_tensor.cc b/mindspore/ccsrc/utils/stub_tensor.cc index cc575dd4b6b..1b9ea78aedf 100644 --- a/mindspore/ccsrc/utils/stub_tensor.cc +++ b/mindspore/ccsrc/utils/stub_tensor.cc @@ -70,30 +70,6 @@ py::object MakeOutput(StubNodePtr node) { return out; } } - -class StubException : public ExceptionListener { - public: - StubException() { - MsException::Instance().SetExceptionListener(this); - MsException::Instance().CheckException(); - } - ~StubException() = default; - - void Finalize() { - MsException::Instance().SetExceptionListener(nullptr); - MsException::Instance().CheckException(); - } - bool HasException() const { return has_exception_; } - - void OnException() override { - has_exception_ = true; - std::unique_lock lock(stub_mutex_); - stub_cond_var_.notify_all(); - } - - private: - bool has_exception_{false}; -}; } // namespace bool StubNode::SetAbstract(const AbstractBasePtr &abs) { @@ -113,6 +89,14 @@ void StubNode::SetValue(const ValuePtr &val) { } } +void StubNode::SetException(const std::exception_ptr &e_ptr) { + e_ptr_ = e_ptr; + if (wait_flag_.load()) { + std::unique_lock lock(stub_mutex_); + stub_cond_var_.notify_all(); + } +} + AbstractBasePtr StubNode::WaitAbstract() { GilReleaseWithCheck gil_release; if (abstract_.get() == nullptr) { @@ -120,15 +104,15 @@ AbstractBasePtr StubNode::WaitAbstract() { if (top) { top->WaitAbstract(); } else { - StubException e; wait_flag_.store(true); std::unique_lock lock(stub_mutex_); - stub_cond_var_.wait(lock, [this, &e] { return abstract_.get() != nullptr || e.HasException(); }); - if (e.HasException()) { - abstract_ = std::make_shared(); - } + stub_cond_var_.wait(lock, [this] { return abstract_.get() != nullptr || e_ptr_ != nullptr; }); wait_flag_.store(false); - e.Finalize(); + if (e_ptr_ != nullptr) { + // Need to clear exception in the instance. + MsException::Instance().CheckException(); + std::rethrow_exception(e_ptr_); + } } } return abstract_; @@ -141,15 +125,15 @@ ValuePtr StubNode::WaitValue() { if (top) { top->WaitValue(); } else { - StubException e; wait_flag_.store(true); std::unique_lock lock(stub_mutex_); - stub_cond_var_.wait(lock, [this, &e] { return value_.get() != nullptr || e.HasException(); }); - if (e.HasException()) { - value_ = std::make_shared(); - } + stub_cond_var_.wait(lock, [this] { return value_.get() != nullptr || e_ptr_ != nullptr; }); wait_flag_.store(false); - e.Finalize(); + if (e_ptr_ != nullptr) { + // Need to clear exception in the instance. + MsException::Instance().CheckException(); + std::rethrow_exception(e_ptr_); + } } } return value_; diff --git a/tests/st/ops/dynamic_shape/test_concat_offset_dyn.py b/tests/st/ops/dynamic_shape/test_concat_offset_dyn.py index 50ebb13ec9b..33c8a84613e 100644 --- a/tests/st/ops/dynamic_shape/test_concat_offset_dyn.py +++ b/tests/st/ops/dynamic_shape/test_concat_offset_dyn.py @@ -40,7 +40,12 @@ def run_case(run_mode): net = ConcatOffsetNet(1) net.set_inputs(x0_dyn, x1_dyn) output = net(x0, x1) - assert np.allclose(expect, output.asnumpy()) + if run_mode == context.GRAPH_MODE: + assert np.allclose(expect, output.asnumpy()) + else: + # In PyNative, set_inputs will be ignored. Static shape for ConcatOffset + # infer output is not a tensor, get constant value output. + assert np.allclose(expect, output) @pytest.mark.level0