index : static-web-server.git

ascending towards madness

author Jose Quintana <joseluisquintana20@gmail.com> 2021-05-08 13:34:43.0 +00:00:00
committer Jose Quintana <joseluisquintana20@gmail.com> 2021-05-08 13:34:43.0 +00:00:00
commit
dc74a9a14329191cbd710e83bfab571cf6d63e42 [patch]
tree
c7638735f7a7d4b486f1c6a89fb04b3fd54af030
parent
8e2b0ac5c0aee1f7cc1d05506baae15db809d550
download
dc74a9a14329191cbd710e83bfab571cf6d63e42.tar.gz

feat: http/2 + tls support



Diff

 Cargo.lock                   | 193 +++++++++++++++++++++-
 Cargo.toml                   |   3 +-
 src/lib.rs                   |   2 +-
 src/server.rs                |  97 +++++++---
 src/tls.rs                   | 414 ++++++++++++++++++++++++++++++++++++++++++++-
 src/transport.rs             |  56 ++++++-
 tests/tls/local.dev_cert.pem |  26 +++-
 tests/tls/local.dev_key.pem  |  28 +++-
 8 files changed, 794 insertions(+), 25 deletions(-)

diff --git a/Cargo.lock b/Cargo.lock
index 2521eff..7b28003 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -99,6 +99,12 @@ dependencies = [
]

[[package]]
name = "bumpalo"
version = "3.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "63396b8a4b9de3f4fdfb320ab6080762242f66a8ef174c49d8e19b674db4cdbe"

[[package]]
name = "byteorder"
version = "1.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -280,6 +286,31 @@ dependencies = [
]

[[package]]
name = "h2"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "825343c4eef0b63f541f8903f395dc5beb362a979b5799a84062527ef1e37726"
dependencies = [
 "bytes",
 "fnv",
 "futures-core",
 "futures-sink",
 "futures-util",
 "http",
 "indexmap",
 "slab",
 "tokio",
 "tokio-util",
 "tracing",
]

[[package]]
name = "hashbrown"
version = "0.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d7afe4a420e3fe79967a00898cc1f4db7c8a49a9333a29f8a4bd76a253d5cd04"

[[package]]
name = "headers"
version = "0.3.4"
source = "git+https://github.com/joseluisq/hyper-headers.git?branch=headers_encoding#ca704fcb605adf33f327d0f5a41d5072606058a1"
@@ -365,6 +396,7 @@ dependencies = [
 "futures-channel",
 "futures-core",
 "futures-util",
 "h2",
 "http",
 "http-body",
 "httparse",
@@ -379,6 +411,16 @@ dependencies = [
]

[[package]]
name = "indexmap"
version = "1.6.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "824845a0bf897a9042383849b02c1bc219c2383772efcd5c6f9766fa4b81aef3"
dependencies = [
 "autocfg",
 "hashbrown",
]

[[package]]
name = "itertools"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -415,6 +457,15 @@ dependencies = [
]

[[package]]
name = "js-sys"
version = "0.3.50"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2d99f9e3e84b8f67f846ef5b4cbbc3b1c29f6c759fcbce6f01aa0e73d932a24c"
dependencies = [
 "wasm-bindgen",
]

[[package]]
name = "lazy_static"
version = "1.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -667,12 +718,50 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f497285884f3fcff424ffc933e56d7cbca511def0c9831a7f9b5f6153e3cc89b"

[[package]]
name = "ring"
version = "0.16.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3053cf52e236a3ed746dfc745aa9cacf1b791d846bdaf412f60a8d7d6e17c8fc"
dependencies = [
 "cc",
 "libc",
 "once_cell",
 "spin",
 "untrusted",
 "web-sys",
 "winapi",
]

[[package]]
name = "rustls"
version = "0.19.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "35edb675feee39aec9c99fa5ff985081995a06d594114ae14cbe797ad7b7a6d7"
dependencies = [
 "base64",
 "log",
 "ring",
 "sct",
 "webpki",
]

[[package]]
name = "ryu"
version = "1.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "71d301d4193d031abdd79ff7e3dd721168a9572ef3fe51a1517aba235bd8f86e"

[[package]]
name = "sct"
version = "0.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b362b83898e0e69f38515b82ee15aa80636befe47c3b6d3d89a911e78fc228ce"
dependencies = [
 "ring",
 "untrusted",
]

[[package]]
name = "serde"
version = "1.0.125"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -722,6 +811,12 @@ dependencies = [
]

[[package]]
name = "slab"
version = "0.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f173ac3d1a7e3b28003f40de0b5ce7fe2710f9b9dc3fc38664cebee46b3b6527"

[[package]]
name = "smallvec"
version = "1.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -738,6 +833,12 @@ dependencies = [
]

[[package]]
name = "spin"
version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d"

[[package]]
name = "static-web-server"
version = "2.0.0-beta.3"
dependencies = [
@@ -758,6 +859,7 @@ dependencies = [
 "signal",
 "structopt",
 "tokio",
 "tokio-rustls",
 "tokio-util",
 "tracing",
 "tracing-subscriber",
@@ -855,6 +957,17 @@ dependencies = [
]

[[package]]
name = "tokio-rustls"
version = "0.22.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bc6844de72e57df1980054b38be3a9f4702aba4858be64dd700181a8a6d0e1b6"
dependencies = [
 "rustls",
 "tokio",
 "webpki",
]

[[package]]
name = "tokio-util"
version = "0.6.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -989,6 +1102,12 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8ccb82d61f80a663efe1f787a51b16b5a51e3314d6ac365b08639f52387b33f3"

[[package]]
name = "untrusted"
version = "0.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a"

[[package]]
name = "version_check"
version = "0.9.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -1017,6 +1136,80 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1a143597ca7c7793eff794def352d41792a93c481eb1042423ff7ff72ba2c31f"

[[package]]
name = "wasm-bindgen"
version = "0.2.73"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "83240549659d187488f91f33c0f8547cbfef0b2088bc470c116d1d260ef623d9"
dependencies = [
 "cfg-if 1.0.0",
 "wasm-bindgen-macro",
]

[[package]]
name = "wasm-bindgen-backend"
version = "0.2.73"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ae70622411ca953215ca6d06d3ebeb1e915f0f6613e3b495122878d7ebec7dae"
dependencies = [
 "bumpalo",
 "lazy_static",
 "log",
 "proc-macro2",
 "quote",
 "syn",
 "wasm-bindgen-shared",
]

[[package]]
name = "wasm-bindgen-macro"
version = "0.2.73"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3e734d91443f177bfdb41969de821e15c516931c3c3db3d318fa1b68975d0f6f"
dependencies = [
 "quote",
 "wasm-bindgen-macro-support",
]

[[package]]
name = "wasm-bindgen-macro-support"
version = "0.2.73"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d53739ff08c8a68b0fdbcd54c372b8ab800b1449ab3c9d706503bc7dd1621b2c"
dependencies = [
 "proc-macro2",
 "quote",
 "syn",
 "wasm-bindgen-backend",
 "wasm-bindgen-shared",
]

[[package]]
name = "wasm-bindgen-shared"
version = "0.2.73"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d9a543ae66aa233d14bb765ed9af4a33e81b8b58d1584cf1b47ff8cd0b9e4489"

[[package]]
name = "web-sys"
version = "0.3.50"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a905d57e488fec8861446d3393670fb50d27a262344013181c2cdf9fff5481be"
dependencies = [
 "js-sys",
 "wasm-bindgen",
]

[[package]]
name = "webpki"
version = "0.21.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b8e38c0608262c46d4a56202ebabdeb094cef7e560ca7a226c6bf055188aa4ea"
dependencies = [
 "ring",
 "untrusted",
]

[[package]]
name = "winapi"
version = "0.3.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
diff --git a/Cargo.toml b/Cargo.toml
index e6a1cc1..6a887d9 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -25,7 +25,7 @@ name = "static-web-server"
path = "src/bin/server.rs"

[dependencies]
hyper = { version = "0.14", features = ["stream", "http1", "tcp", "server"] }
hyper = { version = "0.14", features = ["stream", "http1", "http2", "tcp", "server"] }
tokio = { version = "1", features = ["rt-multi-thread", "macros", "fs", "io-util"], default-features = false }
futures = { version = "0.3", default-features = false }
async-compression = { version = "0.3", features = ["brotli", "deflate", "gzip", "tokio"] }
@@ -42,6 +42,7 @@ structopt = { version = "0.3", default-features = false }
num_cpus = { version = "1.13" }
once_cell = "1.7"
pin-project = "1.0"
tokio-rustls = { version = "0.22" }

[target.'cfg(not(windows))'.dependencies.nix]
version = "0.14"
diff --git a/src/lib.rs b/src/lib.rs
index f80d600..3ccd4e5 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -12,6 +12,8 @@ pub mod logger;
pub mod server;
pub mod signals;
pub mod static_files;
pub mod tls;
pub mod transport;

#[macro_use]
pub mod error;
diff --git a/src/server.rs b/src/server.rs
index 03a70b4..25c47e9 100644
--- a/src/server.rs
+++ b/src/server.rs
@@ -1,14 +1,15 @@
use hyper::server::conn::AddrIncoming;
use hyper::server::Server as HyperServer;
use hyper::service::{make_service_fn, service_fn};
use std::{
    net::{IpAddr, SocketAddr},
    sync::Arc,
};
use std::net::{IpAddr, SocketAddr};
use std::sync::Arc;
use structopt::StructOpt;

use crate::{config::Config, error_page};
use crate::{error, helpers, logger, Result};
use crate::{handler, static_files::ArcPath};
use crate::config::Config;
use crate::static_files::ArcPath;
use crate::tls::{TlsAcceptor, TlsConfigBuilder};
use crate::Result;
use crate::{error, error_page, handler, helpers, logger};

/// Define a multi-thread HTTP or HTTP/2 web server.
pub struct Server {
@@ -75,25 +76,73 @@ impl Server {

        // TODO: CORS support

        // TODO: HTTP/2 + TLS

        // Spawn a new Tokio asynchronous server task with its given options
        tokio::task::spawn(async move {
            let span = tracing::info_span!("Server::run", ?addr, threads = ?self.threads);
            tracing::info!(parent: &span, "listening on http://{}", addr);

            let make_service = make_service_fn(move |_| {
                let root_dir = root_dir.clone();
                async move {
                    Ok::<_, error::Error>(service_fn(move |req| {
                        let root_dir = root_dir.clone();
                        async move { handler::handle_request(root_dir.as_ref(), &req).await }
                    }))
                }
        let threads = self.threads;

        if opts.http2 {
            // HTTP/2 + TLS

            let cert_path = opts.http2_tls_cert.clone();
            let key_path = opts.http2_tls_key.clone();

            tokio::task::spawn(async move {
                let make_service = make_service_fn(move |_| {
                    let root_dir = root_dir.clone();
                    async move {
                        Ok::<_, error::Error>(service_fn(move |req| {
                            let root_dir = root_dir.clone();
                            async move { handler::handle_request(root_dir.as_ref(), &req).await }
                        }))
                    }
                });

                let mut incoming = AddrIncoming::bind(&addr)?;
                incoming.set_nodelay(true);

                let tls = TlsConfigBuilder::new()
                    .cert_path(cert_path)
                    .key_path(key_path)
                    .build()
                    .unwrap();

                let server =
                    HyperServer::builder(TlsAcceptor::new(tls, incoming)).serve(make_service);

                tracing::info!(
                    parent: tracing::info_span!("Server::start_server", ?addr, ?threads),
                    "listening on https://{}",
                    addr
                );

                server.await
            });

            HyperServer::bind(&addr).serve(make_service).await
        });
        } else {
            // HTTP/1

            tokio::task::spawn(async move {
                let make_service = make_service_fn(move |_| {
                    let root_dir = root_dir.clone();
                    async move {
                        Ok::<_, error::Error>(service_fn(move |req| {
                            let root_dir = root_dir.clone();
                            async move { handler::handle_request(root_dir.as_ref(), &req).await }
                        }))
                    }
                });

                let server = HyperServer::bind(&addr)
                    .tcp_nodelay(true)
                    .serve(make_service);

                tracing::info!(
                    parent: tracing::info_span!("Server::start_server", ?addr, ?threads),
                    "listening on http://{}",
                    addr
                );

                server.await
            });
        }

        handle_signals();

diff --git a/src/tls.rs b/src/tls.rs
new file mode 100644
index 0000000..d8e7df0
--- /dev/null
+++ b/src/tls.rs
@@ -0,0 +1,414 @@
// Handles requests over TLS
// -> Most of the file is borrowed from https://github.com/seanmonstar/warp/blob/master/src/tls.rs

use std::fs::File;
use std::future::Future;
use std::io::{self, BufReader, Cursor, Read};
use std::net::SocketAddr;
use std::path::{Path, PathBuf};
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};

use futures::ready;
use hyper::server::accept::Accept;
use hyper::server::conn::{AddrIncoming, AddrStream};

use crate::transport::Transport;
use tokio_rustls::rustls::{
    AllowAnyAnonymousOrAuthenticatedClient, AllowAnyAuthenticatedClient, NoClientAuth,
    RootCertStore, ServerConfig, TLSError,
};

/// Represents errors that can occur building the TlsConfig
#[derive(Debug)]
pub enum TlsConfigError {
    Io(io::Error),
    /// An Error parsing the Certificate
    CertParseError,
    /// An Error parsing a Pkcs8 key
    Pkcs8ParseError,
    /// An Error parsing a Rsa key
    RsaParseError,
    /// An error from an empty key
    EmptyKey,
    /// An error from an invalid key
    InvalidKey(TLSError),
}

impl std::fmt::Display for TlsConfigError {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            TlsConfigError::Io(err) => err.fmt(f),
            TlsConfigError::CertParseError => write!(f, "certificate parse error"),
            TlsConfigError::Pkcs8ParseError => write!(f, "pkcs8 parse error"),
            TlsConfigError::RsaParseError => write!(f, "rsa parse error"),
            TlsConfigError::EmptyKey => write!(f, "key contains no private key"),
            TlsConfigError::InvalidKey(err) => write!(f, "key contains an invalid key, {}", err),
        }
    }
}

impl std::error::Error for TlsConfigError {}

/// Tls client authentication configuration.
pub enum TlsClientAuth {
    /// No client auth.
    Off,
    /// Allow any anonymous or authenticated client.
    Optional(Box<dyn Read + Send + Sync>),
    /// Allow any authenticated client.
    Required(Box<dyn Read + Send + Sync>),
}

/// Builder to set the configuration for the Tls server.
pub struct TlsConfigBuilder {
    cert: Box<dyn Read + Send + Sync>,
    key: Box<dyn Read + Send + Sync>,
    client_auth: TlsClientAuth,
    ocsp_resp: Vec<u8>,
}

impl std::fmt::Debug for TlsConfigBuilder {
    fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
        f.debug_struct("TlsConfigBuilder").finish()
    }
}

impl TlsConfigBuilder {
    /// Create a new TlsConfigBuilder
    pub fn new() -> TlsConfigBuilder {
        TlsConfigBuilder {
            key: Box::new(io::empty()),
            cert: Box::new(io::empty()),
            client_auth: TlsClientAuth::Off,
            ocsp_resp: Vec::new(),
        }
    }

    /// sets the Tls key via File Path, returns `TlsConfigError::IoError` if the file cannot be open
    pub fn key_path(mut self, path: impl AsRef<Path>) -> Self {
        self.key = Box::new(LazyFile {
            path: path.as_ref().into(),
            file: None,
        });
        self
    }

    /// sets the Tls key via bytes slice
    pub fn key(mut self, key: &[u8]) -> Self {
        self.key = Box::new(Cursor::new(Vec::from(key)));
        self
    }

    /// Specify the file path for the TLS certificate to use.
    pub fn cert_path(mut self, path: impl AsRef<Path>) -> Self {
        self.cert = Box::new(LazyFile {
            path: path.as_ref().into(),
            file: None,
        });
        self
    }

    /// sets the Tls certificate via bytes slice
    pub fn cert(mut self, cert: &[u8]) -> Self {
        self.cert = Box::new(Cursor::new(Vec::from(cert)));
        self
    }

    /// Sets the trust anchor for optional Tls client authentication via file path.
    ///
    /// Anonymous and authenticated clients will be accepted. If no trust anchor is provided by any
    /// of the `client_auth_` methods, then client authentication is disabled by default.
    pub fn client_auth_optional_path(mut self, path: impl AsRef<Path>) -> Self {
        let file = Box::new(LazyFile {
            path: path.as_ref().into(),
            file: None,
        });
        self.client_auth = TlsClientAuth::Optional(file);
        self
    }

    /// Sets the trust anchor for optional Tls client authentication via bytes slice.
    ///
    /// Anonymous and authenticated clients will be accepted. If no trust anchor is provided by any
    /// of the `client_auth_` methods, then client authentication is disabled by default.
    pub fn client_auth_optional(mut self, trust_anchor: &[u8]) -> Self {
        let cursor = Box::new(Cursor::new(Vec::from(trust_anchor)));
        self.client_auth = TlsClientAuth::Optional(cursor);
        self
    }

    /// Sets the trust anchor for required Tls client authentication via file path.
    ///
    /// Only authenticated clients will be accepted. If no trust anchor is provided by any of the
    /// `client_auth_` methods, then client authentication is disabled by default.
    pub fn client_auth_required_path(mut self, path: impl AsRef<Path>) -> Self {
        let file = Box::new(LazyFile {
            path: path.as_ref().into(),
            file: None,
        });
        self.client_auth = TlsClientAuth::Required(file);
        self
    }

    /// Sets the trust anchor for required Tls client authentication via bytes slice.
    ///
    /// Only authenticated clients will be accepted. If no trust anchor is provided by any of the
    /// `client_auth_` methods, then client authentication is disabled by default.
    pub fn client_auth_required(mut self, trust_anchor: &[u8]) -> Self {
        let cursor = Box::new(Cursor::new(Vec::from(trust_anchor)));
        self.client_auth = TlsClientAuth::Required(cursor);
        self
    }

    /// sets the DER-encoded OCSP response
    pub fn ocsp_resp(mut self, ocsp_resp: &[u8]) -> Self {
        self.ocsp_resp = Vec::from(ocsp_resp);
        self
    }

    pub fn build(mut self) -> Result<ServerConfig, TlsConfigError> {
        let mut cert_rdr = BufReader::new(self.cert);
        let cert = tokio_rustls::rustls::internal::pemfile::certs(&mut cert_rdr)
            .map_err(|()| TlsConfigError::CertParseError)?;

        let key = {
            // convert it to Vec<u8> to allow reading it again if key is RSA
            let mut key_vec = Vec::new();
            self.key
                .read_to_end(&mut key_vec)
                .map_err(TlsConfigError::Io)?;

            if key_vec.is_empty() {
                return Err(TlsConfigError::EmptyKey);
            }

            let mut pkcs8 = tokio_rustls::rustls::internal::pemfile::pkcs8_private_keys(
                &mut key_vec.as_slice(),
            )
            .map_err(|()| TlsConfigError::Pkcs8ParseError)?;

            if !pkcs8.is_empty() {
                pkcs8.remove(0)
            } else {
                let mut rsa = tokio_rustls::rustls::internal::pemfile::rsa_private_keys(
                    &mut key_vec.as_slice(),
                )
                .map_err(|()| TlsConfigError::RsaParseError)?;

                if !rsa.is_empty() {
                    rsa.remove(0)
                } else {
                    return Err(TlsConfigError::EmptyKey);
                }
            }
        };

        fn read_trust_anchor(
            trust_anchor: Box<dyn Read + Send + Sync>,
        ) -> Result<RootCertStore, TlsConfigError> {
            let mut reader = BufReader::new(trust_anchor);
            let mut store = RootCertStore::empty();
            if let Ok((0, _)) | Err(()) = store.add_pem_file(&mut reader) {
                Err(TlsConfigError::CertParseError)
            } else {
                Ok(store)
            }
        }

        let client_auth = match self.client_auth {
            TlsClientAuth::Off => NoClientAuth::new(),
            TlsClientAuth::Optional(trust_anchor) => {
                AllowAnyAnonymousOrAuthenticatedClient::new(read_trust_anchor(trust_anchor)?)
            }
            TlsClientAuth::Required(trust_anchor) => {
                AllowAnyAuthenticatedClient::new(read_trust_anchor(trust_anchor)?)
            }
        };

        let mut config = ServerConfig::new(client_auth);
        config
            .set_single_cert_with_ocsp_and_sct(cert, key, self.ocsp_resp, Vec::new())
            .map_err(TlsConfigError::InvalidKey)?;
        config.set_protocols(&["h2".into(), "http/1.1".into()]);
        Ok(config)
    }
}

impl Default for TlsConfigBuilder {
    fn default() -> Self {
        Self::new()
    }
}

struct LazyFile {
    path: PathBuf,
    file: Option<File>,
}

impl LazyFile {
    fn lazy_read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
        if self.file.is_none() {
            self.file = Some(File::open(&self.path)?);
        }

        self.file.as_mut().unwrap().read(buf)
    }
}

impl Read for LazyFile {
    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
        self.lazy_read(buf).map_err(|err| {
            let kind = err.kind();
            io::Error::new(
                kind,
                format!("error reading file ({:?}): {}", self.path.display(), err),
            )
        })
    }
}

impl Transport for TlsStream {
    fn remote_addr(&self) -> Option<SocketAddr> {
        Some(self.remote_addr)
    }
}

enum State {
    Handshaking(tokio_rustls::Accept<AddrStream>),
    Streaming(tokio_rustls::server::TlsStream<AddrStream>),
}

// tokio_rustls::server::TlsStream doesn't expose constructor methods,
// so we have to TlsAcceptor::accept and handshake to have access to it
// TlsStream implements AsyncRead/AsyncWrite handshaking tokio_rustls::Accept first
pub struct TlsStream {
    state: State,
    remote_addr: SocketAddr,
}

impl TlsStream {
    fn new(stream: AddrStream, config: Arc<ServerConfig>) -> TlsStream {
        let remote_addr = stream.remote_addr();
        let accept = tokio_rustls::TlsAcceptor::from(config).accept(stream);
        TlsStream {
            state: State::Handshaking(accept),
            remote_addr,
        }
    }
}

impl AsyncRead for TlsStream {
    fn poll_read(
        self: Pin<&mut Self>,
        cx: &mut Context,
        buf: &mut ReadBuf,
    ) -> Poll<io::Result<()>> {
        let pin = self.get_mut();
        match pin.state {
            State::Handshaking(ref mut accept) => match ready!(Pin::new(accept).poll(cx)) {
                Ok(mut stream) => {
                    let result = Pin::new(&mut stream).poll_read(cx, buf);
                    pin.state = State::Streaming(stream);
                    result
                }
                Err(err) => Poll::Ready(Err(err)),
            },
            State::Streaming(ref mut stream) => Pin::new(stream).poll_read(cx, buf),
        }
    }
}

impl AsyncWrite for TlsStream {
    fn poll_write(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        buf: &[u8],
    ) -> Poll<io::Result<usize>> {
        let pin = self.get_mut();
        match pin.state {
            State::Handshaking(ref mut accept) => match ready!(Pin::new(accept).poll(cx)) {
                Ok(mut stream) => {
                    let result = Pin::new(&mut stream).poll_write(cx, buf);
                    pin.state = State::Streaming(stream);
                    result
                }
                Err(err) => Poll::Ready(Err(err)),
            },
            State::Streaming(ref mut stream) => Pin::new(stream).poll_write(cx, buf),
        }
    }

    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
        match self.state {
            State::Handshaking(_) => Poll::Ready(Ok(())),
            State::Streaming(ref mut stream) => Pin::new(stream).poll_flush(cx),
        }
    }

    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
        match self.state {
            State::Handshaking(_) => Poll::Ready(Ok(())),
            State::Streaming(ref mut stream) => Pin::new(stream).poll_shutdown(cx),
        }
    }
}

pub struct TlsAcceptor {
    config: Arc<ServerConfig>,
    incoming: AddrIncoming,
}

impl TlsAcceptor {
    pub fn new(config: ServerConfig, incoming: AddrIncoming) -> TlsAcceptor {
        TlsAcceptor {
            config: Arc::new(config),
            incoming,
        }
    }
}

impl Accept for TlsAcceptor {
    type Conn = TlsStream;
    type Error = io::Error;

    fn poll_accept(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
    ) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
        let pin = self.get_mut();
        match ready!(Pin::new(&mut pin.incoming).poll_accept(cx)) {
            Some(Ok(sock)) => Poll::Ready(Some(Ok(TlsStream::new(sock, pin.config.clone())))),
            Some(Err(e)) => Poll::Ready(Some(Err(e))),
            None => Poll::Ready(None),
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn file_cert_key() {
        TlsConfigBuilder::new()
            .cert_path("tests/tls/local.dev_cert.pem")
            .key_path("tests/tls/local.dev_key.pem")
            .build()
            .unwrap();
    }

    #[test]
    fn bytes_cert_key() {
        let cert = include_str!("../tests/tls/local.dev_cert.pem");
        let key = include_str!("../tests/tls/local.dev_key.pem");

        TlsConfigBuilder::new()
            .key(key.as_bytes())
            .cert(cert.as_bytes())
            .build()
            .unwrap();
    }
}
diff --git a/src/transport.rs b/src/transport.rs
new file mode 100644
index 0000000..7f52ed7
--- /dev/null
+++ b/src/transport.rs
@@ -0,0 +1,56 @@
// Handles requests over TLS
// -> Most of the file is borrowed from https://github.com/seanmonstar/warp/blob/master/src/transport.rs

use std::io;
use std::net::SocketAddr;
use std::pin::Pin;
use std::task::{Context, Poll};

use hyper::server::conn::AddrStream;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};

pub trait Transport: AsyncRead + AsyncWrite {
    fn remote_addr(&self) -> Option<SocketAddr>;
}

impl Transport for AddrStream {
    fn remote_addr(&self) -> Option<SocketAddr> {
        Some(self.remote_addr())
    }
}

pub struct LiftIo<T>(pub T);

impl<T: AsyncRead + Unpin> AsyncRead for LiftIo<T> {
    fn poll_read(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        buf: &mut ReadBuf<'_>,
    ) -> Poll<io::Result<()>> {
        Pin::new(&mut self.get_mut().0).poll_read(cx, buf)
    }
}

impl<T: AsyncWrite + Unpin> AsyncWrite for LiftIo<T> {
    fn poll_write(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        buf: &[u8],
    ) -> Poll<io::Result<usize>> {
        Pin::new(&mut self.get_mut().0).poll_write(cx, buf)
    }

    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
        Pin::new(&mut self.get_mut().0).poll_flush(cx)
    }

    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
        Pin::new(&mut self.get_mut().0).poll_shutdown(cx)
    }
}

impl<T: AsyncRead + AsyncWrite + Unpin> Transport for LiftIo<T> {
    fn remote_addr(&self) -> Option<SocketAddr> {
        None
    }
}
diff --git a/tests/tls/local.dev_cert.pem b/tests/tls/local.dev_cert.pem
new file mode 100644
index 0000000..dbfa7eb
--- /dev/null
+++ b/tests/tls/local.dev_cert.pem
@@ -0,0 +1,26 @@
-----BEGIN CERTIFICATE-----
MIIEUDCCArigAwIBAgIRAJ+wfoSa2gGM7bMna/sxtB4wDQYJKoZIhvcNAQELBQAw
YTEeMBwGA1UEChMVbWtjZXJ0IGRldmVsb3BtZW50IENBMRswGQYDVQQLDBJqb3Nl
bHVpc3FAcXVpbnRhbmExIjAgBgNVBAMMGW1rY2VydCBqb3NlbHVpc3FAcXVpbnRh
bmEwHhcNMTkwODI1MjIzNDM4WhcNMjkwODI1MjIzNDM4WjBGMScwJQYDVQQKEx5t
a2NlcnQgZGV2ZWxvcG1lbnQgY2VydGlmaWNhdGUxGzAZBgNVBAsMEmpvc2VsdWlz
cUBxdWludGFuYTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAL7Lg4jL
kuabbfC2Qbv2iU+fCPKMht9LUT16VBh0cOxqpd75aTj6qikaDmYQZYJcAJYD2Hfh
fgP6dsT6/VRw7oWWYD/h9f7cz9xKjLRl/jBN1ob7VMbzJTFiJ4ajMZI5g/Yy6azC
/HEAlFGkXWfwblJPQdZHoQLksTSaHS5NR7RnmFMkgYxyaqIpkXNqUtyc+f5nUW6t
1VRoVBfG6V+LFY4IRYXoYehI5q+uK6w6jNEDHnDUTLagFc+D2UgMXQtG7TtvHAQz
jjTzpmb4pwmemkdc1xJlRa/1UdsPYHffjE2vUm6xrVJ07zvcxkS9gLwXKLLzuHnU
I2brgY0DdzFx3s0CAwEAAaOBnTCBmjAOBgNVHQ8BAf8EBAMCBaAwEwYDVR0lBAww
CgYIKwYBBQUHAwEwDAYDVR0TAQH/BAIwADAfBgNVHSMEGDAWgBTmeQONv1LFhIi6
WK47Dmc46TuFBDBEBgNVHREEPTA7gglsb2NhbC5kZXaCCyoubG9jYWwuZGV2ggls
b2NhbGhvc3SHBH8AAAGHEAAAAAAAAAAAAAAAAAAAAAEwDQYJKoZIhvcNAQELBQAD
ggGBADlgyQy/bwIekxRITUXnArLO9//C+I9hDOVs4GnY6nZJ0XwUOfWHq90U92Yh
YmCcQOBstYclBL9KzVHAOLa0erTEqbh1+2ZRrY8vzAf7RGwaZsE4uj6bB3KdOa00
zvkyHNYJnvL1xdOJAWckbaMgnBJwEGQGA9Bk5urozDYhbwIZS5PKXGPcLeiHIvn5
taC4x0fsCk4QkkPhOk92NjUD5t70vGQ5ty69fD11p1GOrC0szHZjnEdeW7SfPtsY
5qES+U9ppbJFeaFK/hhlRSdXjqk4a/P/HdM52QDvkrujk3DJYmNSQGdCa3fxiAnK
ivEBoYVIyVKRrCKNhyw8D4uWEUrMbsoo9/joAJYFOPHeYhSmkxA9HN0GvGBQ1MH4
zPd9B+hw90f8YokfGOH3dQiHAvvUyb1//uYN1FOlp/a9cTx0Y8oXTZuTvRL/259+
NjAizN+fctVbGloPEvTlxPkqveLNmzzJBk1bbj+Gt6tPqXN+DecNQMsMRzJ3HFOk
4EcwBQ==
-----END CERTIFICATE-----
diff --git a/tests/tls/local.dev_key.pem b/tests/tls/local.dev_key.pem
new file mode 100644
index 0000000..d924d4d
--- /dev/null
+++ b/tests/tls/local.dev_key.pem
@@ -0,0 +1,28 @@
-----BEGIN PRIVATE KEY-----
MIIEwAIBADANBgkqhkiG9w0BAQEFAASCBKowggSmAgEAAoIBAQC+y4OIy5Lmm23w
tkG79olPnwjyjIbfS1E9elQYdHDsaqXe+Wk4+qopGg5mEGWCXACWA9h34X4D+nbE
+v1UcO6FlmA/4fX+3M/cSoy0Zf4wTdaG+1TG8yUxYieGozGSOYP2MumswvxxAJRR
pF1n8G5ST0HWR6EC5LE0mh0uTUe0Z5hTJIGMcmqiKZFzalLcnPn+Z1FurdVUaFQX
xulfixWOCEWF6GHoSOavriusOozRAx5w1Ey2oBXPg9lIDF0LRu07bxwEM44086Zm
+KcJnppHXNcSZUWv9VHbD2B334xNr1Jusa1SdO873MZEvYC8Fyiy87h51CNm64GN
A3cxcd7NAgMBAAECggEBALSchOR/CY3hvt4qOenMBMnpm5e3rYk9jCctYORRfgBf
KKv94Dy/FUuZTd4SUXVo0GkyNL2vKRJtC/eGPT+tNC4jXvO6XJspvl8j9zRihJCH
brgSvXsj+qZX62DJpYhth90M7yXK4xu51629cWqOMHEcdA97eRD7GkDYTx1grKs5
7ykYki3NNGQFDncSmQz/ZjHs/W44byVKdKVLUHWeexfkOFZ4tmr4gDcLG+M6f6m3
TTDOIdh9FvpNBOyg+GDWgJbn1nw6PYF3c2cOMQopRwAQKuHfVwpbF+zzxvtcCTkF
GmsprSdLTeXY4v2RT+kla9Hmgot1XIPY6iMvXUkkhwUCgYEA07BuPYWTxY0gfNo3
CrTNhhGyW4IA8wjwA57ao71Eg7vzhTub+sMZXCMMpFibIGD3pEcW8hG4ke5ghH3n
4jxNBCtFX0q3OHAbBtStX03iggsDoy8piYLxjHrRp+pxEDncwqFaIhWhR0S5wi/M
u2+hE0A9pWAhc/y+DWnoUZvFTL8CgYEA5rtyavAGMA0m8hN9uwMtmc7gSFp0oo8a
mm1pDFe8Z8Mv/SG16pYAuM3wUa+KqfdRXOf6vHvI4OmX1PuiqocECf6acOJu9lzg
bU0WTwoweusGISY5oYUzQ98lkbOVpGR5+1kslACVQmzvX8+EHqFIbdS2de29TGux
vj9drfYX23MCgYEAobKGwp+h/KiMRFI68QaiZuJlpthq+Tm+fEV/JMuR5j5PCVo7
DxSv7l0nbvHvrI/lGarjsAwxO+cl+o5h7cG54pFa8CsWQRoAyvrxY3cOqd7X7HI9
/Df1YiT+uJCvxIEuS80MGDUFeHbanaX9cL8X/qh3bjc71mkckwpu1sdxsekCgYEA
tYjnpeGBTM8cRDw3oSsH9srAxcx9leS3zqakjvR8pLr6h9O9GHu6x6woF2zg0Ydn
uYw/R4qw6tx+/DCbtEWUVPS/uG8/VJCQdw6+raNbr2o4oV4826s8QXtRSMidxQDU
xIBNxYiL5v5ke+J+lcbZgKhqgnBxjq3w47lhUFyeOqcCgYEAy3leJK+E9xtXcCGO
WszUz9LU6UqrTqtiFP29ZLmd+VG17/1bt8az6wChMqUJfHa7i8DuG1hANNmiJimE
lpTJY6rxzoq8wkaUi7MqZnkACeWLnLT9i5BZDPTwsNSxkKezYg7j3zoQVj4d86PN
A/DYsS6Gzzwo60cYfO/Kcfwb6vE=
-----END PRIVATE KEY-----