From 146a593c8e3abea8bc4c1888ae6781a3f2e1422e Mon Sep 17 00:00:00 2001 From: Jeff Mendez Date: Tue, 25 Jun 2024 11:42:11 -0400 Subject: [PATCH] fix(tokio): add safe access join handles (#85) Fixes: https://github.com/zkat/cacache-rs/issues/84 --- src/async_lib.rs | 31 +++-- src/content/write.rs | 294 +++++++++++++++++++++++++------------------ 2 files changed, 192 insertions(+), 133 deletions(-) diff --git a/src/async_lib.rs b/src/async_lib.rs index 94506d7..5af05bb 100644 --- a/src/async_lib.rs +++ b/src/async_lib.rs @@ -100,8 +100,8 @@ pub fn unwrap_joinhandle_value(value: T) -> T { pub use tokio::task::JoinHandle; #[cfg(feature = "tokio")] #[inline] -pub fn unwrap_joinhandle_value(value: Result) -> T { - value.unwrap() +pub fn unwrap_joinhandle_value(value: T) -> T { + value } use tempfile::NamedTempFile; @@ -110,19 +110,28 @@ use crate::errors::IoErrorExt; #[cfg(feature = "async-std")] #[inline] -pub async fn create_named_tempfile(tmp_path: std::path::PathBuf) -> crate::Result { +pub async fn create_named_tempfile( + tmp_path: std::path::PathBuf, +) -> Option> { let cloned = tmp_path.clone(); - spawn_blocking(|| NamedTempFile::new_in(tmp_path)) - .await - .with_context(|| format!("Failed to create a temp file at {}", cloned.display())) + + Some( + spawn_blocking(|| NamedTempFile::new_in(tmp_path)) + .await + .with_context(|| format!("Failed to create a temp file at {}", cloned.display())), + ) } #[cfg(feature = "tokio")] #[inline] -pub async fn create_named_tempfile(tmp_path: std::path::PathBuf) -> crate::Result { +pub async fn create_named_tempfile( + tmp_path: std::path::PathBuf, +) -> Option> { let cloned = tmp_path.clone(); - spawn_blocking(|| NamedTempFile::new_in(tmp_path)) - .await - .unwrap() - .with_context(|| format!("Failed to create a temp file at {}", cloned.display())) + match spawn_blocking(|| NamedTempFile::new_in(tmp_path)).await { + Ok(ctx) => Some( + ctx.with_context(|| format!("Failed to create a temp file at {}", cloned.display())), + ), + _ => None, + } } diff --git a/src/content/write.rs b/src/content/write.rs index 3241c13..8b7961e 100644 --- a/src/content/write.rs +++ b/src/content/write.rs @@ -19,6 +19,7 @@ use tempfile::NamedTempFile; use crate::async_lib::{AsyncWrite, JoinHandle}; use crate::content::path; use crate::errors::{IoErrorExt, Result}; +use crate::Error; #[cfg(feature = "mmap")] pub const MAX_MMAP_SIZE: usize = 1024 * 1024; @@ -171,16 +172,25 @@ impl AsyncWriter { tmp_path.display() ) })?; - let mut tmpfile = crate::async_lib::create_named_tempfile(tmp_path).await?; - let mmap = make_mmap(&mut tmpfile, size)?; - Ok(AsyncWriter(Mutex::new(State::Idle(Some(Inner { - cache: cache_path, - builder: IntegrityOpts::new().algorithm(algo), - mmap, - tmpfile, - buf: vec![], - last_op: None, - }))))) + + match crate::async_lib::create_named_tempfile(tmp_path).await { + Some(tmpfile) => { + let mut tmpfile = tmpfile?; + let mmap = make_mmap(&mut tmpfile, size)?; + Ok(AsyncWriter(Mutex::new(State::Idle(Some(Inner { + cache: cache_path, + builder: IntegrityOpts::new().algorithm(algo), + mmap, + tmpfile, + buf: vec![], + last_op: None, + }))))) + } + _ => Err(Error::IoError( + std::io::Error::new(std::io::ErrorKind::Other, "temp file create error"), + "Possible memory issues for file handle".into(), + )), + } } pub async fn close(self) -> Result { @@ -247,9 +257,11 @@ impl AsyncWriter { }, // Poll the asynchronous operation the file is currently blocked on. State::Busy(task) => { - *state = crate::async_lib::unwrap_joinhandle_value(futures::ready!( - Pin::new(task).poll(cx) - )) + let next_state = crate::async_lib::unwrap_joinhandle_value( + futures::ready!(Pin::new(task).poll(cx)), + ); + + update_state(state, next_state); } } } @@ -270,108 +282,119 @@ impl AsyncWrite for AsyncWriter { cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { - let state = &mut *self.0.lock().unwrap(); - - loop { - match state { - State::Idle(opt) => { - // Grab a reference to the inner representation of the file or return an error - // if the file is closed. - let inner = opt - .as_mut() - .ok_or_else(|| crate::errors::io_error("file closed"))?; - - // Check if the operation has completed. - if let Some(Operation::Write(res)) = inner.last_op.take() { - let n = res?; - - // If more data was written than is available in the buffer, let's retry - // the write operation. - if n <= buf.len() { - return Poll::Ready(Ok(n)); - } - } else { - let mut inner = opt.take().unwrap(); + match self.0.lock() { + Ok(mut state) => { + let state = &mut *state; + + loop { + match state { + State::Idle(opt) => { + // Grab a reference to the inner representation of the file or return an error + // if the file is closed. + let inner = opt + .as_mut() + .ok_or_else(|| crate::errors::io_error("file closed"))?; + + // Check if the operation has completed. + if let Some(Operation::Write(res)) = inner.last_op.take() { + let n = res?; + + // If more data was written than is available in the buffer, let's retry + // the write operation. + if n <= buf.len() { + return Poll::Ready(Ok(n)); + } + } else { + let mut inner = opt.take().unwrap(); - // Set the length of the inner buffer to the length of the provided buffer. - if inner.buf.len() < buf.len() { - inner.buf.reserve(buf.len() - inner.buf.len()); - } - unsafe { - inner.buf.set_len(buf.len()); - } + // Set the length of the inner buffer to the length of the provided buffer. + if inner.buf.len() < buf.len() { + inner.buf.reserve(buf.len() - inner.buf.len()); + } + unsafe { + inner.buf.set_len(buf.len()); + } - // Copy the data to write into the inner buffer. - inner.buf[..buf.len()].copy_from_slice(buf); + // Copy the data to write into the inner buffer. + inner.buf[..buf.len()].copy_from_slice(buf); - // Start the operation asynchronously. - *state = State::Busy(crate::async_lib::spawn_blocking(|| { - inner.builder.input(&inner.buf); - if let Some(mmap) = &mut inner.mmap { - mmap.copy_from_slice(&inner.buf); - inner.last_op = Some(Operation::Write(Ok(inner.buf.len()))); - State::Idle(Some(inner)) - } else { - let res = inner.tmpfile.write(&inner.buf); - inner.last_op = Some(Operation::Write(res)); - State::Idle(Some(inner)) + // Start the operation asynchronously. + *state = State::Busy(crate::async_lib::spawn_blocking(|| { + inner.builder.input(&inner.buf); + if let Some(mmap) = &mut inner.mmap { + mmap.copy_from_slice(&inner.buf); + inner.last_op = Some(Operation::Write(Ok(inner.buf.len()))); + State::Idle(Some(inner)) + } else { + let res = inner.tmpfile.write(&inner.buf); + inner.last_op = Some(Operation::Write(res)); + State::Idle(Some(inner)) + } + })); } - })); + } + // Poll the asynchronous operation the file is currently blocked on. + State::Busy(task) => { + let next_state = crate::async_lib::unwrap_joinhandle_value( + futures::ready!(Pin::new(task).poll(cx)), + ); + + update_state(state, next_state); + } } } - // Poll the asynchronous operation the file is currently blocked on. - State::Busy(task) => { - *state = crate::async_lib::unwrap_joinhandle_value(futures::ready!(Pin::new( - task - ) - .poll(cx))) - } } + _ => Poll::Pending, } } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let state = &mut *self.0.lock().unwrap(); - - loop { - match state { - State::Idle(opt) => { - // Grab a reference to the inner representation of the file or return if the - // file is closed. - let inner = match opt.as_mut() { - None => return Poll::Ready(Ok(())), - Some(s) => s, - }; - - // Check if the operation has completed. - if let Some(Operation::Flush(res)) = inner.last_op.take() { - return Poll::Ready(res); - } else { - let mut inner = opt.take().unwrap(); - - if let Some(mmap) = &inner.mmap { - match mmap.flush_async() { - Ok(_) => (), - Err(e) => return Poll::Ready(Err(e)), + match self.0.lock() { + Ok(mut state) => { + let state = &mut *state; + loop { + match state { + State::Idle(opt) => { + // Grab a reference to the inner representation of the file or return if the + // file is closed. + let inner = match opt.as_mut() { + None => return Poll::Ready(Ok(())), + Some(s) => s, }; + + // Check if the operation has completed. + if let Some(Operation::Flush(res)) = inner.last_op.take() { + return Poll::Ready(res); + } else { + let mut inner = opt.take().unwrap(); + + if let Some(mmap) = &inner.mmap { + match mmap.flush_async() { + Ok(_) => (), + Err(e) => return Poll::Ready(Err(e)), + }; + } + + // Start the operation asynchronously. + *state = State::Busy(crate::async_lib::spawn_blocking(|| { + let res = inner.tmpfile.flush(); + inner.last_op = Some(Operation::Flush(res)); + State::Idle(Some(inner)) + })); + } } + // Poll the asynchronous operation the file is currently blocked on. + State::Busy(task) => { + let next_state = crate::async_lib::unwrap_joinhandle_value( + futures::ready!(Pin::new(task).poll(cx)), + ); - // Start the operation asynchronously. - *state = State::Busy(crate::async_lib::spawn_blocking(|| { - let res = inner.tmpfile.flush(); - inner.last_op = Some(Operation::Flush(res)); - State::Idle(Some(inner)) - })); + update_state(state, next_state); + } } } - // Poll the asynchronous operation the file is currently blocked on. - State::Busy(task) => { - *state = crate::async_lib::unwrap_joinhandle_value(futures::ready!(Pin::new( - task - ) - .poll(cx))) - } } + _ => Poll::Pending, } } @@ -386,6 +409,28 @@ impl AsyncWrite for AsyncWriter { } } +#[cfg(feature = "tokio")] +/// Update the state. +fn update_state( + current_state: &mut State, + next_state: std::result::Result, +) { + match next_state { + Ok(next) => { + *current_state = next; + } + _ => { + *current_state = State::Idle(None); + } + } +} + +#[cfg(not(feature = "tokio"))] +/// Update the state. +fn update_state(current_state: &mut State, next_state: State) { + *current_state = next_state; +} + #[cfg(any(feature = "async-std", feature = "tokio"))] impl AsyncWriter { #[inline] @@ -393,32 +438,37 @@ impl AsyncWriter { self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> Poll> { - let state = &mut *self.0.lock().unwrap(); - - loop { - match state { - State::Idle(opt) => { - // Grab a reference to the inner representation of the file or return if the - // file is closed. - let inner = match opt.take() { - None => return Poll::Ready(Ok(())), - Some(s) => s, - }; - - // Start the operation asynchronously. - *state = State::Busy(crate::async_lib::spawn_blocking(|| { - drop(inner); - State::Idle(None) - })); - } - // Poll the asynchronous operation the file is currently blocked on. - State::Busy(task) => { - *state = crate::async_lib::unwrap_joinhandle_value(futures::ready!(Pin::new( - task - ) - .poll(cx))) + match self.0.lock() { + Ok(mut state) => { + let state = &mut *state; + loop { + match state { + State::Idle(opt) => { + // Grab a reference to the inner representation of the file or return if the + // file is closed. + let inner = match opt.take() { + None => return Poll::Ready(Ok(())), + Some(s) => s, + }; + + // Start the operation asynchronously. + *state = State::Busy(crate::async_lib::spawn_blocking(|| { + drop(inner); + State::Idle(None) + })); + } + // Poll the asynchronous operation the file is currently blocked on. + State::Busy(task) => { + let next_state = crate::async_lib::unwrap_joinhandle_value( + futures::ready!(Pin::new(task).poll(cx)), + ); + + update_state(state, next_state); + } + } } } + _ => Poll::Pending, } } }