From 624670f1e8a96f4a217366dbfc86021ab71f7e88 Mon Sep 17 00:00:00 2001 From: Jose Quintana <1700322+joseluisq@users.noreply.github.com> Date: Wed, 2 Mar 2022 23:49:50 +0100 Subject: [PATCH] Merge pull request #87 from joseluisq/feature/cors_allow_headers_and_proper_handling feat: CORS `access-control-request-headers` support and `OPTIONS` requests --- src/compression.rs | 4 ++-- src/config.rs | 9 +++++++++ src/cors.rs | 109 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++------------------- src/handler.rs | 29 ++++++++++++++++++++++++++--- src/server.rs | 5 ++++- src/static_files.rs | 18 ++++++++++++++++-- tests/cors.rs | 121 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++------------- tests/dir_listing.rs | 1 - tests/static_files.rs | 1 - 9 files changed, 255 insertions(+), 42 deletions(-) diff --git a/src/compression.rs b/src/compression.rs index 54a9a58..9ddd677 100644 --- a/src/compression.rs +++ b/src/compression.rs @@ -54,8 +54,8 @@ pub fn auto( headers: &HeaderMap, resp: Response, ) -> Result> { - // Skip compression for HEAD request methods - if method == Method::HEAD { + // Skip compression for HEAD and OPTIONS request methods + if method == Method::HEAD || method == Method::OPTIONS { return Ok(resp); } diff --git a/src/config.rs b/src/config.rs index a6bd12c..5126d9b 100644 --- a/src/config.rs +++ b/src/config.rs @@ -73,6 +73,15 @@ pub struct Config { #[structopt( long, + short = "j", + default_value = "origin, content-type", + env = "SERVER_CORS_ALLOW_HEADERS" + )] + /// Specify an optional CORS list of allowed headers separated by comas. Default "origin, content-type". It requires `--cors-allow-origins` to be used along with. + pub cors_allow_headers: String, + + #[structopt( + long, short = "t", parse(try_from_str), default_value = "false", diff --git a/src/cors.rs b/src/cors.rs index 5bc48fa..6d53e53 100644 --- a/src/cors.rs +++ b/src/cors.rs @@ -1,7 +1,10 @@ // CORS handler for incoming requests. // -> Part of the file is borrowed from https://github.com/seanmonstar/warp/blob/master/src/filters/cors.rs -use headers::{HeaderName, HeaderValue, Origin}; +use headers::{ + AccessControlAllowHeaders, AccessControlAllowMethods, HeaderMapExt, HeaderName, HeaderValue, + Origin, +}; use http::header; use std::{collections::HashSet, convert::TryFrom, sync::Arc}; @@ -15,26 +18,47 @@ pub struct Cors { } /// It builds a new CORS instance. -pub fn new(origins_str: String) -> Option> { +pub fn new(origins_str: String, headers_str: String) -> Option> { let cors = Cors::new(); let cors = if origins_str.is_empty() { None - } else if origins_str == "*" { - Some(cors.allow_any_origin().allow_methods(vec!["GET", "HEAD"])) } else { - let hosts = origins_str.split(',').map(|s| s.trim()).collect::>(); - if hosts.is_empty() { - None + let headers_vec = if headers_str.is_empty() { + vec!["origin", "content-type"] + } else { + headers_str.split(',').map(|s| s.trim()).collect::>() + }; + let headers_str = headers_vec.join(","); + + let cors_res = if origins_str == "*" { + Some( + cors.allow_any_origin() + .allow_headers(headers_vec) + .allow_methods(vec!["GET", "HEAD", "OPTIONS"]), + ) } else { - Some(cors.allow_origins(hosts).allow_methods(vec!["GET", "HEAD"])) + let hosts = origins_str.split(',').map(|s| s.trim()).collect::>(); + if hosts.is_empty() { + None + } else { + Some( + cors.allow_origins(hosts) + .allow_headers(headers_vec) + .allow_methods(vec!["GET", "HEAD", "OPTIONS"]), + ) + } + }; + + if cors_res.is_some() { + tracing::info!( + "enabled=true, allow_methods=[GET,HEAD,OPTIONS], allow_origins={}, allow_headers=[{}]", + origins_str, + headers_str + ); } + cors_res }; - if cors.is_some() { - tracing::info!( - "enabled=true, allow_methods=[GET, HEAD], allow_origins={}", - origins_str - ); - } + Cors::build(cors) } @@ -109,11 +133,39 @@ impl Cors { self } + /// Adds multiple headers to the list of allowed request headers. + /// + /// **Note**: These should match the values the browser sends via `Access-Control-Request-Headers`, e.g.`content-type`. + /// + /// # Panics + /// + /// Panics if any of the headers are not a valid `http::header::HeaderName`. + pub fn allow_headers(mut self, headers: I) -> Self + where + I: IntoIterator, + HeaderName: TryFrom, + { + let iter = headers.into_iter().map(|h| match TryFrom::try_from(h) { + Ok(h) => h, + Err(_) => panic!("cors: illegal Header"), + }); + self.allowed_headers.extend(iter); + self + } + /// Builds the `Cors` wrapper from the configured settings. pub fn build(cors: Option) -> Option> { cors.as_ref()?; let cors = cors?; - Some(Arc::new(Configured { cors })) + + let allowed_headers = cors.allowed_headers.iter().cloned().collect(); + let methods_header = cors.allowed_methods.iter().cloned().collect(); + + Some(Arc::new(Configured { + cors, + allowed_headers, + methods_header, + })) } } @@ -126,6 +178,8 @@ impl Default for Cors { #[derive(Clone, Debug)] pub struct Configured { cors: Cors, + allowed_headers: AccessControlAllowHeaders, + methods_header: AccessControlAllowMethods, } #[derive(Debug)] @@ -153,7 +207,7 @@ impl Configured { &self, method: &http::Method, headers: &http::HeaderMap, - ) -> Result { + ) -> Result<(http::HeaderMap, Validated), Forbidden> { match (headers.get(header::ORIGIN), method) { (Some(origin), &http::Method::OPTIONS) => { // OPTIONS requests are preflight CORS requests... @@ -182,21 +236,29 @@ impl Configured { } } - Ok(Validated::Preflight(origin.clone())) + let mut headers = http::HeaderMap::new(); + self.append_preflight_headers(&mut headers); + headers.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.into()); + + Ok((headers, Validated::Preflight(origin.clone()))) } (Some(origin), _) => { // Any other method, simply check for a valid origin... tracing::trace!("cors origin header: {:?}", origin); if self.is_origin_allowed(origin) { - Ok(Validated::Simple(origin.clone())) + let mut headers = http::HeaderMap::new(); + self.append_preflight_headers(&mut headers); + headers.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.into()); + + Ok((headers, Validated::Simple(origin.clone()))) } else { Err(Forbidden::Origin) } } (None, _) => { // No `ORIGIN` header means this isn't CORS! - Ok(Validated::NotCors) + Ok((http::HeaderMap::new(), Validated::NotCors)) } } } @@ -220,6 +282,15 @@ impl Configured { true } } + + fn append_preflight_headers(&self, headers: &mut http::HeaderMap) { + headers.typed_insert(self.allowed_headers.clone()); + headers.typed_insert(self.methods_header.clone()); + + if let Some(max_age) = self.cors.max_age { + headers.insert(header::ACCESS_CONTROL_MAX_AGE, max_age.into()); + } + } } pub trait Seconds { diff --git a/src/handler.rs b/src/handler.rs index f1cb3d0..4df4158 100644 --- a/src/handler.rs +++ b/src/handler.rs @@ -1,4 +1,4 @@ -use hyper::{header::WWW_AUTHENTICATE, Body, Request, Response, StatusCode}; +use hyper::{header::WWW_AUTHENTICATE, Body, Method, Request, Response, StatusCode}; use std::{future::Future, path::PathBuf, sync::Arc}; use crate::{ @@ -41,13 +41,26 @@ impl RequestHandler { let dir_listing = self.opts.dir_listing; let dir_listing_order = self.opts.dir_listing_order; + let mut cors_headers: Option = None; + async move { + // Check for disallowed HTTP methods and reject request accordently + if !(method == Method::GET || method == Method::HEAD || method == Method::OPTIONS) { + return error_page::error_response( + method, + &StatusCode::METHOD_NOT_ALLOWED, + self.opts.page404.as_ref(), + self.opts.page50x.as_ref(), + ); + } + // CORS if self.opts.cors.is_some() { let cors = self.opts.cors.as_ref().unwrap(); match cors.check_request(method, headers) { - Ok(res) => { - tracing::debug!("cors ok: {:?}", res); + Ok((headers, state)) => { + tracing::debug!("cors state: {:?}", state); + cors_headers = Some(headers); } Err(err) => { tracing::error!("cors error kind: {:?}", err); @@ -104,6 +117,16 @@ impl RequestHandler { .await { Ok(mut resp) => { + // Append CORS headers if they are present + if let Some(cors_headers) = cors_headers { + if !cors_headers.is_empty() { + for (k, v) in cors_headers.iter() { + resp.headers_mut().insert(k, v.to_owned()); + } + resp.headers_mut().remove(http::header::ALLOW); + } + } + // Auto compression based on the `Accept-Encoding` header if self.opts.compression { resp = match compression::auto(method, headers, resp) { diff --git a/src/server.rs b/src/server.rs index c49a5fc..8a135a6 100644 --- a/src/server.rs +++ b/src/server.rs @@ -121,7 +121,10 @@ impl Server { tracing::info!("cache control headers: enabled={}", cache_control_headers); // CORS option - let cors = cors::new(opts.cors_allow_origins.trim().to_owned()); + let cors = cors::new( + opts.cors_allow_origins.trim().to_owned(), + opts.cors_allow_headers.trim().to_owned(), + ); // `Basic` HTTP Authentication Schema option let basic_auth = opts.basic_auth.trim(); diff --git a/src/static_files.rs b/src/static_files.rs index 2433187..ef4d3d9 100644 --- a/src/static_files.rs +++ b/src/static_files.rs @@ -50,8 +50,8 @@ pub async fn handle( dir_listing: bool, dir_listing_order: u8, ) -> Result, StatusCode> { - // Reject requests for non HEAD or GET methods - if !(method == Method::HEAD || method == Method::GET) { + // Check for disallowed HTTP methods and reject request accordently + if !(method == Method::GET || method == Method::HEAD || method == Method::OPTIONS) { return Err(StatusCode::METHOD_NOT_ALLOWED); } @@ -79,6 +79,20 @@ pub async fn handle( return Ok(resp); } + // Respond the permitted communication options + if method == Method::OPTIONS { + let mut resp = Response::new(Body::empty()); + *resp.status_mut() = StatusCode::NO_CONTENT; + resp.headers_mut() + .typed_insert(headers::Allow::from_iter(vec![ + Method::OPTIONS, + Method::HEAD, + Method::GET, + ])); + resp.headers_mut().typed_insert(AcceptRanges::bytes()); + return Ok(resp); + } + // Directory listing // 1. Check if "directory listing" feature is enabled // if current path is a valid directory and diff --git a/tests/cors.rs b/tests/cors.rs index 0917a03..0f4db64 100644 --- a/tests/cors.rs +++ b/tests/cors.rs @@ -11,29 +11,29 @@ mod tests { #[tokio::test] async fn allow_methods() { - let cors = cors::new("*".to_owned()).unwrap(); + let cors = cors::new("*".to_owned(), "".to_owned()).unwrap(); let headers = HeaderMap::new(); - let methods = &[Method::GET, Method::HEAD]; + let methods = &[Method::GET, Method::HEAD, Method::OPTIONS]; for method in methods { - assert!(cors.check_request(method, &headers).is_ok()) + assert!(cors.check_request(method, &headers).is_ok()); } - let cors = cors::new("https://localhost".to_owned()).unwrap(); + let cors = cors::new("https://localhost".to_owned(), "".to_owned()).unwrap(); let mut headers = HeaderMap::new(); headers.insert("origin", "https://localhost".parse().unwrap()); + headers.insert("access-control-request-method", "GET".parse().unwrap()); for method in methods { - assert!(cors.check_request(method, &headers).is_ok()) + assert!(cors.check_request(method, &headers).is_ok()); } } #[test] fn disallow_methods() { - let cors = cors::new("*".to_owned()).unwrap(); + let cors = cors::new("*".to_owned(), "".to_owned()).unwrap(); let headers = HeaderMap::new(); let methods = [ Method::CONNECT, Method::DELETE, - Method::OPTIONS, Method::PATCH, Method::POST, Method::PUT, @@ -42,31 +42,126 @@ mod tests { for method in methods { let res = cors.check_request(&method, &headers); assert!(res.is_ok()); - assert!(matches!(res.unwrap(), cors::Validated::NotCors)); + let res = res.unwrap(); + assert!(res.0.is_empty()); + assert!(matches!(res.1, cors::Validated::NotCors)); } } #[tokio::test] async fn origin_allowed() { - let cors = cors::new("*".to_owned()).unwrap(); + let cors = cors::new("*".to_owned(), "".to_owned()).unwrap(); let mut headers = HeaderMap::new(); headers.insert("origin", "https://localhost".parse().unwrap()); - let methods = [Method::GET, Method::HEAD]; + let methods = [Method::GET, Method::HEAD, Method::OPTIONS]; for method in methods { - assert!(cors.check_request(&method, &headers).is_ok()) + let res = cors.check_request(&method, &headers); + if method == Method::OPTIONS { + // Forbidden (403) - preflight request missing access-control-request-method header + assert!(res.is_err()) + } else { + assert!(res.is_ok()) + } } } #[tokio::test] async fn origin_not_allowed() { - let cors = cors::new("https://localhost.rs".to_owned()).unwrap(); + let cors = cors::new("https://localhost.rs".to_owned(), "".to_owned()).unwrap(); let mut headers = HeaderMap::new(); headers.insert("origin", "https://localhost".parse().unwrap()); - let methods = [Method::GET, Method::HEAD]; + let methods = [Method::GET, Method::HEAD, Method::OPTIONS]; for method in methods { let res = cors.check_request(&method, &headers); assert!(res.is_err()); assert!(matches!(res.unwrap_err(), cors::Forbidden::Origin)) } } + + #[tokio::test] + async fn method_allowed() { + let cors = cors::new("*".to_owned(), "".to_owned()).unwrap(); + let mut headers = HeaderMap::new(); + headers.insert("origin", "https://localhost".parse().unwrap()); + headers.insert("access-control-request-method", "GET".parse().unwrap()); + let methods = [Method::GET, Method::HEAD, Method::OPTIONS]; + for method in methods { + assert!(cors.check_request(&method, &headers).is_ok()) + } + } + + #[tokio::test] + async fn method_disallowed() { + let cors = cors::new("*".to_owned(), "".to_owned()).unwrap(); + let mut headers = HeaderMap::new(); + headers.insert("origin", "https://localhost".parse().unwrap()); + headers.insert("access-control-request-method", "POST".parse().unwrap()); + let methods = [Method::GET, Method::HEAD, Method::OPTIONS]; + for method in methods { + let res = cors.check_request(&method, &headers); + if method == Method::OPTIONS { + // Forbidden (403) - preflight request missing access-control-request-method header + assert!(res.is_err()) + } else { + assert!(res.is_ok()) + } + } + } + + #[tokio::test] + async fn headers_allowed() { + let cors = cors::new("*".to_owned(), "".to_owned()).unwrap(); + let mut headers = HeaderMap::new(); + headers.insert("origin", "https://localhost".parse().unwrap()); + headers.insert("access-control-request-method", "GET".parse().unwrap()); + headers.insert( + "access-control-request-headers", + "origin,content-type".parse().unwrap(), + ); + let methods = [Method::OPTIONS]; + for method in methods { + let res = cors.check_request(&method, &headers); + assert!(res.is_ok()) + } + } + + #[tokio::test] + async fn headers_invalid() { + let cors = cors::new("*".to_owned(), "".to_owned()).unwrap(); + let mut headers = HeaderMap::new(); + headers.insert("origin", "https://localhost".parse().unwrap()); + headers.insert( + "access-control-request-method", + "GET,HEAD,OPTIONS".parse().unwrap(), + ); + headers.insert( + "access-control-request-headers", + "origin, content-type".parse().unwrap(), + ); + let methods = [Method::GET, Method::HEAD, Method::OPTIONS]; + for method in &methods { + let res = cors.check_request(method, &headers); + if method == Method::OPTIONS { + assert!(res.is_err()) + } else { + assert!(res.is_ok()) + } + } + + let mut headers = HeaderMap::new(); + headers.insert("origin", "https://localhost".parse().unwrap()); + headers.insert("access-control-request-method", "GET".parse().unwrap()); + headers.insert( + "access-control-request-headers", + "origin,authorization".parse().unwrap(), + ); + for method in methods { + let res = cors.check_request(&method, &headers); + if method == Method::OPTIONS { + assert!(res.is_err()) + } else { + assert!(res.is_ok()) + } + } + } } diff --git a/tests/dir_listing.rs b/tests/dir_listing.rs index 8000918..23cdd53 100644 --- a/tests/dir_listing.rs +++ b/tests/dir_listing.rs @@ -22,7 +22,6 @@ mod tests { Method::DELETE, Method::GET, Method::HEAD, - Method::OPTIONS, Method::PATCH, Method::POST, Method::PUT, diff --git a/tests/static_files.rs b/tests/static_files.rs index 5b1485a..44e1ad9 100644 --- a/tests/static_files.rs +++ b/tests/static_files.rs @@ -388,7 +388,6 @@ mod tests { Method::DELETE, Method::GET, Method::HEAD, - Method::OPTIONS, Method::PATCH, Method::POST, Method::PUT, -- libgit2 1.7.2