Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose zero-copy-ish headers on ClientRequest #35

Merged
merged 18 commits into from
Jun 11, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ default = []
deflate = ["flate2"]

[dependencies]
arrayvec = { default-features = false, version = "0.7.1" }
base64 = { default-features = false, features = ["alloc"], version = "0.13" }
bytes = { default-features = false, version = "1.0" }
flate2 = { default-features = false, features = ["zlib"], optional = true, version = "1.0.13" }
Expand Down
10 changes: 10 additions & 0 deletions src/handshake.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,10 @@ pub enum Error {
Io(io::Error),
/// An HTTP version =/= 1.1 was encountered.
UnsupportedHttpVersion,
/// An incomplete HTTP request.
IncompleteHttpRequest,
/// The value of the `Sec-WebSocket-Key` header is too long
SecWebsocketKeyTooLong,
/// The handshake request was not a GET request.
InvalidRequestMethod,
/// An HTTP header has not been present.
Expand Down Expand Up @@ -160,6 +164,10 @@ impl fmt::Display for Error {
write!(f, "i/o error: {}", e),
Error::UnsupportedHttpVersion =>
f.write_str("http version was not 1.1"),
Error::IncompleteHttpRequest =>
f.write_str("http request was incomplete"),
Error::SecWebsocketKeyTooLong =>
f.write_str("Sec-WebSocket-Key header is too long"),
Error::InvalidRequestMethod =>
f.write_str("handshake was not a GET request"),
Error::HeaderNotFound(name) =>
Expand Down Expand Up @@ -190,6 +198,8 @@ impl std::error::Error for Error {
Error::Http(e) => Some(&**e),
Error::Utf8(e) => Some(e),
Error::UnsupportedHttpVersion
| Error::IncompleteHttpRequest
| Error::SecWebsocketKeyTooLong
| Error::InvalidRequestMethod
| Error::HeaderNotFound(_)
| Error::UnexpectedHeader(_)
Expand Down
98 changes: 73 additions & 25 deletions src/handshake/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
//!
//! [handshake]: https://tools.ietf.org/html/rfc6455#section-4

use bytes::{Buf, BytesMut};
use crate::{Parsing, extension::Extension};
use arrayvec::ArrayVec;
use bytes::BytesMut;
use crate::extension::Extension;
use crate::connection::{self, Mode};
use futures::prelude::*;
use sha1::{Digest, Sha1};
Expand Down Expand Up @@ -43,6 +44,9 @@ pub struct Server<'a, T> {
buffer: BytesMut
}

/// Owned value of the `Sec-WebSocket-Key` header.
pub type WebSocketKey = ArrayVec<u8, 28>;

impl<'a, T: AsyncRead + AsyncWrite + Unpin> Server<'a, T> {
/// Create a new server handshake.
pub fn new(socket: T) -> Self {
Expand Down Expand Up @@ -83,15 +87,25 @@ impl<'a, T: AsyncRead + AsyncWrite + Unpin> Server<'a, T> {
}

/// Await an incoming client handshake request.
pub async fn receive_request(&mut self) -> Result<ClientRequest<'a>, Error> {
pub async fn receive_request<'r>(&'r mut self) -> Result<ClientRequest<'r>, Error> {
self.buffer.clear();

let mut skip = 0;

loop {
crate::read(&mut self.socket, &mut self.buffer, BLOCK_SIZE).await?;
if let Parsing::Done { value, offset } = self.decode_request()? {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Attempting decode after each read also made lifetimes convoluted since there is a mutable borrow to self.buffer made on every iteration. Opted to do a single parse after we are sure headers are complete in the new code.

self.buffer.advance(offset);
Copy link
Contributor Author

@maciejhirsz maciejhirsz Jun 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This mutable borrow of self.buffer prevented us from borrowing stuff in return value. The advance was not necessary since the buffer is never read after the headers (not just that, it should not contain any bytes after the request AFAIU).

return Ok(value)

// We don't expect body, so can search for the CRLF headers tail from
// the end of the buffer.
if self.buffer[skip..].windows(4).rev().find(|w| w == b"\r\n\r\n").is_some() {
break;
}

// Skip bytes that did not contain CRLF in the next iteration
skip = self.buffer.len().saturating_sub(4);
}

self.decode_request()
}

/// Respond to the client.
Expand All @@ -118,32 +132,44 @@ impl<'a, T: AsyncRead + AsyncWrite + Unpin> Server<'a, T> {
}

// Decode client handshake request.
fn decode_request(&mut self) -> Result<Parsing<ClientRequest<'a>>, Error> {
fn decode_request(&mut self) -> Result<ClientRequest, Error> {
let mut header_buf = [httparse::EMPTY_HEADER; MAX_NUM_HEADERS];
let mut request = httparse::Request::new(&mut header_buf);

let offset = match request.parse(self.buffer.as_ref()) {
Ok(httparse::Status::Complete(off)) => off,
Ok(httparse::Status::Partial) => return Ok(Parsing::NeedMore(())),
match request.parse(self.buffer.as_ref()) {
Ok(httparse::Status::Complete(_)) => (),
Ok(httparse::Status::Partial) => return Err(Error::IncompleteHttpRequest),
Err(e) => return Err(Error::Http(Box::new(e)))
};

if request.method != Some("GET") {
return Err(Error::InvalidRequestMethod)
}
if request.version != Some(1) {
return Err(Error::UnsupportedHttpVersion)
}

// TODO: Host Validation
with_first_header(&request.headers, "Host", |_h| Ok(()))?;
let host = with_first_header(&request.headers, "Host", |h| Ok(header_to_str(h)))?;

expect_ascii_header(request.headers, "Upgrade", "websocket")?;
expect_ascii_header(request.headers, "Connection", "upgrade")?;
expect_ascii_header(request.headers, "Sec-WebSocket-Version", "13")?;

let origin = request.headers.iter().find_map(|h| {
if h.name.eq_ignore_ascii_case("Origin") {
Some(header_to_str(h.value))
} else {
None
}
});
let headers = RequestHeaders { host, origin };

let ws_key = with_first_header(&request.headers, "Sec-WebSocket-Key", |k| {
Ok(Vec::from(k))
let mut key = ArrayVec::new();

match key.try_extend_from_slice(k) {
Ok(()) => Ok(key),
Err(_) => Err(Error::SecWebsocketKeyTooLong)
}
})?;

for h in request.headers.iter()
Expand All @@ -161,14 +187,9 @@ impl<'a, T: AsyncRead + AsyncWrite + Unpin> Server<'a, T> {
}
}

let mut path = String::new();
if let Some(val) = request.path {
path.push_str(val)
}
let path = request.path.unwrap_or("/");

Ok(Parsing::Done {
value: ClientRequest { ws_key, protocols, path }, offset,
})
Ok(ClientRequest { ws_key, protocols, path, headers })
}

// Encode server handshake response.
Expand Down Expand Up @@ -214,12 +235,26 @@ impl<'a, T: AsyncRead + AsyncWrite + Unpin> Server<'a, T> {
}
}

fn header_to_str(bytes: &[u8]) -> &str {
str::from_utf8(bytes).unwrap_or("INVALID_UTF8")
}

/// Handshake request received from the client.
#[derive(Debug)]
pub struct ClientRequest<'a> {
ws_key: Vec<u8>,
ws_key: WebSocketKey,
protocols: Vec<&'a str>,
path: String,
path: &'a str,
headers: RequestHeaders<'a>,
}

/// Select HTTP headers sent by the client.
#[derive(Debug)]
pub struct RequestHeaders<'a> {
/// The [`Host`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Host) header.
pub host: &'a str,
/// The [`Origin`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Origin) header, if provided.
pub origin: Option<&'a str>,
}

impl<'a> ClientRequest<'a> {
Expand All @@ -228,7 +263,7 @@ impl<'a> ClientRequest<'a> {
&self.ws_key
}

pub fn into_key(self) -> Vec<u8> {
pub fn into_key(self) -> WebSocketKey {
self.ws_key
}

Expand All @@ -239,7 +274,11 @@ impl<'a> ClientRequest<'a> {

/// The path the client is requesting.
pub fn path(&self) -> &str {
&self.path
self.path
}

pub fn headers(&self) -> &RequestHeaders {
&self.headers
}
}

Expand Down Expand Up @@ -321,3 +360,12 @@ const STATUSCODES: &[(u16, &str, &str)] = &[
(511, "511", "Network Authentication Required")
];

#[cfg(test)]
mod tests {
use super::WebSocketKey;

#[test]
fn ws_key_stack_size() {
assert_eq!(32, std::mem::size_of::<WebSocketKey>());
}
}