Skip to content

Commit 8c2af0c

Browse files
committed
Move TlsConnector and TlsConnectorWithAlpn into client module
1 parent d3346d3 commit 8c2af0c

File tree

2 files changed

+143
-137
lines changed

2 files changed

+143
-137
lines changed

src/client.rs

Lines changed: 141 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
use std::io::{self, BufRead as _};
21
#[cfg(unix)]
32
use std::os::unix::io::{AsRawFd, RawFd};
43
#[cfg(windows)]
@@ -7,11 +6,150 @@ use std::pin::Pin;
76
#[cfg(feature = "early-data")]
87
use std::task::Waker;
98
use std::task::{Context, Poll};
9+
use std::{
10+
io::{self, BufRead as _},
11+
sync::Arc,
12+
};
1013

11-
use rustls::ClientConnection;
14+
use rustls::{pki_types::ServerName, ClientConfig, ClientConnection};
1215
use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
1316

14-
use crate::common::{IoSession, Stream, TlsState};
17+
use crate::{
18+
common::{IoSession, MidHandshake, Stream, TlsState},
19+
Connect,
20+
};
21+
22+
/// A wrapper around a `rustls::ClientConfig`, providing an async `connect` method.
23+
#[derive(Clone)]
24+
pub struct TlsConnector {
25+
inner: Arc<ClientConfig>,
26+
#[cfg(feature = "early-data")]
27+
early_data: bool,
28+
}
29+
30+
impl TlsConnector {
31+
/// Enable 0-RTT.
32+
///
33+
/// If you want to use 0-RTT,
34+
/// You must also set `ClientConfig.enable_early_data` to `true`.
35+
#[cfg(feature = "early-data")]
36+
pub fn early_data(mut self, flag: bool) -> TlsConnector {
37+
self.early_data = flag;
38+
self
39+
}
40+
41+
#[inline]
42+
pub fn connect<IO>(&self, domain: ServerName<'static>, stream: IO) -> Connect<IO>
43+
where
44+
IO: AsyncRead + AsyncWrite + Unpin,
45+
{
46+
self.connect_impl(domain, stream, None, |_| ())
47+
}
48+
49+
#[inline]
50+
pub fn connect_with<IO, F>(&self, domain: ServerName<'static>, stream: IO, f: F) -> Connect<IO>
51+
where
52+
IO: AsyncRead + AsyncWrite + Unpin,
53+
F: FnOnce(&mut ClientConnection),
54+
{
55+
self.connect_impl(domain, stream, None, f)
56+
}
57+
58+
fn connect_impl<IO, F>(
59+
&self,
60+
domain: ServerName<'static>,
61+
stream: IO,
62+
alpn_protocols: Option<Vec<Vec<u8>>>,
63+
f: F,
64+
) -> Connect<IO>
65+
where
66+
IO: AsyncRead + AsyncWrite + Unpin,
67+
F: FnOnce(&mut ClientConnection),
68+
{
69+
let alpn = alpn_protocols.unwrap_or_else(|| self.inner.alpn_protocols.clone());
70+
let mut session = match ClientConnection::new_with_alpn(self.inner.clone(), domain, alpn) {
71+
Ok(session) => session,
72+
Err(error) => {
73+
return Connect(MidHandshake::Error {
74+
io: stream,
75+
// TODO(eliza): should this really return an `io::Error`?
76+
// Probably not...
77+
error: io::Error::new(io::ErrorKind::Other, error),
78+
});
79+
}
80+
};
81+
f(&mut session);
82+
83+
Connect(MidHandshake::Handshaking(TlsStream {
84+
io: stream,
85+
86+
#[cfg(not(feature = "early-data"))]
87+
state: TlsState::Stream,
88+
89+
#[cfg(feature = "early-data")]
90+
state: if self.early_data && session.early_data().is_some() {
91+
TlsState::EarlyData(0, Vec::new())
92+
} else {
93+
TlsState::Stream
94+
},
95+
96+
need_flush: false,
97+
98+
#[cfg(feature = "early-data")]
99+
early_waker: None,
100+
101+
session,
102+
}))
103+
}
104+
105+
pub fn with_alpn(&self, alpn_protocols: Vec<Vec<u8>>) -> TlsConnectorWithAlpn<'_> {
106+
TlsConnectorWithAlpn {
107+
inner: self,
108+
alpn_protocols,
109+
}
110+
}
111+
112+
/// Get a read-only reference to underlying config
113+
pub fn config(&self) -> &Arc<ClientConfig> {
114+
&self.inner
115+
}
116+
}
117+
118+
impl From<Arc<ClientConfig>> for TlsConnector {
119+
fn from(inner: Arc<ClientConfig>) -> TlsConnector {
120+
TlsConnector {
121+
inner,
122+
#[cfg(feature = "early-data")]
123+
early_data: false,
124+
}
125+
}
126+
}
127+
128+
pub struct TlsConnectorWithAlpn<'c> {
129+
inner: &'c TlsConnector,
130+
alpn_protocols: Vec<Vec<u8>>,
131+
}
132+
133+
impl TlsConnectorWithAlpn<'_> {
134+
#[inline]
135+
pub fn connect<IO>(self, domain: ServerName<'static>, stream: IO) -> Connect<IO>
136+
where
137+
IO: AsyncRead + AsyncWrite + Unpin,
138+
{
139+
self.inner
140+
.connect_impl(domain, stream, Some(self.alpn_protocols), |_| ())
141+
}
142+
143+
#[inline]
144+
pub fn connect_with<IO, F>(self, domain: ServerName<'static>, stream: IO, f: F) -> Connect<IO>
145+
where
146+
IO: AsyncRead + AsyncWrite + Unpin,
147+
F: FnOnce(&mut ClientConnection),
148+
{
149+
self.inner
150+
.connect_impl(domain, stream, Some(self.alpn_protocols), f)
151+
}
152+
}
15153

16154
/// A wrapper around an underlying raw stream which implements the TLS or SSL
17155
/// protocol.

src/lib.rs

Lines changed: 2 additions & 134 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,8 @@ use std::task::{Context, Poll};
5050

5151
pub use rustls;
5252

53-
use rustls::pki_types::ServerName;
5453
use rustls::server::AcceptedAlert;
55-
use rustls::{ClientConfig, ClientConnection, CommonState, ServerConfig, ServerConnection};
54+
use rustls::{CommonState, ServerConfig, ServerConnection};
5655
use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
5756

5857
macro_rules! ready {
@@ -65,148 +64,17 @@ macro_rules! ready {
6564
}
6665

6766
pub mod client;
67+
pub use client::{TlsConnector, TlsConnectorWithAlpn};
6868
mod common;
6969
use common::{MidHandshake, TlsState};
7070
pub mod server;
7171

72-
/// A wrapper around a `rustls::ClientConfig`, providing an async `connect` method.
73-
#[derive(Clone)]
74-
pub struct TlsConnector {
75-
inner: Arc<ClientConfig>,
76-
#[cfg(feature = "early-data")]
77-
early_data: bool,
78-
}
79-
80-
impl TlsConnector {
81-
/// Enable 0-RTT.
82-
///
83-
/// If you want to use 0-RTT,
84-
/// You must also set `ClientConfig.enable_early_data` to `true`.
85-
#[cfg(feature = "early-data")]
86-
pub fn early_data(mut self, flag: bool) -> TlsConnector {
87-
self.early_data = flag;
88-
self
89-
}
90-
91-
#[inline]
92-
pub fn connect<IO>(&self, domain: ServerName<'static>, stream: IO) -> Connect<IO>
93-
where
94-
IO: AsyncRead + AsyncWrite + Unpin,
95-
{
96-
self.connect_impl(domain, stream, None, |_| ())
97-
}
98-
99-
#[inline]
100-
pub fn connect_with<IO, F>(&self, domain: ServerName<'static>, stream: IO, f: F) -> Connect<IO>
101-
where
102-
IO: AsyncRead + AsyncWrite + Unpin,
103-
F: FnOnce(&mut ClientConnection),
104-
{
105-
self.connect_impl(domain, stream, None, f)
106-
}
107-
108-
fn connect_impl<IO, F>(
109-
&self,
110-
domain: ServerName<'static>,
111-
stream: IO,
112-
alpn_protocols: Option<Vec<Vec<u8>>>,
113-
f: F,
114-
) -> Connect<IO>
115-
where
116-
IO: AsyncRead + AsyncWrite + Unpin,
117-
F: FnOnce(&mut ClientConnection),
118-
{
119-
let alpn = alpn_protocols.unwrap_or_else(|| self.inner.alpn_protocols.clone());
120-
let mut session = match ClientConnection::new_with_alpn(self.inner.clone(), domain, alpn) {
121-
Ok(session) => session,
122-
Err(error) => {
123-
return Connect(MidHandshake::Error {
124-
io: stream,
125-
// TODO(eliza): should this really return an `io::Error`?
126-
// Probably not...
127-
error: io::Error::new(io::ErrorKind::Other, error),
128-
});
129-
}
130-
};
131-
f(&mut session);
132-
133-
Connect(MidHandshake::Handshaking(client::TlsStream {
134-
io: stream,
135-
136-
#[cfg(not(feature = "early-data"))]
137-
state: TlsState::Stream,
138-
139-
#[cfg(feature = "early-data")]
140-
state: if self.early_data && session.early_data().is_some() {
141-
TlsState::EarlyData(0, Vec::new())
142-
} else {
143-
TlsState::Stream
144-
},
145-
146-
need_flush: false,
147-
148-
#[cfg(feature = "early-data")]
149-
early_waker: None,
150-
151-
session,
152-
}))
153-
}
154-
155-
pub fn with_alpn(&self, alpn_protocols: Vec<Vec<u8>>) -> TlsConnectorWithAlpn<'_> {
156-
TlsConnectorWithAlpn {
157-
inner: self,
158-
alpn_protocols,
159-
}
160-
}
161-
162-
/// Get a read-only reference to underlying config
163-
pub fn config(&self) -> &Arc<ClientConfig> {
164-
&self.inner
165-
}
166-
}
167-
168-
pub struct TlsConnectorWithAlpn<'c> {
169-
inner: &'c TlsConnector,
170-
alpn_protocols: Vec<Vec<u8>>,
171-
}
172-
173-
impl TlsConnectorWithAlpn<'_> {
174-
#[inline]
175-
pub fn connect<IO>(self, domain: ServerName<'static>, stream: IO) -> Connect<IO>
176-
where
177-
IO: AsyncRead + AsyncWrite + Unpin,
178-
{
179-
self.inner
180-
.connect_impl(domain, stream, Some(self.alpn_protocols), |_| ())
181-
}
182-
183-
#[inline]
184-
pub fn connect_with<IO, F>(self, domain: ServerName<'static>, stream: IO, f: F) -> Connect<IO>
185-
where
186-
IO: AsyncRead + AsyncWrite + Unpin,
187-
F: FnOnce(&mut ClientConnection),
188-
{
189-
self.inner
190-
.connect_impl(domain, stream, Some(self.alpn_protocols), f)
191-
}
192-
}
193-
19472
/// A wrapper around a `rustls::ServerConfig`, providing an async `accept` method.
19573
#[derive(Clone)]
19674
pub struct TlsAcceptor {
19775
inner: Arc<ServerConfig>,
19876
}
19977

200-
impl From<Arc<ClientConfig>> for TlsConnector {
201-
fn from(inner: Arc<ClientConfig>) -> TlsConnector {
202-
TlsConnector {
203-
inner,
204-
#[cfg(feature = "early-data")]
205-
early_data: false,
206-
}
207-
}
208-
}
209-
21078
impl From<Arc<ServerConfig>> for TlsAcceptor {
21179
fn from(inner: Arc<ServerConfig>) -> TlsAcceptor {
21280
TlsAcceptor { inner }

0 commit comments

Comments
 (0)