Skip to content

Commit

Permalink
Remove unsafe methods and clean up MaybeOwned enums
Browse files Browse the repository at this point in the history
  • Loading branch information
DouglasDwyer committed Jun 18, 2024
1 parent 90a526d commit 5b3ce01
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 78 deletions.
158 changes: 84 additions & 74 deletions src/stream/raw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
//!
//! They are mostly thin wrappers around `zstd_safe::{DCtx, CCtx}`.
use std::io;
use std::marker::PhantomData;
use std::ops::{Deref, DerefMut};

pub use zstd_safe::{CParameter, DParameter, InBuffer, OutBuffer, WriteBuf};

Expand Down Expand Up @@ -134,7 +132,7 @@ pub struct Status {

/// An in-memory decoder for streams of data.
pub struct Decoder<'a> {
context: MaybeOwned<'a, zstd_safe::DCtx<'a>>,
context: MaybeOwnedDCtx<'a>,
}

impl Decoder<'static> {
Expand All @@ -150,14 +148,18 @@ impl Decoder<'static> {
context
.load_dictionary(dictionary)
.map_err(map_error_code)?;
Ok(Decoder { context: MaybeOwned::owned(context) })
Ok(Decoder {
context: MaybeOwnedDCtx::Owned(context),
})
}
}

impl<'a> Decoder<'a> {
/// Creates a new decoder which employs the provided context for deserialization.
pub fn with_context<'b: 'a>(context: &'a mut zstd_safe::DCtx<'b>) -> Self {
Self { context: MaybeOwned::borrowed(context) }
pub fn with_context(context: &'a mut zstd_safe::DCtx<'static>) -> Self {
Self {
context: MaybeOwnedDCtx::Borrowed(context),
}
}

/// Creates a new decoder, using an existing `DecoderDictionary`.
Expand All @@ -171,14 +173,18 @@ impl<'a> Decoder<'a> {
context
.ref_ddict(dictionary.as_ddict())
.map_err(map_error_code)?;
Ok(Decoder { context: MaybeOwned::owned(context) })
Ok(Decoder {
context: MaybeOwnedDCtx::Owned(context),
})
}

/// Sets a decompression parameter for this decoder.
pub fn set_parameter(&mut self, parameter: DParameter) -> io::Result<()> {
self.context
.set_parameter(parameter)
.map_err(map_error_code)?;
match &mut self.context {
MaybeOwnedDCtx::Owned(x) => x.set_parameter(parameter),
MaybeOwnedDCtx::Borrowed(x) => x.set_parameter(parameter),
}
.map_err(map_error_code)?;
Ok(())
}
}
Expand All @@ -189,9 +195,11 @@ impl Operation for Decoder<'_> {
input: &mut InBuffer<'_>,
output: &mut OutBuffer<'_, C>,
) -> io::Result<usize> {
self.context
.decompress_stream(output, input)
.map_err(map_error_code)
match &mut self.context {
MaybeOwnedDCtx::Owned(x) => x.decompress_stream(output, input),
MaybeOwnedDCtx::Borrowed(x) => x.decompress_stream(output, input),
}
.map_err(map_error_code)
}

fn flush<C: WriteBuf + ?Sized>(
Expand All @@ -212,9 +220,15 @@ impl Operation for Decoder<'_> {
}

fn reinit(&mut self) -> io::Result<()> {
self.context
.reset(zstd_safe::ResetDirective::SessionOnly)
.map_err(map_error_code)?;
match &mut self.context {
MaybeOwnedDCtx::Owned(x) => {
x.reset(zstd_safe::ResetDirective::SessionOnly)
}
MaybeOwnedDCtx::Borrowed(x) => {
x.reset(zstd_safe::ResetDirective::SessionOnly)
}
}
.map_err(map_error_code)?;
Ok(())
}

Expand All @@ -236,7 +250,7 @@ impl Operation for Decoder<'_> {

/// An in-memory encoder for streams of data.
pub struct Encoder<'a> {
context: MaybeOwned<'a, zstd_safe::CCtx<'a>>,
context: MaybeOwnedCCtx<'a>,
}

impl Encoder<'static> {
Expand All @@ -257,14 +271,18 @@ impl Encoder<'static> {
.load_dictionary(dictionary)
.map_err(map_error_code)?;

Ok(Encoder { context: MaybeOwned::owned(context) })
Ok(Encoder {
context: MaybeOwnedCCtx::Owned(context),
})
}
}

impl<'a> Encoder<'a> {
/// Creates a new encoder that uses the provided context for serialization.
pub fn with_context<'b: 'a>(context: &'a mut zstd_safe::CCtx<'b>) -> Self {
Self { context: MaybeOwned::borrowed(context) }
pub fn with_context(context: &'a mut zstd_safe::CCtx<'static>) -> Self {
Self {
context: MaybeOwnedCCtx::Borrowed(context),
}
}

/// Creates a new encoder using an existing `EncoderDictionary`.
Expand All @@ -278,14 +296,18 @@ impl<'a> Encoder<'a> {
context
.ref_cdict(dictionary.as_cdict())
.map_err(map_error_code)?;
Ok(Encoder { context: MaybeOwned::owned(context) })
Ok(Encoder {
context: MaybeOwnedCCtx::Owned(context),
})
}

/// Sets a compression parameter for this encoder.
pub fn set_parameter(&mut self, parameter: CParameter) -> io::Result<()> {
self.context
.set_parameter(parameter)
.map_err(map_error_code)?;
match &mut self.context {
MaybeOwnedCCtx::Owned(x) => x.set_parameter(parameter),
MaybeOwnedCCtx::Borrowed(x) => x.set_parameter(parameter),
}
.map_err(map_error_code)?;
Ok(())
}

Expand All @@ -301,9 +323,15 @@ impl<'a> Encoder<'a> {
&mut self,
pledged_src_size: Option<u64>,
) -> io::Result<()> {
self.context
.set_pledged_src_size(pledged_src_size)
.map_err(map_error_code)?;
match &mut self.context {
MaybeOwnedCCtx::Owned(x) => {
x.set_pledged_src_size(pledged_src_size)
}
MaybeOwnedCCtx::Borrowed(x) => {
x.set_pledged_src_size(pledged_src_size)
}
}
.map_err(map_error_code)?;
Ok(())
}
}
Expand All @@ -314,77 +342,59 @@ impl<'a> Operation for Encoder<'a> {
input: &mut InBuffer<'_>,
output: &mut OutBuffer<'_, C>,
) -> io::Result<usize> {
self.context
.compress_stream(output, input)
.map_err(map_error_code)
match &mut self.context {
MaybeOwnedCCtx::Owned(x) => x.compress_stream(output, input),
MaybeOwnedCCtx::Borrowed(x) => x.compress_stream(output, input),
}
.map_err(map_error_code)
}

fn flush<C: WriteBuf + ?Sized>(
&mut self,
output: &mut OutBuffer<'_, C>,
) -> io::Result<usize> {
self.context.flush_stream(output).map_err(map_error_code)
match &mut self.context {
MaybeOwnedCCtx::Owned(x) => x.flush_stream(output),
MaybeOwnedCCtx::Borrowed(x) => x.flush_stream(output),
}
.map_err(map_error_code)
}

fn finish<C: WriteBuf + ?Sized>(
&mut self,
output: &mut OutBuffer<'_, C>,
_finished_frame: bool,
) -> io::Result<usize> {
self.context.end_stream(output).map_err(map_error_code)
match &mut self.context {
MaybeOwnedCCtx::Owned(x) => x.end_stream(output),
MaybeOwnedCCtx::Borrowed(x) => x.end_stream(output),
}
.map_err(map_error_code)
}

fn reinit(&mut self) -> io::Result<()> {
self.context
.reset(zstd_safe::ResetDirective::SessionOnly)
.map_err(map_error_code)?;
Ok(())
}
}

struct MaybeOwned<'a, T>(MaybeOwnedInner<'a, T>);

impl<'a, T> MaybeOwned<'a, T> {
pub fn owned(value: T) -> Self {
Self(MaybeOwnedInner::Owned(value))
}

pub fn borrowed(value: &'a mut T) -> Self {
Self(MaybeOwnedInner::Borrowed((value as *mut T) as *mut _, PhantomData))
}
}

impl<'a, T> Deref for MaybeOwned<'a, T> {
type Target = T;

fn deref(&self) -> &Self::Target {
unsafe {
match &self.0 {
MaybeOwnedInner::Owned(x) => x,
MaybeOwnedInner::Borrowed(x, _) => &*(*x as *mut _)
match &mut self.context {
MaybeOwnedCCtx::Owned(x) => {
x.reset(zstd_safe::ResetDirective::SessionOnly)
}
}
}
}

impl<'a, T> DerefMut for MaybeOwned<'a, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
unsafe {
match &mut self.0 {
MaybeOwnedInner::Owned(x) => x,
MaybeOwnedInner::Borrowed(x, _) => &mut *(*x as *mut _)
MaybeOwnedCCtx::Borrowed(x) => {
x.reset(zstd_safe::ResetDirective::SessionOnly)
}
}
.map_err(map_error_code)?;
Ok(())
}
}

enum MaybeOwnedInner<'a, T> {
Owned(T),
Borrowed(*mut (), PhantomData<&'a ()>)
enum MaybeOwnedCCtx<'a> {
Owned(zstd_safe::CCtx<'a>),
Borrowed(&'a mut zstd_safe::CCtx<'static>),
}

unsafe impl<'a, T: Send> Send for MaybeOwned<'a, T> {}
unsafe impl<'a, T: Sync> Sync for MaybeOwned<'a, T> {}
enum MaybeOwnedDCtx<'a> {
Owned(zstd_safe::DCtx<'a>),
Borrowed(&'a mut zstd_safe::DCtx<'static>),
}

#[cfg(test)]
mod tests {
Expand Down
12 changes: 10 additions & 2 deletions src/stream/read/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,16 @@ impl<R: BufRead> Decoder<'static, R> {
}
impl<'a, R: BufRead> Decoder<'a, R> {
/// Creates a new decoder which employs the provided context for deserialization.
pub fn with_context<'b: 'a>(reader: R, context: &'a mut zstd_safe::DCtx<'b>) -> Self {
Self { reader: zio::Reader::new(reader, raw::Decoder::with_context(context)) }
pub fn with_context(
reader: R,
context: &'a mut zstd_safe::DCtx<'static>,
) -> Self {
Self {
reader: zio::Reader::new(
reader,
raw::Decoder::with_context(context),
),
}
}

/// Sets this `Decoder` to stop after the first frame.
Expand Down
12 changes: 10 additions & 2 deletions src/stream/write/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,8 +194,16 @@ impl<W: Write> Encoder<'static, W> {

impl<'a, W: Write> Encoder<'a, W> {
/// Creates an encoder that uses the provided context to compress a stream.
pub fn with_context<'b: 'a>(writer: W, context: &'a mut zstd_safe::CCtx<'b>) -> Self {
Self { writer: zio::Writer::new(writer, raw::Encoder::with_context(context)) }
pub fn with_context(
writer: W,
context: &'a mut zstd_safe::CCtx<'static>,
) -> Self {
Self {
writer: zio::Writer::new(
writer,
raw::Encoder::with_context(context),
),
}
}

/// Creates a new encoder, using an existing prepared `EncoderDictionary`.
Expand Down

0 comments on commit 5b3ce01

Please sign in to comment.