#![forbid(unsafe_code)] #![deny(warnings)] #![deny(rust_2018_idioms)] #![deny(dead_code)] #[cfg(test)] mod tests { use headers::HeaderMap; use http::Method; use static_web_server::cors; #[tokio::test] async fn allow_methods() { let cors = cors::new("*", "", "").unwrap(); let headers = HeaderMap::new(); let methods = &[Method::GET, Method::HEAD, Method::OPTIONS]; for method in methods { assert!(cors.check_request(method, &headers).is_ok()); } let cors = cors::new("https://localhost", "", "").unwrap(); let mut headers = HeaderMap::new(); headers.insert("origin", "https://localhost".parse().unwrap()); headers.insert("access-control-request-method", "GET".parse().unwrap()); for method in methods { assert!(cors.check_request(method, &headers).is_ok()); } } #[test] fn disallow_methods() { let cors = cors::new("*", "", "").unwrap(); let headers = HeaderMap::new(); let methods = [ Method::CONNECT, Method::DELETE, Method::PATCH, Method::POST, Method::PUT, Method::TRACE, ]; for method in methods { let res = cors.check_request(&method, &headers); assert!(res.is_ok()); let res = res.unwrap(); assert!(res.0.is_empty()); assert!(matches!(res.1, cors::Validated::NotCors)); } } #[tokio::test] async fn origin_allowed() { let cors = cors::new("*", "", "").unwrap(); let mut headers = HeaderMap::new(); headers.insert("origin", "https://localhost".parse().unwrap()); let methods = [Method::GET, Method::HEAD, Method::OPTIONS]; for method in methods { let res = cors.check_request(&method, &headers); if method == Method::OPTIONS { // Forbidden (403) - preflight request missing access-control-request-method header assert!(res.is_err()) } else { assert!(res.is_ok()) } } } #[tokio::test] async fn origin_not_allowed() { let cors = cors::new("https://localhost.rs", "", "").unwrap(); let mut headers = HeaderMap::new(); headers.insert("origin", "https://localhost".parse().unwrap()); let methods = [Method::GET, Method::HEAD, Method::OPTIONS]; for method in methods { let res = cors.check_request(&method, &headers); assert!(res.is_err()); assert!(matches!(res.unwrap_err(), cors::Forbidden::Origin)) } } #[tokio::test] async fn method_allowed() { let cors = cors::new("*", "", "").unwrap(); let mut headers = HeaderMap::new(); headers.insert("origin", "https://localhost".parse().unwrap()); headers.insert("access-control-request-method", "GET".parse().unwrap()); let methods = [Method::GET, Method::HEAD, Method::OPTIONS]; for method in methods { assert!(cors.check_request(&method, &headers).is_ok()) } } #[tokio::test] async fn method_disallowed() { let cors = cors::new("*", "", "").unwrap(); let mut headers = HeaderMap::new(); headers.insert("origin", "https://localhost".parse().unwrap()); headers.insert("access-control-request-method", "POST".parse().unwrap()); let methods = [Method::GET, Method::HEAD, Method::OPTIONS]; for method in methods { let res = cors.check_request(&method, &headers); if method == Method::OPTIONS { // Forbidden (403) - preflight request missing access-control-request-method header assert!(res.is_err()) } else { assert!(res.is_ok()) } } } #[tokio::test] async fn headers_allowed() { let cors = cors::new("*", "", "").unwrap(); let mut headers = HeaderMap::new(); headers.insert("origin", "https://localhost".parse().unwrap()); headers.insert("access-control-request-method", "GET".parse().unwrap()); headers.insert( "access-control-request-headers", "origin,content-type".parse().unwrap(), ); let methods = [Method::OPTIONS]; for method in methods { let res = cors.check_request(&method, &headers); assert!(res.is_ok()) } } #[tokio::test] async fn headers_invalid() { let cors = cors::new("*", "", "").unwrap(); let mut headers = HeaderMap::new(); headers.insert("origin", "https://localhost".parse().unwrap()); headers.insert( "access-control-request-method", "GET,HEAD,OPTIONS".parse().unwrap(), ); headers.insert( "access-control-request-headers", "origin, content-type".parse().unwrap(), ); let methods = [Method::GET, Method::HEAD, Method::OPTIONS]; for method in &methods { let res = cors.check_request(method, &headers); if method == Method::OPTIONS { assert!(res.is_err()) } else { assert!(res.is_ok()) } } let mut headers = HeaderMap::new(); headers.insert("origin", "https://localhost".parse().unwrap()); headers.insert("access-control-request-method", "GET".parse().unwrap()); headers.insert( "access-control-request-headers", "origin,authorization".parse().unwrap(), ); for method in methods { let res = cors.check_request(&method, &headers); if method == Method::OPTIONS { assert!(res.is_err()) } else { assert!(res.is_ok()) } } } }