aboutsummaryrefslogtreecommitdiffstats
path: root/atuin-server/src/lib.rs
diff options
context:
space:
mode:
authorEric Hodel <drbrain@segment7.net>2023-12-27 06:15:48 -0800
committerGitHub <noreply@github.com>2023-12-27 14:15:48 +0000
commitd52e57612942cbe0c6a0dd774fcc2caac8f439d5 (patch)
tree6abc226ffa71156b0ac747529e7effaa21c75c15 /atuin-server/src/lib.rs
parentfeat: add semver checking to client requests (#1456) (diff)
downloadatuin-d52e57612942cbe0c6a0dd774fcc2caac8f439d5.zip
feat: Add TLS to atuin-server (#1457)
* Add TLS to atuin-server atuin as a project already includes most of the dependencies necessary for server-side TLS. This allows `atuin server start` to use a TLS certificate when self-hosting in order to avoid the complication of wrapping it in a TLS-aware proxy server. Configuration is handled similar to the metrics server with its own struct and currently accepts only the private key and certificate file paths. Starting a TLS server and a TCP server are divergent because the tests need to bind to an arbitrary port to avoid collisions across tests. The API to accomplish this for a TLS server is much more verbose. * Fix clippy, fmt * Add TLS section to self-hosting
Diffstat (limited to 'atuin-server/src/lib.rs')
-rw-r--r--atuin-server/src/lib.rs75
1 files changed, 61 insertions, 14 deletions
diff --git a/atuin-server/src/lib.rs b/atuin-server/src/lib.rs
index 2d2a9c78..b505a8ec 100644
--- a/atuin-server/src/lib.rs
+++ b/atuin-server/src/lib.rs
@@ -1,10 +1,13 @@
#![forbid(unsafe_code)]
+use std::net::SocketAddr;
+use std::sync::Arc;
use std::{future::Future, net::TcpListener};
use atuin_server_database::Database;
use axum::Router;
use axum::Server;
+use axum_server::Handle;
use eyre::{Context, Result};
mod handlers;
@@ -12,6 +15,7 @@ mod metrics;
mod router;
mod utils;
+use rustls::ServerConfig;
pub use settings::example_config;
pub use settings::Settings;
@@ -44,27 +48,26 @@ async fn shutdown_signal() {
pub async fn launch<Db: Database>(
settings: Settings<Db::Settings>,
- host: &str,
- port: u16,
+ addr: SocketAddr,
) -> Result<()> {
- launch_with_listener::<Db>(
- settings,
- TcpListener::bind((host, port)).context("could not connect to socket")?,
- shutdown_signal(),
- )
- .await
+ if settings.tls.enable {
+ launch_with_tls::<Db>(settings, addr, shutdown_signal()).await
+ } else {
+ launch_with_tcp_listener::<Db>(
+ settings,
+ TcpListener::bind(addr).context("could not connect to socket")?,
+ shutdown_signal(),
+ )
+ .await
+ }
}
-pub async fn launch_with_listener<Db: Database>(
+pub async fn launch_with_tcp_listener<Db: Database>(
settings: Settings<Db::Settings>,
listener: TcpListener,
shutdown: impl Future<Output = ()>,
) -> Result<()> {
- let db = Db::new(&settings.db_settings)
- .await
- .wrap_err_with(|| format!("failed to connect to db: {:?}", settings.db_settings))?;
-
- let r = router::router(db, settings);
+ let r = make_router::<Db>(settings).await?;
Server::from_tcp(listener)
.context("could not launch server")?
@@ -75,6 +78,40 @@ pub async fn launch_with_listener<Db: Database>(
Ok(())
}
+async fn launch_with_tls<Db: Database>(
+ settings: Settings<Db::Settings>,
+ addr: SocketAddr,
+ shutdown: impl Future<Output = ()>,
+) -> Result<()> {
+ let certificates = settings.tls.certificates()?;
+ let pkey = settings.tls.private_key()?;
+
+ let server_config = ServerConfig::builder()
+ .with_safe_defaults()
+ .with_no_client_auth()
+ .with_single_cert(certificates, pkey)?;
+
+ let server_config = Arc::new(server_config);
+ let rustls_config = axum_server::tls_rustls::RustlsConfig::from_config(server_config);
+
+ let r = make_router::<Db>(settings).await?;
+
+ let handle = Handle::new();
+
+ let server = axum_server::bind_rustls(addr, rustls_config)
+ .handle(handle.clone())
+ .serve(r.into_make_service());
+
+ tokio::select! {
+ _ = server => {}
+ _ = shutdown => {
+ handle.graceful_shutdown(None);
+ }
+ }
+
+ Ok(())
+}
+
// The separate listener means it's much easier to ensure metrics are not accidentally exposed to
// the public.
pub async fn launch_metrics_server(host: String, port: u16) -> Result<()> {
@@ -95,3 +132,13 @@ pub async fn launch_metrics_server(host: String, port: u16) -> Result<()> {
Ok(())
}
+
+async fn make_router<Db: Database>(
+ settings: Settings<<Db as Database>::Settings>,
+) -> Result<Router, eyre::Error> {
+ let db = Db::new(&settings.db_settings)
+ .await
+ .wrap_err_with(|| format!("failed to connect to db: {:?}", settings.db_settings))?;
+ let r = router::router(db, settings);
+ Ok(r)
+}