From d33d093a781174beb580bd97fdeadc0dafdde244 Mon Sep 17 00:00:00 2001 From: Jose Quintana Date: Tue, 8 Mar 2022 00:01:16 +0100 Subject: [PATCH] refactor: small optimizations on request handler --- src/cors.rs | 10 +++++----- src/handler.rs | 41 ++++++++++++++++++++--------------------- src/server.rs | 23 +++++++++-------------- tests/cors.rs | 18 +++++++++--------- 4 files changed, 43 insertions(+), 49 deletions(-) diff --git a/src/cors.rs b/src/cors.rs index 6d53e53..24286ad 100644 --- a/src/cors.rs +++ b/src/cors.rs @@ -6,7 +6,7 @@ use headers::{ Origin, }; use http::header; -use std::{collections::HashSet, convert::TryFrom, sync::Arc}; +use std::{collections::HashSet, convert::TryFrom}; /// It defines CORS instance. #[derive(Clone, Debug)] @@ -18,7 +18,7 @@ pub struct Cors { } /// It builds a new CORS instance. -pub fn new(origins_str: String, headers_str: String) -> Option> { +pub fn new(origins_str: &str, headers_str: &str) -> Option { let cors = Cors::new(); let cors = if origins_str.is_empty() { None @@ -154,18 +154,18 @@ impl Cors { } /// Builds the `Cors` wrapper from the configured settings. - pub fn build(cors: Option) -> Option> { + pub fn build(cors: Option) -> Option { cors.as_ref()?; let cors = cors?; let allowed_headers = cors.allowed_headers.iter().cloned().collect(); let methods_header = cors.allowed_methods.iter().cloned().collect(); - Some(Arc::new(Configured { + Some(Configured { cors, allowed_headers, methods_header, - })) + }) } } diff --git a/src/handler.rs b/src/handler.rs index 4df4158..bbae8f5 100644 --- a/src/handler.rs +++ b/src/handler.rs @@ -8,21 +8,21 @@ use crate::{Error, Result}; /// It defines options for a request handler. pub struct RequestHandlerOpts { - pub root_dir: Arc, + pub root_dir: PathBuf, pub compression: bool, pub dir_listing: bool, pub dir_listing_order: u8, - pub cors: Option>, + pub cors: Option, pub security_headers: bool, pub cache_control_headers: bool, - pub page404: Arc, - pub page50x: Arc, - pub basic_auth: Arc, + pub page404: String, + pub page50x: String, + pub basic_auth: String, } /// It defines the main request handler used by the Hyper service request. pub struct RequestHandler { - pub opts: RequestHandlerOpts, + pub opts: Arc, } impl RequestHandler { @@ -35,7 +35,7 @@ impl RequestHandler { let headers = req.headers(); let uri = req.uri(); - let root_dir = self.opts.root_dir.as_ref(); + let root_dir = &self.opts.root_dir; let uri_path = uri.path(); let uri_query = uri.query(); let dir_listing = self.opts.dir_listing; @@ -49,14 +49,13 @@ impl RequestHandler { return error_page::error_response( method, &StatusCode::METHOD_NOT_ALLOWED, - self.opts.page404.as_ref(), - self.opts.page50x.as_ref(), + &self.opts.page404, + &self.opts.page50x, ); } // CORS - if self.opts.cors.is_some() { - let cors = self.opts.cors.as_ref().unwrap(); + if let Some(cors) = &self.opts.cors { match cors.check_request(method, headers) { Ok((headers, state)) => { tracing::debug!("cors state: {:?}", state); @@ -67,8 +66,8 @@ impl RequestHandler { return error_page::error_response( method, &StatusCode::FORBIDDEN, - self.opts.page404.as_ref(), - self.opts.page50x.as_ref(), + &self.opts.page404, + &self.opts.page50x, ); } }; @@ -82,8 +81,8 @@ impl RequestHandler { let mut resp = error_page::error_response( method, &StatusCode::UNAUTHORIZED, - self.opts.page404.as_ref(), - self.opts.page50x.as_ref(), + &self.opts.page404, + &self.opts.page50x, )?; resp.headers_mut().insert( WWW_AUTHENTICATE, @@ -98,8 +97,8 @@ impl RequestHandler { return error_page::error_response( method, &StatusCode::INTERNAL_SERVER_ERROR, - self.opts.page404.as_ref(), - self.opts.page50x.as_ref(), + &self.opts.page404, + &self.opts.page50x, ); } } @@ -136,8 +135,8 @@ impl RequestHandler { return error_page::error_response( method, &StatusCode::INTERNAL_SERVER_ERROR, - self.opts.page404.as_ref(), - self.opts.page50x.as_ref(), + &self.opts.page404, + &self.opts.page50x, ); } }; @@ -158,8 +157,8 @@ impl RequestHandler { Err(status) => error_page::error_response( method, &status, - self.opts.page404.as_ref(), - self.opts.page50x.as_ref(), + &self.opts.page404, + &self.opts.page50x, ), } } diff --git a/src/server.rs b/src/server.rs index 8a135a6..3d937c5 100644 --- a/src/server.rs +++ b/src/server.rs @@ -90,11 +90,10 @@ impl Server { // Check for a valid root directory let root_dir = helpers::get_valid_dirpath(&opts.root) .with_context(|| "root directory was not found or inaccessible".to_string())?; - let root_dir = Arc::new(root_dir); // Custom error pages content - let page404 = Arc::from(helpers::read_file_content(opts.page404.as_ref()).as_str()); - let page50x = Arc::from(helpers::read_file_content(opts.page50x.as_ref()).as_str()); + let page404 = helpers::read_file_content(&opts.page404); + let page50x = helpers::read_file_content(&opts.page50x); // Number of worker threads option let threads = self.threads; @@ -122,17 +121,16 @@ impl Server { // CORS option let cors = cors::new( - opts.cors_allow_origins.trim().to_owned(), - opts.cors_allow_headers.trim().to_owned(), + opts.cors_allow_origins.trim(), + opts.cors_allow_headers.trim(), ); // `Basic` HTTP Authentication Schema option - let basic_auth = opts.basic_auth.trim(); + let basic_auth = opts.basic_auth.trim().to_owned(); tracing::info!( "basic authentication: enabled={}", !self.opts.basic_auth.is_empty() ); - let basic_auth = Arc::from(basic_auth); // Grace period option let grace_period = opts.grace_period; @@ -140,7 +138,7 @@ impl Server { // Create a service router for Hyper let router_service = RouterService::new(RequestHandler { - opts: RequestHandlerOpts { + opts: Arc::from(RequestHandlerOpts { root_dir, compression, dir_listing, @@ -151,7 +149,7 @@ impl Server { page404, page50x, basic_auth, - }, + }), }); // Run the corresponding HTTP Server asynchronously with its given options @@ -159,9 +157,6 @@ impl Server { if opts.http2 { // HTTP/2 + TLS - let cert_path = opts.http2_tls_cert.clone(); - let key_path = opts.http2_tls_key.clone(); - tcp_listener .set_nonblocking(true) .expect("cannot set non-blocking"); @@ -174,8 +169,8 @@ impl Server { incoming.set_nodelay(true); let tls = TlsConfigBuilder::new() - .cert_path(cert_path) - .key_path(key_path) + .cert_path(&opts.http2_tls_cert) + .key_path(&opts.http2_tls_key) .build() .with_context(|| { "failed to initialize TLS, probably wrong cert/key or file missing".to_string() diff --git a/tests/cors.rs b/tests/cors.rs index 0f4db64..a094ade 100644 --- a/tests/cors.rs +++ b/tests/cors.rs @@ -11,14 +11,14 @@ mod tests { #[tokio::test] async fn allow_methods() { - let cors = cors::new("*".to_owned(), "".to_owned()).unwrap(); + let cors = cors::new("*", "").unwrap(); let headers = HeaderMap::new(); let methods = &[Method::GET, Method::HEAD, Method::OPTIONS]; for method in methods { assert!(cors.check_request(method, &headers).is_ok()); } - let cors = cors::new("https://localhost".to_owned(), "".to_owned()).unwrap(); + let cors = cors::new("https://localhost", "").unwrap(); let mut headers = HeaderMap::new(); headers.insert("origin", "https://localhost".parse().unwrap()); headers.insert("access-control-request-method", "GET".parse().unwrap()); @@ -29,7 +29,7 @@ mod tests { #[test] fn disallow_methods() { - let cors = cors::new("*".to_owned(), "".to_owned()).unwrap(); + let cors = cors::new("*", "").unwrap(); let headers = HeaderMap::new(); let methods = [ Method::CONNECT, @@ -50,7 +50,7 @@ mod tests { #[tokio::test] async fn origin_allowed() { - let cors = cors::new("*".to_owned(), "".to_owned()).unwrap(); + let cors = cors::new("*", "").unwrap(); let mut headers = HeaderMap::new(); headers.insert("origin", "https://localhost".parse().unwrap()); let methods = [Method::GET, Method::HEAD, Method::OPTIONS]; @@ -67,7 +67,7 @@ mod tests { #[tokio::test] async fn origin_not_allowed() { - let cors = cors::new("https://localhost.rs".to_owned(), "".to_owned()).unwrap(); + let cors = cors::new("https://localhost.rs", "").unwrap(); let mut headers = HeaderMap::new(); headers.insert("origin", "https://localhost".parse().unwrap()); let methods = [Method::GET, Method::HEAD, Method::OPTIONS]; @@ -80,7 +80,7 @@ mod tests { #[tokio::test] async fn method_allowed() { - let cors = cors::new("*".to_owned(), "".to_owned()).unwrap(); + let cors = cors::new("*", "").unwrap(); let mut headers = HeaderMap::new(); headers.insert("origin", "https://localhost".parse().unwrap()); headers.insert("access-control-request-method", "GET".parse().unwrap()); @@ -92,7 +92,7 @@ mod tests { #[tokio::test] async fn method_disallowed() { - let cors = cors::new("*".to_owned(), "".to_owned()).unwrap(); + let cors = cors::new("*", "").unwrap(); let mut headers = HeaderMap::new(); headers.insert("origin", "https://localhost".parse().unwrap()); headers.insert("access-control-request-method", "POST".parse().unwrap()); @@ -110,7 +110,7 @@ mod tests { #[tokio::test] async fn headers_allowed() { - let cors = cors::new("*".to_owned(), "".to_owned()).unwrap(); + let cors = cors::new("*", "").unwrap(); let mut headers = HeaderMap::new(); headers.insert("origin", "https://localhost".parse().unwrap()); headers.insert("access-control-request-method", "GET".parse().unwrap()); @@ -127,7 +127,7 @@ mod tests { #[tokio::test] async fn headers_invalid() { - let cors = cors::new("*".to_owned(), "".to_owned()).unwrap(); + let cors = cors::new("*", "").unwrap(); let mut headers = HeaderMap::new(); headers.insert("origin", "https://localhost".parse().unwrap()); headers.insert( -- libgit2 1.7.2