Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose zero-copy-ish headers on ClientRequest #35

Merged
merged 18 commits into from
Jun 11, 2021
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions examples/autobahn_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

use futures::io::{BufReader, BufWriter};
use soketto::{BoxedError, connection, handshake};
use tokio::{net::{TcpListener, TcpStream}};
use tokio::net::{TcpListener, TcpStream};
use tokio_util::compat::{Compat, TokioAsyncReadCompatExt};
use tokio_stream::{wrappers::TcpListenerStream, StreamExt};
#[tokio::main]
Expand All @@ -27,9 +27,9 @@ async fn main() -> Result<(), BoxedError> {
let mut server = new_server(socket?);
let key = {
let req = server.receive_request().await?;
req.into_key()
req.key()
};
let accept = handshake::server::Response::Accept { key: &key, protocol: None };
let accept = handshake::server::Response::Accept { key, protocol: None };
server.send_response(&accept).await?;
let (mut sender, mut receiver) = server.into_builder().finish();
let mut message = Vec::new();
Expand Down
23 changes: 23 additions & 0 deletions src/handshake.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,10 @@ pub enum Error {
Io(io::Error),
/// An HTTP version =/= 1.1 was encountered.
UnsupportedHttpVersion,
/// An incomplete HTTP request.
IncompleteHttpRequest,
/// The value of the `Sec-WebSocket-Key` header is of unexpected length.
SecWebSocketKeyInvalidLength(usize),
/// The handshake request was not a GET request.
InvalidRequestMethod,
/// An HTTP header has not been present.
Expand Down Expand Up @@ -160,6 +164,10 @@ impl fmt::Display for Error {
write!(f, "i/o error: {}", e),
Error::UnsupportedHttpVersion =>
f.write_str("http version was not 1.1"),
Error::IncompleteHttpRequest =>
f.write_str("http request was incomplete"),
Error::SecWebSocketKeyInvalidLength(len) =>
write!(f, "Sec-WebSocket-Key header was {} bytes longth, expected 24", len),
Error::InvalidRequestMethod =>
f.write_str("handshake was not a GET request"),
Error::HeaderNotFound(name) =>
Expand Down Expand Up @@ -190,6 +198,8 @@ impl std::error::Error for Error {
Error::Http(e) => Some(&**e),
Error::Utf8(e) => Some(e),
Error::UnsupportedHttpVersion
| Error::IncompleteHttpRequest
| Error::SecWebSocketKeyInvalidLength(_)
| Error::InvalidRequestMethod
| Error::HeaderNotFound(_)
| Error::UnexpectedHeader(_)
Expand All @@ -213,6 +223,19 @@ impl From<str::Utf8Error> for Error {
}
}

/// Owned value of the `Sec-WebSocket-Key` header.
///
/// Per [RFC 6455](https://datatracker.ietf.org/doc/html/rfc6455#section-4.1):
///
/// ```text
/// (...) The value of this header field MUST be a
/// nonce consisting of a randomly selected 16-byte value that has
/// been base64-encoded (see Section 4 of [RFC4648]). (...)
/// ```
///
/// Base64 encoding of the nonce produces 24 ASCII bytes, padding included.
pub type WebSocketKey = [u8; 24];

#[cfg(test)]
mod tests {
use super::expect_ascii_header;
Expand Down
15 changes: 6 additions & 9 deletions src/handshake/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use futures::prelude::*;
use sha1::{Digest, Sha1};
use std::{mem, str};
use super::{
WebSocketKey,
Error,
KEY,
MAX_NUM_HEADERS,
Expand All @@ -42,9 +43,7 @@ pub struct Client<'a, T> {
/// The HTTP origin header.
origin: Option<&'a str>,
/// A buffer holding the base-64 encoded request nonce.
nonce: [u8; 32],
/// The offset into the nonce buffer.
nonce_offset: usize,
nonce: WebSocketKey,
/// The protocols to include in the handshake.
protocols: Vec<&'a str>,
/// The extensions the client wishes to include in the request.
Expand All @@ -61,8 +60,7 @@ impl<'a, T: AsyncRead + AsyncWrite + Unpin> Client<'a, T> {
host,
resource,
origin: None,
nonce: [0; 32],
nonce_offset: 0,
nonce: [0; 24],
protocols: Vec::new(),
extensions: Vec::new(),
buffer: BytesMut::new()
Expand Down Expand Up @@ -136,15 +134,15 @@ impl<'a, T: AsyncRead + AsyncWrite + Unpin> Client<'a, T> {
/// Encode the client handshake as a request, ready to be sent to the server.
fn encode_request(&mut self) {
let nonce: [u8; 16] = rand::random();
self.nonce_offset = base64::encode_config_slice(&nonce, base64::STANDARD, &mut self.nonce);
base64::encode_config_slice(&nonce, base64::STANDARD, &mut self.nonce);
self.buffer.extend_from_slice(b"GET ");
self.buffer.extend_from_slice(self.resource.as_bytes());
self.buffer.extend_from_slice(b" HTTP/1.1");
self.buffer.extend_from_slice(b"\r\nHost: ");
self.buffer.extend_from_slice(self.host.as_bytes());
self.buffer.extend_from_slice(b"\r\nUpgrade: websocket\r\nConnection: Upgrade");
self.buffer.extend_from_slice(b"\r\nSec-WebSocket-Key: ");
self.buffer.extend_from_slice(&self.nonce[.. self.nonce_offset]);
self.buffer.extend_from_slice(&self.nonce);
if let Some(o) = &self.origin {
self.buffer.extend_from_slice(b"\r\nOrigin: ");
self.buffer.extend_from_slice(o.as_bytes())
Expand Down Expand Up @@ -194,10 +192,9 @@ impl<'a, T: AsyncRead + AsyncWrite + Unpin> Client<'a, T> {
expect_ascii_header(response.headers, "Upgrade", "websocket")?;
expect_ascii_header(response.headers, "Connection", "upgrade")?;

let nonce = &self.nonce[.. self.nonce_offset];
with_first_header(&response.headers, "Sec-WebSocket-Accept", |theirs| {
let mut digest = Sha1::new();
digest.update(nonce);
digest.update(&self.nonce);
digest.update(KEY);
let ours = base64::encode(&digest.finalize());
if ours.as_bytes() != theirs {
Expand Down
101 changes: 69 additions & 32 deletions src/handshake/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@
//!
//! [handshake]: https://tools.ietf.org/html/rfc6455#section-4

use bytes::{Buf, BytesMut};
use crate::{Parsing, extension::Extension};
use bytes::BytesMut;
use crate::extension::Extension;
use crate::connection::{self, Mode};
use futures::prelude::*;
use sha1::{Digest, Sha1};
use std::{mem, str};
use super::{
WebSocketKey,
Error,
KEY,
MAX_NUM_HEADERS,
Expand All @@ -28,6 +29,8 @@ use super::{
with_first_header
};

// Most HTTP servers default to 8KB limit on headers
const MAX_HEADERS_SIZE: usize = 8 * 1024;
const BLOCK_SIZE: usize = 8 * 1024;
const SOKETTO_VERSION: &str = env!("CARGO_PKG_VERSION");

Expand Down Expand Up @@ -83,15 +86,35 @@ impl<'a, T: AsyncRead + AsyncWrite + Unpin> Server<'a, T> {
}

/// Await an incoming client handshake request.
pub async fn receive_request(&mut self) -> Result<ClientRequest<'a>, Error> {
pub async fn receive_request(&mut self) -> Result<ClientRequest<'_>, Error> {
self.buffer.clear();

let mut skip = 0;

loop {
crate::read(&mut self.socket, &mut self.buffer, BLOCK_SIZE).await?;
if let Parsing::Done { value, offset } = self.decode_request()? {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Attempting decode after each read also made lifetimes convoluted since there is a mutable borrow to self.buffer made on every iteration. Opted to do a single parse after we are sure headers are complete in the new code.

self.buffer.advance(offset);
Copy link
Contributor Author

@maciejhirsz maciejhirsz Jun 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This mutable borrow of self.buffer prevented us from borrowing stuff in return value. The advance was not necessary since the buffer is never read after the headers (not just that, it should not contain any bytes after the request AFAIU).

return Ok(value)

let limit = std::cmp::min(self.buffer.len(), MAX_HEADERS_SIZE);

// We don't expect body, so can search for the CRLF headers tail from
// the end of the buffer.
if self.buffer[skip..limit].windows(4).rev().any(|w| w == b"\r\n\r\n") {
break;
}

// Give up if we've reached the limit. We could emit a specific error here,
// but httparse will produce meaningful error for us regardless.
if limit == MAX_HEADERS_SIZE {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, this will also emit an error if self.buffer.len() > MAX_HEADER_SIZE?

it's truncated below by std::cmp::min a couple lines above

break;
}

// Skip bytes that did not contain CRLF in the next iteration.
// If we only read a partial CRLF sequence, we would miss it if we skipped the full buffer
// length, hence backing off the full 4 bytes.
skip = self.buffer.len().saturating_sub(4);
}

self.decode_request()
}

/// Respond to the client.
Expand All @@ -118,32 +141,41 @@ impl<'a, T: AsyncRead + AsyncWrite + Unpin> Server<'a, T> {
}

// Decode client handshake request.
fn decode_request(&mut self) -> Result<Parsing<ClientRequest<'a>>, Error> {
fn decode_request(&mut self) -> Result<ClientRequest, Error> {
let mut header_buf = [httparse::EMPTY_HEADER; MAX_NUM_HEADERS];
let mut request = httparse::Request::new(&mut header_buf);

let offset = match request.parse(self.buffer.as_ref()) {
Ok(httparse::Status::Complete(off)) => off,
Ok(httparse::Status::Partial) => return Ok(Parsing::NeedMore(())),
match request.parse(self.buffer.as_ref()) {
Ok(httparse::Status::Complete(_)) => (),
Ok(httparse::Status::Partial) => return Err(Error::IncompleteHttpRequest),
Err(e) => return Err(Error::Http(Box::new(e)))
};

if request.method != Some("GET") {
return Err(Error::InvalidRequestMethod)
}
if request.version != Some(1) {
return Err(Error::UnsupportedHttpVersion)
}

// TODO: Host Validation
with_first_header(&request.headers, "Host", |_h| Ok(()))?;
let host = with_first_header(&request.headers, "Host", |h| Ok(h))?;

expect_ascii_header(request.headers, "Upgrade", "websocket")?;
expect_ascii_header(request.headers, "Connection", "upgrade")?;
expect_ascii_header(request.headers, "Sec-WebSocket-Version", "13")?;

let origin = request.headers.iter().find_map(|h| {
if h.name.eq_ignore_ascii_case("Origin") {
Some(h.value)
} else {
None
}
});
let headers = RequestHeaders { host, origin };

let ws_key = with_first_header(&request.headers, "Sec-WebSocket-Key", |k| {
Ok(Vec::from(k))
use std::convert::TryFrom;

WebSocketKey::try_from(k).map_err(|_| Error::SecWebSocketKeyInvalidLength(k.len()))
})?;

for h in request.headers.iter()
Expand All @@ -161,14 +193,9 @@ impl<'a, T: AsyncRead + AsyncWrite + Unpin> Server<'a, T> {
}
}

let mut path = String::new();
if let Some(val) = request.path {
path.push_str(val)
}
let path = request.path.unwrap_or("/");

Ok(Parsing::Done {
value: ClientRequest { ws_key, protocols, path }, offset,
})
Ok(ClientRequest { ws_key, protocols, path, headers })
}

// Encode server handshake response.
Expand Down Expand Up @@ -217,18 +244,24 @@ impl<'a, T: AsyncRead + AsyncWrite + Unpin> Server<'a, T> {
/// Handshake request received from the client.
#[derive(Debug)]
pub struct ClientRequest<'a> {
ws_key: Vec<u8>,
ws_key: WebSocketKey,
protocols: Vec<&'a str>,
path: String,
path: &'a str,
headers: RequestHeaders<'a>,
}

impl<'a> ClientRequest<'a> {
/// A reference to the nonce.
pub fn key(&self) -> &[u8] {
&self.ws_key
}
/// Select HTTP headers sent by the client.
#[derive(Debug, Copy, Clone)]
pub struct RequestHeaders<'a> {
/// The [`Host`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Host) header.
pub host: &'a [u8],
/// The [`Origin`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Origin) header, if provided.
pub origin: Option<&'a [u8]>,
}

pub fn into_key(self) -> Vec<u8> {
impl<'a> ClientRequest<'a> {
/// The `Sec-WebSocket-Key` header nonce value.
pub fn key(&self) -> WebSocketKey {
self.ws_key
}

Expand All @@ -239,7 +272,12 @@ impl<'a> ClientRequest<'a> {

/// The path the client is requesting.
pub fn path(&self) -> &str {
&self.path
self.path
}

/// Select HTTP headers sent by the client.
pub fn headers(&self) -> RequestHeaders {
self.headers
}
}

Expand All @@ -248,7 +286,7 @@ impl<'a> ClientRequest<'a> {
pub enum Response<'a> {
/// The server accepts the handshake request.
Accept {
key: &'a [u8],
key: WebSocketKey,
protocol: Option<&'a str>
},
/// The server rejects the handshake request.
Expand Down Expand Up @@ -320,4 +358,3 @@ const STATUSCODES: &[(u16, &str, &str)] = &[
(510, "510", "Not Extended"),
(511, "511", "Network Authentication Required")
];

4 changes: 2 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,11 @@
//!
//! let websocket_key = {
//! let req = server.receive_request().await?;
//! req.into_key()
//! req.key()
//! };
//!
//! // Here we accept the client unconditionally.
//! let accept = Response::Accept { key: &websocket_key, protocol: None };
//! let accept = Response::Accept { key: websocket_key, protocol: None };
//! server.send_response(&accept).await?;
//!
//! // And we can finally transition to a websocket connection.
Expand Down