!49921 PyTrace Bugfix

Merge pull request !49921 from caifubi/master-pynative-pytrace-async-ci-dev-new
This commit is contained in:
i-robot 2023-03-08 03:25:21 +00:00 committed by Gitee
commit 08bf8f9d2c
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
11 changed files with 62 additions and 55 deletions

View File

@ -46,6 +46,7 @@ class COMMON_EXPORT StubNode : public Value {
virtual bool SetAbstract(const AbstractBasePtr &abs);
virtual void SetValue(const ValuePtr &val);
void SetException(const std::exception_ptr &e_ptr);
AbstractBasePtr WaitAbstract();
ValuePtr WaitValue();
@ -60,6 +61,7 @@ class COMMON_EXPORT StubNode : public Value {
ValuePtr value_;
std::atomic<bool> wait_flag_{false};
StubNodePtr top_node_;
std::exception_ptr e_ptr_{};
};
class TensorNode : public StubNode {

View File

@ -87,12 +87,8 @@ void InferOperation::PynativeInfer(const FrontendOpRunInfoPtr &op_run_info) cons
auto eval_impl = abstract::GetFrontendPrimitiveInferImpl(prim);
bool need_call_python_code = false;
// Charge if the primitive should call the python code, when infer abstract.
if (!eval_impl.has_value()) {
eval_impl = abstract::GetBackendPrimitiveInferImpl(prim);
if (!eval_impl.has_value()) {
MS_LOG(DEBUG) << "Can't found infer function from Frontend and Backend, try to infer with python";
need_call_python_code = true;
}
if (prim->prim_type() == kPrimTypePyCheck || !eval_impl.has_value()) {
need_call_python_code = true;
}
// Only cache the abstract when the primitive should call the python code.
if (need_call_python_code && GetOutputAbstractByCache(op_run_info)) {
@ -103,6 +99,7 @@ void InferOperation::PynativeInfer(const FrontendOpRunInfoPtr &op_run_info) cons
// Call Python func
if (need_call_python_code) {
py::gil_scoped_acquire acquire;
CallPyInferFunc(prim, op_run_info);
if (op_run_info->base_op_run_info.abstract != nullptr) {
return;

View File

@ -171,7 +171,6 @@ void ForwardExecutor::Init() {
compile::SetMindRTEnable();
python_adapter::set_python_env_flag(true);
init_ = true;
forward_queue_ = std::make_shared<AsyncQueue>();
runtime::OpExecutor::GetInstance().RegisterForwardCallback([this]() { forward_queue_->Wait(); });
}
@ -548,7 +547,10 @@ ValuePtr ForwardExecutor::RunOpInMsInner(const FrontendOpRunInfoPtr &op_run_info
void ForwardExecutor::ClearRes() {
MS_LOG(DEBUG) << "Clear forward res";
forward_queue_->Reset();
{
GilReleaseWithCheck gil_release;
forward_queue_->Clear();
}
for (const auto &item : mindrt_backends_) {
MS_EXCEPTION_IF_NULL(item.second);
item.second->ClearOpExecutorResource();

View File

@ -42,7 +42,8 @@ class ForwardExecutor {
ForwardExecutor()
: cast_operation_(std::make_shared<CastOperation>()),
infer_operation_(std::make_shared<InferOperation>()),
enable_async_(std::getenv("ENABLE_ASYNC")) {}
enable_async_(std::getenv("ENABLE_ASYNC")),
forward_queue_(std::make_shared<AsyncQueue>()) {}
~ForwardExecutor() = default;
void Init();

View File

@ -19,5 +19,7 @@
namespace mindspore {
namespace pynative {
void ForwardTask::Run() { run_func_(op_run_info_); }
void ForwardTask::SetException(const std::exception_ptr &e) { op_run_info_->stub_output->SetException(e); }
} // namespace pynative
} // namespace mindspore

View File

@ -30,6 +30,7 @@ class ForwardTask : public AsyncTask {
: AsyncTask(kForwardTask), run_func_(std::move(run_func)), op_run_info_(std::move(op_run_info)) {}
~ForwardTask() = default;
void Run() override;
void SetException(const std::exception_ptr &e) override;
private:
std::function<void(const FrontendOpRunInfoPtr &op_run_info)> run_func_;

View File

@ -91,11 +91,10 @@ bool PyNativeExecutor::DisablePyTraceAsync(const FrontendOpRunInfoPtr &op_run_in
#ifdef ENABLE_TEST
return true;
#else
return forward_executor()->IsVmOp(op_run_info->base_op_run_info.op_name) ||
op_run_info->op_prim->name() == "Custom" || ScopedFallbackRunning::on() ||
op_run_info->op_prim->HasAttr("side_effect_mem") ||
(!abstract::GetFrontendPrimitiveInferImpl(op_run_info->op_prim).has_value() &&
!abstract::GetBackendPrimitiveInferImpl(op_run_info->op_prim).has_value()) ||
const auto &op_prim = op_run_info->op_prim;
return forward_executor()->IsVmOp(op_run_info->base_op_run_info.op_name) || op_prim->name() == "Custom" ||
ScopedFallbackRunning::on() || op_prim->HasAttr("side_effect_mem") ||
(op_prim->prim_type() == kPrimTypePyCheck || !abstract::GetFrontendPrimitiveInferImpl(op_prim).has_value()) ||
MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_SYNCHRONIZE);
#endif
}

View File

@ -66,9 +66,17 @@ void AsyncQueue::WorkerLoop() {
MS_LOG(ERROR) << "Run task failed, error msg:" << e.what();
{
std::unique_lock<std::mutex> lock(task_mutex_);
std::queue<std::shared_ptr<AsyncTask>> empty;
std::swap(tasks_, empty);
MsException::Instance().SetException();
// MsException is unreliable because it gets modified everywhere.
auto e_ptr = std::current_exception();
task->SetException(e_ptr);
while (!tasks_.empty()) {
auto &t = tasks_.front();
t->SetException(e_ptr);
tasks_.pop();
}
task_cond_var_.notify_all();
}
}
@ -110,8 +118,13 @@ void AsyncQueue::Clear() {
}
std::queue<std::shared_ptr<AsyncTask>> empty;
std::swap(tasks_, empty);
auto task = std::make_shared<WaitTask>();
tasks_.push(task);
// Avoid to push task after WorkerJoin.
if (worker_ != nullptr && worker_->joinable()) {
auto task = std::make_shared<WaitTask>();
tasks_.push(task);
}
task_cond_var_.notify_all();
}
// There is still one task in progress

View File

@ -26,6 +26,7 @@ class AsyncTask {
explicit AsyncTask(TaskType task_type) : task_type_(task_type) {}
virtual ~AsyncTask() = default;
virtual void Run() = 0;
virtual void SetException(const std::exception_ptr &e) {}
TaskType task_type() const { return task_type_; }

View File

@ -70,30 +70,6 @@ py::object MakeOutput(StubNodePtr node) {
return out;
}
}
class StubException : public ExceptionListener {
public:
StubException() {
MsException::Instance().SetExceptionListener(this);
MsException::Instance().CheckException();
}
~StubException() = default;
void Finalize() {
MsException::Instance().SetExceptionListener(nullptr);
MsException::Instance().CheckException();
}
bool HasException() const { return has_exception_; }
void OnException() override {
has_exception_ = true;
std::unique_lock<std::mutex> lock(stub_mutex_);
stub_cond_var_.notify_all();
}
private:
bool has_exception_{false};
};
} // namespace
bool StubNode::SetAbstract(const AbstractBasePtr &abs) {
@ -113,6 +89,14 @@ void StubNode::SetValue(const ValuePtr &val) {
}
}
void StubNode::SetException(const std::exception_ptr &e_ptr) {
e_ptr_ = e_ptr;
if (wait_flag_.load()) {
std::unique_lock<std::mutex> lock(stub_mutex_);
stub_cond_var_.notify_all();
}
}
AbstractBasePtr StubNode::WaitAbstract() {
GilReleaseWithCheck gil_release;
if (abstract_.get() == nullptr) {
@ -120,15 +104,15 @@ AbstractBasePtr StubNode::WaitAbstract() {
if (top) {
top->WaitAbstract();
} else {
StubException e;
wait_flag_.store(true);
std::unique_lock<std::mutex> lock(stub_mutex_);
stub_cond_var_.wait(lock, [this, &e] { return abstract_.get() != nullptr || e.HasException(); });
if (e.HasException()) {
abstract_ = std::make_shared<abstract::AbstractNone>();
}
stub_cond_var_.wait(lock, [this] { return abstract_.get() != nullptr || e_ptr_ != nullptr; });
wait_flag_.store(false);
e.Finalize();
if (e_ptr_ != nullptr) {
// Need to clear exception in the instance.
MsException::Instance().CheckException();
std::rethrow_exception(e_ptr_);
}
}
}
return abstract_;
@ -141,15 +125,15 @@ ValuePtr StubNode::WaitValue() {
if (top) {
top->WaitValue();
} else {
StubException e;
wait_flag_.store(true);
std::unique_lock<std::mutex> lock(stub_mutex_);
stub_cond_var_.wait(lock, [this, &e] { return value_.get() != nullptr || e.HasException(); });
if (e.HasException()) {
value_ = std::make_shared<tensor::Tensor>();
}
stub_cond_var_.wait(lock, [this] { return value_.get() != nullptr || e_ptr_ != nullptr; });
wait_flag_.store(false);
e.Finalize();
if (e_ptr_ != nullptr) {
// Need to clear exception in the instance.
MsException::Instance().CheckException();
std::rethrow_exception(e_ptr_);
}
}
}
return value_;

View File

@ -40,7 +40,12 @@ def run_case(run_mode):
net = ConcatOffsetNet(1)
net.set_inputs(x0_dyn, x1_dyn)
output = net(x0, x1)
assert np.allclose(expect, output.asnumpy())
if run_mode == context.GRAPH_MODE:
assert np.allclose(expect, output.asnumpy())
else:
# In PyNative, set_inputs will be ignored. Static shape for ConcatOffset
# infer output is not a tensor, get constant value output.
assert np.allclose(expect, output)
@pytest.mark.level0