use hyper::server::conn::AddrIncoming;
use hyper::server::Server as HyperServer;
use listenfd::ListenFd;
use std::net::{IpAddr, SocketAddr, TcpListener};
use std::sync::Arc;
use crate::handler::{RequestHandler, RequestHandlerOpts};
use crate::tls::{TlsAcceptor, TlsConfigBuilder};
use crate::{cors, helpers, logger, signals, Settings};
use crate::{service::RouterService, Context, Result};
pub struct Server {
opts: Settings,
threads: usize,
}
impl Server {
pub fn new() -> Result<Server> {
let opts = Settings::get()?;
let cpus = num_cpus::get();
let threads = match opts.general.threads_multiplier {
0 | 1 => cpus,
n => cpus * n,
};
Ok(Server { opts, threads })
}
pub fn run(self) -> Result {
tokio::runtime::Builder::new_multi_thread()
.worker_threads(self.threads)
.thread_name("static-web-server")
.enable_all()
.build()?
.block_on(async {
let r = self.start_server().await;
if r.is_err() {
println!("server failed to start up: {:?}", r.unwrap_err());
std::process::exit(1)
}
});
Ok(())
}
async fn start_server(self) -> Result {
let general = self.opts.general;
let advanced_opts = self.opts.advanced;
let log_level = &general.log_level.to_lowercase();
logger::init(log_level).with_context(|| "failed to initialize logging")?;
tracing::info!("logging level: {}", log_level.to_lowercase());
if general.config_file.is_some() && general.config_file.is_some() {
tracing::info!("config file: {}", general.config_file.unwrap().display());
}
let (tcp_listener, addr_str);
match general.fd {
Some(fd) => {
addr_str = format!("@FD({})", fd);
tcp_listener = ListenFd::from_env()
.take_tcp_listener(fd)?
.with_context(|| "failed to convert inherited FD into a TCP listener")?;
tracing::info!(
"converted inherited file descriptor {} to a TCP listener",
fd
);
}
None => {
let ip = general
.host
.parse::<IpAddr>()
.with_context(|| format!("failed to parse {} address", general.host))?;
let addr = SocketAddr::from((ip, general.port));
tcp_listener = TcpListener::bind(addr)
.with_context(|| format!("failed to bind to {} address", addr))?;
addr_str = addr.to_string();
tracing::info!("server bound to TCP socket {}", addr_str);
}
}
let root_dir = helpers::get_valid_dirpath(&general.root)
.with_context(|| "root directory was not found or inaccessible")?;
let page404 = helpers::read_file_content(&general.page404);
let page50x = helpers::read_file_content(&general.page50x);
let page_fallback = helpers::read_file_content(&general.page_fallback);
let threads = self.threads;
tracing::info!("runtime worker threads: {}", self.threads);
let security_headers = general.security_headers;
tracing::info!("security headers: enabled={}", security_headers);
let compression = general.compression;
tracing::info!("auto compression: enabled={}", compression);
let dir_listing = general.directory_listing;
tracing::info!("directory listing: enabled={}", dir_listing);
let dir_listing_order = general.directory_listing_order;
tracing::info!("directory listing order code: {}", dir_listing_order);
let cache_control_headers = general.cache_control_headers;
tracing::info!("cache control headers: enabled={}", cache_control_headers);
let cors = cors::new(
general.cors_allow_origins.trim(),
general.cors_allow_headers.trim(),
);
let basic_auth = general.basic_auth.trim().to_owned();
tracing::info!(
"basic authentication: enabled={}",
!general.basic_auth.is_empty()
);
let grace_period = general.grace_period;
tracing::info!("grace period before graceful shutdown: {}s", grace_period);
let router_service = RouterService::new(RequestHandler {
opts: Arc::from(RequestHandlerOpts {
root_dir,
compression,
dir_listing,
dir_listing_order,
cors,
security_headers,
cache_control_headers,
page404,
page50x,
page_fallback,
basic_auth,
advanced_opts,
}),
});
if general.http2 {
tcp_listener
.set_nonblocking(true)
.expect("cannot set non-blocking");
let listener = tokio::net::TcpListener::from_std(tcp_listener)
.with_context(|| "failed to create tokio::net::TcpListener")?;
let mut incoming = AddrIncoming::from_listener(listener).with_context(|| {
"failed to create an AddrIncoming from the current tokio::net::TcpListener"
})?;
incoming.set_nodelay(true);
let tls = TlsConfigBuilder::new()
.cert_path(&general.http2_tls_cert)
.key_path(&general.http2_tls_key)
.build()
.with_context(|| {
"failed to initialize TLS, probably wrong cert/key or file missing"
})?;
#[cfg(unix)]
let signals = signals::create_signals()
.with_context(|| "failed to register termination signals")?;
#[cfg(unix)]
let handle = signals.handle();
let server =
HyperServer::builder(TlsAcceptor::new(tls, incoming)).serve(router_service);
#[cfg(unix)]
let server =
server.with_graceful_shutdown(signals::wait_for_signals(signals, grace_period));
#[cfg(windows)]
let server = server.with_graceful_shutdown(signals::wait_for_ctrl_c(grace_period));
tracing::info!(
parent: tracing::info_span!("Server::start_server", ?addr_str, ?threads),
"listening on https://{}",
addr_str
);
tracing::info!("press ctrl+c to shut down the server");
server.await?;
#[cfg(unix)]
handle.close();
} else {
#[cfg(unix)]
let signals = signals::create_signals()
.with_context(|| "failed to register termination signals")?;
#[cfg(unix)]
let handle = signals.handle();
let server = HyperServer::from_tcp(tcp_listener)
.unwrap()
.tcp_nodelay(true)
.serve(router_service);
#[cfg(unix)]
let server =
server.with_graceful_shutdown(signals::wait_for_signals(signals, grace_period));
#[cfg(windows)]
let server = server.with_graceful_shutdown(signals::wait_for_ctrl_c(grace_period));
tracing::info!(
parent: tracing::info_span!("Server::start_server", ?addr_str, ?threads),
"listening on http://{}",
addr_str
);
tracing::info!("press ctrl+c to shut down the server");
server.await?;
#[cfg(unix)]
handle.close();
}
tracing::warn!("termination signal caught, shutting down the server execution");
Ok(())
}
}