Skip to content

Commit

Permalink
Expose zero-copy-ish headers on ClientRequest (#35)
Browse files Browse the repository at this point in the history
* Expose zero-copy-ish headers on ClientRequest

* Minor cleanup

* Cleaner way to check for CRLF header tail

* Keep exposed headers as &[u8]

* Make that Sec-WebSocket-Key `ArrayVec` is `WebSocketKey`

* More pedantic `Sec-WebSocket-Key` handling

* Comment on backtracking 4 bytes

Co-authored-by: David <[email protected]>

* Limit headers to 8kb

* Clarify ClientRequest::key() comment

* Apply suggestions from code review

Return copied `RequestHeaders`.

Co-authored-by: Pierre Krieger <[email protected]>

* Use `try_from` instead of manually checking length

* Document ClientRequest::headers

* Unnecessary lifetime

* Pedantic camel case

* Fix doc comment

* Fix cargo doc warning

* .find().is_some() -> .any()

* Fix one clippy nit

Co-authored-by: David <[email protected]>
Co-authored-by: Pierre Krieger <[email protected]>
  • Loading branch information
3 people authored Jun 11, 2021
1 parent 38f5254 commit b18992e
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 46 deletions.
6 changes: 3 additions & 3 deletions examples/autobahn_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

use futures::io::{BufReader, BufWriter};
use soketto::{BoxedError, connection, handshake};
use tokio::{net::{TcpListener, TcpStream}};
use tokio::net::{TcpListener, TcpStream};
use tokio_util::compat::{Compat, TokioAsyncReadCompatExt};
use tokio_stream::{wrappers::TcpListenerStream, StreamExt};
#[tokio::main]
Expand All @@ -27,9 +27,9 @@ async fn main() -> Result<(), BoxedError> {
let mut server = new_server(socket?);
let key = {
let req = server.receive_request().await?;
req.into_key()
req.key()
};
let accept = handshake::server::Response::Accept { key: &key, protocol: None };
let accept = handshake::server::Response::Accept { key, protocol: None };
server.send_response(&accept).await?;
let (mut sender, mut receiver) = server.into_builder().finish();
let mut message = Vec::new();
Expand Down
23 changes: 23 additions & 0 deletions src/handshake.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,10 @@ pub enum Error {
Io(io::Error),
/// An HTTP version =/= 1.1 was encountered.
UnsupportedHttpVersion,
/// An incomplete HTTP request.
IncompleteHttpRequest,
/// The value of the `Sec-WebSocket-Key` header is of unexpected length.
SecWebSocketKeyInvalidLength(usize),
/// The handshake request was not a GET request.
InvalidRequestMethod,
/// An HTTP header has not been present.
Expand Down Expand Up @@ -160,6 +164,10 @@ impl fmt::Display for Error {
write!(f, "i/o error: {}", e),
Error::UnsupportedHttpVersion =>
f.write_str("http version was not 1.1"),
Error::IncompleteHttpRequest =>
f.write_str("http request was incomplete"),
Error::SecWebSocketKeyInvalidLength(len) =>
write!(f, "Sec-WebSocket-Key header was {} bytes longth, expected 24", len),
Error::InvalidRequestMethod =>
f.write_str("handshake was not a GET request"),
Error::HeaderNotFound(name) =>
Expand Down Expand Up @@ -190,6 +198,8 @@ impl std::error::Error for Error {
Error::Http(e) => Some(&**e),
Error::Utf8(e) => Some(e),
Error::UnsupportedHttpVersion
| Error::IncompleteHttpRequest
| Error::SecWebSocketKeyInvalidLength(_)
| Error::InvalidRequestMethod
| Error::HeaderNotFound(_)
| Error::UnexpectedHeader(_)
Expand All @@ -213,6 +223,19 @@ impl From<str::Utf8Error> for Error {
}
}

/// Owned value of the `Sec-WebSocket-Key` header.
///
/// Per [RFC 6455](https://datatracker.ietf.org/doc/html/rfc6455#section-4.1):
///
/// ```text
/// (...) The value of this header field MUST be a
/// nonce consisting of a randomly selected 16-byte value that has
/// been base64-encoded (see Section 4 of [RFC4648]). (...)
/// ```
///
/// Base64 encoding of the nonce produces 24 ASCII bytes, padding included.
pub type WebSocketKey = [u8; 24];

#[cfg(test)]
mod tests {
use super::expect_ascii_header;
Expand Down
15 changes: 6 additions & 9 deletions src/handshake/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use futures::prelude::*;
use sha1::{Digest, Sha1};
use std::{mem, str};
use super::{
WebSocketKey,
Error,
KEY,
MAX_NUM_HEADERS,
Expand All @@ -42,9 +43,7 @@ pub struct Client<'a, T> {
/// The HTTP origin header.
origin: Option<&'a str>,
/// A buffer holding the base-64 encoded request nonce.
nonce: [u8; 32],
/// The offset into the nonce buffer.
nonce_offset: usize,
nonce: WebSocketKey,
/// The protocols to include in the handshake.
protocols: Vec<&'a str>,
/// The extensions the client wishes to include in the request.
Expand All @@ -61,8 +60,7 @@ impl<'a, T: AsyncRead + AsyncWrite + Unpin> Client<'a, T> {
host,
resource,
origin: None,
nonce: [0; 32],
nonce_offset: 0,
nonce: [0; 24],
protocols: Vec::new(),
extensions: Vec::new(),
buffer: BytesMut::new()
Expand Down Expand Up @@ -136,15 +134,15 @@ impl<'a, T: AsyncRead + AsyncWrite + Unpin> Client<'a, T> {
/// Encode the client handshake as a request, ready to be sent to the server.
fn encode_request(&mut self) {
let nonce: [u8; 16] = rand::random();
self.nonce_offset = base64::encode_config_slice(&nonce, base64::STANDARD, &mut self.nonce);
base64::encode_config_slice(&nonce, base64::STANDARD, &mut self.nonce);
self.buffer.extend_from_slice(b"GET ");
self.buffer.extend_from_slice(self.resource.as_bytes());
self.buffer.extend_from_slice(b" HTTP/1.1");
self.buffer.extend_from_slice(b"\r\nHost: ");
self.buffer.extend_from_slice(self.host.as_bytes());
self.buffer.extend_from_slice(b"\r\nUpgrade: websocket\r\nConnection: Upgrade");
self.buffer.extend_from_slice(b"\r\nSec-WebSocket-Key: ");
self.buffer.extend_from_slice(&self.nonce[.. self.nonce_offset]);
self.buffer.extend_from_slice(&self.nonce);
if let Some(o) = &self.origin {
self.buffer.extend_from_slice(b"\r\nOrigin: ");
self.buffer.extend_from_slice(o.as_bytes())
Expand Down Expand Up @@ -194,10 +192,9 @@ impl<'a, T: AsyncRead + AsyncWrite + Unpin> Client<'a, T> {
expect_ascii_header(response.headers, "Upgrade", "websocket")?;
expect_ascii_header(response.headers, "Connection", "upgrade")?;

let nonce = &self.nonce[.. self.nonce_offset];
with_first_header(&response.headers, "Sec-WebSocket-Accept", |theirs| {
let mut digest = Sha1::new();
digest.update(nonce);
digest.update(&self.nonce);
digest.update(KEY);
let ours = base64::encode(&digest.finalize());
if ours.as_bytes() != theirs {
Expand Down
101 changes: 69 additions & 32 deletions src/handshake/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@
//!
//! [handshake]: https://tools.ietf.org/html/rfc6455#section-4
use bytes::{Buf, BytesMut};
use crate::{Parsing, extension::Extension};
use bytes::BytesMut;
use crate::extension::Extension;
use crate::connection::{self, Mode};
use futures::prelude::*;
use sha1::{Digest, Sha1};
use std::{mem, str};
use super::{
WebSocketKey,
Error,
KEY,
MAX_NUM_HEADERS,
Expand All @@ -28,6 +29,8 @@ use super::{
with_first_header
};

// Most HTTP servers default to 8KB limit on headers
const MAX_HEADERS_SIZE: usize = 8 * 1024;
const BLOCK_SIZE: usize = 8 * 1024;
const SOKETTO_VERSION: &str = env!("CARGO_PKG_VERSION");

Expand Down Expand Up @@ -83,15 +86,35 @@ impl<'a, T: AsyncRead + AsyncWrite + Unpin> Server<'a, T> {
}

/// Await an incoming client handshake request.
pub async fn receive_request(&mut self) -> Result<ClientRequest<'a>, Error> {
pub async fn receive_request(&mut self) -> Result<ClientRequest<'_>, Error> {
self.buffer.clear();

let mut skip = 0;

loop {
crate::read(&mut self.socket, &mut self.buffer, BLOCK_SIZE).await?;
if let Parsing::Done { value, offset } = self.decode_request()? {
self.buffer.advance(offset);
return Ok(value)

let limit = std::cmp::min(self.buffer.len(), MAX_HEADERS_SIZE);

// We don't expect body, so can search for the CRLF headers tail from
// the end of the buffer.
if self.buffer[skip..limit].windows(4).rev().any(|w| w == b"\r\n\r\n") {
break;
}

// Give up if we've reached the limit. We could emit a specific error here,
// but httparse will produce meaningful error for us regardless.
if limit == MAX_HEADERS_SIZE {
break;
}

// Skip bytes that did not contain CRLF in the next iteration.
// If we only read a partial CRLF sequence, we would miss it if we skipped the full buffer
// length, hence backing off the full 4 bytes.
skip = self.buffer.len().saturating_sub(4);
}

self.decode_request()
}

/// Respond to the client.
Expand All @@ -118,32 +141,41 @@ impl<'a, T: AsyncRead + AsyncWrite + Unpin> Server<'a, T> {
}

// Decode client handshake request.
fn decode_request(&mut self) -> Result<Parsing<ClientRequest<'a>>, Error> {
fn decode_request(&mut self) -> Result<ClientRequest, Error> {
let mut header_buf = [httparse::EMPTY_HEADER; MAX_NUM_HEADERS];
let mut request = httparse::Request::new(&mut header_buf);

let offset = match request.parse(self.buffer.as_ref()) {
Ok(httparse::Status::Complete(off)) => off,
Ok(httparse::Status::Partial) => return Ok(Parsing::NeedMore(())),
match request.parse(self.buffer.as_ref()) {
Ok(httparse::Status::Complete(_)) => (),
Ok(httparse::Status::Partial) => return Err(Error::IncompleteHttpRequest),
Err(e) => return Err(Error::Http(Box::new(e)))
};

if request.method != Some("GET") {
return Err(Error::InvalidRequestMethod)
}
if request.version != Some(1) {
return Err(Error::UnsupportedHttpVersion)
}

// TODO: Host Validation
with_first_header(&request.headers, "Host", |_h| Ok(()))?;
let host = with_first_header(&request.headers, "Host", Ok)?;

expect_ascii_header(request.headers, "Upgrade", "websocket")?;
expect_ascii_header(request.headers, "Connection", "upgrade")?;
expect_ascii_header(request.headers, "Sec-WebSocket-Version", "13")?;

let origin = request.headers.iter().find_map(|h| {
if h.name.eq_ignore_ascii_case("Origin") {
Some(h.value)
} else {
None
}
});
let headers = RequestHeaders { host, origin };

let ws_key = with_first_header(&request.headers, "Sec-WebSocket-Key", |k| {
Ok(Vec::from(k))
use std::convert::TryFrom;

WebSocketKey::try_from(k).map_err(|_| Error::SecWebSocketKeyInvalidLength(k.len()))
})?;

for h in request.headers.iter()
Expand All @@ -161,14 +193,9 @@ impl<'a, T: AsyncRead + AsyncWrite + Unpin> Server<'a, T> {
}
}

let mut path = String::new();
if let Some(val) = request.path {
path.push_str(val)
}
let path = request.path.unwrap_or("/");

Ok(Parsing::Done {
value: ClientRequest { ws_key, protocols, path }, offset,
})
Ok(ClientRequest { ws_key, protocols, path, headers })
}

// Encode server handshake response.
Expand Down Expand Up @@ -217,18 +244,24 @@ impl<'a, T: AsyncRead + AsyncWrite + Unpin> Server<'a, T> {
/// Handshake request received from the client.
#[derive(Debug)]
pub struct ClientRequest<'a> {
ws_key: Vec<u8>,
ws_key: WebSocketKey,
protocols: Vec<&'a str>,
path: String,
path: &'a str,
headers: RequestHeaders<'a>,
}

impl<'a> ClientRequest<'a> {
/// A reference to the nonce.
pub fn key(&self) -> &[u8] {
&self.ws_key
}
/// Select HTTP headers sent by the client.
#[derive(Debug, Copy, Clone)]
pub struct RequestHeaders<'a> {
/// The [`Host`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Host) header.
pub host: &'a [u8],
/// The [`Origin`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Origin) header, if provided.
pub origin: Option<&'a [u8]>,
}

pub fn into_key(self) -> Vec<u8> {
impl<'a> ClientRequest<'a> {
/// The `Sec-WebSocket-Key` header nonce value.
pub fn key(&self) -> WebSocketKey {
self.ws_key
}

Expand All @@ -239,7 +272,12 @@ impl<'a> ClientRequest<'a> {

/// The path the client is requesting.
pub fn path(&self) -> &str {
&self.path
self.path
}

/// Select HTTP headers sent by the client.
pub fn headers(&self) -> RequestHeaders {
self.headers
}
}

Expand All @@ -248,7 +286,7 @@ impl<'a> ClientRequest<'a> {
pub enum Response<'a> {
/// The server accepts the handshake request.
Accept {
key: &'a [u8],
key: WebSocketKey,
protocol: Option<&'a str>
},
/// The server rejects the handshake request.
Expand Down Expand Up @@ -320,4 +358,3 @@ const STATUSCODES: &[(u16, &str, &str)] = &[
(510, "510", "Not Extended"),
(511, "511", "Network Authentication Required")
];

4 changes: 2 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,11 @@
//!
//! let websocket_key = {
//! let req = server.receive_request().await?;
//! req.into_key()
//! req.key()
//! };
//!
//! // Here we accept the client unconditionally.
//! let accept = Response::Accept { key: &websocket_key, protocol: None };
//! let accept = Response::Accept { key: websocket_key, protocol: None };
//! server.send_response(&accept).await?;
//!
//! // And we can finally transition to a websocket connection.
Expand Down

0 comments on commit b18992e

Please sign in to comment.