Implemented and tested panic handling in WriteTask. Fixed bug in panic handling of `run` and `spawn`.

This commit is contained in:
Samuel Guerra 2021-07-26 21:17:47 -03:00
parent 715c29cb91
commit c8d744a18d
1 changed files with 234 additions and 119 deletions

View File

@ -219,18 +219,27 @@ pub fn spawn<F>(task: F)
where
F: Future<Output = ()> + Send + 'static,
{
type Fut = Pin<Box<dyn Future<Output = ()> + Send>>;
// A future that is its own waker that polls inside the rayon primary thread-pool.
struct RayonTask(Mutex<Pin<Box<dyn Future<Output = ()> + Send>>>);
struct RayonTask(Mutex<Option<Fut>>);
impl RayonTask {
fn poll(self: Arc<RayonTask>) {
rayon::spawn(move || {
let r = panic::catch_unwind(panic::AssertUnwindSafe(move || {
// this `Option<Fut>` dance is used to avoid a `poll` after `Ready` or panic.
let mut task = self.0.lock();
if let Some(mut t) = task.take() {
let waker = self.clone().into();
let mut cx = std::task::Context::from_waker(&waker);
let _ = self.0.lock().as_mut().poll(&mut cx);
}));
if let Err(p) = r {
log::error!("panic in `task::spawn`: {}", panic_str(&p));
let r = panic::catch_unwind(panic::AssertUnwindSafe(move || {
if t.as_mut().poll(&mut cx).is_pending() {
*task = Some(t);
}
}));
if let Err(p) = r {
log::error!("panic in `task::spawn`: {}", panic_str(&p));
}
}
})
}
@ -241,7 +250,7 @@ where
}
}
Arc::new(RayonTask(Mutex::new(Box::pin(task)))).poll()
Arc::new(RayonTask(Mutex::new(Some(Box::pin(task))))).poll()
}
/// Spawn a parallel async task that can also be `.await` for the task result.
@ -331,8 +340,10 @@ where
R: Send + 'static,
T: Future<Output = R> + Send + 'static,
{
type Fut<R> = Pin<Box<dyn Future<Output = R> + Send>>;
// A future that is its own waker that polls inside the rayon primary thread-pool.
struct RayonCatchTask<R>(Mutex<Pin<Box<dyn Future<Output = R> + Send>>>, flume::Sender<PanicResult<R>>);
struct RayonCatchTask<R>(Mutex<Option<Fut<R>>>, flume::Sender<PanicResult<R>>);
impl<R: Send + 'static> RayonCatchTask<R> {
fn poll(self: Arc<Self>) {
let sender = self.1.clone();
@ -340,20 +351,27 @@ where
return; // cancel.
}
rayon::spawn(move || {
let r = panic::catch_unwind(panic::AssertUnwindSafe(|| {
// this `Option<Fut>` dance is used to avoid a `poll` after `Ready` or panic.
let mut task = self.0.lock();
if let Some(mut t) = task.take() {
let waker = self.clone().into();
let mut cx = std::task::Context::from_waker(&waker);
self.0.lock().as_mut().poll(&mut cx)
}));
match r {
Ok(Poll::Ready(r)) => {
let _ = sender.send(Ok(r));
let r = panic::catch_unwind(panic::AssertUnwindSafe(|| t.as_mut().poll(&mut cx)));
match r {
Ok(Poll::Ready(r)) => {
drop(task);
let _ = sender.send(Ok(r));
}
Ok(Poll::Pending) => {
*task = Some(t);
}
Err(p) => {
drop(task);
let _ = sender.send(Err(p));
}
}
Err(p) => {
let _ = sender.send(Err(p));
}
_ => {}
}
})
}
@ -366,7 +384,7 @@ where
let (sender, receiver) = channel::bounded(1);
Arc::new(RayonCatchTask(Mutex::new(Box::pin(task)), sender.into())).poll();
Arc::new(RayonCatchTask(Mutex::new(Some(Box::pin(task))), sender.into())).poll();
receiver.recv().await.unwrap()
}
@ -2821,7 +2839,8 @@ where
let (w, p, r) = match r {
Ok((w, p, r)) => (w, p, r),
Err(p) => {
let _ = f_sender.send(WriteTaskFinishMsg::Panic { payload: p, receiver }).await;
drop(receiver);
let _ = f_sender.send(WriteTaskFinishMsg::Panic(p)).await;
return;
}
};
@ -2850,7 +2869,8 @@ where
let (w, r) = match r {
Ok((w, r)) => (w, r),
Err(p) => {
let _ = f_sender.send(WriteTaskFinishMsg::Panic { payload: p, receiver }).await;
drop(receiver);
let _ = f_sender.send(WriteTaskFinishMsg::Panic(p)).await;
return;
}
};
@ -2880,7 +2900,8 @@ where
let (w, r) = match r {
Ok((w, r)) => (w, r),
Err(p) => {
let _ = f_sender.send(WriteTaskFinishMsg::Panic { payload: p, receiver }).await;
drop(receiver);
let _ = f_sender.send(WriteTaskFinishMsg::Panic(p)).await;
return;
}
};
@ -2895,17 +2916,10 @@ where
}
}
let payloads = Self::drain_payloads(receiver, error_payload);
// send non-panic finish message.
let _ = f_sender
.send(WriteTaskFinishMsg::Io {
write,
result: match error {
Some(e) => Err((error_payload, e)),
None => Ok(()),
},
receiver,
})
.await;
let _ = f_sender.send(WriteTaskFinishMsg::Io { write, error, payloads }).await;
});
WriteTask {
sender,
@ -2929,9 +2943,7 @@ where
} else {
unreachable!()
}
})?;
Ok(())
})
}
/// Awaits until all previous requested [`write`] are finished.
@ -2953,25 +2965,24 @@ where
///
/// [`write`]: Self::write
pub async fn finish(self) -> Result<W, WriteTaskError<W>> {
self.sender.send(WriteTaskMsg::Finish).await.unwrap();
let _ = self.sender.send(WriteTaskMsg::Finish).await;
let msg = self.finish.recv().await.unwrap();
match msg {
WriteTaskFinishMsg::Io { write, result, receiver } => match result {
Ok(_) => Ok(write),
Err((payload, io_err)) => Err(WriteTaskError::io(
WriteTaskFinishMsg::Io { write, error, payloads } => match error {
None => Ok(write),
Some(error) => Err(WriteTaskError::Io {
write,
self.state.bytes_written(),
Self::drain_payloads(receiver, payload),
io_err,
)),
error,
bytes_written: self.state.bytes_written(),
payloads,
}),
},
WriteTaskFinishMsg::Panic { payload, receiver } => Err(WriteTaskError::panic(
self.state.bytes_written(),
Self::drain_payloads(receiver, vec![]),
payload,
)),
WriteTaskFinishMsg::Panic(panic_payload) => Err(WriteTaskError::Panic {
panic_payload,
bytes_written: self.state.bytes_written(),
}),
}
}
@ -3046,13 +3057,10 @@ enum WriteTaskMsg {
enum WriteTaskFinishMsg<W> {
Io {
write: W,
result: Result<(), (Vec<u8>, io::Error)>,
receiver: channel::Receiver<WriteTaskMsg>,
},
Panic {
payload: PanicPayload,
receiver: channel::Receiver<WriteTaskMsg>,
error: Option<io::Error>,
payloads: Vec<Vec<u8>>,
},
Panic(PanicPayload),
}
/// Error from [`WriteTask::finish`].
@ -3060,52 +3068,51 @@ enum WriteTaskFinishMsg<W> {
/// The write task worker closes on the first IO error or panic, the [`WriteTask`] send methods
/// return [`WriteTaskClosed`] when this happens and the [`WriteTask::finish`]
/// method returns this error that contains the actual error.
pub struct WriteTaskError<W> {
/// Error source.
///
/// If is an IO error also contains the [`io::Write`].
pub source: WriteTaskErrorSource<W>,
pub enum WriteTaskError<W> {
/// A write error.
Io {
/// The [`io::Write`].
write: W,
/// The error.
error: io::Error,
/// Number of bytes that where written before the error.
///
/// Note that some bytes from the last payload where probably written too, but
/// only confirmed written payloads are counted here.
pub bytes_written: u64,
/// Number of bytes that where written before the error.
///
/// Note that some bytes from the last payload where probably written too, but
/// only confirmed written payloads are counted here.
bytes_written: u64,
/// The payloads that where not written.
pub payloads: Vec<Vec<u8>>,
/// The payloads that where not written.
payloads: Vec<Vec<u8>>,
},
/// A panic in the [`io::Write`].
///
/// You can propagate the panic using [`std::panic::resume_unwind`].
Panic {
/// The panic message object.
panic_payload: PanicPayload,
/// Number of bytes that where written before the error.
///
/// Note that some bytes from the last payload where probably written too, and
/// given there was a panic some bytes could be corrupted.
bytes_written: u64,
},
}
impl<W> WriteTaskError<W> {
fn io(write: W, bytes_written: u64, payloads: Vec<Vec<u8>>, error: io::Error) -> Self {
Self {
source: WriteTaskErrorSource::Io { write, error },
bytes_written,
payloads,
}
}
fn panic(bytes_written: u64, payloads: Vec<Vec<u8>>, payload: PanicPayload) -> Self {
Self {
source: WriteTaskErrorSource::Panic(payload),
bytes_written,
payloads,
}
}
/// Returns the error of [`Io`] or panics.
///
/// # Panics
///
/// If the error is a [`Panic`] the panic is propagated using [`resume_unwind`].
///
/// [`Io`]: WriteTaskErrorSource::Io
/// [`Panic`]: WriteTaskErrorSource::Panic
/// [`Io`]: Self::Io
/// [`Panic`]: Self::Panic
/// [`resume_unwind`]: std::panic::resume_unwind
#[track_caller]
pub fn unwrap_io(self) -> io::Error {
match self.source {
WriteTaskErrorSource::Io { error, .. } => error,
WriteTaskErrorSource::Panic(p) => panic::resume_unwind(p),
match self {
Self::Io { error, .. } => error,
Self::Panic { panic_payload, .. } => panic::resume_unwind(panic_payload),
}
}
@ -3115,56 +3122,52 @@ impl<W> WriteTaskError<W> {
///
/// If the error is a [`Panic`] the panic is propagated using [`resume_unwind`].
///
/// [`Io`]: WriteTaskErrorSource::Io
/// [`Panic`]: WriteTaskErrorSource::Panic
/// [`Io`]: Self::Io
/// [`Panic`]: Self::Panic
/// [`resume_unwind`]: std::panic::resume_unwind
pub fn unwrap_write(self) -> (W, io::Error) {
match self.source {
WriteTaskErrorSource::Io { write, error } => (write, error),
WriteTaskErrorSource::Panic(p) => panic::resume_unwind(p),
match self {
Self::Io { write, error, .. } => (write, error),
Self::Panic { panic_payload, .. } => panic::resume_unwind(panic_payload),
}
}
}
impl<W: io::Write> fmt::Debug for WriteTaskError<W> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match &self.source {
WriteTaskErrorSource::Io { error, .. } => write!(f, "Io({:?})", error),
WriteTaskErrorSource::Panic(p) => write!(f, "Panic({:?})", panic_str(p)),
match &self {
Self::Io { error, bytes_written, .. } => f
.debug_struct("Io")
.field("error", error)
.field("bytes_written", bytes_written)
.finish_non_exhaustive(),
Self::Panic {
panic_payload: p,
bytes_written,
} => f
.debug_struct("Panic")
.field("panic_payload", &panic_str(p))
.field("bytes_written", bytes_written)
.finish(),
}
}
}
impl<W: io::Write> fmt::Display for WriteTaskError<W> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match &self.source {
WriteTaskErrorSource::Io { error, .. } => write!(f, "{}", error),
WriteTaskErrorSource::Panic(p) => write!(f, "{}", panic_str(p)),
match &self {
Self::Io { error, .. } => write!(f, "{}", error),
Self::Panic { panic_payload: p, .. } => write!(f, "{}", panic_str(p)),
}
}
}
impl<W: io::Write> std::error::Error for WriteTaskError<W> {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match &self.source {
WriteTaskErrorSource::Io { error, .. } => Some(error),
WriteTaskErrorSource::Panic(_) => None,
match &self {
Self::Io { error, .. } => Some(error),
Self::Panic { .. } => None,
}
}
}
/// Source of an [`WriteTaskError<W>`].
pub enum WriteTaskErrorSource<W> {
/// A write error.
Io {
/// The [`io::Write`].
write: W,
/// The error.
error: io::Error,
},
/// A panic in the [`io::Write`].
///
/// You can propagate the panic using [`std::panic::resume_unwind`].
Panic(PanicPayload),
}
/// Error from [`WriteTask`].
///
/// This error is returned to indicate that the task worker has permanently stopped because
@ -3747,6 +3750,8 @@ pub mod http {
#[cfg(test)]
pub mod tests {
use rayon::prelude::*;
use crate::units::TimeUnits;
use std::sync::atomic::AtomicBool;
@ -3894,17 +3899,106 @@ pub mod tests {
#[test]
pub fn write_task() {
todo!()
async_test(async {
let write = TestWrite::default();
let task = WriteTask::default().spawn(write);
for byte in 0u8..20 {
task.write(vec![byte, byte + 100]).await.unwrap();
}
let write = task.finish().await.unwrap();
assert_eq!(20, write.write_calls);
assert_eq!(40, write.bytes_written);
assert_eq!(1, write.flush_calls);
})
}
#[test]
pub fn write_task_flush() {
async_test(async {
let write = TestWrite::default();
let task = WriteTask::default().spawn(write);
for byte in 0u8..20 {
task.write(vec![byte, byte + 100]).await.unwrap();
}
task.flush().await.unwrap();
let task_bytes_written = task.bytes_written();
let write = task.finish().await.unwrap();
assert_eq!(40, task_bytes_written);
assert_eq!(2, write.flush_calls);
assert_eq!(20, write.write_calls);
assert_eq!(40, write.bytes_written);
})
}
#[test]
pub fn write_error() {
todo!()
async_test(async {
let write = TestWrite::default();
let flag = write.cause_error.clone();
let task = WriteTask::default().spawn(write);
for byte in 0u8..20 {
if byte == 10 {
flag.set();
}
if task.write(vec![byte, byte + 100]).await.is_err() {
break;
}
}
let e = task.finish().await.unwrap_err();
if let WriteTaskError::Io { bytes_written, write, .. } = &e {
assert_eq!(write.bytes_written as u64, *bytes_written);
} else {
panic!("expected WriteTaskError::Io")
}
let (write, e) = e.unwrap_write();
assert!(write.bytes_written > 0);
assert_eq!("test error", e.to_string());
})
}
#[test]
pub fn write_panic() {
todo!()
async_test(async {
let write = TestWrite::default();
let flag = write.cause_panic.clone();
let task = WriteTask::default().spawn(write);
for byte in 0u8..20 {
if byte == 10 {
flag.set();
}
if task.write(vec![byte, byte + 100]).await.is_err() {
break;
}
}
let e = task.finish().await.unwrap_err();
if let WriteTaskError::Panic {
bytes_written,
panic_payload,
} = &e
{
assert!(*bytes_written > 0);
assert_eq!("test panic", panic_str(&panic_payload))
} else {
panic!("expected WriteTaskError::Panic")
}
})
}
#[test]
@ -3933,6 +4027,26 @@ pub mod tests {
})
}
#[test]
pub fn run_panic_handling_parallel() {
async_test(async {
let r = run_catch(async {
run(async {
timeout(Duration::from_millis(1)).await;
(0..100000).into_par_iter().for_each(|i| {
if i == 50005 {
panic!("test panic");
}
});
})
.await;
})
.await;
assert!(r.is_err());
})
}
#[derive(Default, Debug)]
pub struct TestRead {
bytes_read: usize,
@ -3977,6 +4091,7 @@ pub mod tests {
} else if self.cause_panic.is_set() {
panic!("test panic");
} else {
std::thread::sleep(Duration::from_millis(2));
self.bytes_written += buf.len();
Ok(buf.len())
}