Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update rustls to version 0.23 #59

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "async-tls"
version = "0.13.0"
version = "0.14.0"
authors = [
"The async-rs developers",
"Florian Gilcher <[email protected]>",
Expand All @@ -23,11 +23,11 @@ appveyor = { repository = "async-std/async-tls" }
[dependencies]
futures-io = "0.3.5"
futures-core = "0.3.5"
rustls = "0.21"
rustls-pemfile = "1.0"
rustls = "0.23.21"
rustls-pemfile = "2.2"
# webpki = { version = "0.22.0", optional = true }
rustls-webpki = { version = "0.101.4", optional = true }
webpki-roots = { version = "0.22.3", optional = true }
rustls-webpki = { version = "0.102", optional = true }
webpki-roots = { version = "0.26", optional = true }

[features]
default = ["client", "server"]
Expand All @@ -36,7 +36,6 @@ early-data = []
server = []

[dev-dependencies]
lazy_static = "1"
futures-executor = "0.3.5"
futures-util = { version = "0.3.5", features = ["io"] }
async-std = { version = "1.11", features = ["unstable"] }
Expand Down
10 changes: 2 additions & 8 deletions src/common/tls_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,10 @@ impl TlsState {
}

pub(crate) fn writeable(&self) -> bool {
match *self {
TlsState::WriteShutdown | TlsState::FullyShutdown => false,
_ => true,
}
!matches!(*self, TlsState::WriteShutdown | TlsState::FullyShutdown)
}

pub(crate) fn readable(self) -> bool {
match self {
TlsState::ReadShutdown | TlsState::FullyShutdown => false,
_ => true,
}
!matches!(self, TlsState::ReadShutdown | TlsState::FullyShutdown)
}
}
21 changes: 8 additions & 13 deletions src/connector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ use crate::common::tls_state::TlsState;
use crate::client;

use futures_io::{AsyncRead, AsyncWrite};
use rustls::{ClientConfig, ClientConnection, OwnedTrustAnchor, RootCertStore, ServerName};
use rustls::pki_types::ServerName;
use rustls::{ClientConfig, ClientConnection, RootCertStore};
use std::convert::TryFrom;
use std::future::Future;
use std::io;
Expand Down Expand Up @@ -64,16 +65,10 @@ impl From<ClientConfig> for TlsConnector {

impl Default for TlsConnector {
fn default() -> Self {
let mut root_certs = RootCertStore::empty();
root_certs.add_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
OwnedTrustAnchor::from_subject_spki_name_constraints(
ta.subject,
ta.spki,
ta.name_constraints,
)
}));
let root_certs = RootCertStore {
roots: webpki_roots::TLS_SERVER_ROOTS.to_vec(),
};
let config = ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(root_certs)
.with_no_client_auth();
Arc::new(config).into()
Expand Down Expand Up @@ -103,7 +98,7 @@ impl TlsConnector {
/// The function will return a `Connect` Future, representing the connecting part of a Tls
/// handshake. It will resolve when the handshake is over.
#[inline]
pub fn connect<'a, IO>(&self, domain: impl AsRef<str>, stream: IO) -> Connect<IO>
pub fn connect<IO>(&self, domain: impl AsRef<str>, stream: IO) -> Connect<IO>
where
IO: AsyncRead + AsyncWrite + Unpin,
{
Expand All @@ -112,12 +107,12 @@ impl TlsConnector {

// NOTE: Currently private, exposing ClientConnection exposes rusttls
// Early data should be exposed differently
fn connect_with<'a, IO, F>(&self, domain: impl AsRef<str>, stream: IO, f: F) -> Connect<IO>
fn connect_with<IO, F>(&self, domain: impl AsRef<str>, stream: IO, f: F) -> Connect<IO>
where
IO: AsyncRead + AsyncWrite + Unpin,
F: FnOnce(&mut ClientConnection),
{
let domain = match ServerName::try_from(domain.as_ref()) {
let domain = match ServerName::try_from(domain.as_ref().to_owned()) {
Ok(domain) => domain,
Err(_) => {
return Connect(ConnectInner::Error(Some(io::Error::new(
Expand Down
10 changes: 5 additions & 5 deletions src/rusttls/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin> Stream<'a, IO> {
cx: &'a mut Context<'b>,
}

impl<'a, 'b, T: AsyncRead + Unpin> Read for Reader<'a, 'b, T> {
impl<T: AsyncRead + Unpin> Read for Reader<'_, '_, T> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match Pin::new(&mut self.io).poll_read(self.cx, buf) {
Poll::Ready(result) => result,
Expand Down Expand Up @@ -253,7 +253,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin> Stream<'a, IO> {
}
}

impl<'a, IO: AsyncRead + AsyncWrite + Unpin> WriteTls<IO> for Stream<'a, IO> {
impl<IO: AsyncRead + AsyncWrite + Unpin> WriteTls<IO> for Stream<'_, IO> {
fn write_tls(&mut self, cx: &mut Context) -> io::Result<usize> {
// TODO writev

Expand All @@ -262,7 +262,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin> WriteTls<IO> for Stream<'a, IO> {
cx: &'a mut Context<'b>,
}

impl<'a, 'b, T: AsyncWrite + Unpin> Write for Writer<'a, 'b, T> {
impl<T: AsyncWrite + Unpin> Write for Writer<'_, '_, T> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match Pin::new(&mut self.io).poll_write(self.cx, buf) {
Poll::Ready(result) => result,
Expand All @@ -283,7 +283,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin> WriteTls<IO> for Stream<'a, IO> {
}
}

impl<'a, IO: AsyncRead + AsyncWrite + Unpin> AsyncRead for Stream<'a, IO> {
impl<IO: AsyncRead + AsyncWrite + Unpin> AsyncRead for Stream<'_, IO> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context,
Expand Down Expand Up @@ -312,7 +312,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin> AsyncRead for Stream<'a, IO> {
}
}

impl<'a, IO: AsyncRead + AsyncWrite + Unpin> AsyncWrite for Stream<'a, IO> {
impl<IO: AsyncRead + AsyncWrite + Unpin> AsyncWrite for Stream<'_, IO> {
fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
let this = self.get_mut();

Expand Down
28 changes: 16 additions & 12 deletions src/rusttls/test_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@ use futures_io::{AsyncRead, AsyncWrite};
use futures_util::io::{AsyncReadExt, AsyncWriteExt};
use futures_util::task::{noop_waker_ref, Context};
use futures_util::{future, ready};
use rustls::pki_types::{PrivateKeyDer, ServerName};
use rustls::{
Certificate, ClientConfig, ClientConnection, ConnectionCommon, PrivateKey, RootCertStore,
ServerConfig, ServerConnection, ServerName,
ClientConfig, ClientConnection, ConnectionCommon, RootCertStore,
ServerConfig, ServerConnection,
};
use rustls_pemfile::{certs, pkcs8_private_keys};
use std::convert::TryFrom;
Expand All @@ -17,7 +18,7 @@ use std::task::Poll;

struct Good<'a, D>(&'a mut ConnectionCommon<D>);

impl<'a, D> AsyncRead for Good<'a, D> {
impl<D> AsyncRead for Good<'_, D> {
fn poll_read(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
Expand All @@ -27,7 +28,7 @@ impl<'a, D> AsyncRead for Good<'a, D> {
}
}

impl<'a, D> AsyncWrite for Good<'a, D> {
impl<D> AsyncWrite for Good<'_, D> {
fn poll_write(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
Expand Down Expand Up @@ -223,24 +224,27 @@ fn make_pair() -> (ServerConnection, ClientConnection) {
const CHAIN: &str = include_str!("../../tests/end.chain");
const RSA: &str = include_str!("../../tests/end.rsa");

let cert = certs(&mut BufReader::new(Cursor::new(CERT))).unwrap();
let cert = cert.into_iter().map(Certificate).collect();
let mut keys = pkcs8_private_keys(&mut BufReader::new(Cursor::new(RSA))).unwrap();
let key = PrivateKey(keys.pop().unwrap());
let cert = certs(&mut BufReader::new(Cursor::new(CERT)))
.collect::<Result<Vec<_>,_>>()
.unwrap();
let mut keys = pkcs8_private_keys(&mut BufReader::new(Cursor::new(RSA)))
.collect::<Result<Vec<_>,_>>()
.unwrap();
let key = PrivateKeyDer::Pkcs8(keys.pop().unwrap());
let sconfig = ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(cert, key)
.unwrap();
let server = ServerConnection::new(Arc::new(sconfig));

let domain = ServerName::try_from("localhost").unwrap();
let mut root_store = RootCertStore::empty();
let chain = certs(&mut BufReader::new(Cursor::new(CHAIN))).unwrap();
let (added, ignored) = root_store.add_parsable_certificates(&chain);
let chain = certs(&mut BufReader::new(Cursor::new(CHAIN)))
.collect::<Result<Vec<_>,_>>()
.unwrap();
let (added, ignored) = root_store.add_parsable_certificates(chain);
assert!(added >= 1 && ignored == 0);
let cconfig = ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(root_store)
.with_no_client_auth();
let client = ClientConnection::new(Arc::new(cconfig), domain);
Expand Down
102 changes: 52 additions & 50 deletions tests/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,61 +4,65 @@ use async_std::net::{TcpListener, TcpStream};
use async_std::prelude::*;
use async_std::task;
use async_tls::{TlsAcceptor, TlsConnector};
use lazy_static::lazy_static;
use rustls::{Certificate, ClientConfig, PrivateKey, RootCertStore, ServerConfig};
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
use rustls::{ClientConfig, RootCertStore, ServerConfig};
use rustls_pemfile::{certs, pkcs8_private_keys};
use std::io::{BufReader, Cursor};
use std::net::SocketAddr;
use std::sync::Arc;
use std::sync::LazyLock;

const CERT: &str = include_str!("end.cert");
const CHAIN: &str = include_str!("end.chain");
const RSA: &str = include_str!("end.rsa");

lazy_static! {
static ref TEST_SERVER: (SocketAddr, &'static str, Vec<Vec<u8>>) = {
let cert = certs(&mut BufReader::new(Cursor::new(CERT))).unwrap();
let cert = cert.into_iter().map(Certificate).collect();
let chain = certs(&mut BufReader::new(Cursor::new(CHAIN))).unwrap();
let mut keys = pkcs8_private_keys(&mut BufReader::new(Cursor::new(RSA))).unwrap();
let key = PrivateKey(keys.pop().unwrap());
let sconfig = ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(cert, key)
.unwrap();
let acceptor = TlsAcceptor::from(Arc::new(sconfig));

let (send, recv) = bounded(1);

task::spawn(async move {
let addr = SocketAddr::from(([127, 0, 0, 1], 0));
let listener = TcpListener::bind(&addr).await?;

send.send(listener.local_addr()?).await.unwrap();

let mut incoming = listener.incoming();
while let Some(stream) = incoming.next().await {
let acceptor = acceptor.clone();
task::spawn(async move {
use futures_util::io::AsyncReadExt;
let stream = acceptor.accept(stream?).await?;
let (mut reader, mut writer) = stream.split();
io::copy(&mut reader, &mut writer).await?;
Ok(()) as io::Result<()>
});
}

Ok(()) as io::Result<()>
});

let addr = task::block_on(async move { recv.recv().await.unwrap() });
(addr, "localhost", chain)
};
}

fn start_server() -> &'static (SocketAddr, &'static str, Vec<Vec<u8>>) {
&*TEST_SERVER
static TEST_SERVER: LazyLock<(SocketAddr, &'static str, Vec<CertificateDer<'_>>)> = LazyLock::new(|| {
let cert = certs(&mut BufReader::new(Cursor::new(CERT)))
.collect::<Result<Vec<_>,_>>()
.unwrap();
let chain = certs(&mut BufReader::new(Cursor::new(CHAIN)))
.collect::<Result<Vec<_>,_>>()
.unwrap();
let mut keys = pkcs8_private_keys(&mut BufReader::new(Cursor::new(RSA)))
.map(|res| res.map(PrivateKeyDer::Pkcs8))
.collect::<Result<Vec<_>,_>>()
.unwrap();
let key = keys.pop().unwrap();
let sconfig = ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(cert, key)
.unwrap();
let acceptor = TlsAcceptor::from(Arc::new(sconfig));

let (send, recv) = bounded(1);

task::spawn(async move {
let addr = SocketAddr::from(([127, 0, 0, 1], 0));
let listener = TcpListener::bind(&addr).await?;

send.send(listener.local_addr()?).await.unwrap();

let mut incoming = listener.incoming();
while let Some(stream) = incoming.next().await {
let acceptor = acceptor.clone();
task::spawn(async move {
use futures_util::io::AsyncReadExt;
let stream = acceptor.accept(stream?).await?;
let (mut reader, mut writer) = stream.split();
io::copy(&mut reader, &mut writer).await?;
Ok(()) as io::Result<()>
});
}

Ok(()) as io::Result<()>
});

let addr = task::block_on(async { recv.recv().await.unwrap() });
(addr, "localhost", chain)
});

fn start_server() -> &'static (SocketAddr, &'static str, Vec<CertificateDer<'static>>) {
&TEST_SERVER
}

async fn start_client(addr: SocketAddr, domain: &str, config: Arc<ClientConfig>) -> io::Result<()> {
Expand All @@ -82,10 +86,9 @@ async fn start_client(addr: SocketAddr, domain: &str, config: Arc<ClientConfig>)
fn pass() {
let (addr, domain, chain) = start_server();
let mut root_store = RootCertStore::empty();
let (added, ignored) = root_store.add_parsable_certificates(&chain);
let (added, ignored) = root_store.add_parsable_certificates(chain.clone());
assert!(added >= 1 && ignored == 0);
let config = ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(root_store)
.with_no_client_auth();
task::block_on(start_client(*addr, domain, Arc::new(config))).unwrap();
Expand All @@ -95,10 +98,9 @@ fn pass() {
fn fail() {
let (addr, domain, chain) = start_server();
let mut root_store = RootCertStore::empty();
let (added, ignored) = root_store.add_parsable_certificates(&chain);
let (added, ignored) = root_store.add_parsable_certificates(chain.clone());
assert!(added >= 1 && ignored == 0);
let config = ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(root_store)
.with_no_client_auth();
let config = Arc::new(config);
Expand Down