index : static-web-server.git

ascending towards madness

author Jose Quintana <joseluisquintana20@gmail.com> 2022-03-07 23:01:16.0 +00:00:00
committer Jose Quintana <joseluisquintana20@gmail.com> 2022-03-07 23:03:17.0 +00:00:00
commit
d33d093a781174beb580bd97fdeadc0dafdde244 [patch]
tree
625ecbfcd22aa539440529d11daec0f8192dd23c
parent
b831bb92e08bffcae0345c0e476280569f727716
download
d33d093a781174beb580bd97fdeadc0dafdde244.tar.gz

refactor: small optimizations on request handler



Diff

 src/cors.rs    | 10 +++++-----
 src/handler.rs | 41 ++++++++++++++++++++---------------------
 src/server.rs  | 23 +++++++++--------------
 tests/cors.rs  | 18 +++++++++---------
 4 files changed, 43 insertions(+), 49 deletions(-)

diff --git a/src/cors.rs b/src/cors.rs
index 6d53e53..24286ad 100644
--- a/src/cors.rs
+++ b/src/cors.rs
@@ -6,7 +6,7 @@ use headers::{
    Origin,
};
use http::header;
use std::{collections::HashSet, convert::TryFrom, sync::Arc};
use std::{collections::HashSet, convert::TryFrom};

/// It defines CORS instance.
#[derive(Clone, Debug)]
@@ -18,7 +18,7 @@ pub struct Cors {
}

/// It builds a new CORS instance.
pub fn new(origins_str: String, headers_str: String) -> Option<Arc<Configured>> {
pub fn new(origins_str: &str, headers_str: &str) -> Option<Configured> {
    let cors = Cors::new();
    let cors = if origins_str.is_empty() {
        None
@@ -154,18 +154,18 @@ impl Cors {
    }

    /// Builds the `Cors` wrapper from the configured settings.
    pub fn build(cors: Option<Cors>) -> Option<Arc<Configured>> {
    pub fn build(cors: Option<Cors>) -> Option<Configured> {
        cors.as_ref()?;
        let cors = cors?;

        let allowed_headers = cors.allowed_headers.iter().cloned().collect();
        let methods_header = cors.allowed_methods.iter().cloned().collect();

        Some(Arc::new(Configured {
        Some(Configured {
            cors,
            allowed_headers,
            methods_header,
        }))
        })
    }
}

diff --git a/src/handler.rs b/src/handler.rs
index 4df4158..bbae8f5 100644
--- a/src/handler.rs
+++ b/src/handler.rs
@@ -8,21 +8,21 @@ use crate::{Error, Result};

/// It defines options for a request handler.
pub struct RequestHandlerOpts {
    pub root_dir: Arc<PathBuf>,
    pub root_dir: PathBuf,
    pub compression: bool,
    pub dir_listing: bool,
    pub dir_listing_order: u8,
    pub cors: Option<Arc<cors::Configured>>,
    pub cors: Option<cors::Configured>,
    pub security_headers: bool,
    pub cache_control_headers: bool,
    pub page404: Arc<str>,
    pub page50x: Arc<str>,
    pub basic_auth: Arc<str>,
    pub page404: String,
    pub page50x: String,
    pub basic_auth: String,
}

/// It defines the main request handler used by the Hyper service request.
pub struct RequestHandler {
    pub opts: RequestHandlerOpts,
    pub opts: Arc<RequestHandlerOpts>,
}

impl RequestHandler {
@@ -35,7 +35,7 @@ impl RequestHandler {
        let headers = req.headers();
        let uri = req.uri();

        let root_dir = self.opts.root_dir.as_ref();
        let root_dir = &self.opts.root_dir;
        let uri_path = uri.path();
        let uri_query = uri.query();
        let dir_listing = self.opts.dir_listing;
@@ -49,14 +49,13 @@ impl RequestHandler {
                return error_page::error_response(
                    method,
                    &StatusCode::METHOD_NOT_ALLOWED,
                    self.opts.page404.as_ref(),
                    self.opts.page50x.as_ref(),
                    &self.opts.page404,
                    &self.opts.page50x,
                );
            }

            // CORS
            if self.opts.cors.is_some() {
                let cors = self.opts.cors.as_ref().unwrap();
            if let Some(cors) = &self.opts.cors {
                match cors.check_request(method, headers) {
                    Ok((headers, state)) => {
                        tracing::debug!("cors state: {:?}", state);
@@ -67,8 +66,8 @@ impl RequestHandler {
                        return error_page::error_response(
                            method,
                            &StatusCode::FORBIDDEN,
                            self.opts.page404.as_ref(),
                            self.opts.page50x.as_ref(),
                            &self.opts.page404,
                            &self.opts.page50x,
                        );
                    }
                };
@@ -82,8 +81,8 @@ impl RequestHandler {
                        let mut resp = error_page::error_response(
                            method,
                            &StatusCode::UNAUTHORIZED,
                            self.opts.page404.as_ref(),
                            self.opts.page50x.as_ref(),
                            &self.opts.page404,
                            &self.opts.page50x,
                        )?;
                        resp.headers_mut().insert(
                            WWW_AUTHENTICATE,
@@ -98,8 +97,8 @@ impl RequestHandler {
                    return error_page::error_response(
                        method,
                        &StatusCode::INTERNAL_SERVER_ERROR,
                        self.opts.page404.as_ref(),
                        self.opts.page50x.as_ref(),
                        &self.opts.page404,
                        &self.opts.page50x,
                    );
                }
            }
@@ -136,8 +135,8 @@ impl RequestHandler {
                                return error_page::error_response(
                                    method,
                                    &StatusCode::INTERNAL_SERVER_ERROR,
                                    self.opts.page404.as_ref(),
                                    self.opts.page50x.as_ref(),
                                    &self.opts.page404,
                                    &self.opts.page50x,
                                );
                            }
                        };
@@ -158,8 +157,8 @@ impl RequestHandler {
                Err(status) => error_page::error_response(
                    method,
                    &status,
                    self.opts.page404.as_ref(),
                    self.opts.page50x.as_ref(),
                    &self.opts.page404,
                    &self.opts.page50x,
                ),
            }
        }
diff --git a/src/server.rs b/src/server.rs
index 8a135a6..3d937c5 100644
--- a/src/server.rs
+++ b/src/server.rs
@@ -90,11 +90,10 @@ impl Server {
        // Check for a valid root directory
        let root_dir = helpers::get_valid_dirpath(&opts.root)
            .with_context(|| "root directory was not found or inaccessible".to_string())?;
        let root_dir = Arc::new(root_dir);

        // Custom error pages content
        let page404 = Arc::from(helpers::read_file_content(opts.page404.as_ref()).as_str());
        let page50x = Arc::from(helpers::read_file_content(opts.page50x.as_ref()).as_str());
        let page404 = helpers::read_file_content(&opts.page404);
        let page50x = helpers::read_file_content(&opts.page50x);

        // Number of worker threads option
        let threads = self.threads;
@@ -122,17 +121,16 @@ impl Server {

        // CORS option
        let cors = cors::new(
            opts.cors_allow_origins.trim().to_owned(),
            opts.cors_allow_headers.trim().to_owned(),
            opts.cors_allow_origins.trim(),
            opts.cors_allow_headers.trim(),
        );

        // `Basic` HTTP Authentication Schema option
        let basic_auth = opts.basic_auth.trim();
        let basic_auth = opts.basic_auth.trim().to_owned();
        tracing::info!(
            "basic authentication: enabled={}",
            !self.opts.basic_auth.is_empty()
        );
        let basic_auth = Arc::from(basic_auth);

        // Grace period option
        let grace_period = opts.grace_period;
@@ -140,7 +138,7 @@ impl Server {

        // Create a service router for Hyper
        let router_service = RouterService::new(RequestHandler {
            opts: RequestHandlerOpts {
            opts: Arc::from(RequestHandlerOpts {
                root_dir,
                compression,
                dir_listing,
@@ -151,7 +149,7 @@ impl Server {
                page404,
                page50x,
                basic_auth,
            },
            }),
        });

        // Run the corresponding HTTP Server asynchronously with its given options
@@ -159,9 +157,6 @@ impl Server {
        if opts.http2 {
            // HTTP/2 + TLS

            let cert_path = opts.http2_tls_cert.clone();
            let key_path = opts.http2_tls_key.clone();

            tcp_listener
                .set_nonblocking(true)
                .expect("cannot set non-blocking");
@@ -174,8 +169,8 @@ impl Server {
            incoming.set_nodelay(true);

            let tls = TlsConfigBuilder::new()
                .cert_path(cert_path)
                .key_path(key_path)
                .cert_path(&opts.http2_tls_cert)
                .key_path(&opts.http2_tls_key)
                .build()
                .with_context(|| {
                    "failed to initialize TLS, probably wrong cert/key or file missing".to_string()
diff --git a/tests/cors.rs b/tests/cors.rs
index 0f4db64..a094ade 100644
--- a/tests/cors.rs
+++ b/tests/cors.rs
@@ -11,14 +11,14 @@ mod tests {

    #[tokio::test]
    async fn allow_methods() {
        let cors = cors::new("*".to_owned(), "".to_owned()).unwrap();
        let cors = cors::new("*", "").unwrap();
        let headers = HeaderMap::new();
        let methods = &[Method::GET, Method::HEAD, Method::OPTIONS];
        for method in methods {
            assert!(cors.check_request(method, &headers).is_ok());
        }

        let cors = cors::new("https://localhost".to_owned(), "".to_owned()).unwrap();
        let cors = cors::new("https://localhost", "").unwrap();
        let mut headers = HeaderMap::new();
        headers.insert("origin", "https://localhost".parse().unwrap());
        headers.insert("access-control-request-method", "GET".parse().unwrap());
@@ -29,7 +29,7 @@ mod tests {

    #[test]
    fn disallow_methods() {
        let cors = cors::new("*".to_owned(), "".to_owned()).unwrap();
        let cors = cors::new("*", "").unwrap();
        let headers = HeaderMap::new();
        let methods = [
            Method::CONNECT,
@@ -50,7 +50,7 @@ mod tests {

    #[tokio::test]
    async fn origin_allowed() {
        let cors = cors::new("*".to_owned(), "".to_owned()).unwrap();
        let cors = cors::new("*", "").unwrap();
        let mut headers = HeaderMap::new();
        headers.insert("origin", "https://localhost".parse().unwrap());
        let methods = [Method::GET, Method::HEAD, Method::OPTIONS];
@@ -67,7 +67,7 @@ mod tests {

    #[tokio::test]
    async fn origin_not_allowed() {
        let cors = cors::new("https://localhost.rs".to_owned(), "".to_owned()).unwrap();
        let cors = cors::new("https://localhost.rs", "").unwrap();
        let mut headers = HeaderMap::new();
        headers.insert("origin", "https://localhost".parse().unwrap());
        let methods = [Method::GET, Method::HEAD, Method::OPTIONS];
@@ -80,7 +80,7 @@ mod tests {

    #[tokio::test]
    async fn method_allowed() {
        let cors = cors::new("*".to_owned(), "".to_owned()).unwrap();
        let cors = cors::new("*", "").unwrap();
        let mut headers = HeaderMap::new();
        headers.insert("origin", "https://localhost".parse().unwrap());
        headers.insert("access-control-request-method", "GET".parse().unwrap());
@@ -92,7 +92,7 @@ mod tests {

    #[tokio::test]
    async fn method_disallowed() {
        let cors = cors::new("*".to_owned(), "".to_owned()).unwrap();
        let cors = cors::new("*", "").unwrap();
        let mut headers = HeaderMap::new();
        headers.insert("origin", "https://localhost".parse().unwrap());
        headers.insert("access-control-request-method", "POST".parse().unwrap());
@@ -110,7 +110,7 @@ mod tests {

    #[tokio::test]
    async fn headers_allowed() {
        let cors = cors::new("*".to_owned(), "".to_owned()).unwrap();
        let cors = cors::new("*", "").unwrap();
        let mut headers = HeaderMap::new();
        headers.insert("origin", "https://localhost".parse().unwrap());
        headers.insert("access-control-request-method", "GET".parse().unwrap());
@@ -127,7 +127,7 @@ mod tests {

    #[tokio::test]
    async fn headers_invalid() {
        let cors = cors::new("*".to_owned(), "".to_owned()).unwrap();
        let cors = cors::new("*", "").unwrap();
        let mut headers = HeaderMap::new();
        headers.insert("origin", "https://localhost".parse().unwrap());
        headers.insert(