Implemented and tested panic handling in WriteTask. Fixed bug in panic handling of `run` and `spawn`.
This commit is contained in:
parent
715c29cb91
commit
c8d744a18d
|
@ -219,18 +219,27 @@ pub fn spawn<F>(task: F)
|
|||
where
|
||||
F: Future<Output = ()> + Send + 'static,
|
||||
{
|
||||
type Fut = Pin<Box<dyn Future<Output = ()> + Send>>;
|
||||
|
||||
// A future that is its own waker that polls inside the rayon primary thread-pool.
|
||||
struct RayonTask(Mutex<Pin<Box<dyn Future<Output = ()> + Send>>>);
|
||||
struct RayonTask(Mutex<Option<Fut>>);
|
||||
impl RayonTask {
|
||||
fn poll(self: Arc<RayonTask>) {
|
||||
rayon::spawn(move || {
|
||||
let r = panic::catch_unwind(panic::AssertUnwindSafe(move || {
|
||||
// this `Option<Fut>` dance is used to avoid a `poll` after `Ready` or panic.
|
||||
let mut task = self.0.lock();
|
||||
if let Some(mut t) = task.take() {
|
||||
let waker = self.clone().into();
|
||||
let mut cx = std::task::Context::from_waker(&waker);
|
||||
let _ = self.0.lock().as_mut().poll(&mut cx);
|
||||
}));
|
||||
if let Err(p) = r {
|
||||
log::error!("panic in `task::spawn`: {}", panic_str(&p));
|
||||
|
||||
let r = panic::catch_unwind(panic::AssertUnwindSafe(move || {
|
||||
if t.as_mut().poll(&mut cx).is_pending() {
|
||||
*task = Some(t);
|
||||
}
|
||||
}));
|
||||
if let Err(p) = r {
|
||||
log::error!("panic in `task::spawn`: {}", panic_str(&p));
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -241,7 +250,7 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
Arc::new(RayonTask(Mutex::new(Box::pin(task)))).poll()
|
||||
Arc::new(RayonTask(Mutex::new(Some(Box::pin(task))))).poll()
|
||||
}
|
||||
|
||||
/// Spawn a parallel async task that can also be `.await` for the task result.
|
||||
|
@ -331,8 +340,10 @@ where
|
|||
R: Send + 'static,
|
||||
T: Future<Output = R> + Send + 'static,
|
||||
{
|
||||
type Fut<R> = Pin<Box<dyn Future<Output = R> + Send>>;
|
||||
|
||||
// A future that is its own waker that polls inside the rayon primary thread-pool.
|
||||
struct RayonCatchTask<R>(Mutex<Pin<Box<dyn Future<Output = R> + Send>>>, flume::Sender<PanicResult<R>>);
|
||||
struct RayonCatchTask<R>(Mutex<Option<Fut<R>>>, flume::Sender<PanicResult<R>>);
|
||||
impl<R: Send + 'static> RayonCatchTask<R> {
|
||||
fn poll(self: Arc<Self>) {
|
||||
let sender = self.1.clone();
|
||||
|
@ -340,20 +351,27 @@ where
|
|||
return; // cancel.
|
||||
}
|
||||
rayon::spawn(move || {
|
||||
let r = panic::catch_unwind(panic::AssertUnwindSafe(|| {
|
||||
// this `Option<Fut>` dance is used to avoid a `poll` after `Ready` or panic.
|
||||
let mut task = self.0.lock();
|
||||
if let Some(mut t) = task.take() {
|
||||
let waker = self.clone().into();
|
||||
let mut cx = std::task::Context::from_waker(&waker);
|
||||
self.0.lock().as_mut().poll(&mut cx)
|
||||
}));
|
||||
|
||||
match r {
|
||||
Ok(Poll::Ready(r)) => {
|
||||
let _ = sender.send(Ok(r));
|
||||
let r = panic::catch_unwind(panic::AssertUnwindSafe(|| t.as_mut().poll(&mut cx)));
|
||||
|
||||
match r {
|
||||
Ok(Poll::Ready(r)) => {
|
||||
drop(task);
|
||||
let _ = sender.send(Ok(r));
|
||||
}
|
||||
Ok(Poll::Pending) => {
|
||||
*task = Some(t);
|
||||
}
|
||||
Err(p) => {
|
||||
drop(task);
|
||||
let _ = sender.send(Err(p));
|
||||
}
|
||||
}
|
||||
Err(p) => {
|
||||
let _ = sender.send(Err(p));
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -366,7 +384,7 @@ where
|
|||
|
||||
let (sender, receiver) = channel::bounded(1);
|
||||
|
||||
Arc::new(RayonCatchTask(Mutex::new(Box::pin(task)), sender.into())).poll();
|
||||
Arc::new(RayonCatchTask(Mutex::new(Some(Box::pin(task))), sender.into())).poll();
|
||||
|
||||
receiver.recv().await.unwrap()
|
||||
}
|
||||
|
@ -2821,7 +2839,8 @@ where
|
|||
let (w, p, r) = match r {
|
||||
Ok((w, p, r)) => (w, p, r),
|
||||
Err(p) => {
|
||||
let _ = f_sender.send(WriteTaskFinishMsg::Panic { payload: p, receiver }).await;
|
||||
drop(receiver);
|
||||
let _ = f_sender.send(WriteTaskFinishMsg::Panic(p)).await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
@ -2850,7 +2869,8 @@ where
|
|||
let (w, r) = match r {
|
||||
Ok((w, r)) => (w, r),
|
||||
Err(p) => {
|
||||
let _ = f_sender.send(WriteTaskFinishMsg::Panic { payload: p, receiver }).await;
|
||||
drop(receiver);
|
||||
let _ = f_sender.send(WriteTaskFinishMsg::Panic(p)).await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
@ -2880,7 +2900,8 @@ where
|
|||
let (w, r) = match r {
|
||||
Ok((w, r)) => (w, r),
|
||||
Err(p) => {
|
||||
let _ = f_sender.send(WriteTaskFinishMsg::Panic { payload: p, receiver }).await;
|
||||
drop(receiver);
|
||||
let _ = f_sender.send(WriteTaskFinishMsg::Panic(p)).await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
@ -2895,17 +2916,10 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
let payloads = Self::drain_payloads(receiver, error_payload);
|
||||
|
||||
// send non-panic finish message.
|
||||
let _ = f_sender
|
||||
.send(WriteTaskFinishMsg::Io {
|
||||
write,
|
||||
result: match error {
|
||||
Some(e) => Err((error_payload, e)),
|
||||
None => Ok(()),
|
||||
},
|
||||
receiver,
|
||||
})
|
||||
.await;
|
||||
let _ = f_sender.send(WriteTaskFinishMsg::Io { write, error, payloads }).await;
|
||||
});
|
||||
WriteTask {
|
||||
sender,
|
||||
|
@ -2929,9 +2943,7 @@ where
|
|||
} else {
|
||||
unreachable!()
|
||||
}
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
/// Awaits until all previous requested [`write`] are finished.
|
||||
|
@ -2953,25 +2965,24 @@ where
|
|||
///
|
||||
/// [`write`]: Self::write
|
||||
pub async fn finish(self) -> Result<W, WriteTaskError<W>> {
|
||||
self.sender.send(WriteTaskMsg::Finish).await.unwrap();
|
||||
let _ = self.sender.send(WriteTaskMsg::Finish).await;
|
||||
|
||||
let msg = self.finish.recv().await.unwrap();
|
||||
|
||||
match msg {
|
||||
WriteTaskFinishMsg::Io { write, result, receiver } => match result {
|
||||
Ok(_) => Ok(write),
|
||||
Err((payload, io_err)) => Err(WriteTaskError::io(
|
||||
WriteTaskFinishMsg::Io { write, error, payloads } => match error {
|
||||
None => Ok(write),
|
||||
Some(error) => Err(WriteTaskError::Io {
|
||||
write,
|
||||
self.state.bytes_written(),
|
||||
Self::drain_payloads(receiver, payload),
|
||||
io_err,
|
||||
)),
|
||||
error,
|
||||
bytes_written: self.state.bytes_written(),
|
||||
payloads,
|
||||
}),
|
||||
},
|
||||
WriteTaskFinishMsg::Panic { payload, receiver } => Err(WriteTaskError::panic(
|
||||
self.state.bytes_written(),
|
||||
Self::drain_payloads(receiver, vec![]),
|
||||
payload,
|
||||
)),
|
||||
WriteTaskFinishMsg::Panic(panic_payload) => Err(WriteTaskError::Panic {
|
||||
panic_payload,
|
||||
bytes_written: self.state.bytes_written(),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -3046,13 +3057,10 @@ enum WriteTaskMsg {
|
|||
enum WriteTaskFinishMsg<W> {
|
||||
Io {
|
||||
write: W,
|
||||
result: Result<(), (Vec<u8>, io::Error)>,
|
||||
receiver: channel::Receiver<WriteTaskMsg>,
|
||||
},
|
||||
Panic {
|
||||
payload: PanicPayload,
|
||||
receiver: channel::Receiver<WriteTaskMsg>,
|
||||
error: Option<io::Error>,
|
||||
payloads: Vec<Vec<u8>>,
|
||||
},
|
||||
Panic(PanicPayload),
|
||||
}
|
||||
|
||||
/// Error from [`WriteTask::finish`].
|
||||
|
@ -3060,52 +3068,51 @@ enum WriteTaskFinishMsg<W> {
|
|||
/// The write task worker closes on the first IO error or panic, the [`WriteTask`] send methods
|
||||
/// return [`WriteTaskClosed`] when this happens and the [`WriteTask::finish`]
|
||||
/// method returns this error that contains the actual error.
|
||||
pub struct WriteTaskError<W> {
|
||||
/// Error source.
|
||||
///
|
||||
/// If is an IO error also contains the [`io::Write`].
|
||||
pub source: WriteTaskErrorSource<W>,
|
||||
pub enum WriteTaskError<W> {
|
||||
/// A write error.
|
||||
Io {
|
||||
/// The [`io::Write`].
|
||||
write: W,
|
||||
/// The error.
|
||||
error: io::Error,
|
||||
|
||||
/// Number of bytes that where written before the error.
|
||||
///
|
||||
/// Note that some bytes from the last payload where probably written too, but
|
||||
/// only confirmed written payloads are counted here.
|
||||
pub bytes_written: u64,
|
||||
/// Number of bytes that where written before the error.
|
||||
///
|
||||
/// Note that some bytes from the last payload where probably written too, but
|
||||
/// only confirmed written payloads are counted here.
|
||||
bytes_written: u64,
|
||||
|
||||
/// The payloads that where not written.
|
||||
pub payloads: Vec<Vec<u8>>,
|
||||
/// The payloads that where not written.
|
||||
payloads: Vec<Vec<u8>>,
|
||||
},
|
||||
/// A panic in the [`io::Write`].
|
||||
///
|
||||
/// You can propagate the panic using [`std::panic::resume_unwind`].
|
||||
Panic {
|
||||
/// The panic message object.
|
||||
panic_payload: PanicPayload,
|
||||
/// Number of bytes that where written before the error.
|
||||
///
|
||||
/// Note that some bytes from the last payload where probably written too, and
|
||||
/// given there was a panic some bytes could be corrupted.
|
||||
bytes_written: u64,
|
||||
},
|
||||
}
|
||||
impl<W> WriteTaskError<W> {
|
||||
fn io(write: W, bytes_written: u64, payloads: Vec<Vec<u8>>, error: io::Error) -> Self {
|
||||
Self {
|
||||
source: WriteTaskErrorSource::Io { write, error },
|
||||
bytes_written,
|
||||
payloads,
|
||||
}
|
||||
}
|
||||
|
||||
fn panic(bytes_written: u64, payloads: Vec<Vec<u8>>, payload: PanicPayload) -> Self {
|
||||
Self {
|
||||
source: WriteTaskErrorSource::Panic(payload),
|
||||
bytes_written,
|
||||
payloads,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the error of [`Io`] or panics.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// If the error is a [`Panic`] the panic is propagated using [`resume_unwind`].
|
||||
///
|
||||
/// [`Io`]: WriteTaskErrorSource::Io
|
||||
/// [`Panic`]: WriteTaskErrorSource::Panic
|
||||
/// [`Io`]: Self::Io
|
||||
/// [`Panic`]: Self::Panic
|
||||
/// [`resume_unwind`]: std::panic::resume_unwind
|
||||
#[track_caller]
|
||||
pub fn unwrap_io(self) -> io::Error {
|
||||
match self.source {
|
||||
WriteTaskErrorSource::Io { error, .. } => error,
|
||||
WriteTaskErrorSource::Panic(p) => panic::resume_unwind(p),
|
||||
match self {
|
||||
Self::Io { error, .. } => error,
|
||||
Self::Panic { panic_payload, .. } => panic::resume_unwind(panic_payload),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -3115,56 +3122,52 @@ impl<W> WriteTaskError<W> {
|
|||
///
|
||||
/// If the error is a [`Panic`] the panic is propagated using [`resume_unwind`].
|
||||
///
|
||||
/// [`Io`]: WriteTaskErrorSource::Io
|
||||
/// [`Panic`]: WriteTaskErrorSource::Panic
|
||||
/// [`Io`]: Self::Io
|
||||
/// [`Panic`]: Self::Panic
|
||||
/// [`resume_unwind`]: std::panic::resume_unwind
|
||||
pub fn unwrap_write(self) -> (W, io::Error) {
|
||||
match self.source {
|
||||
WriteTaskErrorSource::Io { write, error } => (write, error),
|
||||
WriteTaskErrorSource::Panic(p) => panic::resume_unwind(p),
|
||||
match self {
|
||||
Self::Io { write, error, .. } => (write, error),
|
||||
Self::Panic { panic_payload, .. } => panic::resume_unwind(panic_payload),
|
||||
}
|
||||
}
|
||||
}
|
||||
impl<W: io::Write> fmt::Debug for WriteTaskError<W> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match &self.source {
|
||||
WriteTaskErrorSource::Io { error, .. } => write!(f, "Io({:?})", error),
|
||||
WriteTaskErrorSource::Panic(p) => write!(f, "Panic({:?})", panic_str(p)),
|
||||
match &self {
|
||||
Self::Io { error, bytes_written, .. } => f
|
||||
.debug_struct("Io")
|
||||
.field("error", error)
|
||||
.field("bytes_written", bytes_written)
|
||||
.finish_non_exhaustive(),
|
||||
Self::Panic {
|
||||
panic_payload: p,
|
||||
bytes_written,
|
||||
} => f
|
||||
.debug_struct("Panic")
|
||||
.field("panic_payload", &panic_str(p))
|
||||
.field("bytes_written", bytes_written)
|
||||
.finish(),
|
||||
}
|
||||
}
|
||||
}
|
||||
impl<W: io::Write> fmt::Display for WriteTaskError<W> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match &self.source {
|
||||
WriteTaskErrorSource::Io { error, .. } => write!(f, "{}", error),
|
||||
WriteTaskErrorSource::Panic(p) => write!(f, "{}", panic_str(p)),
|
||||
match &self {
|
||||
Self::Io { error, .. } => write!(f, "{}", error),
|
||||
Self::Panic { panic_payload: p, .. } => write!(f, "{}", panic_str(p)),
|
||||
}
|
||||
}
|
||||
}
|
||||
impl<W: io::Write> std::error::Error for WriteTaskError<W> {
|
||||
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
|
||||
match &self.source {
|
||||
WriteTaskErrorSource::Io { error, .. } => Some(error),
|
||||
WriteTaskErrorSource::Panic(_) => None,
|
||||
match &self {
|
||||
Self::Io { error, .. } => Some(error),
|
||||
Self::Panic { .. } => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Source of an [`WriteTaskError<W>`].
|
||||
pub enum WriteTaskErrorSource<W> {
|
||||
/// A write error.
|
||||
Io {
|
||||
/// The [`io::Write`].
|
||||
write: W,
|
||||
/// The error.
|
||||
error: io::Error,
|
||||
},
|
||||
/// A panic in the [`io::Write`].
|
||||
///
|
||||
/// You can propagate the panic using [`std::panic::resume_unwind`].
|
||||
Panic(PanicPayload),
|
||||
}
|
||||
|
||||
/// Error from [`WriteTask`].
|
||||
///
|
||||
/// This error is returned to indicate that the task worker has permanently stopped because
|
||||
|
@ -3747,6 +3750,8 @@ pub mod http {
|
|||
|
||||
#[cfg(test)]
|
||||
pub mod tests {
|
||||
use rayon::prelude::*;
|
||||
|
||||
use crate::units::TimeUnits;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
|
||||
|
@ -3894,17 +3899,106 @@ pub mod tests {
|
|||
|
||||
#[test]
|
||||
pub fn write_task() {
|
||||
todo!()
|
||||
async_test(async {
|
||||
let write = TestWrite::default();
|
||||
|
||||
let task = WriteTask::default().spawn(write);
|
||||
|
||||
for byte in 0u8..20 {
|
||||
task.write(vec![byte, byte + 100]).await.unwrap();
|
||||
}
|
||||
|
||||
let write = task.finish().await.unwrap();
|
||||
|
||||
assert_eq!(20, write.write_calls);
|
||||
assert_eq!(40, write.bytes_written);
|
||||
assert_eq!(1, write.flush_calls);
|
||||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn write_task_flush() {
|
||||
async_test(async {
|
||||
let write = TestWrite::default();
|
||||
|
||||
let task = WriteTask::default().spawn(write);
|
||||
|
||||
for byte in 0u8..20 {
|
||||
task.write(vec![byte, byte + 100]).await.unwrap();
|
||||
}
|
||||
|
||||
task.flush().await.unwrap();
|
||||
let task_bytes_written = task.bytes_written();
|
||||
|
||||
let write = task.finish().await.unwrap();
|
||||
|
||||
assert_eq!(40, task_bytes_written);
|
||||
assert_eq!(2, write.flush_calls);
|
||||
|
||||
assert_eq!(20, write.write_calls);
|
||||
assert_eq!(40, write.bytes_written);
|
||||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn write_error() {
|
||||
todo!()
|
||||
async_test(async {
|
||||
let write = TestWrite::default();
|
||||
let flag = write.cause_error.clone();
|
||||
|
||||
let task = WriteTask::default().spawn(write);
|
||||
|
||||
for byte in 0u8..20 {
|
||||
if byte == 10 {
|
||||
flag.set();
|
||||
}
|
||||
if task.write(vec![byte, byte + 100]).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let e = task.finish().await.unwrap_err();
|
||||
if let WriteTaskError::Io { bytes_written, write, .. } = &e {
|
||||
assert_eq!(write.bytes_written as u64, *bytes_written);
|
||||
} else {
|
||||
panic!("expected WriteTaskError::Io")
|
||||
}
|
||||
|
||||
let (write, e) = e.unwrap_write();
|
||||
assert!(write.bytes_written > 0);
|
||||
assert_eq!("test error", e.to_string());
|
||||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn write_panic() {
|
||||
todo!()
|
||||
async_test(async {
|
||||
let write = TestWrite::default();
|
||||
let flag = write.cause_panic.clone();
|
||||
|
||||
let task = WriteTask::default().spawn(write);
|
||||
|
||||
for byte in 0u8..20 {
|
||||
if byte == 10 {
|
||||
flag.set();
|
||||
}
|
||||
if task.write(vec![byte, byte + 100]).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let e = task.finish().await.unwrap_err();
|
||||
if let WriteTaskError::Panic {
|
||||
bytes_written,
|
||||
panic_payload,
|
||||
} = &e
|
||||
{
|
||||
assert!(*bytes_written > 0);
|
||||
assert_eq!("test panic", panic_str(&panic_payload))
|
||||
} else {
|
||||
panic!("expected WriteTaskError::Panic")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -3933,6 +4027,26 @@ pub mod tests {
|
|||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn run_panic_handling_parallel() {
|
||||
async_test(async {
|
||||
let r = run_catch(async {
|
||||
run(async {
|
||||
timeout(Duration::from_millis(1)).await;
|
||||
(0..100000).into_par_iter().for_each(|i| {
|
||||
if i == 50005 {
|
||||
panic!("test panic");
|
||||
}
|
||||
});
|
||||
})
|
||||
.await;
|
||||
})
|
||||
.await;
|
||||
|
||||
assert!(r.is_err());
|
||||
})
|
||||
}
|
||||
|
||||
#[derive(Default, Debug)]
|
||||
pub struct TestRead {
|
||||
bytes_read: usize,
|
||||
|
@ -3977,6 +4091,7 @@ pub mod tests {
|
|||
} else if self.cause_panic.is_set() {
|
||||
panic!("test panic");
|
||||
} else {
|
||||
std::thread::sleep(Duration::from_millis(2));
|
||||
self.bytes_written += buf.len();
|
||||
Ok(buf.len())
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue