index : static-web-server.git

ascending towards madness

author Syrus Akbary <me@syrusakbary.com> 2023-03-04 6:25:34.0 +00:00:00
committer GitHub <noreply@github.com> 2023-03-04 6:25:34.0 +00:00:00
commit
87a0896bf0fcf1f2bb2d8e9868127484ddba995f [patch]
tree
0d28391d64ce07afe3a2796863d7b94a5e642987
parent
9796d35a8b50e15cbe3abecd8888065a1bdc888d
download
87a0896bf0fcf1f2bb2d8e9868127484ddba995f.tar.gz

feat: add optional feature flag for http2 (#183)

Co-authored-by: John Sharratt's Shared Account <johnathan.sharratt@gmail.com>

Diff

 Cargo.toml           |  7 ++++-
 src/lib.rs           |  1 +-
 src/server.rs        | 82 +++++++++++++++++++++++++++++------------------------
 src/settings/cli.rs  |  3 ++-
 src/settings/file.rs |  3 ++-
 src/settings/mod.rs  |  9 ++++++-
 6 files changed, 67 insertions(+), 38 deletions(-)

diff --git a/Cargo.toml b/Cargo.toml
index 4fb5812..cfbdafd 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -27,6 +27,11 @@ include = ["src/**/*", "Cargo.toml", "Cargo.lock"]
name = "static-web-server"
path = "src/bin/server.rs"

[features]
default = ["http2"]
tls = ["tokio-rustls"]
http2 = ["tls"]

[dependencies]
anyhow = "1.0"
async-compression = { version = "0.3", default-features = false, features = ["brotli", "deflate", "gzip", "tokio"] }
@@ -52,7 +57,7 @@ serde_repr = "0.1"
structopt = { version = "0.3", default-features = false }
chrono = { version = "0.4", default-features = false, features = ["std", "clock"] }
tokio = { version = "1", default-features = false, features = ["rt-multi-thread", "macros", "fs", "io-util", "signal"] }
tokio-rustls = { version = "0.23" }
tokio-rustls = { version = "0.23", optional = true }
tokio-util = { version = "0.7", default-features = false, features = ["io"] }
toml = "0.5"
tracing = { version = "0.1", default-features = false, features = ["std"] }
diff --git a/src/lib.rs b/src/lib.rs
index 0c4c385..5f32657 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -30,6 +30,7 @@ pub mod service;
pub mod settings;
pub mod signals;
pub mod static_files;
#[cfg(feature = "tls")]
pub mod tls;
pub mod transport;

diff --git a/src/server.rs b/src/server.rs
index 418e7db..65ed7c4 100644
--- a/src/server.rs
+++ b/src/server.rs
@@ -1,4 +1,6 @@
#[allow(unused_imports)]
use hyper::server::conn::AddrIncoming;
#[allow(unused_imports)]
use hyper::server::Server as HyperServer;
use listenfd::ListenFd;
use std::net::{IpAddr, SocketAddr, TcpListener};
@@ -6,6 +8,7 @@ use std::sync::Arc;
use tokio::sync::oneshot::Receiver;

use crate::handler::{RequestHandler, RequestHandlerOpts};
#[cfg(feature = "tls")]
use crate::tls::{TlsAcceptor, TlsConfigBuilder};
use crate::{cors, helpers, logger, signals, Settings};
use crate::{service::RouterService, Context, Result};
@@ -62,12 +65,15 @@ impl Server {
            .enable_all()
            .build()?
            .block_on(async {
                tracing::trace!("starting web server");
                if let Err(err) = self.start_server(cancel_recv, cancel_fn).await {
                    tracing::error!("server failed to start up: {:?}", err);
                    std::process::exit(1)
                }
            });

        tracing::trace!("runtime initialized");

        Ok(())
    }

@@ -215,7 +221,7 @@ impl Server {
        });

        // Run the corresponding HTTP Server asynchronously with its given options

        #[cfg(feature = "http2")]
        if general.http2 {
            // HTTP/2 + TLS

@@ -277,50 +283,52 @@ impl Server {

            #[cfg(unix)]
            handle.close();
        } else {
            // HTTP/1

            #[cfg(unix)]
            let signals = signals::create_signals()
                .with_context(|| "failed to register termination signals")?;
            #[cfg(unix)]
            let handle = signals.handle();

            tcp_listener
                .set_nonblocking(true)
                .with_context(|| "failed to set TCP non-blocking mode")?;

            let server = HyperServer::from_tcp(tcp_listener)
                .unwrap()
                .tcp_nodelay(true)
                .serve(router_service);
            tracing::warn!("termination signal caught, shutting down the server execution");
            return Ok(());
        }

            #[cfg(unix)]
            let server =
                server.with_graceful_shutdown(signals::wait_for_signals(signals, grace_period));
            #[cfg(windows)]
            let server = server.with_graceful_shutdown(signals::wait_for_ctrl_c(
                _cancel_recv,
                _cancel_fn,
                grace_period,
            ));
        // HTTP/1

        #[cfg(unix)]
        let signals =
            signals::create_signals().with_context(|| "failed to register termination signals")?;
        #[cfg(unix)]
        let handle = signals.handle();

        tcp_listener
            .set_nonblocking(true)
            .with_context(|| "failed to set TCP non-blocking mode")?;

        let server = HyperServer::from_tcp(tcp_listener)
            .unwrap()
            .tcp_nodelay(true)
            .serve(router_service);

        #[cfg(unix)]
        let server =
            server.with_graceful_shutdown(signals::wait_for_signals(signals, grace_period));
        #[cfg(windows)]
        let server = server.with_graceful_shutdown(signals::wait_for_ctrl_c(
            _cancel_recv,
            _cancel_fn,
            grace_period,
        ));

            tracing::info!(
                parent: tracing::info_span!("Server::start_server", ?addr_str, ?threads),
                "listening on http://{}",
                addr_str
            );
        tracing::info!(
            parent: tracing::info_span!("Server::start_server", ?addr_str, ?threads),
            "listening on http://{}",
            addr_str
        );

            tracing::info!("press ctrl+c to shut down the server");
        tracing::info!("press ctrl+c to shut down the server");

            server.await?;
        server.await?;

            #[cfg(unix)]
            handle.close();
        }
        #[cfg(unix)]
        handle.close();

        tracing::warn!("termination signal caught, shutting down the server execution");

        Ok(())
    }
}
diff --git a/src/settings/cli.rs b/src/settings/cli.rs
index 8746958..9b541c1 100644
--- a/src/settings/cli.rs
+++ b/src/settings/cli.rs
@@ -105,14 +105,17 @@ pub struct General {
        default_value = "false",
        env = "SERVER_HTTP2_TLS"
    )]
    #[cfg(feature = "http2")]
    /// Enable HTTP/2 with TLS support.
    pub http2: bool,

    #[structopt(long, required_if("http2", "true"), env = "SERVER_HTTP2_TLS_CERT")]
    #[cfg(feature = "http2")]
    /// Specify the file path to read the certificate.
    pub http2_tls_cert: Option<PathBuf>,

    #[structopt(long, required_if("http2", "true"), env = "SERVER_HTTP2_TLS_KEY")]
    #[cfg(feature = "http2")]
    /// Specify the file path to read the private key.
    pub http2_tls_key: Option<PathBuf>,

diff --git a/src/settings/file.rs b/src/settings/file.rs
index 35b060f..bf103bc 100644
--- a/src/settings/file.rs
+++ b/src/settings/file.rs
@@ -102,8 +102,11 @@ pub struct General {
    pub page50x: Option<PathBuf>,

    // HTTP/2 + TLS
    #[cfg(feature = "http2")]
    pub http2: Option<bool>,
    #[cfg(feature = "http2")]
    pub http2_tls_cert: Option<PathBuf>,
    #[cfg(feature = "http2")]
    pub http2_tls_key: Option<PathBuf>,

    // Security headers
diff --git a/src/settings/mod.rs b/src/settings/mod.rs
index 47483e1..9506f19 100644
--- a/src/settings/mod.rs
+++ b/src/settings/mod.rs
@@ -70,8 +70,11 @@ impl Settings {
        let mut compression_static = opts.compression_static;
        let mut page404 = opts.page404;
        let mut page50x = opts.page50x;
        #[cfg(feature = "http2")]
        let mut http2 = opts.http2;
        #[cfg(feature = "http2")]
        let mut http2_tls_cert = opts.http2_tls_cert;
        #[cfg(feature = "http2")]
        let mut http2_tls_key = opts.http2_tls_key;
        let mut security_headers = opts.security_headers;
        let mut cors_allow_origins = opts.cors_allow_origins;
@@ -140,12 +143,15 @@ impl Settings {
                    if let Some(v) = general.page50x {
                        page50x = v
                    }
                    #[cfg(feature = "http2")]
                    if let Some(v) = general.http2 {
                        http2 = v
                    }
                    #[cfg(feature = "http2")]
                    if let Some(v) = general.http2_tls_cert {
                        http2_tls_cert = Some(v)
                    }
                    #[cfg(feature = "http2")]
                    if let Some(v) = general.http2_tls_key {
                        http2_tls_key = Some(v)
                    }
@@ -307,8 +313,11 @@ impl Settings {
                compression_static,
                page404,
                page50x,
                #[cfg(feature = "http2")]
                http2,
                #[cfg(feature = "http2")]
                http2_tls_cert,
                #[cfg(feature = "http2")]
                http2_tls_key,
                security_headers,
                cors_allow_origins,