refactor: reject requests for non head or get methods
- it skips compression for HEAD requests only
- it skips error page content on HEAD requests only
Diff
src/compression.rs | 16 ++++++++++++----
src/error_page.rs | 2 +-
src/handler.rs | 8 +++++---
src/static_files.rs | 10 ++++++++--
4 files changed, 26 insertions(+), 10 deletions(-)
@@ -5,10 +5,9 @@ use async_compression::tokio::bufread::{BrotliEncoder, DeflateEncoder, GzipEncod
use bytes::Bytes;
use futures::Stream;
use headers::{AcceptEncoding, ContentCoding, ContentType, HeaderMap, HeaderMapExt};
use http::header::HeaderValue;
use hyper::{
header::{CONTENT_ENCODING, CONTENT_LENGTH},
Body, Response,
header::{HeaderValue, CONTENT_ENCODING, CONTENT_LENGTH},
Body, Method, Response,
};
use pin_project::pin_project;
use std::convert::TryFrom;
@@ -42,7 +41,16 @@ pub const TEXT_MIME_TYPES: [&str; 16] = [
pub fn auto(headers: &HeaderMap<HeaderValue>, resp: Response<Body>) -> Result<Response<Body>> {
pub fn auto(
method: &Method,
headers: &HeaderMap<HeaderValue>,
resp: Response<Body>,
) -> Result<Response<Body>> {
if method == Method::HEAD {
return Ok(resp);
}
if let Some(content_type) = resp.headers().typed_get::<ContentType>() {
let content_type = content_type.to_string();
@@ -74,7 +74,7 @@ pub fn get_error_response(method: &Method, status_code: &StatusCode) -> Result<R
let mut body = Body::empty();
let len = error_page_content.len() as u64;
if method == Method::GET {
if method != Method::HEAD {
body = Body::from(error_page_content)
}
@@ -7,8 +7,10 @@ use crate::{error::Result, error_page};
pub async fn handle_request(base: &Path, req: Request<Body>) -> Result<Response<Body>> {
let headers = req.headers();
match static_files::handle_request(base, headers, req.uri().path()).await {
Ok(resp) => compression::auto(headers, resp),
Err(status) => error_page::get_error_response(req.method(), &status),
let method = req.method();
match static_files::handle_request(method, headers, base, req.uri().path()).await {
Ok(resp) => compression::auto(method, headers, resp),
Err(status) => error_page::get_error_response(method, &status),
}
}
@@ -8,7 +8,7 @@ use headers::{
AcceptRanges, ContentLength, ContentRange, ContentType, HeaderMap, HeaderMapExt, HeaderValue,
IfModifiedSince, IfRange, IfUnmodifiedSince, LastModified, Range,
};
use hyper::{Body, Response, StatusCode};
use hyper::{Body, Method, Response, StatusCode};
use percent_encoding::percent_decode_str;
use std::fs::Metadata;
use std::future::Future;
@@ -36,10 +36,16 @@ impl AsRef<Path> for ArcPath {
pub async fn handle_request(
base: &Path,
method: &Method,
headers: &HeaderMap<HeaderValue>,
base: &Path,
uri_path: &str,
) -> Result<Response<Body>, StatusCode> {
if !(method == Method::HEAD || method == Method::GET) {
return Err(StatusCode::METHOD_NOT_ALLOWED);
}
let base = Arc::new(base.into());
let res = path_from_tail(base, uri_path).await?;
file_reply(headers, res).await