Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,795 changes: 1,748 additions & 47 deletions Cargo.lock

Large diffs are not rendered by default.

10 changes: 8 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@ homepage = "https://github.com/DefGuard/proxy"
repository = "https://github.com/DefGuard/proxy"

[dependencies]
defguard_certs = { git = "https://github.com/DefGuard/defguard.git", rev = "01957186101fc105803d56f1190efbdb5102df2f" }
defguard_version = { git = "https://github.com/DefGuard/defguard.git", rev = "01957186101fc105803d56f1190efbdb5102df2f" }
defguard_certs = { git = "https://github.com/DefGuard/defguard.git", rev = "564dc72c" }
defguard_grpc_tls = { git = "https://github.com/DefGuard/defguard.git", rev = "710b1bfd" }
defguard_version = { git = "https://github.com/DefGuard/defguard.git", rev = "7d28f46e" }
rustls-webpki = { version = "0.103", features = ["aws-lc-rs", "std"] }
rustls-pki-types = "1"
# base `axum` deps
axum = { version = "0.8", features = ["ws"] }
axum-client-ip = "0.7"
Expand Down Expand Up @@ -61,6 +64,9 @@ rustls = { version = "0.23", default-features = false, features = [
instant-acme = { version = "0.8", features = ["hyper-rustls", "aws-lc-rs"] }
reqwest = { version = "0.12", default-features = false, features = ["rustls-tls-native-roots", "json"] }

[dev-dependencies]
tokio = { version = "1", features = ["sync", "time"] }

[build-dependencies]
tonic-prost-build = "0.14"
vergen-git2 = { version = "9.1", features = ["build"] }
Expand Down
2 changes: 1 addition & 1 deletion proto
97 changes: 56 additions & 41 deletions src/grpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,26 @@ use std::{
};

use axum_extra::extract::cookie::Key;
use defguard_certs::{CertificateError, CertificateInfo};
use defguard_grpc_tls::{certs::server_tls_config, server::certificate_serial_interceptor};
use defguard_version::{
ComponentInfo, DefguardComponent, Version, get_tracing_variables,
server::{DefguardVersionLayer, grpc::DefguardVersionInterceptor},
};
use tokio::sync::{broadcast, mpsc, oneshot};
use tokio_stream::wrappers::UnboundedReceiverStream;
use tonic::{
Request, Response, Status, Streaming,
transport::{Identity, Server, ServerTlsConfig},
use tokio::{
fs::remove_file,
sync::{broadcast, mpsc, oneshot},
};
use tokio_stream::wrappers::UnboundedReceiverStream;
use tonic::{Request, Response, Status, Streaming, service::InterceptorLayer, transport::Server};
use tower::ServiceBuilder;
use tracing::Instrument;

use crate::{
LogsReceiver, MIN_CORE_VERSION, VERSION, acme,
acme::Port80Permit,
error::ApiError,
http::{GRPC_CERT_NAME, GRPC_KEY_NAME},
http::{CORE_CLIENT_CERT_NAME, GRPC_CA_CERT_NAME, GRPC_CERT_NAME, GRPC_KEY_NAME},
proto::{
AcmeCertificate, AcmeChallenge, AcmeIssueEvent, AcmeLogs, AcmeProgress, AcmeStep,
CoreRequest, CoreResponse, DeviceInfo, acme_issue_event, core_request, core_response,
Expand All @@ -40,9 +42,13 @@ use crate::{
type ClientMap = HashMap<SocketAddr, mpsc::UnboundedSender<Result<CoreRequest, Status>>>;

#[derive(Debug, Clone, Default)]
pub struct Configuration {
pub struct TlsConfig {
pub grpc_key_pem: String,
pub grpc_cert_pem: String,
/// PEM-encoded CA certificate used to verify Core's mTLS client certificate chain.
pub grpc_ca_cert_pem: String,
/// DER-encoded Core client certificate; used to extract and pin the expected serial.
pub core_client_cert_der: Vec<u8>,
}

pub(crate) struct ProxyServer {
Expand All @@ -51,7 +57,7 @@ pub(crate) struct ProxyServer {
results: Arc<RwLock<HashMap<u64, oneshot::Sender<core_response::Payload>>>>,
pub(crate) connected: Arc<AtomicBool>,
pub(crate) core_version: Arc<Mutex<Option<Version>>>,
config: Arc<Mutex<Option<Configuration>>>,
tls_config: Arc<Mutex<Option<TlsConfig>>>,
cookie_key: Arc<RwLock<Option<Key>>>,
cert_dir: PathBuf,
reset_tx: broadcast::Sender<()>,
Expand Down Expand Up @@ -87,7 +93,7 @@ impl ProxyServer {
results: Arc::new(RwLock::new(HashMap::new())),
connected: Arc::new(AtomicBool::new(false)),
core_version: Arc::new(Mutex::new(None)),
config: Arc::new(Mutex::new(None)),
tls_config: Arc::new(Mutex::new(None)),
cert_dir,
reset_tx,
https_cert_tx,
Expand All @@ -98,17 +104,17 @@ impl ProxyServer {
}
}

pub(crate) fn configure(&self, config: Configuration) {
pub(crate) fn configure(&self, config: TlsConfig) {
let mut lock = self
.config
.tls_config
.lock()
.expect("Failed to acquire lock on config mutex when applying proxy configuration");
*lock = Some(config);
}

pub(crate) fn get_configuration(&self) -> Option<Configuration> {
pub(crate) fn get_tls_config(&self) -> Option<TlsConfig> {
let lock = self
.config
.tls_config
.lock()
.expect("Failed to acquire lock on config mutex when retrieving proxy configuration");
lock.clone()
Expand All @@ -119,19 +125,27 @@ impl ProxyServer {
F: Future<Output = ()> + Send + 'static,
{
info!("Starting gRPC server on {addr}");
let config = self.get_configuration();
let (grpc_cert, grpc_key) = if let Some(cfg) = config {
(cfg.grpc_cert_pem, cfg.grpc_key_pem)
} else {
return Err(anyhow::anyhow!("gRPC server configuration is missing"));
};

let identity = Identity::from_pem(grpc_cert, grpc_key);
let mut builder =
Server::builder().tls_config(ServerTlsConfig::new().identity(identity))?;
let tls_config = self
.get_tls_config()
.ok_or_else(|| anyhow::anyhow!("gRPC server TLS configuration is missing"))?;

// Extract Core client cert serial for pinning (None in no-TLS mode).
let expected_serial = CertificateInfo::from_der(&tls_config.core_client_cert_der)
.map_err(|e: CertificateError| anyhow::anyhow!("invalid core client cert DER: {e}"))?
.serial;

let tls_config = server_tls_config(
&tls_config.grpc_cert_pem,
&tls_config.grpc_key_pem,
&tls_config.grpc_ca_cert_pem,
)?;
let mut builder = Server::builder().tls_config(tls_config)?;

let own_version = Version::parse(VERSION)?;
let versioned_service = ServiceBuilder::new()
.layer(InterceptorLayer::new(certificate_serial_interceptor(
expected_serial,
)))
.layer(tonic::service::InterceptorLayer::new(
DefguardVersionInterceptor::new(
own_version.clone(),
Expand Down Expand Up @@ -197,7 +211,7 @@ impl ProxyServer {

pub(crate) fn setup_completed(&self) -> bool {
let lock = self
.config
.tls_config
.lock()
.expect("Failed to acquire lock on config mutex when checking setup status");
lock.is_some()
Expand All @@ -213,7 +227,7 @@ impl Clone for ProxyServer {
connected: Arc::clone(&self.connected),
core_version: Arc::clone(&self.core_version),
cookie_key: Arc::clone(&self.cookie_key),
config: Arc::clone(&self.config),
tls_config: Arc::clone(&self.tls_config),
cert_dir: self.cert_dir.clone(),
reset_tx: self.reset_tx.clone(),
https_cert_tx: self.https_cert_tx.clone(),
Expand Down Expand Up @@ -343,26 +357,27 @@ impl proxy_server::Proxy for ProxyServer {
debug!("Received purge request, removing gRPC certificate files");
let cert_path = self.cert_dir.join(GRPC_CERT_NAME);
let key_path = self.cert_dir.join(GRPC_KEY_NAME);
let ca_cert_path = self.cert_dir.join(GRPC_CA_CERT_NAME);
let core_client_cert_path = self.cert_dir.join(CORE_CLIENT_CERT_NAME);

if let Err(err) = tokio::fs::remove_file(&cert_path).await
&& err.kind() != std::io::ErrorKind::NotFound
{
error!(
"Failed to remove gRPC certificate at {:?}: {err}",
cert_path
);
return Err(Status::internal("Failed to remove gRPC certificate"));
}
let remove_cert_file = async |path: &std::path::Path, label: &str| -> Result<(), Status> {
if let Err(err) = remove_file(path).await
&& err.kind() != std::io::ErrorKind::NotFound
{
error!("Failed to remove {label} at {}: {err}", path.display());
return Err(Status::internal(format!("Failed to remove {label}")));
}
info!("Removed {label} at {}", path.display());
Ok(())
};

if let Err(err) = tokio::fs::remove_file(&key_path).await
&& err.kind() != std::io::ErrorKind::NotFound
{
error!("Failed to remove gRPC key at {:?}: {err}", key_path);
return Err(Status::internal("Failed to remove gRPC key"));
}
remove_cert_file(&cert_path, "gRPC certificate").await?;
remove_cert_file(&key_path, "gRPC key").await?;
remove_cert_file(&ca_cert_path, "CA certificate").await?;
remove_cert_file(&core_client_cert_path, "Core client certificate").await?;

*self
.config
.tls_config
.lock()
.expect("Failed to lock config mutex during purge") = None;
*self
Expand Down
53 changes: 36 additions & 17 deletions src/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ use crate::{
config::EnvConfig,
enterprise::handlers::openid_login,
error::ApiError,
grpc::{Configuration, ProxyServer},
grpc::{ProxyServer, TlsConfig},
handlers::{desktop_client_mfa, enrollment, password_reset, polling},
setup::ProxySetupServer,
};
Expand All @@ -55,6 +55,8 @@ const X_FORWARDED_FOR: &str = "x-forwarded-for";
const X_POWERED_BY: &str = "x-powered-by";
pub const GRPC_CERT_NAME: &str = "proxy_grpc_cert.pem";
pub const GRPC_KEY_NAME: &str = "proxy_grpc_key.pem";
pub const GRPC_CA_CERT_NAME: &str = "grpc_ca_cert.pem";
pub const CORE_CLIENT_CERT_NAME: &str = "core_client_cert.pem";

#[derive(Clone)]
pub(crate) struct AppState {
Expand Down Expand Up @@ -178,10 +180,7 @@ async fn powered_by_header<B>(mut response: Response<B>) -> Response<B> {
response
}

pub async fn run_setup(
env_config: &EnvConfig,
logs_rx: LogsReceiver,
) -> anyhow::Result<Configuration> {
pub async fn run_setup(env_config: &EnvConfig, logs_rx: LogsReceiver) -> anyhow::Result<TlsConfig> {
let setup_server = ProxySetupServer::new(logs_rx);
let cert_dir = Path::new(&env_config.cert_dir);
if !cert_dir.exists() {
Expand All @@ -204,7 +203,7 @@ pub async fn run_setup(
"No gRPC TLS certificates found at {}, new certificates will be obtained during setup",
cert_dir.display()
);
let configuration = setup_server
let tls_config = setup_server
.await_initial_setup(SocketAddr::new(
env_config
.grpc_bind_address
Expand All @@ -214,11 +213,12 @@ pub async fn run_setup(
.await?;
info!("Generated new gRPC TLS certificates and signed by Defguard Core");

let Configuration {
let TlsConfig {
grpc_cert_pem,
grpc_key_pem,
..
} = &configuration;
grpc_ca_cert_pem,
core_client_cert_der,
} = &tls_config;

let cert_path = cert_dir.join(GRPC_CERT_NAME);
let key_path = cert_dir.join(GRPC_KEY_NAME);
Expand Down Expand Up @@ -247,6 +247,7 @@ pub async fn run_setup(
})?;
// Write key to a file.
options
.clone()
.open(&key_path)
.await?
.write_all(grpc_key_pem.as_bytes())
Expand All @@ -262,8 +263,26 @@ pub async fn run_setup(
err.into()
}
})?;
// Write CA certificate to a file.
options
.clone()
.open(cert_dir.join(GRPC_CA_CERT_NAME))
.await?
.write_all(grpc_ca_cert_pem.as_bytes())
.await?;
// Write Core client certificate (PEM-encoded) to a file for serial pinning on restart.
let core_client_cert_pem =
defguard_certs::der_to_pem(core_client_cert_der, defguard_certs::PemLabel::Certificate)
.map_err(|err| {
anyhow::anyhow!("Failed to PEM-encode Core client certificate: {err}")
})?;
options
.open(cert_dir.join(CORE_CLIENT_CERT_NAME))
.await?
.write_all(core_client_cert_pem.as_bytes())
.await?;

Ok(configuration)
Ok(tls_config)
}

/// Middleware that gates all HTTP endpoints except health checks until the proxy
Expand Down Expand Up @@ -306,7 +325,7 @@ async fn build_tls_config(cert_pem: &str, key_pem: &str) -> anyhow::Result<Rustl

pub async fn run_server(
env_config: EnvConfig,
config: Option<Configuration>,
tls_config: Option<TlsConfig>,
logs_rx: Option<LogsReceiver>,
) -> anyhow::Result<()> {
info!("Starting Defguard Proxy server");
Expand Down Expand Up @@ -346,8 +365,8 @@ pub async fn run_server(

// Preload existing TLS configuration so /api/v1/info can report "disconnected"
// immediately on startup
if let Some(existing_configuration) = config.clone() {
grpc_server.configure(existing_configuration);
if let Some(existing_tls_config) = tls_config.clone() {
grpc_server.configure(existing_tls_config);
}

let server_clone = grpc_server.clone();
Expand All @@ -356,17 +375,17 @@ pub async fn run_server(
// Start gRPC server.
debug!("Spawning gRPC server task");
tasks.spawn(async move {
let mut proxy_configuration = config;
let mut proxy_tls_config = tls_config;

loop {
let configuration = if let Some(conf) = proxy_configuration.clone() {
let configuration = if let Some(conf) = proxy_tls_config.clone() {
debug!("Using existing gRPC certificates, skipping setup process");
conf
} else {
info!("gRPC certificates not found, running setup process");
let conf = run_setup(&env_config_clone, Arc::clone(&logs_rx)).await?;
info!("Setup process completed successfully");
proxy_configuration = Some(conf.clone());
proxy_tls_config = Some(conf.clone());
conf
};

Expand Down Expand Up @@ -399,7 +418,7 @@ pub async fn run_server(
result = reset_rx.recv() => {
if result.is_ok() {
info!("Reset requested, restarting setup process");
proxy_configuration = None;
proxy_tls_config = None;
} else {
error!("Reset channel closed; gRPC server will keep running");
}
Expand Down
8 changes: 3 additions & 5 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ pub mod http;
pub mod logging;
mod setup;

#[cfg(test)]
mod tests;

pub(crate) mod generated {
pub(crate) mod defguard {
pub(crate) mod proxy {
Expand Down Expand Up @@ -48,9 +51,4 @@ extern crate tracing;
pub static VERSION: &str = concat!(env!("CARGO_PKG_VERSION"), "+", env!("VERGEN_GIT_SHA"));
pub const MIN_CORE_VERSION: Version = Version::new(2, 0, 0);

type CommsChannel<T> = (
Arc<tokio::sync::Mutex<mpsc::Sender<T>>>,
Arc<tokio::sync::Mutex<mpsc::Receiver<T>>>,
);

type LogsReceiver = Arc<tokio::sync::Mutex<mpsc::Receiver<LogEntry>>>;
Loading