From 8c4ce94abb8d77b89a19575d260d4747a386c0e5 Mon Sep 17 00:00:00 2001 From: Jose Quintana Date: Fri, 22 Jan 2021 22:40:43 +0100 Subject: [PATCH] feat: cors support --- src/bin/server.rs | 243 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++---------------------------------------------------------------------------------------------- src/core/cache.rs | 2 +- src/core/config.rs | 19 ++++++++++++++----- src/core/cors.rs | 36 ++++++++++++++++++++++++++++++++++++ src/core/mod.rs | 1 + 5 files changed, 201 insertions(+), 100 deletions(-) create mode 100644 src/core/cors.rs diff --git a/src/bin/server.rs b/src/bin/server.rs index 03a4ed4..7c3aa44 100644 --- a/src/bin/server.rs +++ b/src/bin/server.rs @@ -7,7 +7,6 @@ static ALLOC: jemallocator::Jemalloc = jemallocator::Jemalloc; extern crate static_web_server; use structopt::StructOpt; -use tracing::warn; use warp::Filter; use self::static_web_server::core::*; @@ -16,115 +15,171 @@ use self::static_web_server::core::*; async fn server(opts: config::Options) -> Result { logger::init(&opts.log_level)?; + let host = opts.host.parse::()?; + let port = opts.port; + // Check a valid root directory let root_dir = helpers::get_valid_dirpath(opts.root)?; - // Read custom error pages content + // Custom error pages content let page404 = helpers::read_file_content(opts.page404.as_ref()); let page50x = helpers::read_file_content(opts.page50x.as_ref()); - - // Public HEAD endpoint let page404_a = page404.clone(); let page50x_a = page50x.clone(); - let public_head = warp::head().and( - warp::fs::dir(root_dir.clone()) - .map(cache::control_headers) - .with(warp::trace::request()) - .recover(move |rej| { - let page404_a = page404_a.clone(); - let page50x_a = page50x_a.clone(); - async move { rejection::handle_rejection(page404_a, page50x_a, rej).await } - }), - ); - // Public GET endpoint (default) - let page404_b = page404.clone(); - let page50x_b = page50x.clone(); - let public_get_default = warp::get().and( - warp::fs::dir(root_dir.clone()) - .map(cache::control_headers) - .with(warp::trace::request()) - .recover(move |rej| { - let page404_b = page404_b.clone(); - let page50x_b = page50x_b.clone(); - async move { rejection::handle_rejection(page404_b, page50x_b, rej).await } - }), - ); + // CORS support + let (cors_filter, cors_allowed_origins) = + cors::get_opt_cors_filter(opts.cors_allow_origins.as_ref()); + + // Base fs directory filter + let base_dir_filter = warp::fs::dir(root_dir.clone()) + .map(cache::control_headers) + .with(warp::trace::request()) + .recover(move |rej| { + let page404_a = page404_a.clone(); + let page50x_a = page50x_a.clone(); + async move { rejection::handle_rejection(page404_a, page50x_a, rej).await } + }); - let host = opts.host.parse::()?; - let port = opts.port; + // Public HEAD endpoint + let public_head = warp::head().and(base_dir_filter.clone()); + + // Public GET endpoint (default) + let public_get_default = warp::get().and(base_dir_filter.clone()); // Public GET/HEAD endpoints with compression (deflate, gzip, brotli, none) match opts.compression.as_ref() { - "brotli" => tokio::task::spawn( - warp::serve( - public_head.or(warp::get() - .and(cache::accept_encoding("br")) - .and( - warp::fs::dir(root_dir) - .map(cache::control_headers) - .with(warp::trace::request()) - .with(warp::compression::brotli(true)) - .recover(move |rej| { - let page404_c = page404.clone(); - let page50x_c = page50x.clone(); - async move { - rejection::handle_rejection(page404_c, page50x_c, rej).await - } - }), - ) - .or(public_get_default)), - ) - .run((host, port)), - ), - "deflate" => tokio::task::spawn( - warp::serve( - public_head.or(warp::get() - .and(cache::accept_encoding("deflate")) - .and( - warp::fs::dir(root_dir) - .map(cache::control_headers) - .with(warp::trace::request()) - .with(warp::compression::deflate(true)) - .recover(move |rej| { - let page404_c = page404.clone(); - let page50x_c = page50x.clone(); - async move { - rejection::handle_rejection(page404_c, page50x_c, rej).await - } - }), - ) - .or(public_get_default)), - ) - .run((host, port)), - ), - "gzip" => tokio::task::spawn( - warp::serve( - public_head.or(warp::get() - .and(cache::accept_encoding("gzip")) - .and( - warp::fs::dir(root_dir) - .map(cache::control_headers) - .with(warp::trace::request()) - .with(warp::compression::gzip(true)) - .recover(move |rej| { - let page404_c = page404.clone(); - let page50x_c = page50x.clone(); - async move { - rejection::handle_rejection(page404_c, page50x_c, rej).await - } - }), - ) - .or(public_get_default)), - ) - .run((host, port)), - ), - _ => tokio::task::spawn(warp::serve(public_head.or(public_get_default)).run((host, port))), + "brotli" => tokio::task::spawn(async move { + let with_dir = warp::fs::dir(root_dir) + .map(cache::control_headers) + .with(warp::trace::request()) + .with(warp::compression::brotli(true)) + .recover(move |rej| { + let page404 = page404.clone(); + let page50x = page50x.clone(); + async move { rejection::handle_rejection(page404, page50x, rej).await } + }); + + if let Some(cors_filter) = cors_filter { + tracing::info!( + cors_enabled = ?true, + allowed_origins = ?cors_allowed_origins + ); + warp::serve( + public_head.with(cors_filter.clone()).or(warp::get() + .and(cache::has_accept_encoding("br")) + .and(with_dir) + .with(cors_filter.clone()) + .or(public_get_default.with(cors_filter))), + ) + .run((host, port)) + .await + } else { + warp::serve( + public_head.or(warp::get() + .and(cache::has_accept_encoding("br")) + .and(with_dir) + .or(public_get_default)), + ) + .run((host, port)) + .await + } + }), + "deflate" => tokio::task::spawn(async move { + let with_dir = warp::fs::dir(root_dir) + .map(cache::control_headers) + .with(warp::trace::request()) + .with(warp::compression::deflate(true)) + .recover(move |rej| { + let page404 = page404.clone(); + let page50x = page50x.clone(); + async move { rejection::handle_rejection(page404, page50x, rej).await } + }); + + if let Some(cors_filter) = cors_filter { + tracing::info!( + cors_enabled = ?true, + allowed_origins = ?cors_allowed_origins + ); + warp::serve( + public_head.with(cors_filter.clone()).or(warp::get() + .and(cache::has_accept_encoding("deflate")) + .and(with_dir) + .with(cors_filter.clone()) + .or(public_get_default.with(cors_filter))), + ) + .run((host, port)) + .await + } else { + warp::serve( + public_head.or(warp::get() + .and(cache::has_accept_encoding("deflate")) + .and(with_dir) + .or(public_get_default)), + ) + .run((host, port)) + .await + } + }), + "gzip" => tokio::task::spawn(async move { + let with_dir = warp::fs::dir(root_dir) + .map(cache::control_headers) + .with(warp::trace::request()) + .with(warp::compression::gzip(true)) + .recover(move |rej| { + let page404 = page404.clone(); + let page50x = page50x.clone(); + async move { rejection::handle_rejection(page404, page50x, rej).await } + }); + + if let Some(cors_filter) = cors_filter { + tracing::info!( + cors_enabled = ?true, + allowed_origins = ?cors_allowed_origins + ); + warp::serve( + public_head.with(cors_filter.clone()).or(warp::get() + .and(cache::has_accept_encoding("gzip")) + .and(with_dir) + .with(cors_filter.clone()) + .or(public_get_default.with(cors_filter))), + ) + .run((host, port)) + .await + } else { + warp::serve( + public_head.or(warp::get() + .and(cache::has_accept_encoding("gzip")) + .and(with_dir) + .or(public_get_default)), + ) + .run((host, port)) + .await + } + }), + _ => tokio::task::spawn(async move { + if let Some(cors_filter) = cors_filter { + tracing::info!( + cors_enabled = ?true, + allowed_origins = ?cors_allowed_origins + ); + let public_get_default = warp::get() + .and(base_dir_filter.clone()) + .with(cors_filter.clone()); + warp::serve(public_head.or(public_get_default.with(cors_filter))) + .run((host, port)) + .await + } else { + warp::serve(public_head.or(public_get_default)) + .run((host, port)) + .await + } + }), }; signals::wait(|sig: signals::Signal| { let code = signals::as_int(sig); - warn!("Signal {} caught. Server execution exited.", code); + tracing::warn!("Signal {} caught. Server execution exited.", code); std::process::exit(code) }); diff --git a/src/core/cache.rs b/src/core/cache.rs index 148e7d5..a9f93dd 100644 --- a/src/core/cache.rs +++ b/src/core/cache.rs @@ -38,7 +38,7 @@ fn duration(n: u64) -> u32 { } /// Warp filter in order to check for an `Accept-Encoding` header value. -pub fn accept_encoding( +pub fn has_accept_encoding( val: &'static str, ) -> impl warp::Filter + Copy { warp::header::contains("accept-encoding", val) diff --git a/src/core/config.rs b/src/core/config.rs index 406f43c..d8dfe96 100644 --- a/src/core/config.rs +++ b/src/core/config.rs @@ -3,7 +3,7 @@ use structopt::StructOpt; /// Static Web Server #[derive(Debug, StructOpt)] pub struct Options { - #[structopt(long, short = "s", default_value = "::", env = "SERVER_HOST")] + #[structopt(long, short = "a", default_value = "::", env = "SERVER_HOST")] /// Host address (E.g 127.0.0.1 or ::1) pub host: String, @@ -13,7 +13,7 @@ pub struct Options { #[structopt( long, - short = "t", + short = "n", default_value = "8", env = "SERVER_THREADS_MULTIPLIER" )] @@ -23,7 +23,7 @@ pub struct Options { /// Number of worker threads result should be a number between 1 and 32,768 though it is advised to keep this value on the smaller side. pub threads_multiplier: usize, - #[structopt(long, short = "r", default_value = "./public", env = "SERVER_ROOT")] + #[structopt(long, short = "d", default_value = "./public", env = "SERVER_ROOT")] /// Root directory path of static files pub root: String, @@ -43,12 +43,21 @@ pub struct Options { /// HTML file path for 404 errors. If path is not specified or simply don't exists then server will use a generic HTML error message. pub page404: String, - #[structopt(long, short = "c", default_value = "gzip", env = "SERVER_COMPRESSION")] + #[structopt(long, short = "x", default_value = "gzip", env = "SERVER_COMPRESSION")] /// Compression body support for web text-based file types. Values: "gzip", "deflate" or "brotli". /// Use an empty value to skip compression. pub compression: String, - #[structopt(long, short = "l", default_value = "error", env = "SERVER_LOG_LEVEL")] + #[structopt(long, short = "g", default_value = "error", env = "SERVER_LOG_LEVEL")] /// Specify a logging level in lower case. pub log_level: String, + + #[structopt( + long, + short = "c", + default_value = "", + env = "SERVER_CORS_ALLOW_ORIGINS" + )] + /// Specify a optional CORS list of allowed origin hosts separated by comas. Host ports or protocols aren't being checked. Use an asterisk (*) to allow any host. + pub cors_allow_origins: String, } diff --git a/src/core/cors.rs b/src/core/cors.rs new file mode 100644 index 0000000..9482ba9 --- /dev/null +++ b/src/core/cors.rs @@ -0,0 +1,36 @@ +use std::collections::HashSet; +use warp::filters::cors::Builder; + +/// Warp filter which provides an optional CORS if its supported. +pub fn get_opt_cors_filter(origins: &str) -> (Option, String) { + let mut cors_allowed_hosts = String::new(); + let cors_filter = if origins.is_empty() { + None + } else if origins == "*" { + cors_allowed_hosts = origins.into(); + Some( + warp::cors() + .allow_any_origin() + .allow_methods(vec!["GET", "HEAD", "OPTIONS"]), + ) + } else { + cors_allowed_hosts = origins.into(); + let hosts = cors_allowed_hosts + .split(',') + .map(|s| s.trim().as_ref()) + .collect::>(); + + if hosts.is_empty() { + cors_allowed_hosts = hosts.into_iter().collect::>().join(", "); + None + } else { + Some( + warp::cors() + .allow_origins(hosts) + .allow_methods(vec!["GET", "HEAD", "OPTIONS"]), + ) + } + }; + + (cors_filter, cors_allowed_hosts) +} diff --git a/src/core/mod.rs b/src/core/mod.rs index 188dde3..567cfc7 100644 --- a/src/core/mod.rs +++ b/src/core/mod.rs @@ -1,5 +1,6 @@ pub mod cache; pub mod config; +pub mod cors; pub mod helpers; pub mod logger; pub mod rejection; -- libgit2 1.7.2