From af9a32951c866e180dee6c8c359be13cc1f79d29 Mon Sep 17 00:00:00 2001 From: Jose Quintana Date: Tue, 1 Jun 2021 14:17:11 +0200 Subject: [PATCH] feat: cors support --- src/cors.rs | 267 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ src/handler.rs | 22 ++++++++++++++++++++-- src/lib.rs | 1 + src/server.rs | 8 +++++--- 4 files changed, 293 insertions(+), 5 deletions(-) create mode 100644 src/cors.rs diff --git a/src/cors.rs b/src/cors.rs new file mode 100644 index 0000000..c110a30 --- /dev/null +++ b/src/cors.rs @@ -0,0 +1,267 @@ +// CORS handler for incoming requests. +// -> Part of the file is borrowed from https://github.com/seanmonstar/warp/blob/master/src/filters/cors.rs + +use headers::{HeaderName, HeaderValue, Origin}; +use http::header; +use std::{collections::HashSet, convert::TryFrom, sync::Arc}; + +/// It defines CORS instance. +#[derive(Clone, Debug)] +pub struct Cors { + allowed_headers: HashSet, + max_age: Option, + allowed_methods: HashSet, + origins_str: String, + origins: Option>, +} + +/// It builds a new CORS instance. +pub fn new(origins_str: String) -> Option> { + let cors = Cors::new(origins_str.clone()); + let cors = if origins_str.is_empty() { + None + } else if origins_str == "*" { + Some(cors.allow_any_origin().allow_methods(vec!["GET", "HEAD"])) + } else { + let hosts = origins_str + .split(',') + .map(|s| s.trim().as_ref()) + .collect::>(); + + if hosts.is_empty() { + None + } else { + Some(cors.allow_origins(hosts).allow_methods(vec!["GET", "HEAD"])) + } + }; + if cors.is_some() { + tracing::info!( + "enabled=true, allow_methods=[GET, HEAD], allow_origins={}", + origins_str + ); + } + Cors::build(cors) +} + +impl Cors { + /// Creates a new Cors instance. + pub fn new(origins_str: String) -> Self { + Self { + origins: None, + allowed_headers: HashSet::new(), + allowed_methods: HashSet::new(), + max_age: None, + origins_str, + } + } + + /// Adds multiple methods to the existing list of allowed request methods. + /// + /// # Panics + /// + /// Panics if the provided argument is not a valid `http::Method`. + pub fn allow_methods(mut self, methods: I) -> Self + where + I: IntoIterator, + http::Method: TryFrom, + { + let iter = methods.into_iter().map(|m| match TryFrom::try_from(m) { + Ok(m) => m, + Err(_) => panic!("cors: illegal method"), + }); + self.allowed_methods.extend(iter); + self + } + + /// Sets that *any* `Origin` header is allowed. + /// + /// # Warning + /// + /// This can allow websites you didn't intend to access this resource, + /// it is usually better to set an explicit list. + pub fn allow_any_origin(mut self) -> Self { + self.origins = None; + self + } + + /// Add multiple origins to the existing list of allowed `Origin`s. + /// + /// # Panics + /// + /// Panics if the provided argument is not a valid `Origin`. + pub fn allow_origins(mut self, origins: I) -> Self + where + I: IntoIterator, + I::Item: IntoOrigin, + { + let iter = origins + .into_iter() + .map(IntoOrigin::into_origin) + .map(|origin| { + origin + .to_string() + .parse() + .expect("cors: Origin is always a valid HeaderValue") + }); + + self.origins.get_or_insert_with(HashSet::new).extend(iter); + self + } + + /// Sets the `Access-Control-Max-Age` header. + /// TODO: we could enable this in the future. + /// + /// # Example + /// + /// ``` + /// let cors = cors::new("*") + /// .max_age(30) // 30u32 seconds + /// .max_age(Duration::from_secs(30)); // or a Duration + /// ``` + pub fn max_age(mut self, seconds: impl Seconds) -> Self { + self.max_age = Some(seconds.seconds()); + self + } + + /// Builds the `Cors` wrapper from the configured settings. + pub fn build(cors: Option) -> Option> { + cors.as_ref()?; + let cors = cors?; + Some(Arc::new(Configured { cors })) + } +} + +impl Default for Cors { + fn default() -> Self { + Self::new("*".to_string()) + } +} + +#[derive(Clone, Debug)] +pub struct Configured { + cors: Cors, +} + +#[derive(Debug)] +pub enum Validated { + Preflight(HeaderValue), + Simple(HeaderValue), + NotCors, +} + +#[derive(Debug)] +pub enum Forbidden { + Origin, + Method, + Header, +} + +impl Default for Forbidden { + fn default() -> Self { + Self::Origin + } +} + +impl Configured { + pub fn check_request( + &self, + method: &http::Method, + headers: &http::HeaderMap, + ) -> Result { + match (headers.get(header::ORIGIN), method) { + (Some(origin), &http::Method::OPTIONS) => { + // OPTIONS requests are preflight CORS requests... + + if !self.is_origin_allowed(origin) { + return Err(Forbidden::Origin); + } + + if let Some(req_method) = headers.get(header::ACCESS_CONTROL_REQUEST_METHOD) { + if !self.is_method_allowed(req_method) { + return Err(Forbidden::Method); + } + } else { + tracing::trace!( + "cors: preflight request missing access-control-request-method header" + ); + return Err(Forbidden::Method); + } + + if let Some(req_headers) = headers.get(header::ACCESS_CONTROL_REQUEST_HEADERS) { + let headers = req_headers.to_str().map_err(|_| Forbidden::Header)?; + for header in headers.split(',') { + if !self.is_header_allowed(header.trim()) { + return Err(Forbidden::Header); + } + } + } + + Ok(Validated::Preflight(origin.clone())) + } + (Some(origin), _) => { + // Any other method, simply check for a valid origin... + tracing::trace!("cors origin header: {:?}", origin); + + if self.is_origin_allowed(origin) { + Ok(Validated::Simple(origin.clone())) + } else { + Err(Forbidden::Origin) + } + } + (None, _) => { + // No `ORIGIN` header means this isn't CORS! + Ok(Validated::NotCors) + } + } + } + + pub fn is_method_allowed(&self, header: &HeaderValue) -> bool { + http::Method::from_bytes(header.as_bytes()) + .map(|method| self.cors.allowed_methods.contains(&method)) + .unwrap_or(false) + } + + pub fn is_header_allowed(&self, header: &str) -> bool { + HeaderName::from_bytes(header.as_bytes()) + .map(|header| self.cors.allowed_headers.contains(&header)) + .unwrap_or(false) + } + + pub fn is_origin_allowed(&self, origin: &HeaderValue) -> bool { + if let Some(ref allowed) = self.cors.origins { + allowed.contains(origin) + } else { + true + } + } +} + +pub trait Seconds { + fn seconds(self) -> u64; +} + +impl Seconds for u32 { + fn seconds(self) -> u64 { + self.into() + } +} + +impl Seconds for ::std::time::Duration { + fn seconds(self) -> u64 { + self.as_secs() + } +} + +pub trait IntoOrigin { + fn into_origin(self) -> Origin; +} + +impl<'a> IntoOrigin for &'a str { + fn into_origin(self) -> Origin { + let mut parts = self.splitn(2, "://"); + let scheme = parts.next().expect("cors::into_origin: missing url scheme"); + let rest = parts.next().expect("cors::into_origin: missing url scheme"); + + Origin::try_from_parts(scheme, rest, None).expect("cors::into_origin: invalid Origin") + } +} diff --git a/src/handler.rs b/src/handler.rs index 04bd213..e6f643a 100644 --- a/src/handler.rs +++ b/src/handler.rs @@ -1,7 +1,8 @@ +use http::StatusCode; use hyper::{Body, Request, Response}; -use std::{future::Future, path::PathBuf}; +use std::{future::Future, path::PathBuf, sync::Arc}; -use crate::{compression, control_headers, static_files}; +use crate::{compression, control_headers, cors, static_files}; use crate::{error_page, Error, Result}; // It defines options for a request handler. @@ -9,6 +10,7 @@ pub struct RequestHandlerOpts { pub root_dir: PathBuf, pub compression: bool, pub dir_listing: bool, + pub cors: Option>, } // It defines the main request handler for Hyper service request. @@ -23,11 +25,27 @@ impl RequestHandler { ) -> impl Future, Error>> + Send + 'a { let method = req.method(); let headers = req.headers(); + let root_dir = self.opts.root_dir.as_path(); let uri_path = req.uri().path(); let dir_listing = self.opts.dir_listing; async move { + // CORS + if self.opts.cors.is_some() { + let cors = self.opts.cors.as_ref().unwrap(); + match cors.check_request(method, headers) { + Ok(r) => { + tracing::debug!("cors ok: {:?}", r); + } + Err(e) => { + tracing::debug!("cors error kind: {:?}", e); + return error_page::get_error_response(method, &StatusCode::FORBIDDEN); + } + }; + } + + // Static files match static_files::handle_request(method, headers, root_dir, uri_path, dir_listing) .await { diff --git a/src/lib.rs b/src/lib.rs index a9b5e38..1b3d95c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,6 +6,7 @@ extern crate anyhow; pub mod compression; pub mod config; pub mod control_headers; +pub mod cors; pub mod error_page; pub mod handler; pub mod helpers; diff --git a/src/server.rs b/src/server.rs index ef53295..5d54942 100644 --- a/src/server.rs +++ b/src/server.rs @@ -8,7 +8,7 @@ use crate::handler::{RequestHandler, RequestHandlerOpts}; use crate::tls::{TlsAcceptor, TlsConfigBuilder}; use crate::Result; use crate::{config::Config, service::RouterService}; -use crate::{error_page, helpers, logger}; +use crate::{cors, error_page, helpers, logger}; /// Define a multi-thread HTTP or HTTP/2 web server. pub struct Server { @@ -89,8 +89,6 @@ impl Server { .set(helpers::read_file_content(opts.page50x.as_ref())) .expect("page 50x is not initialized"); - // TODO: CORS support - // Auto compression based on the `Accept-Encoding` header let compression = opts.compression; tracing::info!("auto compression compression: enabled={}", compression); @@ -102,12 +100,16 @@ impl Server { // Spawn a new Tokio asynchronous server task with its given options let threads = self.threads; + // CORS support + let cors = cors::new(opts.cors_allow_origins.trim().to_string()); + // Create a service router for Hyper let router_service = RouterService::new(RequestHandler { opts: RequestHandlerOpts { root_dir, compression, dir_listing, + cors, }, }); -- libgit2 1.7.2