Skip to content

Commit

Permalink
Support non-contiguous put payloads (apache#5514)
Browse files Browse the repository at this point in the history
  • Loading branch information
tustvold committed Apr 4, 2024
1 parent 5a0baf1 commit 29d3393
Show file tree
Hide file tree
Showing 22 changed files with 498 additions and 244 deletions.
15 changes: 0 additions & 15 deletions object_store/src/aws/checksum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
// under the License.

use crate::config::Parse;
use ring::digest::{self, digest as ring_digest};
use std::str::FromStr;

#[allow(non_camel_case_types)]
Expand All @@ -27,20 +26,6 @@ pub enum Checksum {
SHA256,
}

impl Checksum {
pub(super) fn digest(&self, bytes: &[u8]) -> Vec<u8> {
match self {
Self::SHA256 => ring_digest(&digest::SHA256, bytes).as_ref().to_owned(),
}
}

pub(super) fn header_name(&self) -> &'static str {
match self {
Self::SHA256 => "x-amz-checksum-sha256",
}
}
}

impl std::fmt::Display for Checksum {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match &self {
Expand Down
62 changes: 30 additions & 32 deletions object_store/src/aws/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ use crate::client::GetOptionsExt;
use crate::multipart::PartId;
use crate::path::DELIMITER;
use crate::{
ClientOptions, GetOptions, ListResult, MultipartId, Path, PutResult, Result, RetryConfig,
ClientOptions, GetOptions, ListResult, MultipartId, Path, PutPayload, PutResult, Result,
RetryConfig,
};
use async_trait::async_trait;
use base64::prelude::BASE64_STANDARD;
Expand All @@ -51,11 +52,14 @@ use reqwest::{
header::{CONTENT_LENGTH, CONTENT_TYPE},
Client as ReqwestClient, Method, RequestBuilder, Response,
};
use ring::digest;
use ring::digest::Context;
use serde::{Deserialize, Serialize};
use snafu::{ResultExt, Snafu};
use std::sync::Arc;

const VERSION_HEADER: &str = "x-amz-version-id";
const SHA256_CHECKSUM: &str = "x-amz-checksum-sha256";

/// A specialized `Error` for object store-related errors
#[derive(Debug, Snafu)]
Expand Down Expand Up @@ -266,7 +270,8 @@ pub(crate) struct Request<'a> {
path: &'a Path,
config: &'a S3Config,
builder: RequestBuilder,
payload_sha256: Option<Vec<u8>>,
payload_sha256: Option<digest::Digest>,
payload: Option<PutPayload>,
use_session_creds: bool,
}

Expand Down Expand Up @@ -295,10 +300,12 @@ impl<'a> Request<'a> {
},
};

let sha = self.payload_sha256.as_ref().map(|x| x.as_ref());

let path = self.path.as_ref();
self.builder
.with_aws_sigv4(credential.authorizer(), self.payload_sha256.as_deref())
.send_retry(&self.config.retry_config)
.with_aws_sigv4(credential.authorizer(), sha)
.send_retry_payload(&self.config.retry_config, self.payload)
.await
.context(RetrySnafu { path })
}
Expand Down Expand Up @@ -327,37 +334,35 @@ impl S3Client {
pub fn put_request<'a>(
&'a self,
path: &'a Path,
bytes: Bytes,
payload: PutPayload,
with_encryption_headers: bool,
) -> Request<'a> {
let url = self.config.path_url(path);
let mut builder = self.client.request(Method::PUT, url);
if with_encryption_headers {
builder = builder.headers(self.config.encryption_headers.clone().into());
}
let mut payload_sha256 = None;

if let Some(checksum) = self.config.checksum {
let digest = checksum.digest(&bytes);
builder = builder.header(checksum.header_name(), BASE64_STANDARD.encode(&digest));
if checksum == Checksum::SHA256 {
payload_sha256 = Some(digest);
}
}
let mut sha256 = Context::new(&digest::SHA256);
payload.iter().for_each(|x| sha256.update(x));
let payload_sha256 = sha256.finish();

builder = match bytes.is_empty() {
true => builder.header(CONTENT_LENGTH, 0), // Handle empty uploads (#4514)
false => builder.body(bytes),
};
if let Some(Checksum::SHA256) = self.config.checksum {
builder = builder.header(
"x-amz-checksum-sha256",
BASE64_STANDARD.encode(payload_sha256),
)
}

if let Some(value) = self.config.client_options.get_content_type(path) {
builder = builder.header(CONTENT_TYPE, value);
}

Request {
path,
builder,
payload_sha256,
builder: builder.header(CONTENT_LENGTH, payload.content_length()),
payload: Some(payload),
payload_sha256: Some(payload_sha256),
config: &self.config,
use_session_creds: true,
}
Expand Down Expand Up @@ -439,16 +444,8 @@ impl S3Client {

let mut builder = self.client.request(Method::POST, url);

// Compute checksum - S3 *requires* this for DeleteObjects requests, so we default to
// their algorithm if the user hasn't specified one.
let checksum = self.config.checksum.unwrap_or(Checksum::SHA256);
let digest = checksum.digest(&body);
builder = builder.header(checksum.header_name(), BASE64_STANDARD.encode(&digest));
let payload_sha256 = if checksum == Checksum::SHA256 {
Some(digest)
} else {
None
};
let digest = digest::digest(&digest::SHA256, &body);
builder = builder.header(SHA256_CHECKSUM, BASE64_STANDARD.encode(digest));

// S3 *requires* DeleteObjects to include a Content-MD5 header:
// https://docs.aws.amazon.com/AmazonS3/latest/API/API_DeleteObjects.html
Expand All @@ -461,7 +458,7 @@ impl S3Client {
let response = builder
.header(CONTENT_TYPE, "application/xml")
.body(body)
.with_aws_sigv4(credential.authorizer(), payload_sha256.as_deref())
.with_aws_sigv4(credential.authorizer(), Some(digest.as_ref()))
.send_retry(&self.config.retry_config)
.await
.context(DeleteObjectsRequestSnafu {})?
Expand Down Expand Up @@ -508,6 +505,7 @@ impl S3Client {
builder,
path: from,
config: &self.config,
payload: None,
payload_sha256: None,
use_session_creds: false,
}
Expand Down Expand Up @@ -540,7 +538,7 @@ impl S3Client {
path: &Path,
upload_id: &MultipartId,
part_idx: usize,
data: Bytes,
data: PutPayload,
) -> Result<PartId> {
let part = (part_idx + 1).to_string();

Expand All @@ -564,7 +562,7 @@ impl S3Client {
// If no parts were uploaded, upload an empty part
// otherwise the completion request will fail
let part = self
.put_part(location, &upload_id.to_string(), 0, Bytes::new())
.put_part(location, &upload_id.to_string(), 0, PutPayload::default())
.await?;
vec![part]
} else {
Expand Down
21 changes: 12 additions & 9 deletions object_store/src/aws/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
//! [automatic cleanup]: https://aws.amazon.com/blogs/aws/s3-lifecycle-management-update-support-for-multipart-uploads-and-delete-markers/
use async_trait::async_trait;
use bytes::Bytes;
use futures::stream::BoxStream;
use futures::{StreamExt, TryStreamExt};
use reqwest::header::{HeaderName, IF_MATCH, IF_NONE_MATCH};
Expand All @@ -46,7 +45,7 @@ use crate::signer::Signer;
use crate::util::STRICT_ENCODE_SET;
use crate::{
Error, GetOptions, GetResult, ListResult, MultipartId, MultipartUpload, ObjectMeta,
ObjectStore, Path, PutMode, PutOptions, PutResult, Result, UploadPart,
ObjectStore, Path, PutMode, PutOptions, PutPayload, PutResult, Result, UploadPart,
};

static TAGS_HEADER: HeaderName = HeaderName::from_static("x-amz-tagging");
Expand Down Expand Up @@ -151,8 +150,13 @@ impl Signer for AmazonS3 {

#[async_trait]
impl ObjectStore for AmazonS3 {
async fn put_opts(&self, location: &Path, bytes: Bytes, opts: PutOptions) -> Result<PutResult> {
let mut request = self.client.put_request(location, bytes, true);
async fn put_opts(
&self,
location: &Path,
payload: PutPayload,
opts: PutOptions,
) -> Result<PutResult> {
let mut request = self.client.put_request(location, payload, true);
let tags = opts.tags.encoded();
if !tags.is_empty() && !self.client.config.disable_tagging {
request = request.header(&TAGS_HEADER, tags);
Expand Down Expand Up @@ -316,7 +320,7 @@ struct UploadState {

#[async_trait]
impl MultipartUpload for S3MultiPartUpload {
fn put_part(&mut self, data: Bytes) -> UploadPart {
fn put_part(&mut self, data: PutPayload) -> UploadPart {
let idx = self.part_idx;
self.part_idx += 1;
let state = Arc::clone(&self.state);
Expand Down Expand Up @@ -358,7 +362,7 @@ impl MultipartStore for AmazonS3 {
path: &Path,
id: &MultipartId,
part_idx: usize,
data: Bytes,
data: PutPayload,
) -> Result<PartId> {
self.client.put_part(path, id, part_idx, data).await
}
Expand All @@ -381,7 +385,6 @@ impl MultipartStore for AmazonS3 {
mod tests {
use super::*;
use crate::{client::get::GetClient, tests::*};
use bytes::Bytes;
use hyper::HeaderMap;

const NON_EXISTENT_NAME: &str = "nonexistentname";
Expand Down Expand Up @@ -470,7 +473,7 @@ mod tests {
let integration = config.build().unwrap();

let location = Path::from_iter([NON_EXISTENT_NAME]);
let data = Bytes::from("arbitrary data");
let data = PutPayload::from("arbitrary data");

let err = integration.put(&location, data).await.unwrap_err();
assert!(matches!(err, crate::Error::NotFound { .. }), "{}", err);
Expand Down Expand Up @@ -527,7 +530,7 @@ mod tests {
async fn s3_encryption(store: &AmazonS3) {
crate::test_util::maybe_skip_integration!();

let data = Bytes::from(vec![3u8; 1024]);
let data = PutPayload::from(vec![3u8; 1024]);

let encryption_headers: HeaderMap = store.client.config.encryption_headers.clone().into();
let expected_encryption =
Expand Down
33 changes: 21 additions & 12 deletions object_store/src/azure/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ use crate::multipart::PartId;
use crate::path::DELIMITER;
use crate::util::{deserialize_rfc1123, GetRange};
use crate::{
ClientOptions, GetOptions, ListResult, ObjectMeta, Path, PutMode, PutOptions, PutResult,
Result, RetryConfig,
ClientOptions, GetOptions, ListResult, ObjectMeta, Path, PutMode, PutOptions, PutPayload,
PutResult, Result, RetryConfig,
};
use async_trait::async_trait;
use base64::prelude::BASE64_STANDARD;
Expand Down Expand Up @@ -171,6 +171,7 @@ impl AzureConfig {
struct PutRequest<'a> {
path: &'a Path,
config: &'a AzureConfig,
payload: PutPayload,
builder: RequestBuilder,
}

Expand All @@ -189,8 +190,9 @@ impl<'a> PutRequest<'a> {
let credential = self.config.get_credential().await?;
let response = self
.builder
.header(CONTENT_LENGTH, self.payload.content_length())
.with_azure_authorization(&credential, &self.config.account)
.send_retry(&self.config.retry_config)
.send_retry_payload(&self.config.retry_config, Some(self.payload))
.await
.context(PutRequestSnafu {
path: self.path.as_ref(),
Expand Down Expand Up @@ -222,7 +224,7 @@ impl AzureClient {
self.config.get_credential().await
}

fn put_request<'a>(&'a self, path: &'a Path, bytes: Bytes) -> PutRequest<'a> {
fn put_request<'a>(&'a self, path: &'a Path, payload: PutPayload) -> PutRequest<'a> {
let url = self.config.path_url(path);

let mut builder = self.client.request(Method::PUT, url);
Expand All @@ -231,20 +233,22 @@ impl AzureClient {
builder = builder.header(CONTENT_TYPE, value);
}

builder = builder
.header(CONTENT_LENGTH, HeaderValue::from(bytes.len()))
.body(bytes);

PutRequest {
path,
builder,
payload,
config: &self.config,
}
}

/// Make an Azure PUT request <https://docs.microsoft.com/en-us/rest/api/storageservices/put-blob>
pub async fn put_blob(&self, path: &Path, bytes: Bytes, opts: PutOptions) -> Result<PutResult> {
let builder = self.put_request(path, bytes);
pub async fn put_blob(
&self,
path: &Path,
payload: PutPayload,
opts: PutOptions,
) -> Result<PutResult> {
let builder = self.put_request(path, payload);

let builder = match &opts.mode {
PutMode::Overwrite => builder,
Expand All @@ -265,11 +269,16 @@ impl AzureClient {
}

/// PUT a block <https://learn.microsoft.com/en-us/rest/api/storageservices/put-block>
pub async fn put_block(&self, path: &Path, part_idx: usize, data: Bytes) -> Result<PartId> {
pub async fn put_block(
&self,
path: &Path,
part_idx: usize,
payload: PutPayload,
) -> Result<PartId> {
let content_id = format!("{part_idx:20}");
let block_id = BASE64_STANDARD.encode(&content_id);

self.put_request(path, data)
self.put_request(path, payload)
.query(&[("comp", "block"), ("blockid", &block_id)])
.send()
.await?;
Expand Down
Loading

0 comments on commit 29d3393

Please sign in to comment.