Skip to content

Commit 286e1fa

Browse files
Licenserdjccpu
authored
Add TLSAcceptor and Builder (#186)
Signed-off-by: Heinz N. Gies <[email protected]> Co-authored-by: Dirkjan Ochtman <[email protected]> Co-authored-by: Daniel McCarney <[email protected]>
1 parent a0dd811 commit 286e1fa

File tree

7 files changed

+319
-142
lines changed

7 files changed

+319
-142
lines changed

Cargo.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ rustls = { version = "0.21.0", default-features = false }
1919
tokio = "1.0"
2020
tokio-rustls = { version = "0.24.0", default-features = false }
2121
webpki-roots = { version = "0.23", optional = true }
22+
futures-util = { version = "0.3" }
2223

2324
[dev-dependencies]
2425
futures-util = { version = "0.3.1", default-features = false }
@@ -28,7 +29,8 @@ rustls-pemfile = "1.0.0"
2829
tokio = { version = "1.0", features = ["io-std", "macros", "net", "rt-multi-thread"] }
2930

3031
[features]
31-
default = ["native-tokio", "http1", "tls12", "logging"]
32+
default = ["native-tokio", "http1", "tls12", "logging", "acceptor"]
33+
acceptor = ["hyper/server", "tokio-runtime"]
3234
http1 = ["hyper/http1"]
3335
http2 = ["hyper/http2"]
3436
webpki-tokio = ["tokio-runtime", "webpki-roots"]
@@ -45,7 +47,7 @@ required-features = ["native-tokio", "http1"]
4547
[[example]]
4648
name = "server"
4749
path = "examples/server.rs"
48-
required-features = ["tokio-runtime"]
50+
required-features = ["tokio-runtime", "acceptor"]
4951

5052
[package.metadata.docs.rs]
5153
all-features = true

examples/server.rs

Lines changed: 18 additions & 132 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,16 @@
44
//! Certificate and private key are hardcoded to sample files.
55
//! hyper will automatically use HTTP/2 if a client starts talking HTTP/2,
66
//! otherwise HTTP/1.1 will be used.
7-
use core::task::{Context, Poll};
8-
use futures_util::ready;
9-
use hyper::server::accept::Accept;
10-
use hyper::server::conn::{AddrIncoming, AddrStream};
7+
8+
#![cfg(feature = "acceptor")]
9+
10+
use std::vec::Vec;
11+
use std::{env, fs, io};
12+
13+
use hyper::server::conn::AddrIncoming;
1114
use hyper::service::{make_service_fn, service_fn};
1215
use hyper::{Body, Method, Request, Response, Server, StatusCode};
13-
use std::future::Future;
14-
use std::pin::Pin;
15-
use std::sync::Arc;
16-
use std::vec::Vec;
17-
use std::{env, fs, io, sync};
18-
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
19-
use tokio_rustls::rustls::ServerConfig;
16+
use hyper_rustls::TlsAcceptor;
2017

2118
fn main() {
2219
// Serve an echo service over HTTPS, with proper error handling.
@@ -39,139 +36,28 @@ async fn run_server() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
3936
};
4037
let addr = format!("127.0.0.1:{}", port).parse()?;
4138

39+
// Load public certificate.
40+
let certs = load_certs("examples/sample.pem")?;
41+
// Load private key.
42+
let key = load_private_key("examples/sample.rsa")?;
4243
// Build TLS configuration.
43-
let tls_cfg = {
44-
// Load public certificate.
45-
let certs = load_certs("examples/sample.pem")?;
46-
// Load private key.
47-
let key = load_private_key("examples/sample.rsa")?;
48-
// Do not use client certificate authentication.
49-
let mut cfg = rustls::ServerConfig::builder()
50-
.with_safe_defaults()
51-
.with_no_client_auth()
52-
.with_single_cert(certs, key)
53-
.map_err(|e| error(format!("{}", e)))?;
54-
// Configure ALPN to accept HTTP/2, HTTP/1.1, and HTTP/1.0 in that order.
55-
cfg.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec(), b"http/1.0".to_vec()];
56-
sync::Arc::new(cfg)
57-
};
5844

5945
// Create a TCP listener via tokio.
6046
let incoming = AddrIncoming::bind(&addr)?;
47+
let acceptor = TlsAcceptor::builder()
48+
.with_single_cert(certs, key)
49+
.map_err(|e| error(format!("{}", e)))?
50+
.with_all_versions_alpn()
51+
.with_incoming(incoming);
6152
let service = make_service_fn(|_| async { Ok::<_, io::Error>(service_fn(echo)) });
62-
let server = Server::builder(TlsAcceptor::new(tls_cfg, incoming)).serve(service);
53+
let server = Server::builder(acceptor).serve(service);
6354

6455
// Run the future, keep going until an error occurs.
6556
println!("Starting to serve on https://{}.", addr);
6657
server.await?;
6758
Ok(())
6859
}
6960

70-
enum State {
71-
Handshaking(tokio_rustls::Accept<AddrStream>),
72-
Streaming(tokio_rustls::server::TlsStream<AddrStream>),
73-
}
74-
75-
// tokio_rustls::server::TlsStream doesn't expose constructor methods,
76-
// so we have to TlsAcceptor::accept and handshake to have access to it
77-
// TlsStream implements AsyncRead/AsyncWrite handshaking tokio_rustls::Accept first
78-
pub struct TlsStream {
79-
state: State,
80-
}
81-
82-
impl TlsStream {
83-
fn new(stream: AddrStream, config: Arc<ServerConfig>) -> TlsStream {
84-
let accept = tokio_rustls::TlsAcceptor::from(config).accept(stream);
85-
TlsStream {
86-
state: State::Handshaking(accept),
87-
}
88-
}
89-
}
90-
91-
impl AsyncRead for TlsStream {
92-
fn poll_read(
93-
self: Pin<&mut Self>,
94-
cx: &mut Context,
95-
buf: &mut ReadBuf,
96-
) -> Poll<io::Result<()>> {
97-
let pin = self.get_mut();
98-
match pin.state {
99-
State::Handshaking(ref mut accept) => match ready!(Pin::new(accept).poll(cx)) {
100-
Ok(mut stream) => {
101-
let result = Pin::new(&mut stream).poll_read(cx, buf);
102-
pin.state = State::Streaming(stream);
103-
result
104-
}
105-
Err(err) => Poll::Ready(Err(err)),
106-
},
107-
State::Streaming(ref mut stream) => Pin::new(stream).poll_read(cx, buf),
108-
}
109-
}
110-
}
111-
112-
impl AsyncWrite for TlsStream {
113-
fn poll_write(
114-
self: Pin<&mut Self>,
115-
cx: &mut Context<'_>,
116-
buf: &[u8],
117-
) -> Poll<io::Result<usize>> {
118-
let pin = self.get_mut();
119-
match pin.state {
120-
State::Handshaking(ref mut accept) => match ready!(Pin::new(accept).poll(cx)) {
121-
Ok(mut stream) => {
122-
let result = Pin::new(&mut stream).poll_write(cx, buf);
123-
pin.state = State::Streaming(stream);
124-
result
125-
}
126-
Err(err) => Poll::Ready(Err(err)),
127-
},
128-
State::Streaming(ref mut stream) => Pin::new(stream).poll_write(cx, buf),
129-
}
130-
}
131-
132-
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
133-
match self.state {
134-
State::Handshaking(_) => Poll::Ready(Ok(())),
135-
State::Streaming(ref mut stream) => Pin::new(stream).poll_flush(cx),
136-
}
137-
}
138-
139-
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
140-
match self.state {
141-
State::Handshaking(_) => Poll::Ready(Ok(())),
142-
State::Streaming(ref mut stream) => Pin::new(stream).poll_shutdown(cx),
143-
}
144-
}
145-
}
146-
147-
pub struct TlsAcceptor {
148-
config: Arc<ServerConfig>,
149-
incoming: AddrIncoming,
150-
}
151-
152-
impl TlsAcceptor {
153-
pub fn new(config: Arc<ServerConfig>, incoming: AddrIncoming) -> TlsAcceptor {
154-
TlsAcceptor { config, incoming }
155-
}
156-
}
157-
158-
impl Accept for TlsAcceptor {
159-
type Conn = TlsStream;
160-
type Error = io::Error;
161-
162-
fn poll_accept(
163-
self: Pin<&mut Self>,
164-
cx: &mut Context<'_>,
165-
) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
166-
let pin = self.get_mut();
167-
match ready!(Pin::new(&mut pin.incoming).poll_accept(cx)) {
168-
Some(Ok(sock)) => Poll::Ready(Some(Ok(TlsStream::new(sock, pin.config.clone())))),
169-
Some(Err(e)) => Poll::Ready(Some(Err(e))),
170-
None => Poll::Ready(None),
171-
}
172-
}
173-
}
174-
17561
// Custom echo service, handling two different routes and a
17662
// catch-all 404 responder.
17763
async fn echo(req: Request<Body>) -> Result<Response<Body>, hyper::Error> {

src/acceptor.rs

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
use core::task::{Context, Poll};
2+
use std::future::Future;
3+
use std::io;
4+
use std::pin::Pin;
5+
use std::sync::Arc;
6+
7+
use futures_util::ready;
8+
use hyper::server::{
9+
accept::Accept,
10+
conn::{AddrIncoming, AddrStream},
11+
};
12+
use rustls::ServerConfig;
13+
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
14+
15+
mod builder;
16+
pub use builder::AcceptorBuilder;
17+
use builder::WantsTlsConfig;
18+
19+
enum State {
20+
Handshaking(tokio_rustls::Accept<AddrStream>),
21+
Streaming(tokio_rustls::server::TlsStream<AddrStream>),
22+
}
23+
24+
// tokio_rustls::server::TlsStream doesn't expose constructor methods,
25+
// so we have to TlsAcceptor::accept and handshake to have access to it
26+
// TlsStream implements AsyncRead/AsyncWrite by handshaking with tokio_rustls::Accept first
27+
pub struct TlsStream {
28+
state: State,
29+
}
30+
31+
impl TlsStream {
32+
fn new(stream: AddrStream, config: Arc<ServerConfig>) -> TlsStream {
33+
let accept = tokio_rustls::TlsAcceptor::from(config).accept(stream);
34+
TlsStream {
35+
state: State::Handshaking(accept),
36+
}
37+
}
38+
}
39+
40+
impl AsyncRead for TlsStream {
41+
fn poll_read(
42+
self: Pin<&mut Self>,
43+
cx: &mut Context,
44+
buf: &mut ReadBuf,
45+
) -> Poll<io::Result<()>> {
46+
let pin = self.get_mut();
47+
match pin.state {
48+
State::Handshaking(ref mut accept) => match ready!(Pin::new(accept).poll(cx)) {
49+
Ok(mut stream) => {
50+
let result = Pin::new(&mut stream).poll_read(cx, buf);
51+
pin.state = State::Streaming(stream);
52+
result
53+
}
54+
Err(err) => Poll::Ready(Err(err)),
55+
},
56+
State::Streaming(ref mut stream) => Pin::new(stream).poll_read(cx, buf),
57+
}
58+
}
59+
}
60+
61+
impl AsyncWrite for TlsStream {
62+
fn poll_write(
63+
self: Pin<&mut Self>,
64+
cx: &mut Context<'_>,
65+
buf: &[u8],
66+
) -> Poll<io::Result<usize>> {
67+
let pin = self.get_mut();
68+
match pin.state {
69+
State::Handshaking(ref mut accept) => match ready!(Pin::new(accept).poll(cx)) {
70+
Ok(mut stream) => {
71+
let result = Pin::new(&mut stream).poll_write(cx, buf);
72+
pin.state = State::Streaming(stream);
73+
result
74+
}
75+
Err(err) => Poll::Ready(Err(err)),
76+
},
77+
State::Streaming(ref mut stream) => Pin::new(stream).poll_write(cx, buf),
78+
}
79+
}
80+
81+
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
82+
match self.state {
83+
State::Handshaking(_) => Poll::Ready(Ok(())),
84+
State::Streaming(ref mut stream) => Pin::new(stream).poll_flush(cx),
85+
}
86+
}
87+
88+
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
89+
match self.state {
90+
State::Handshaking(_) => Poll::Ready(Ok(())),
91+
State::Streaming(ref mut stream) => Pin::new(stream).poll_shutdown(cx),
92+
}
93+
}
94+
}
95+
96+
/// A TLS acceptor that can be used with hyper servers.
97+
pub struct TlsAcceptor {
98+
config: Arc<ServerConfig>,
99+
incoming: AddrIncoming,
100+
}
101+
102+
/// An Acceptor for the `https` scheme.
103+
impl TlsAcceptor {
104+
/// Provides a builder for a `TlsAcceptor`.
105+
pub fn builder() -> AcceptorBuilder<WantsTlsConfig> {
106+
AcceptorBuilder::new()
107+
}
108+
/// Creates a new `TlsAcceptor` from a `ServerConfig` and an `AddrIncoming`.
109+
pub fn new(config: Arc<ServerConfig>, incoming: AddrIncoming) -> TlsAcceptor {
110+
TlsAcceptor { config, incoming }
111+
}
112+
}
113+
114+
impl<C, I> From<(C, I)> for TlsAcceptor
115+
where
116+
C: Into<Arc<ServerConfig>>,
117+
I: Into<AddrIncoming>,
118+
{
119+
fn from((config, incoming): (C, I)) -> TlsAcceptor {
120+
TlsAcceptor::new(config.into(), incoming.into())
121+
}
122+
}
123+
124+
impl Accept for TlsAcceptor {
125+
type Conn = TlsStream;
126+
type Error = io::Error;
127+
128+
fn poll_accept(
129+
self: Pin<&mut Self>,
130+
cx: &mut Context<'_>,
131+
) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
132+
let pin = self.get_mut();
133+
match ready!(Pin::new(&mut pin.incoming).poll_accept(cx)) {
134+
Some(Ok(sock)) => Poll::Ready(Some(Ok(TlsStream::new(sock, pin.config.clone())))),
135+
Some(Err(e)) => Poll::Ready(Some(Err(e))),
136+
None => Poll::Ready(None),
137+
}
138+
}
139+
}

0 commit comments

Comments
 (0)