refactor: small optimizations on request handler
Diff
src/cors.rs | 10 +++++-----
src/handler.rs | 41 ++++++++++++++++++++---------------------
src/server.rs | 23 +++++++++--------------
tests/cors.rs | 18 +++++++++---------
4 files changed, 43 insertions(+), 49 deletions(-)
@@ -6,7 +6,7 @@ use headers::{
Origin,
};
use http::header;
use std::{collections::HashSet, convert::TryFrom, sync::Arc};
use std::{collections::HashSet, convert::TryFrom};
#[derive(Clone, Debug)]
@@ -18,7 +18,7 @@ pub struct Cors {
}
pub fn new(origins_str: String, headers_str: String) -> Option<Arc<Configured>> {
pub fn new(origins_str: &str, headers_str: &str) -> Option<Configured> {
let cors = Cors::new();
let cors = if origins_str.is_empty() {
None
@@ -154,18 +154,18 @@ impl Cors {
}
pub fn build(cors: Option<Cors>) -> Option<Arc<Configured>> {
pub fn build(cors: Option<Cors>) -> Option<Configured> {
cors.as_ref()?;
let cors = cors?;
let allowed_headers = cors.allowed_headers.iter().cloned().collect();
let methods_header = cors.allowed_methods.iter().cloned().collect();
Some(Arc::new(Configured {
Some(Configured {
cors,
allowed_headers,
methods_header,
}))
})
}
}
@@ -8,21 +8,21 @@ use crate::{Error, Result};
pub struct RequestHandlerOpts {
pub root_dir: Arc<PathBuf>,
pub root_dir: PathBuf,
pub compression: bool,
pub dir_listing: bool,
pub dir_listing_order: u8,
pub cors: Option<Arc<cors::Configured>>,
pub cors: Option<cors::Configured>,
pub security_headers: bool,
pub cache_control_headers: bool,
pub page404: Arc<str>,
pub page50x: Arc<str>,
pub basic_auth: Arc<str>,
pub page404: String,
pub page50x: String,
pub basic_auth: String,
}
pub struct RequestHandler {
pub opts: RequestHandlerOpts,
pub opts: Arc<RequestHandlerOpts>,
}
impl RequestHandler {
@@ -35,7 +35,7 @@ impl RequestHandler {
let headers = req.headers();
let uri = req.uri();
let root_dir = self.opts.root_dir.as_ref();
let root_dir = &self.opts.root_dir;
let uri_path = uri.path();
let uri_query = uri.query();
let dir_listing = self.opts.dir_listing;
@@ -49,14 +49,13 @@ impl RequestHandler {
return error_page::error_response(
method,
&StatusCode::METHOD_NOT_ALLOWED,
self.opts.page404.as_ref(),
self.opts.page50x.as_ref(),
&self.opts.page404,
&self.opts.page50x,
);
}
if self.opts.cors.is_some() {
let cors = self.opts.cors.as_ref().unwrap();
if let Some(cors) = &self.opts.cors {
match cors.check_request(method, headers) {
Ok((headers, state)) => {
tracing::debug!("cors state: {:?}", state);
@@ -67,8 +66,8 @@ impl RequestHandler {
return error_page::error_response(
method,
&StatusCode::FORBIDDEN,
self.opts.page404.as_ref(),
self.opts.page50x.as_ref(),
&self.opts.page404,
&self.opts.page50x,
);
}
};
@@ -82,8 +81,8 @@ impl RequestHandler {
let mut resp = error_page::error_response(
method,
&StatusCode::UNAUTHORIZED,
self.opts.page404.as_ref(),
self.opts.page50x.as_ref(),
&self.opts.page404,
&self.opts.page50x,
)?;
resp.headers_mut().insert(
WWW_AUTHENTICATE,
@@ -98,8 +97,8 @@ impl RequestHandler {
return error_page::error_response(
method,
&StatusCode::INTERNAL_SERVER_ERROR,
self.opts.page404.as_ref(),
self.opts.page50x.as_ref(),
&self.opts.page404,
&self.opts.page50x,
);
}
}
@@ -136,8 +135,8 @@ impl RequestHandler {
return error_page::error_response(
method,
&StatusCode::INTERNAL_SERVER_ERROR,
self.opts.page404.as_ref(),
self.opts.page50x.as_ref(),
&self.opts.page404,
&self.opts.page50x,
);
}
};
@@ -158,8 +157,8 @@ impl RequestHandler {
Err(status) => error_page::error_response(
method,
&status,
self.opts.page404.as_ref(),
self.opts.page50x.as_ref(),
&self.opts.page404,
&self.opts.page50x,
),
}
}
@@ -90,11 +90,10 @@ impl Server {
let root_dir = helpers::get_valid_dirpath(&opts.root)
.with_context(|| "root directory was not found or inaccessible".to_string())?;
let root_dir = Arc::new(root_dir);
let page404 = Arc::from(helpers::read_file_content(opts.page404.as_ref()).as_str());
let page50x = Arc::from(helpers::read_file_content(opts.page50x.as_ref()).as_str());
let page404 = helpers::read_file_content(&opts.page404);
let page50x = helpers::read_file_content(&opts.page50x);
let threads = self.threads;
@@ -122,17 +121,16 @@ impl Server {
let cors = cors::new(
opts.cors_allow_origins.trim().to_owned(),
opts.cors_allow_headers.trim().to_owned(),
opts.cors_allow_origins.trim(),
opts.cors_allow_headers.trim(),
);
let basic_auth = opts.basic_auth.trim();
let basic_auth = opts.basic_auth.trim().to_owned();
tracing::info!(
"basic authentication: enabled={}",
!self.opts.basic_auth.is_empty()
);
let basic_auth = Arc::from(basic_auth);
let grace_period = opts.grace_period;
@@ -140,7 +138,7 @@ impl Server {
let router_service = RouterService::new(RequestHandler {
opts: RequestHandlerOpts {
opts: Arc::from(RequestHandlerOpts {
root_dir,
compression,
dir_listing,
@@ -151,7 +149,7 @@ impl Server {
page404,
page50x,
basic_auth,
},
}),
});
@@ -159,9 +157,6 @@ impl Server {
if opts.http2 {
let cert_path = opts.http2_tls_cert.clone();
let key_path = opts.http2_tls_key.clone();
tcp_listener
.set_nonblocking(true)
.expect("cannot set non-blocking");
@@ -174,8 +169,8 @@ impl Server {
incoming.set_nodelay(true);
let tls = TlsConfigBuilder::new()
.cert_path(cert_path)
.key_path(key_path)
.cert_path(&opts.http2_tls_cert)
.key_path(&opts.http2_tls_key)
.build()
.with_context(|| {
"failed to initialize TLS, probably wrong cert/key or file missing".to_string()
@@ -11,14 +11,14 @@ mod tests {
#[tokio::test]
async fn allow_methods() {
let cors = cors::new("*".to_owned(), "".to_owned()).unwrap();
let cors = cors::new("*", "").unwrap();
let headers = HeaderMap::new();
let methods = &[Method::GET, Method::HEAD, Method::OPTIONS];
for method in methods {
assert!(cors.check_request(method, &headers).is_ok());
}
let cors = cors::new("https://localhost".to_owned(), "".to_owned()).unwrap();
let cors = cors::new("https://localhost", "").unwrap();
let mut headers = HeaderMap::new();
headers.insert("origin", "https://localhost".parse().unwrap());
headers.insert("access-control-request-method", "GET".parse().unwrap());
@@ -29,7 +29,7 @@ mod tests {
#[test]
fn disallow_methods() {
let cors = cors::new("*".to_owned(), "".to_owned()).unwrap();
let cors = cors::new("*", "").unwrap();
let headers = HeaderMap::new();
let methods = [
Method::CONNECT,
@@ -50,7 +50,7 @@ mod tests {
#[tokio::test]
async fn origin_allowed() {
let cors = cors::new("*".to_owned(), "".to_owned()).unwrap();
let cors = cors::new("*", "").unwrap();
let mut headers = HeaderMap::new();
headers.insert("origin", "https://localhost".parse().unwrap());
let methods = [Method::GET, Method::HEAD, Method::OPTIONS];
@@ -67,7 +67,7 @@ mod tests {
#[tokio::test]
async fn origin_not_allowed() {
let cors = cors::new("https://localhost.rs".to_owned(), "".to_owned()).unwrap();
let cors = cors::new("https://localhost.rs", "").unwrap();
let mut headers = HeaderMap::new();
headers.insert("origin", "https://localhost".parse().unwrap());
let methods = [Method::GET, Method::HEAD, Method::OPTIONS];
@@ -80,7 +80,7 @@ mod tests {
#[tokio::test]
async fn method_allowed() {
let cors = cors::new("*".to_owned(), "".to_owned()).unwrap();
let cors = cors::new("*", "").unwrap();
let mut headers = HeaderMap::new();
headers.insert("origin", "https://localhost".parse().unwrap());
headers.insert("access-control-request-method", "GET".parse().unwrap());
@@ -92,7 +92,7 @@ mod tests {
#[tokio::test]
async fn method_disallowed() {
let cors = cors::new("*".to_owned(), "".to_owned()).unwrap();
let cors = cors::new("*", "").unwrap();
let mut headers = HeaderMap::new();
headers.insert("origin", "https://localhost".parse().unwrap());
headers.insert("access-control-request-method", "POST".parse().unwrap());
@@ -110,7 +110,7 @@ mod tests {
#[tokio::test]
async fn headers_allowed() {
let cors = cors::new("*".to_owned(), "".to_owned()).unwrap();
let cors = cors::new("*", "").unwrap();
let mut headers = HeaderMap::new();
headers.insert("origin", "https://localhost".parse().unwrap());
headers.insert("access-control-request-method", "GET".parse().unwrap());
@@ -127,7 +127,7 @@ mod tests {
#[tokio::test]
async fn headers_invalid() {
let cors = cors::new("*".to_owned(), "".to_owned()).unwrap();
let cors = cors::new("*", "").unwrap();
let mut headers = HeaderMap::new();
headers.insert("origin", "https://localhost".parse().unwrap());
headers.insert(