!49921 PyTrace Bugfix
Merge pull request !49921 from caifubi/master-pynative-pytrace-async-ci-dev-new
This commit is contained in:
commit
08bf8f9d2c
|
@ -46,6 +46,7 @@ class COMMON_EXPORT StubNode : public Value {
|
|||
|
||||
virtual bool SetAbstract(const AbstractBasePtr &abs);
|
||||
virtual void SetValue(const ValuePtr &val);
|
||||
void SetException(const std::exception_ptr &e_ptr);
|
||||
|
||||
AbstractBasePtr WaitAbstract();
|
||||
ValuePtr WaitValue();
|
||||
|
@ -60,6 +61,7 @@ class COMMON_EXPORT StubNode : public Value {
|
|||
ValuePtr value_;
|
||||
std::atomic<bool> wait_flag_{false};
|
||||
StubNodePtr top_node_;
|
||||
std::exception_ptr e_ptr_{};
|
||||
};
|
||||
|
||||
class TensorNode : public StubNode {
|
||||
|
|
|
@ -87,12 +87,8 @@ void InferOperation::PynativeInfer(const FrontendOpRunInfoPtr &op_run_info) cons
|
|||
auto eval_impl = abstract::GetFrontendPrimitiveInferImpl(prim);
|
||||
bool need_call_python_code = false;
|
||||
// Charge if the primitive should call the python code, when infer abstract.
|
||||
if (!eval_impl.has_value()) {
|
||||
eval_impl = abstract::GetBackendPrimitiveInferImpl(prim);
|
||||
if (!eval_impl.has_value()) {
|
||||
MS_LOG(DEBUG) << "Can't found infer function from Frontend and Backend, try to infer with python";
|
||||
need_call_python_code = true;
|
||||
}
|
||||
if (prim->prim_type() == kPrimTypePyCheck || !eval_impl.has_value()) {
|
||||
need_call_python_code = true;
|
||||
}
|
||||
// Only cache the abstract when the primitive should call the python code.
|
||||
if (need_call_python_code && GetOutputAbstractByCache(op_run_info)) {
|
||||
|
@ -103,6 +99,7 @@ void InferOperation::PynativeInfer(const FrontendOpRunInfoPtr &op_run_info) cons
|
|||
|
||||
// Call Python func
|
||||
if (need_call_python_code) {
|
||||
py::gil_scoped_acquire acquire;
|
||||
CallPyInferFunc(prim, op_run_info);
|
||||
if (op_run_info->base_op_run_info.abstract != nullptr) {
|
||||
return;
|
||||
|
|
|
@ -171,7 +171,6 @@ void ForwardExecutor::Init() {
|
|||
compile::SetMindRTEnable();
|
||||
python_adapter::set_python_env_flag(true);
|
||||
init_ = true;
|
||||
forward_queue_ = std::make_shared<AsyncQueue>();
|
||||
runtime::OpExecutor::GetInstance().RegisterForwardCallback([this]() { forward_queue_->Wait(); });
|
||||
}
|
||||
|
||||
|
@ -548,7 +547,10 @@ ValuePtr ForwardExecutor::RunOpInMsInner(const FrontendOpRunInfoPtr &op_run_info
|
|||
|
||||
void ForwardExecutor::ClearRes() {
|
||||
MS_LOG(DEBUG) << "Clear forward res";
|
||||
forward_queue_->Reset();
|
||||
{
|
||||
GilReleaseWithCheck gil_release;
|
||||
forward_queue_->Clear();
|
||||
}
|
||||
for (const auto &item : mindrt_backends_) {
|
||||
MS_EXCEPTION_IF_NULL(item.second);
|
||||
item.second->ClearOpExecutorResource();
|
||||
|
|
|
@ -42,7 +42,8 @@ class ForwardExecutor {
|
|||
ForwardExecutor()
|
||||
: cast_operation_(std::make_shared<CastOperation>()),
|
||||
infer_operation_(std::make_shared<InferOperation>()),
|
||||
enable_async_(std::getenv("ENABLE_ASYNC")) {}
|
||||
enable_async_(std::getenv("ENABLE_ASYNC")),
|
||||
forward_queue_(std::make_shared<AsyncQueue>()) {}
|
||||
~ForwardExecutor() = default;
|
||||
|
||||
void Init();
|
||||
|
|
|
@ -19,5 +19,7 @@
|
|||
namespace mindspore {
|
||||
namespace pynative {
|
||||
void ForwardTask::Run() { run_func_(op_run_info_); }
|
||||
|
||||
void ForwardTask::SetException(const std::exception_ptr &e) { op_run_info_->stub_output->SetException(e); }
|
||||
} // namespace pynative
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -30,6 +30,7 @@ class ForwardTask : public AsyncTask {
|
|||
: AsyncTask(kForwardTask), run_func_(std::move(run_func)), op_run_info_(std::move(op_run_info)) {}
|
||||
~ForwardTask() = default;
|
||||
void Run() override;
|
||||
void SetException(const std::exception_ptr &e) override;
|
||||
|
||||
private:
|
||||
std::function<void(const FrontendOpRunInfoPtr &op_run_info)> run_func_;
|
||||
|
|
|
@ -91,11 +91,10 @@ bool PyNativeExecutor::DisablePyTraceAsync(const FrontendOpRunInfoPtr &op_run_in
|
|||
#ifdef ENABLE_TEST
|
||||
return true;
|
||||
#else
|
||||
return forward_executor()->IsVmOp(op_run_info->base_op_run_info.op_name) ||
|
||||
op_run_info->op_prim->name() == "Custom" || ScopedFallbackRunning::on() ||
|
||||
op_run_info->op_prim->HasAttr("side_effect_mem") ||
|
||||
(!abstract::GetFrontendPrimitiveInferImpl(op_run_info->op_prim).has_value() &&
|
||||
!abstract::GetBackendPrimitiveInferImpl(op_run_info->op_prim).has_value()) ||
|
||||
const auto &op_prim = op_run_info->op_prim;
|
||||
return forward_executor()->IsVmOp(op_run_info->base_op_run_info.op_name) || op_prim->name() == "Custom" ||
|
||||
ScopedFallbackRunning::on() || op_prim->HasAttr("side_effect_mem") ||
|
||||
(op_prim->prim_type() == kPrimTypePyCheck || !abstract::GetFrontendPrimitiveInferImpl(op_prim).has_value()) ||
|
||||
MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_SYNCHRONIZE);
|
||||
#endif
|
||||
}
|
||||
|
|
|
@ -66,9 +66,17 @@ void AsyncQueue::WorkerLoop() {
|
|||
MS_LOG(ERROR) << "Run task failed, error msg:" << e.what();
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(task_mutex_);
|
||||
std::queue<std::shared_ptr<AsyncTask>> empty;
|
||||
std::swap(tasks_, empty);
|
||||
|
||||
MsException::Instance().SetException();
|
||||
// MsException is unreliable because it gets modified everywhere.
|
||||
auto e_ptr = std::current_exception();
|
||||
task->SetException(e_ptr);
|
||||
while (!tasks_.empty()) {
|
||||
auto &t = tasks_.front();
|
||||
t->SetException(e_ptr);
|
||||
tasks_.pop();
|
||||
}
|
||||
|
||||
task_cond_var_.notify_all();
|
||||
}
|
||||
}
|
||||
|
@ -110,8 +118,13 @@ void AsyncQueue::Clear() {
|
|||
}
|
||||
std::queue<std::shared_ptr<AsyncTask>> empty;
|
||||
std::swap(tasks_, empty);
|
||||
auto task = std::make_shared<WaitTask>();
|
||||
tasks_.push(task);
|
||||
|
||||
// Avoid to push task after WorkerJoin.
|
||||
if (worker_ != nullptr && worker_->joinable()) {
|
||||
auto task = std::make_shared<WaitTask>();
|
||||
tasks_.push(task);
|
||||
}
|
||||
|
||||
task_cond_var_.notify_all();
|
||||
}
|
||||
// There is still one task in progress
|
||||
|
|
|
@ -26,6 +26,7 @@ class AsyncTask {
|
|||
explicit AsyncTask(TaskType task_type) : task_type_(task_type) {}
|
||||
virtual ~AsyncTask() = default;
|
||||
virtual void Run() = 0;
|
||||
virtual void SetException(const std::exception_ptr &e) {}
|
||||
|
||||
TaskType task_type() const { return task_type_; }
|
||||
|
||||
|
|
|
@ -70,30 +70,6 @@ py::object MakeOutput(StubNodePtr node) {
|
|||
return out;
|
||||
}
|
||||
}
|
||||
|
||||
class StubException : public ExceptionListener {
|
||||
public:
|
||||
StubException() {
|
||||
MsException::Instance().SetExceptionListener(this);
|
||||
MsException::Instance().CheckException();
|
||||
}
|
||||
~StubException() = default;
|
||||
|
||||
void Finalize() {
|
||||
MsException::Instance().SetExceptionListener(nullptr);
|
||||
MsException::Instance().CheckException();
|
||||
}
|
||||
bool HasException() const { return has_exception_; }
|
||||
|
||||
void OnException() override {
|
||||
has_exception_ = true;
|
||||
std::unique_lock<std::mutex> lock(stub_mutex_);
|
||||
stub_cond_var_.notify_all();
|
||||
}
|
||||
|
||||
private:
|
||||
bool has_exception_{false};
|
||||
};
|
||||
} // namespace
|
||||
|
||||
bool StubNode::SetAbstract(const AbstractBasePtr &abs) {
|
||||
|
@ -113,6 +89,14 @@ void StubNode::SetValue(const ValuePtr &val) {
|
|||
}
|
||||
}
|
||||
|
||||
void StubNode::SetException(const std::exception_ptr &e_ptr) {
|
||||
e_ptr_ = e_ptr;
|
||||
if (wait_flag_.load()) {
|
||||
std::unique_lock<std::mutex> lock(stub_mutex_);
|
||||
stub_cond_var_.notify_all();
|
||||
}
|
||||
}
|
||||
|
||||
AbstractBasePtr StubNode::WaitAbstract() {
|
||||
GilReleaseWithCheck gil_release;
|
||||
if (abstract_.get() == nullptr) {
|
||||
|
@ -120,15 +104,15 @@ AbstractBasePtr StubNode::WaitAbstract() {
|
|||
if (top) {
|
||||
top->WaitAbstract();
|
||||
} else {
|
||||
StubException e;
|
||||
wait_flag_.store(true);
|
||||
std::unique_lock<std::mutex> lock(stub_mutex_);
|
||||
stub_cond_var_.wait(lock, [this, &e] { return abstract_.get() != nullptr || e.HasException(); });
|
||||
if (e.HasException()) {
|
||||
abstract_ = std::make_shared<abstract::AbstractNone>();
|
||||
}
|
||||
stub_cond_var_.wait(lock, [this] { return abstract_.get() != nullptr || e_ptr_ != nullptr; });
|
||||
wait_flag_.store(false);
|
||||
e.Finalize();
|
||||
if (e_ptr_ != nullptr) {
|
||||
// Need to clear exception in the instance.
|
||||
MsException::Instance().CheckException();
|
||||
std::rethrow_exception(e_ptr_);
|
||||
}
|
||||
}
|
||||
}
|
||||
return abstract_;
|
||||
|
@ -141,15 +125,15 @@ ValuePtr StubNode::WaitValue() {
|
|||
if (top) {
|
||||
top->WaitValue();
|
||||
} else {
|
||||
StubException e;
|
||||
wait_flag_.store(true);
|
||||
std::unique_lock<std::mutex> lock(stub_mutex_);
|
||||
stub_cond_var_.wait(lock, [this, &e] { return value_.get() != nullptr || e.HasException(); });
|
||||
if (e.HasException()) {
|
||||
value_ = std::make_shared<tensor::Tensor>();
|
||||
}
|
||||
stub_cond_var_.wait(lock, [this] { return value_.get() != nullptr || e_ptr_ != nullptr; });
|
||||
wait_flag_.store(false);
|
||||
e.Finalize();
|
||||
if (e_ptr_ != nullptr) {
|
||||
// Need to clear exception in the instance.
|
||||
MsException::Instance().CheckException();
|
||||
std::rethrow_exception(e_ptr_);
|
||||
}
|
||||
}
|
||||
}
|
||||
return value_;
|
||||
|
|
|
@ -40,7 +40,12 @@ def run_case(run_mode):
|
|||
net = ConcatOffsetNet(1)
|
||||
net.set_inputs(x0_dyn, x1_dyn)
|
||||
output = net(x0, x1)
|
||||
assert np.allclose(expect, output.asnumpy())
|
||||
if run_mode == context.GRAPH_MODE:
|
||||
assert np.allclose(expect, output.asnumpy())
|
||||
else:
|
||||
# In PyNative, set_inputs will be ignored. Static shape for ConcatOffset
|
||||
# infer output is not a tensor, get constant value output.
|
||||
assert np.allclose(expect, output)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
|
Loading…
Reference in New Issue