diff options
| author | Eric Hodel <drbrain@segment7.net> | 2023-12-27 06:15:48 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-12-27 14:15:48 +0000 |
| commit | d52e57612942cbe0c6a0dd774fcc2caac8f439d5 (patch) | |
| tree | 6abc226ffa71156b0ac747529e7effaa21c75c15 /atuin-server/src | |
| parent | feat: add semver checking to client requests (#1456) (diff) | |
| download | atuin-d52e57612942cbe0c6a0dd774fcc2caac8f439d5.zip | |
feat: Add TLS to atuin-server (#1457)
* Add TLS to atuin-server
atuin as a project already includes most of the dependencies necessary
for server-side TLS. This allows `atuin server start` to use a TLS
certificate when self-hosting in order to avoid the complication of
wrapping it in a TLS-aware proxy server.
Configuration is handled similar to the metrics server with its own
struct and currently accepts only the private key and certificate file
paths.
Starting a TLS server and a TCP server are divergent because the tests
need to bind to an arbitrary port to avoid collisions across tests. The
API to accomplish this for a TLS server is much more verbose.
* Fix clippy, fmt
* Add TLS section to self-hosting
Diffstat (limited to '')
| -rw-r--r-- | atuin-server/src/lib.rs | 75 | ||||
| -rw-r--r-- | atuin-server/src/settings.rs | 54 |
2 files changed, 114 insertions, 15 deletions
diff --git a/atuin-server/src/lib.rs b/atuin-server/src/lib.rs index 2d2a9c78..b505a8ec 100644 --- a/atuin-server/src/lib.rs +++ b/atuin-server/src/lib.rs @@ -1,10 +1,13 @@ #![forbid(unsafe_code)] +use std::net::SocketAddr; +use std::sync::Arc; use std::{future::Future, net::TcpListener}; use atuin_server_database::Database; use axum::Router; use axum::Server; +use axum_server::Handle; use eyre::{Context, Result}; mod handlers; @@ -12,6 +15,7 @@ mod metrics; mod router; mod utils; +use rustls::ServerConfig; pub use settings::example_config; pub use settings::Settings; @@ -44,27 +48,26 @@ async fn shutdown_signal() { pub async fn launch<Db: Database>( settings: Settings<Db::Settings>, - host: &str, - port: u16, + addr: SocketAddr, ) -> Result<()> { - launch_with_listener::<Db>( - settings, - TcpListener::bind((host, port)).context("could not connect to socket")?, - shutdown_signal(), - ) - .await + if settings.tls.enable { + launch_with_tls::<Db>(settings, addr, shutdown_signal()).await + } else { + launch_with_tcp_listener::<Db>( + settings, + TcpListener::bind(addr).context("could not connect to socket")?, + shutdown_signal(), + ) + .await + } } -pub async fn launch_with_listener<Db: Database>( +pub async fn launch_with_tcp_listener<Db: Database>( settings: Settings<Db::Settings>, listener: TcpListener, shutdown: impl Future<Output = ()>, ) -> Result<()> { - let db = Db::new(&settings.db_settings) - .await - .wrap_err_with(|| format!("failed to connect to db: {:?}", settings.db_settings))?; - - let r = router::router(db, settings); + let r = make_router::<Db>(settings).await?; Server::from_tcp(listener) .context("could not launch server")? @@ -75,6 +78,40 @@ pub async fn launch_with_listener<Db: Database>( Ok(()) } +async fn launch_with_tls<Db: Database>( + settings: Settings<Db::Settings>, + addr: SocketAddr, + shutdown: impl Future<Output = ()>, +) -> Result<()> { + let certificates = settings.tls.certificates()?; + let pkey = settings.tls.private_key()?; + + let server_config = ServerConfig::builder() + .with_safe_defaults() + .with_no_client_auth() + .with_single_cert(certificates, pkey)?; + + let server_config = Arc::new(server_config); + let rustls_config = axum_server::tls_rustls::RustlsConfig::from_config(server_config); + + let r = make_router::<Db>(settings).await?; + + let handle = Handle::new(); + + let server = axum_server::bind_rustls(addr, rustls_config) + .handle(handle.clone()) + .serve(r.into_make_service()); + + tokio::select! { + _ = server => {} + _ = shutdown => { + handle.graceful_shutdown(None); + } + } + + Ok(()) +} + // The separate listener means it's much easier to ensure metrics are not accidentally exposed to // the public. pub async fn launch_metrics_server(host: String, port: u16) -> Result<()> { @@ -95,3 +132,13 @@ pub async fn launch_metrics_server(host: String, port: u16) -> Result<()> { Ok(()) } + +async fn make_router<Db: Database>( + settings: Settings<<Db as Database>::Settings>, +) -> Result<Router, eyre::Error> { + let db = Db::new(&settings.db_settings) + .await + .wrap_err_with(|| format!("failed to connect to db: {:?}", settings.db_settings))?; + let r = router::router(db, settings); + Ok(r) +} diff --git a/atuin-server/src/settings.rs b/atuin-server/src/settings.rs index d6f1867c..70008fbc 100644 --- a/atuin-server/src/settings.rs +++ b/atuin-server/src/settings.rs @@ -1,7 +1,7 @@ use std::{io::prelude::*, path::PathBuf}; use config::{Config, Environment, File as ConfigFile, FileFormat}; -use eyre::{eyre, Result}; +use eyre::{bail, eyre, Context, Result}; use fs_err::{create_dir_all, File}; use serde::{de::DeserializeOwned, Deserialize, Serialize}; @@ -36,6 +36,7 @@ pub struct Settings<DbSettings> { pub register_webhook_url: Option<String>, pub register_webhook_username: String, pub metrics: Metrics, + pub tls: Tls, #[serde(flatten)] pub db_settings: DbSettings, @@ -67,6 +68,9 @@ impl<DbSettings: DeserializeOwned> Settings<DbSettings> { .set_default("metrics.enable", false)? .set_default("metrics.host", "127.0.0.1")? .set_default("metrics.port", 9001)? + .set_default("tls.enable", false)? + .set_default("tls.cert_path", "")? + .set_default("tls.key_path", "")? .add_source( Environment::with_prefix("atuin") .prefix_separator("_") @@ -97,3 +101,51 @@ impl<DbSettings: DeserializeOwned> Settings<DbSettings> { pub fn example_config() -> &'static str { EXAMPLE_CONFIG } + +#[derive(Clone, Debug, Default, Deserialize, Serialize)] +pub struct Tls { + pub enable: bool, + pub cert_path: PathBuf, + pub pkey_path: PathBuf, +} + +impl Tls { + pub fn certificates(&self) -> Result<Vec<rustls::Certificate>> { + let cert_file = std::fs::File::open(&self.cert_path) + .with_context(|| format!("tls.cert_path {:?} is missing", self.cert_path))?; + let mut reader = std::io::BufReader::new(cert_file); + let certs: Vec<_> = rustls_pemfile::certs(&mut reader) + .with_context(|| format!("tls.cert_path {:?} is invalid", self.cert_path))? + .into_iter() + .map(rustls::Certificate) + .collect(); + + if certs.is_empty() { + bail!( + "tls.cert_path {:?} must have at least one certificate", + self.cert_path + ); + } + + Ok(certs) + } + + pub fn private_key(&self) -> Result<rustls::PrivateKey> { + let pkey_file = std::fs::File::open(&self.pkey_path) + .with_context(|| format!("tls.pkey_path {:?} is missing", self.pkey_path))?; + let mut reader = std::io::BufReader::new(pkey_file); + let keys = rustls_pemfile::pkcs8_private_keys(&mut reader) + .with_context(|| format!("tls.pkey_path {:?} is not PKCS8-encoded", self.pkey_path))?; + + if keys.is_empty() { + bail!( + "tls.pkey_path {:?} must have at least one private key", + self.pkey_path + ); + } + + let key = rustls::PrivateKey(keys[0].clone()); + + Ok(key) + } +} |
