use hyper::server::conn::AddrIncoming;
use hyper::server::Server as HyperServer;
use listenfd::ListenFd;
use std::net::{IpAddr, SocketAddr, TcpListener};
use structopt::StructOpt;
use crate::handler::{RequestHandler, RequestHandlerOpts};
use crate::tls::{TlsAcceptor, TlsConfigBuilder};
use crate::Result;
use crate::{config::Config, service::RouterService};
use crate::{cors, error_page, helpers, logger};
pub struct Server {
opts: Config,
threads: usize,
}
impl Server {
pub fn new() -> Server {
let opts = Config::from_args();
let cpus = num_cpus::get();
let threads = match opts.threads_multiplier {
0 | 1 => cpus,
n => cpus * n,
};
Server { opts, threads }
}
pub fn run(self) -> Result {
tokio::runtime::Builder::new_multi_thread()
.worker_threads(self.threads)
.thread_name("static-web-server")
.enable_all()
.build()?
.block_on(async {
let r = self.start_server().await;
if r.is_err() {
panic!("Server error during start up: {:?}", r.unwrap_err())
}
});
Ok(())
}
async fn start_server(self) -> Result {
let opts = &self.opts;
logger::init(&opts.log_level)?;
tracing::info!("runtime worker threads: {}", self.threads);
let (tcplistener, addr_string);
match opts.fd {
Some(fd) => {
addr_string = format!("@FD({})", fd);
tcplistener = ListenFd::from_env()
.take_tcp_listener(fd)?
.expect("Failed to convert inherited FD into a a TCP listener");
tracing::info!(
"Converted inherited file descriptor {} to a TCP listener",
fd
);
}
None => {
let ip = opts.host.parse::<IpAddr>()?;
let addr = SocketAddr::from((ip, opts.port));
tcplistener = TcpListener::bind(addr)?;
addr_string = format!("{:?}", addr);
tracing::info!("Bound to TCP socket {}", addr_string);
}
}
let root_dir = helpers::get_valid_dirpath(&opts.root)?;
error_page::PAGE_404
.set(helpers::read_file_content(opts.page404.as_ref()))
.expect("page 404 is not initialized");
error_page::PAGE_50X
.set(helpers::read_file_content(opts.page50x.as_ref()))
.expect("page 50x is not initialized");
let security_headers = opts.security_headers;
tracing::info!("security headers: enabled={}", security_headers);
let compression = opts.compression;
tracing::info!("auto compression: enabled={}", compression);
let dir_listing = opts.directory_listing;
tracing::info!("directory listing: enabled={}", dir_listing);
let threads = self.threads;
let cors = cors::new(opts.cors_allow_origins.trim().to_string());
let router_service = RouterService::new(RequestHandler {
opts: RequestHandlerOpts {
root_dir,
compression,
dir_listing,
cors,
security_headers,
},
});
if opts.http2 {
let cert_path = opts.http2_tls_cert.clone();
let key_path = opts.http2_tls_key.clone();
tokio::task::spawn(async move {
tcplistener
.set_nonblocking(true)
.expect("Cannot set non-blocking");
let listener = tokio::net::TcpListener::from_std(tcplistener)
.expect("Failed to create tokio::net::TcpListener");
let mut incoming = AddrIncoming::from_listener(listener)?;
incoming.set_nodelay(true);
let tls = TlsConfigBuilder::new()
.cert_path(cert_path)
.key_path(key_path)
.build()
.expect(
"error during TLS server initialization, probably cert or key file missing",
);
let server =
HyperServer::builder(TlsAcceptor::new(tls, incoming)).serve(router_service);
tracing::info!(
parent: tracing::info_span!("Server::start_server", ?addr_string, ?threads),
"listening on https://{}",
addr_string
);
server.await
});
} else {
tokio::task::spawn(async move {
let server = HyperServer::from_tcp(tcplistener)
.unwrap()
.tcp_nodelay(true)
.serve(router_service);
tracing::info!(
parent: tracing::info_span!("Server::start_server", ?addr_string, ?threads),
"listening on http://{}",
addr_string
);
server.await
});
}
handle_signals();
Ok(())
}
}
impl Default for Server {
fn default() -> Self {
Self::new()
}
}
#[cfg(not(windows))]
fn handle_signals() {
use crate::signals;
signals::wait(|sig: signals::Signal| {
let code = signals::as_int(sig);
tracing::warn!("Signal {} caught. Server execution exited.", code);
std::process::exit(code)
});
}
#[cfg(windows)]
fn handle_signals() {
}