feat: cors allow headers and proper cors response handling
- new string-list `--cors-allow-headers` option for headers control
- `OPTIONS` requests support:
- to identify server allowed request methods
- for preflighted requests when using CORS
fixes:
- #86
Diff
src/compression.rs | 4 +-
src/config.rs | 9 ++++-
src/cors.rs | 109 +++++++++++++++++++++++++++++++++++++---------
src/handler.rs | 38 ++++++++++++++--
src/server.rs | 5 +-
src/static_files.rs | 16 +++++--
tests/cors.rs | 121 +++++++++++++++++++++++++++++++++++++++++++++------
tests/dir_listing.rs | 1 +-
tests/static_files.rs | 1 +-
9 files changed, 261 insertions(+), 43 deletions(-)
@@ -54,8 +54,8 @@ pub fn auto(
headers: &HeaderMap<HeaderValue>,
resp: Response<Body>,
) -> Result<Response<Body>> {
if method == Method::HEAD {
if method == Method::HEAD || method == Method::OPTIONS {
return Ok(resp);
}
@@ -73,6 +73,15 @@ pub struct Config {
#[structopt(
long,
short = "j",
default_value = "origin, content-type",
env = "SERVER_CORS_ALLOW_HEADERS"
)]
pub cors_allow_headers: String,
#[structopt(
long,
short = "t",
parse(try_from_str),
default_value = "false",
@@ -1,7 +1,10 @@
use headers::{HeaderName, HeaderValue, Origin};
use headers::{
AccessControlAllowHeaders, AccessControlAllowMethods, HeaderMapExt, HeaderName, HeaderValue,
Origin,
};
use http::header;
use std::{collections::HashSet, convert::TryFrom, sync::Arc};
@@ -15,26 +18,47 @@ pub struct Cors {
}
pub fn new(origins_str: String) -> Option<Arc<Configured>> {
pub fn new(origins_str: String, headers_str: String) -> Option<Arc<Configured>> {
let cors = Cors::new();
let cors = if origins_str.is_empty() {
None
} else if origins_str == "*" {
Some(cors.allow_any_origin().allow_methods(vec!["GET", "HEAD"]))
} else {
let hosts = origins_str.split(',').map(|s| s.trim()).collect::<Vec<_>>();
if hosts.is_empty() {
None
let headers_vec = if headers_str.is_empty() {
vec!["origin", "content-type"]
} else {
headers_str.split(',').map(|s| s.trim()).collect::<Vec<_>>()
};
let headers_str = headers_vec.join(",");
let cors_res = if origins_str == "*" {
Some(
cors.allow_any_origin()
.allow_headers(headers_vec)
.allow_methods(vec!["GET", "HEAD", "OPTIONS"]),
)
} else {
Some(cors.allow_origins(hosts).allow_methods(vec!["GET", "HEAD"]))
let hosts = origins_str.split(',').map(|s| s.trim()).collect::<Vec<_>>();
if hosts.is_empty() {
None
} else {
Some(
cors.allow_origins(hosts)
.allow_headers(headers_vec)
.allow_methods(vec!["GET", "HEAD", "OPTIONS"]),
)
}
};
if cors_res.is_some() {
tracing::info!(
"enabled=true, allow_headers=[{}], allow_methods=[GET,HEAD,OPTIONS], allow_origins={}",
headers_str,
origins_str
);
}
cors_res
};
if cors.is_some() {
tracing::info!(
"enabled=true, allow_methods=[GET, HEAD], allow_origins={}",
origins_str
);
}
Cors::build(cors)
}
@@ -109,11 +133,39 @@ impl Cors {
self
}
pub fn allow_headers<I>(mut self, headers: I) -> Self
where
I: IntoIterator,
HeaderName: TryFrom<I::Item>,
{
let iter = headers.into_iter().map(|h| match TryFrom::try_from(h) {
Ok(h) => h,
Err(_) => panic!("cors: illegal Header"),
});
self.allowed_headers.extend(iter);
self
}
pub fn build(cors: Option<Cors>) -> Option<Arc<Configured>> {
cors.as_ref()?;
let cors = cors?;
Some(Arc::new(Configured { cors }))
let allowed_headers = cors.allowed_headers.iter().cloned().collect();
let methods_header = cors.allowed_methods.iter().cloned().collect();
Some(Arc::new(Configured {
cors,
allowed_headers,
methods_header,
}))
}
}
@@ -126,6 +178,8 @@ impl Default for Cors {
#[derive(Clone, Debug)]
pub struct Configured {
cors: Cors,
allowed_headers: AccessControlAllowHeaders,
methods_header: AccessControlAllowMethods,
}
#[derive(Debug)]
@@ -153,7 +207,7 @@ impl Configured {
&self,
method: &http::Method,
headers: &http::HeaderMap,
) -> Result<Validated, Forbidden> {
) -> Result<(http::HeaderMap, Validated), Forbidden> {
match (headers.get(header::ORIGIN), method) {
(Some(origin), &http::Method::OPTIONS) => {
@@ -182,21 +236,29 @@ impl Configured {
}
}
Ok(Validated::Preflight(origin.clone()))
let mut headers = http::HeaderMap::new();
self.append_preflight_headers(&mut headers);
headers.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.into());
Ok((headers, Validated::Preflight(origin.clone())))
}
(Some(origin), _) => {
tracing::trace!("cors origin header: {:?}", origin);
if self.is_origin_allowed(origin) {
Ok(Validated::Simple(origin.clone()))
let mut headers = http::HeaderMap::new();
self.append_preflight_headers(&mut headers);
headers.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.into());
Ok((headers, Validated::Simple(origin.clone())))
} else {
Err(Forbidden::Origin)
}
}
(None, _) => {
Ok(Validated::NotCors)
Ok((http::HeaderMap::new(), Validated::NotCors))
}
}
}
@@ -220,6 +282,15 @@ impl Configured {
true
}
}
fn append_preflight_headers(&self, headers: &mut http::HeaderMap) {
headers.typed_insert(self.allowed_headers.clone());
headers.typed_insert(self.methods_header.clone());
if let Some(max_age) = self.cors.max_age {
headers.insert(header::ACCESS_CONTROL_MAX_AGE, max_age.into());
}
}
}
pub trait Seconds {
@@ -1,4 +1,6 @@
use hyper::{header::WWW_AUTHENTICATE, Body, Request, Response, StatusCode};
use headers::{AcceptRanges, HeaderMapExt, HeaderValue};
use http::header::ALLOW;
use hyper::{header::WWW_AUTHENTICATE, Body, Method, Request, Response, StatusCode};
use std::{future::Future, path::PathBuf, sync::Arc};
use crate::{
@@ -41,13 +43,26 @@ impl RequestHandler {
let dir_listing = self.opts.dir_listing;
let dir_listing_order = self.opts.dir_listing_order;
let mut cors_headers: Option<http::HeaderMap> = None;
async move {
if !(method == Method::GET || method == Method::HEAD || method == Method::OPTIONS) {
return error_page::error_response(
method,
&StatusCode::METHOD_NOT_ALLOWED,
self.opts.page404.as_ref(),
self.opts.page50x.as_ref(),
);
}
if self.opts.cors.is_some() {
let cors = self.opts.cors.as_ref().unwrap();
match cors.check_request(method, headers) {
Ok(res) => {
tracing::debug!("cors ok: {:?}", res);
Ok((headers, state)) => {
tracing::debug!("cors state: {:?}", state);
cors_headers = Some(headers);
}
Err(err) => {
tracing::error!("cors error kind: {:?}", err);
@@ -104,6 +119,13 @@ impl RequestHandler {
.await
{
Ok(mut resp) => {
if let Some(cors_headers) = cors_headers {
for (k, v) in cors_headers.iter() {
resp.headers_mut().insert(k, v.to_owned());
}
}
if self.opts.compression {
resp = match compression::auto(method, headers, resp) {
@@ -130,6 +152,16 @@ impl RequestHandler {
security_headers::append_headers(&mut resp);
}
if method == Method::OPTIONS {
*resp.status_mut() = StatusCode::NO_CONTENT;
*resp.body_mut() = Body::empty();
resp.headers_mut()
.insert(ALLOW, HeaderValue::from_static("OPTIONS, GET, HEAD"));
resp.headers_mut().typed_insert(AcceptRanges::bytes());
return Ok(resp);
}
Ok(resp)
}
Err(status) => error_page::error_response(
@@ -121,7 +121,10 @@ impl Server {
tracing::info!("cache control headers: enabled={}", cache_control_headers);
let cors = cors::new(opts.cors_allow_origins.trim().to_owned());
let cors = cors::new(
opts.cors_allow_origins.trim().to_owned(),
opts.cors_allow_headers.trim().to_owned(),
);
let basic_auth = opts.basic_auth.trim();
@@ -8,7 +8,7 @@ use headers::{
AcceptRanges, ContentLength, ContentRange, ContentType, HeaderMap, HeaderMapExt, HeaderValue,
IfModifiedSince, IfRange, IfUnmodifiedSince, LastModified, Range,
};
use http::header::CONTENT_TYPE;
use http::header::{ALLOW, CONTENT_TYPE};
use humansize::{file_size_opts, FileSize};
use hyper::{Body, Method, Response, StatusCode};
use percent_encoding::percent_decode_str;
@@ -50,11 +50,21 @@ pub async fn handle(
dir_listing: bool,
dir_listing_order: u8,
) -> Result<Response<Body>, StatusCode> {
if !(method == Method::HEAD || method == Method::GET) {
if !(method == Method::GET || method == Method::HEAD || method == Method::OPTIONS) {
return Err(StatusCode::METHOD_NOT_ALLOWED);
}
if method == Method::OPTIONS {
let mut resp = Response::new(Body::empty());
*resp.status_mut() = StatusCode::NO_CONTENT;
resp.headers_mut()
.insert(ALLOW, HeaderValue::from_static("OPTIONS, GET, HEAD"));
resp.headers_mut().typed_insert(AcceptRanges::bytes());
return Ok(resp);
}
let base = Arc::new(base_path.into());
let (filepath, meta, auto_index) = path_from_tail(base, uri_path).await?;
@@ -11,29 +11,29 @@ mod tests {
#[tokio::test]
async fn allow_methods() {
let cors = cors::new("*".to_owned()).unwrap();
let cors = cors::new("*".to_owned(), "".to_owned()).unwrap();
let headers = HeaderMap::new();
let methods = &[Method::GET, Method::HEAD];
let methods = &[Method::GET, Method::HEAD, Method::OPTIONS];
for method in methods {
assert!(cors.check_request(method, &headers).is_ok())
assert!(cors.check_request(method, &headers).is_ok());
}
let cors = cors::new("https://localhost".to_owned()).unwrap();
let cors = cors::new("https://localhost".to_owned(), "".to_owned()).unwrap();
let mut headers = HeaderMap::new();
headers.insert("origin", "https://localhost".parse().unwrap());
headers.insert("access-control-request-method", "GET".parse().unwrap());
for method in methods {
assert!(cors.check_request(method, &headers).is_ok())
assert!(cors.check_request(method, &headers).is_ok());
}
}
#[test]
fn disallow_methods() {
let cors = cors::new("*".to_owned()).unwrap();
let cors = cors::new("*".to_owned(), "".to_owned()).unwrap();
let headers = HeaderMap::new();
let methods = [
Method::CONNECT,
Method::DELETE,
Method::OPTIONS,
Method::PATCH,
Method::POST,
Method::PUT,
@@ -42,31 +42,126 @@ mod tests {
for method in methods {
let res = cors.check_request(&method, &headers);
assert!(res.is_ok());
assert!(matches!(res.unwrap(), cors::Validated::NotCors));
let res = res.unwrap();
assert!(res.0.is_empty());
assert!(matches!(res.1, cors::Validated::NotCors));
}
}
#[tokio::test]
async fn origin_allowed() {
let cors = cors::new("*".to_owned()).unwrap();
let cors = cors::new("*".to_owned(), "".to_owned()).unwrap();
let mut headers = HeaderMap::new();
headers.insert("origin", "https://localhost".parse().unwrap());
let methods = [Method::GET, Method::HEAD];
let methods = [Method::GET, Method::HEAD, Method::OPTIONS];
for method in methods {
assert!(cors.check_request(&method, &headers).is_ok())
let res = cors.check_request(&method, &headers);
if method == Method::OPTIONS {
assert!(res.is_err())
} else {
assert!(res.is_ok())
}
}
}
#[tokio::test]
async fn origin_not_allowed() {
let cors = cors::new("https://localhost.rs".to_owned()).unwrap();
let cors = cors::new("https://localhost.rs".to_owned(), "".to_owned()).unwrap();
let mut headers = HeaderMap::new();
headers.insert("origin", "https://localhost".parse().unwrap());
let methods = [Method::GET, Method::HEAD];
let methods = [Method::GET, Method::HEAD, Method::OPTIONS];
for method in methods {
let res = cors.check_request(&method, &headers);
assert!(res.is_err());
assert!(matches!(res.unwrap_err(), cors::Forbidden::Origin))
}
}
#[tokio::test]
async fn method_allowed() {
let cors = cors::new("*".to_owned(), "".to_owned()).unwrap();
let mut headers = HeaderMap::new();
headers.insert("origin", "https://localhost".parse().unwrap());
headers.insert("access-control-request-method", "GET".parse().unwrap());
let methods = [Method::GET, Method::HEAD, Method::OPTIONS];
for method in methods {
assert!(cors.check_request(&method, &headers).is_ok())
}
}
#[tokio::test]
async fn method_disallowed() {
let cors = cors::new("*".to_owned(), "".to_owned()).unwrap();
let mut headers = HeaderMap::new();
headers.insert("origin", "https://localhost".parse().unwrap());
headers.insert("access-control-request-method", "POST".parse().unwrap());
let methods = [Method::GET, Method::HEAD, Method::OPTIONS];
for method in methods {
let res = cors.check_request(&method, &headers);
if method == Method::OPTIONS {
assert!(res.is_err())
} else {
assert!(res.is_ok())
}
}
}
#[tokio::test]
async fn headers_allowed() {
let cors = cors::new("*".to_owned(), "".to_owned()).unwrap();
let mut headers = HeaderMap::new();
headers.insert("origin", "https://localhost".parse().unwrap());
headers.insert("access-control-request-method", "GET".parse().unwrap());
headers.insert(
"access-control-request-headers",
"origin,content-type".parse().unwrap(),
);
let methods = [Method::OPTIONS];
for method in methods {
let res = cors.check_request(&method, &headers);
assert!(res.is_ok())
}
}
#[tokio::test]
async fn headers_invalid() {
let cors = cors::new("*".to_owned(), "".to_owned()).unwrap();
let mut headers = HeaderMap::new();
headers.insert("origin", "https://localhost".parse().unwrap());
headers.insert(
"access-control-request-method",
"GET,HEAD,OPTIONS".parse().unwrap(),
);
headers.insert(
"access-control-request-headers",
"origin, content-type".parse().unwrap(),
);
let methods = [Method::GET, Method::HEAD, Method::OPTIONS];
for method in &methods {
let res = cors.check_request(method, &headers);
if method == Method::OPTIONS {
assert!(res.is_err())
} else {
assert!(res.is_ok())
}
}
let mut headers = HeaderMap::new();
headers.insert("origin", "https://localhost".parse().unwrap());
headers.insert("access-control-request-method", "GET".parse().unwrap());
headers.insert(
"access-control-request-headers",
"origin,authorization".parse().unwrap(),
);
for method in methods {
let res = cors.check_request(&method, &headers);
if method == Method::OPTIONS {
assert!(res.is_err())
} else {
assert!(res.is_ok())
}
}
}
}
@@ -22,7 +22,6 @@ mod tests {
Method::DELETE,
Method::GET,
Method::HEAD,
Method::OPTIONS,
Method::PATCH,
Method::POST,
Method::PUT,
@@ -388,7 +388,6 @@ mod tests {
Method::DELETE,
Method::GET,
Method::HEAD,
Method::OPTIONS,
Method::PATCH,
Method::POST,
Method::PUT,