index : static-web-server.git

ascending towards madness

author Jose Quintana <joseluisquintana20@gmail.com> 2022-03-02 2:27:51.0 +00:00:00
committer Jose Quintana <joseluisquintana20@gmail.com> 2022-03-02 2:27:51.0 +00:00:00
commit
da85b16c396347738b7a1d4ea907eb015e1dbf62 [patch]
tree
a0e2171abdb292f5d76fa69c8456f707bcadd79f
parent
8fe6a67a785bda4cb6b8c3b121c18bca0ca780d7
download
da85b16c396347738b7a1d4ea907eb015e1dbf62.tar.gz

feat: cors allow headers and proper cors response handling

- new string-list `--cors-allow-headers` option for headers control
- `OPTIONS` requests support:
  - to identify server allowed request methods
  - for preflighted requests when using CORS

fixes:

- #86

Diff

 src/compression.rs    |   4 +-
 src/config.rs         |   9 ++++-
 src/cors.rs           | 109 +++++++++++++++++++++++++++++++++++++---------
 src/handler.rs        |  38 ++++++++++++++--
 src/server.rs         |   5 +-
 src/static_files.rs   |  16 +++++--
 tests/cors.rs         | 121 +++++++++++++++++++++++++++++++++++++++++++++------
 tests/dir_listing.rs  |   1 +-
 tests/static_files.rs |   1 +-
 9 files changed, 261 insertions(+), 43 deletions(-)

diff --git a/src/compression.rs b/src/compression.rs
index 54a9a58..9ddd677 100644
--- a/src/compression.rs
+++ b/src/compression.rs
@@ -54,8 +54,8 @@ pub fn auto(
    headers: &HeaderMap<HeaderValue>,
    resp: Response<Body>,
) -> Result<Response<Body>> {
    // Skip compression for HEAD request methods
    if method == Method::HEAD {
    // Skip compression for HEAD and OPTIONS request methods
    if method == Method::HEAD || method == Method::OPTIONS {
        return Ok(resp);
    }

diff --git a/src/config.rs b/src/config.rs
index a6bd12c..5126d9b 100644
--- a/src/config.rs
+++ b/src/config.rs
@@ -73,6 +73,15 @@ pub struct Config {

    #[structopt(
        long,
        short = "j",
        default_value = "origin, content-type",
        env = "SERVER_CORS_ALLOW_HEADERS"
    )]
    /// Specify an optional CORS list of allowed headers separated by comas. Default "origin, content-type". It requires `--cors-allow-origins` to be used along with.
    pub cors_allow_headers: String,

    #[structopt(
        long,
        short = "t",
        parse(try_from_str),
        default_value = "false",
diff --git a/src/cors.rs b/src/cors.rs
index 5bc48fa..89c61d1 100644
--- a/src/cors.rs
+++ b/src/cors.rs
@@ -1,7 +1,10 @@
// CORS handler for incoming requests.
// -> Part of the file is borrowed from https://github.com/seanmonstar/warp/blob/master/src/filters/cors.rs

use headers::{HeaderName, HeaderValue, Origin};
use headers::{
    AccessControlAllowHeaders, AccessControlAllowMethods, HeaderMapExt, HeaderName, HeaderValue,
    Origin,
};
use http::header;
use std::{collections::HashSet, convert::TryFrom, sync::Arc};

@@ -15,26 +18,47 @@ pub struct Cors {
}

/// It builds a new CORS instance.
pub fn new(origins_str: String) -> Option<Arc<Configured>> {
pub fn new(origins_str: String, headers_str: String) -> Option<Arc<Configured>> {
    let cors = Cors::new();
    let cors = if origins_str.is_empty() {
        None
    } else if origins_str == "*" {
        Some(cors.allow_any_origin().allow_methods(vec!["GET", "HEAD"]))
    } else {
        let hosts = origins_str.split(',').map(|s| s.trim()).collect::<Vec<_>>();
        if hosts.is_empty() {
            None
        let headers_vec = if headers_str.is_empty() {
            vec!["origin", "content-type"]
        } else {
            headers_str.split(',').map(|s| s.trim()).collect::<Vec<_>>()
        };
        let headers_str = headers_vec.join(",");

        let cors_res = if origins_str == "*" {
            Some(
                cors.allow_any_origin()
                    .allow_headers(headers_vec)
                    .allow_methods(vec!["GET", "HEAD", "OPTIONS"]),
            )
        } else {
            Some(cors.allow_origins(hosts).allow_methods(vec!["GET", "HEAD"]))
            let hosts = origins_str.split(',').map(|s| s.trim()).collect::<Vec<_>>();
            if hosts.is_empty() {
                None
            } else {
                Some(
                    cors.allow_origins(hosts)
                        .allow_headers(headers_vec)
                        .allow_methods(vec!["GET", "HEAD", "OPTIONS"]),
                )
            }
        };

        if cors_res.is_some() {
            tracing::info!(
                    "enabled=true, allow_headers=[{}], allow_methods=[GET,HEAD,OPTIONS], allow_origins={}",
                    headers_str,
                    origins_str
                );
        }
        cors_res
    };
    if cors.is_some() {
        tracing::info!(
            "enabled=true, allow_methods=[GET, HEAD], allow_origins={}",
            origins_str
        );
    }

    Cors::build(cors)
}

@@ -109,11 +133,39 @@ impl Cors {
        self
    }

    /// Adds multiple headers to the list of allowed request headers.
    ///
    /// **Note**: These should match the values the browser sends via `Access-Control-Request-Headers`, e.g.`content-type`.
    ///
    /// # Panics
    ///
    /// Panics if any of the headers are not a valid `http::header::HeaderName`.
    pub fn allow_headers<I>(mut self, headers: I) -> Self
    where
        I: IntoIterator,
        HeaderName: TryFrom<I::Item>,
    {
        let iter = headers.into_iter().map(|h| match TryFrom::try_from(h) {
            Ok(h) => h,
            Err(_) => panic!("cors: illegal Header"),
        });
        self.allowed_headers.extend(iter);
        self
    }

    /// Builds the `Cors` wrapper from the configured settings.
    pub fn build(cors: Option<Cors>) -> Option<Arc<Configured>> {
        cors.as_ref()?;
        let cors = cors?;
        Some(Arc::new(Configured { cors }))

        let allowed_headers = cors.allowed_headers.iter().cloned().collect();
        let methods_header = cors.allowed_methods.iter().cloned().collect();

        Some(Arc::new(Configured {
            cors,
            allowed_headers,
            methods_header,
        }))
    }
}

@@ -126,6 +178,8 @@ impl Default for Cors {
#[derive(Clone, Debug)]
pub struct Configured {
    cors: Cors,
    allowed_headers: AccessControlAllowHeaders,
    methods_header: AccessControlAllowMethods,
}

#[derive(Debug)]
@@ -153,7 +207,7 @@ impl Configured {
        &self,
        method: &http::Method,
        headers: &http::HeaderMap,
    ) -> Result<Validated, Forbidden> {
    ) -> Result<(http::HeaderMap, Validated), Forbidden> {
        match (headers.get(header::ORIGIN), method) {
            (Some(origin), &http::Method::OPTIONS) => {
                // OPTIONS requests are preflight CORS requests...
@@ -182,21 +236,29 @@ impl Configured {
                    }
                }

                Ok(Validated::Preflight(origin.clone()))
                let mut headers = http::HeaderMap::new();
                self.append_preflight_headers(&mut headers);
                headers.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.into());

                Ok((headers, Validated::Preflight(origin.clone())))
            }
            (Some(origin), _) => {
                // Any other method, simply check for a valid origin...
                tracing::trace!("cors origin header: {:?}", origin);

                if self.is_origin_allowed(origin) {
                    Ok(Validated::Simple(origin.clone()))
                    let mut headers = http::HeaderMap::new();
                    self.append_preflight_headers(&mut headers);
                    headers.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.into());

                    Ok((headers, Validated::Simple(origin.clone())))
                } else {
                    Err(Forbidden::Origin)
                }
            }
            (None, _) => {
                // No `ORIGIN` header means this isn't CORS!
                Ok(Validated::NotCors)
                Ok((http::HeaderMap::new(), Validated::NotCors))
            }
        }
    }
@@ -220,6 +282,15 @@ impl Configured {
            true
        }
    }

    fn append_preflight_headers(&self, headers: &mut http::HeaderMap) {
        headers.typed_insert(self.allowed_headers.clone());
        headers.typed_insert(self.methods_header.clone());

        if let Some(max_age) = self.cors.max_age {
            headers.insert(header::ACCESS_CONTROL_MAX_AGE, max_age.into());
        }
    }
}

pub trait Seconds {
diff --git a/src/handler.rs b/src/handler.rs
index f1cb3d0..6d3262d 100644
--- a/src/handler.rs
+++ b/src/handler.rs
@@ -1,4 +1,6 @@
use hyper::{header::WWW_AUTHENTICATE, Body, Request, Response, StatusCode};
use headers::{AcceptRanges, HeaderMapExt, HeaderValue};
use http::header::ALLOW;
use hyper::{header::WWW_AUTHENTICATE, Body, Method, Request, Response, StatusCode};
use std::{future::Future, path::PathBuf, sync::Arc};

use crate::{
@@ -41,13 +43,26 @@ impl RequestHandler {
        let dir_listing = self.opts.dir_listing;
        let dir_listing_order = self.opts.dir_listing_order;

        let mut cors_headers: Option<http::HeaderMap> = None;

        async move {
            // Check for disallowed HTTP methods and reject request accordently
            if !(method == Method::GET || method == Method::HEAD || method == Method::OPTIONS) {
                return error_page::error_response(
                    method,
                    &StatusCode::METHOD_NOT_ALLOWED,
                    self.opts.page404.as_ref(),
                    self.opts.page50x.as_ref(),
                );
            }

            // CORS
            if self.opts.cors.is_some() {
                let cors = self.opts.cors.as_ref().unwrap();
                match cors.check_request(method, headers) {
                    Ok(res) => {
                        tracing::debug!("cors ok: {:?}", res);
                    Ok((headers, state)) => {
                        tracing::debug!("cors state: {:?}", state);
                        cors_headers = Some(headers);
                    }
                    Err(err) => {
                        tracing::error!("cors error kind: {:?}", err);
@@ -104,6 +119,13 @@ impl RequestHandler {
            .await
            {
                Ok(mut resp) => {
                    // Append CORS headers if they are present
                    if let Some(cors_headers) = cors_headers {
                        for (k, v) in cors_headers.iter() {
                            resp.headers_mut().insert(k, v.to_owned());
                        }
                    }

                    // Auto compression based on the `Accept-Encoding` header
                    if self.opts.compression {
                        resp = match compression::auto(method, headers, resp) {
@@ -130,6 +152,16 @@ impl RequestHandler {
                        security_headers::append_headers(&mut resp);
                    }

                    // Respond with the permitted communication options
                    if method == Method::OPTIONS {
                        *resp.status_mut() = StatusCode::NO_CONTENT;
                        *resp.body_mut() = Body::empty();
                        resp.headers_mut()
                            .insert(ALLOW, HeaderValue::from_static("OPTIONS, GET, HEAD"));
                        resp.headers_mut().typed_insert(AcceptRanges::bytes());
                        return Ok(resp);
                    }

                    Ok(resp)
                }
                Err(status) => error_page::error_response(
diff --git a/src/server.rs b/src/server.rs
index c49a5fc..8a135a6 100644
--- a/src/server.rs
+++ b/src/server.rs
@@ -121,7 +121,10 @@ impl Server {
        tracing::info!("cache control headers: enabled={}", cache_control_headers);

        // CORS option
        let cors = cors::new(opts.cors_allow_origins.trim().to_owned());
        let cors = cors::new(
            opts.cors_allow_origins.trim().to_owned(),
            opts.cors_allow_headers.trim().to_owned(),
        );

        // `Basic` HTTP Authentication Schema option
        let basic_auth = opts.basic_auth.trim();
diff --git a/src/static_files.rs b/src/static_files.rs
index 2433187..54b1cf9 100644
--- a/src/static_files.rs
+++ b/src/static_files.rs
@@ -8,7 +8,7 @@ use headers::{
    AcceptRanges, ContentLength, ContentRange, ContentType, HeaderMap, HeaderMapExt, HeaderValue,
    IfModifiedSince, IfRange, IfUnmodifiedSince, LastModified, Range,
};
use http::header::CONTENT_TYPE;
use http::header::{ALLOW, CONTENT_TYPE};
use humansize::{file_size_opts, FileSize};
use hyper::{Body, Method, Response, StatusCode};
use percent_encoding::percent_decode_str;
@@ -50,11 +50,21 @@ pub async fn handle(
    dir_listing: bool,
    dir_listing_order: u8,
) -> Result<Response<Body>, StatusCode> {
    // Reject requests for non HEAD or GET methods
    if !(method == Method::HEAD || method == Method::GET) {
    // Check for disallowed HTTP methods and reject request accordently
    if !(method == Method::GET || method == Method::HEAD || method == Method::OPTIONS) {
        return Err(StatusCode::METHOD_NOT_ALLOWED);
    }

    // Respond the permitted communication options
    if method == Method::OPTIONS {
        let mut resp = Response::new(Body::empty());
        *resp.status_mut() = StatusCode::NO_CONTENT;
        resp.headers_mut()
            .insert(ALLOW, HeaderValue::from_static("OPTIONS, GET, HEAD"));
        resp.headers_mut().typed_insert(AcceptRanges::bytes());
        return Ok(resp);
    }

    let base = Arc::new(base_path.into());
    let (filepath, meta, auto_index) = path_from_tail(base, uri_path).await?;

diff --git a/tests/cors.rs b/tests/cors.rs
index 0917a03..0f4db64 100644
--- a/tests/cors.rs
+++ b/tests/cors.rs
@@ -11,29 +11,29 @@ mod tests {

    #[tokio::test]
    async fn allow_methods() {
        let cors = cors::new("*".to_owned()).unwrap();
        let cors = cors::new("*".to_owned(), "".to_owned()).unwrap();
        let headers = HeaderMap::new();
        let methods = &[Method::GET, Method::HEAD];
        let methods = &[Method::GET, Method::HEAD, Method::OPTIONS];
        for method in methods {
            assert!(cors.check_request(method, &headers).is_ok())
            assert!(cors.check_request(method, &headers).is_ok());
        }

        let cors = cors::new("https://localhost".to_owned()).unwrap();
        let cors = cors::new("https://localhost".to_owned(), "".to_owned()).unwrap();
        let mut headers = HeaderMap::new();
        headers.insert("origin", "https://localhost".parse().unwrap());
        headers.insert("access-control-request-method", "GET".parse().unwrap());
        for method in methods {
            assert!(cors.check_request(method, &headers).is_ok())
            assert!(cors.check_request(method, &headers).is_ok());
        }
    }

    #[test]
    fn disallow_methods() {
        let cors = cors::new("*".to_owned()).unwrap();
        let cors = cors::new("*".to_owned(), "".to_owned()).unwrap();
        let headers = HeaderMap::new();
        let methods = [
            Method::CONNECT,
            Method::DELETE,
            Method::OPTIONS,
            Method::PATCH,
            Method::POST,
            Method::PUT,
@@ -42,31 +42,126 @@ mod tests {
        for method in methods {
            let res = cors.check_request(&method, &headers);
            assert!(res.is_ok());
            assert!(matches!(res.unwrap(), cors::Validated::NotCors));
            let res = res.unwrap();
            assert!(res.0.is_empty());
            assert!(matches!(res.1, cors::Validated::NotCors));
        }
    }

    #[tokio::test]
    async fn origin_allowed() {
        let cors = cors::new("*".to_owned()).unwrap();
        let cors = cors::new("*".to_owned(), "".to_owned()).unwrap();
        let mut headers = HeaderMap::new();
        headers.insert("origin", "https://localhost".parse().unwrap());
        let methods = [Method::GET, Method::HEAD];
        let methods = [Method::GET, Method::HEAD, Method::OPTIONS];
        for method in methods {
            assert!(cors.check_request(&method, &headers).is_ok())
            let res = cors.check_request(&method, &headers);
            if method == Method::OPTIONS {
                // Forbidden (403) - preflight request missing access-control-request-method header
                assert!(res.is_err())
            } else {
                assert!(res.is_ok())
            }
        }
    }

    #[tokio::test]
    async fn origin_not_allowed() {
        let cors = cors::new("https://localhost.rs".to_owned()).unwrap();
        let cors = cors::new("https://localhost.rs".to_owned(), "".to_owned()).unwrap();
        let mut headers = HeaderMap::new();
        headers.insert("origin", "https://localhost".parse().unwrap());
        let methods = [Method::GET, Method::HEAD];
        let methods = [Method::GET, Method::HEAD, Method::OPTIONS];
        for method in methods {
            let res = cors.check_request(&method, &headers);
            assert!(res.is_err());
            assert!(matches!(res.unwrap_err(), cors::Forbidden::Origin))
        }
    }

    #[tokio::test]
    async fn method_allowed() {
        let cors = cors::new("*".to_owned(), "".to_owned()).unwrap();
        let mut headers = HeaderMap::new();
        headers.insert("origin", "https://localhost".parse().unwrap());
        headers.insert("access-control-request-method", "GET".parse().unwrap());
        let methods = [Method::GET, Method::HEAD, Method::OPTIONS];
        for method in methods {
            assert!(cors.check_request(&method, &headers).is_ok())
        }
    }

    #[tokio::test]
    async fn method_disallowed() {
        let cors = cors::new("*".to_owned(), "".to_owned()).unwrap();
        let mut headers = HeaderMap::new();
        headers.insert("origin", "https://localhost".parse().unwrap());
        headers.insert("access-control-request-method", "POST".parse().unwrap());
        let methods = [Method::GET, Method::HEAD, Method::OPTIONS];
        for method in methods {
            let res = cors.check_request(&method, &headers);
            if method == Method::OPTIONS {
                // Forbidden (403) - preflight request missing access-control-request-method header
                assert!(res.is_err())
            } else {
                assert!(res.is_ok())
            }
        }
    }

    #[tokio::test]
    async fn headers_allowed() {
        let cors = cors::new("*".to_owned(), "".to_owned()).unwrap();
        let mut headers = HeaderMap::new();
        headers.insert("origin", "https://localhost".parse().unwrap());
        headers.insert("access-control-request-method", "GET".parse().unwrap());
        headers.insert(
            "access-control-request-headers",
            "origin,content-type".parse().unwrap(),
        );
        let methods = [Method::OPTIONS];
        for method in methods {
            let res = cors.check_request(&method, &headers);
            assert!(res.is_ok())
        }
    }

    #[tokio::test]
    async fn headers_invalid() {
        let cors = cors::new("*".to_owned(), "".to_owned()).unwrap();
        let mut headers = HeaderMap::new();
        headers.insert("origin", "https://localhost".parse().unwrap());
        headers.insert(
            "access-control-request-method",
            "GET,HEAD,OPTIONS".parse().unwrap(),
        );
        headers.insert(
            "access-control-request-headers",
            "origin, content-type".parse().unwrap(),
        );
        let methods = [Method::GET, Method::HEAD, Method::OPTIONS];
        for method in &methods {
            let res = cors.check_request(method, &headers);
            if method == Method::OPTIONS {
                assert!(res.is_err())
            } else {
                assert!(res.is_ok())
            }
        }

        let mut headers = HeaderMap::new();
        headers.insert("origin", "https://localhost".parse().unwrap());
        headers.insert("access-control-request-method", "GET".parse().unwrap());
        headers.insert(
            "access-control-request-headers",
            "origin,authorization".parse().unwrap(),
        );
        for method in methods {
            let res = cors.check_request(&method, &headers);
            if method == Method::OPTIONS {
                assert!(res.is_err())
            } else {
                assert!(res.is_ok())
            }
        }
    }
}
diff --git a/tests/dir_listing.rs b/tests/dir_listing.rs
index 8000918..23cdd53 100644
--- a/tests/dir_listing.rs
+++ b/tests/dir_listing.rs
@@ -22,7 +22,6 @@ mod tests {
            Method::DELETE,
            Method::GET,
            Method::HEAD,
            Method::OPTIONS,
            Method::PATCH,
            Method::POST,
            Method::PUT,
diff --git a/tests/static_files.rs b/tests/static_files.rs
index 5b1485a..44e1ad9 100644
--- a/tests/static_files.rs
+++ b/tests/static_files.rs
@@ -388,7 +388,6 @@ mod tests {
            Method::DELETE,
            Method::GET,
            Method::HEAD,
            Method::OPTIONS,
            Method::PATCH,
            Method::POST,
            Method::PUT,