diff --git a/Cargo.lock b/Cargo.lock index 9e9ea2c3..96ff221f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2201,10 +2201,12 @@ dependencies = [ "arc-swap", "arrayvec", "async-stream", + "async-trait", "atomic-time", "bounded-integer", "bytes", "compact_str", + "dashmap", "enum_dispatch", "exponential-backoff", "futures", diff --git a/docs/proposal/connection-draining-multiple-listener-versions.md b/docs/proposal/connection-draining-multiple-listener-versions.md new file mode 100644 index 00000000..ff71ca47 --- /dev/null +++ b/docs/proposal/connection-draining-multiple-listener-versions.md @@ -0,0 +1,511 @@ +--- +title: Connection Draining for Multiple Listener Versions +authors: +- "@Eeshu-Yadav" +reviewers: +- "@YaoZengzeng" +- "@dawid-nowak" +- "@hzxuzhonghu" +approvers: +- "@YaoZengzeng" +- "@dawid-nowak" +- "@hzxuzhonghu" +creation-date: 2025-10-09 +--- + +## Connection Draining for Multiple Listener Versions + +### Summary + +This proposal implements Envoy-compatible connection draining for multiple listener versions in Orion. When listener configurations are updated via LDS (Listener Discovery Service), existing connections continue on old listener versions while new connections seamlessly transition to updated listeners. This ensures zero-downtime updates and follows Envoy's graceful draining behavior with protocol-specific timeout handling. + +### Motivation + +Currently, when updating a listener configuration in Orion: + +1. **Connection Interruption**: Old connections get abruptly terminated +2. **No Graceful Period**: No graceful shutdown mechanism for existing connections +3. **Non-Envoy Compliant**: Doesn't follow Envoy's standard draining behavior +4. **Protocol Agnostic**: No protocol-specific handling (HTTP/1.1, HTTP/2, TCP) + +This causes service disruptions during configuration updates, making Orion unsuitable for production environments requiring high availability. + +#### Goals + +- Implement Envoy-compatible connection draining with protocol-specific behavior +- Support graceful listener updates via LDS without connection interruption +- Follow Envoy's timeout mechanisms and drain sequence +- Maintain full backward compatibility with existing configurations +- Enable production-ready zero-downtime deployments + +#### Non-Goals + +- Implementing custom draining protocols beyond Envoy compatibility +- Supporting non-standard timeout configurations outside Envoy's model +- Backward compatibility with non-Envoy proxy behaviors + +### Proposal + +The proposal implements a comprehensive connection draining system that manages multiple listener versions during LDS updates. When a listener configuration is updated, the system: + +1. **Starts New Version**: Creates the new listener version with updated configuration +2. **Begins Draining**: Marks old versions for graceful draining +3. **Protocol-Specific Handling**: Applies appropriate draining behavior per protocol +4. **Timeout Management**: Enforces Envoy-compatible timeout sequences +5. **Resource Cleanup**: Removes old versions after all connections drain + +The implementation follows Envoy's draining documentation and maintains compatibility with existing Envoy configurations. + +### Design Details + +#### Architecture Overview + +``` +┌─────────────────────────────────────────────────────────────────┐ +│ Orion Listener Manager │ +│ │ +│ ┌──────────────┐ LDS Update ┌──────────────┐ │ +│ │ Listener │◄─────────────────┤ New │ │ +│ │ Version 1 │ │ Listener │ │ +│ │ (Active) │ │ Version 2 │ │ +│ └──────┬───────┘ └──────┬───────┘ │ +│ │ │ │ +│ │ Start Draining │ Accept New │ +│ ▼ ▼ Connections │ +│ ┌─────────────────────────────────────────────────────┐ │ +│ │ Drain Signaling Manager │ │ +│ │ ┌─────────────────────────────────────────────┐ │ │ +│ │ │ Protocol-Specific Drain Handlers │ │ │ +│ │ │ ┌─────────────┬─────────────┬───────────┐ │ │ │ +│ │ │ │ HTTP/1.1 │ HTTP/2 │ TCP │ │ │ │ +│ │ │ │ │ │ │ │ │ │ +│ │ │ │ Connection: │ GOAWAY │ SO_LINGER │ │ │ │ +│ │ │ │ close │ Frame │ Timeout │ │ │ │ +│ │ │ └─────────────┴─────────────┴───────────┘ │ │ │ +│ │ └─────────────────────────────────────────────┘ │ │ +│ └─────────────────────────────────────────────────────┘ │ +│ │ │ +│ │ Timeout Management │ +│ ▼ │ +│ ┌─────────────────────────────────────────────────────┐ │ +│ │ Drain Timeout Handler │ │ +│ │ • HTTP: drain_timeout from HCM config │ │ +│ │ • TCP: Global server drain timeout │ │ +│ │ • Absolute: 600s maximum (--drain-time-s) │ │ +│ └─────────────────────────────────────────────────────┘ │ +└─────────────────────────────────────────────────────────────────┘ +``` + +#### Component Details + +##### 1. Drain Signaling Manager + +**Location**: `orion-lib/src/listeners/drain_signaling.rs` + +```rust +#[derive(Debug)] +pub struct DrainSignalingManager { + drain_contexts: Arc>>>, + global_drain_timeout: Duration, + default_http_drain_timeout: Duration, + listener_drain_state: Arc>>, +} + +#[derive(Debug)] +pub struct ListenerDrainContext { + pub listener_id: String, + pub strategy: DrainStrategy, + pub drain_start: Instant, + pub initial_connections: usize, + pub active_connections: Arc>, + pub completed: Arc>, +} + +#[derive(Debug, Clone)] +pub struct ListenerDrainState { + pub started_at: Instant, + pub strategy: super::listeners_manager::DrainStrategy, + pub protocol_behavior: super::listeners_manager::ProtocolDrainBehavior, + pub drain_scenario: DrainScenario, + pub drain_type: ConfigDrainType, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum DrainScenario { + HealthCheckFail, + ListenerUpdate, + HotRestart, +} + +#[derive(Debug, Clone)] +pub enum DrainStrategy { + Tcp { global_timeout: Duration }, + Http { global_timeout: Duration, drain_timeout: Duration }, + Mixed { global_timeout: Duration, http_drain_timeout: Duration, tcp_connections: bool, http_connections: bool }, + Immediate, +} +``` + +**Key Operations**: +- `start_listener_draining`: Initialize listener-wide draining +- `stop_listener_draining`: Terminate draining process +- `is_listener_draining`: Check draining status +- `create_drain_context`: Create per-listener drain tracking + +##### 2. Connection Manager with Drain Support + +**Location**: `orion-lib/src/listeners/listeners_manager.rs` + +```rust +pub trait ConnectionManager: Send + Sync { + fn on_connection_established(&self, listener_name: &str, conn_info: ConnectionInfo); + fn on_connection_closed(&self, listener_name: &str, connection_id: &str); + fn start_connection_draining( + &self, + listener_name: &str, + connection_id: &str, + protocol_behavior: &ProtocolDrainBehavior, + ); + fn get_active_connections(&self, listener_name: &str) -> Vec; + fn force_close_connection(&self, listener_name: &str, connection_id: &str); +} + +#[derive(Debug, Default)] +pub struct DefaultConnectionManager { + connections: Arc>, + listener_connection_counts: Arc>, + http_managers: Arc>>, +} + +#[derive(Debug, Clone)] +pub struct ConnectionInfo { + pub id: String, + pub protocol: ConnectionProtocol, + pub established_at: Instant, + pub last_activity: Instant, + pub state: ConnectionState, +} + +#[derive(Debug, Clone)] +pub enum ConnectionProtocol { + Http1, + Http2, + Tcp, + Unknown, +} + +#[derive(Debug, Clone, PartialEq)] +pub enum ConnectionState { + Active, + Draining, + Closing, + Closed, +} + +#[derive(Debug, Clone)] +pub enum ProtocolDrainBehavior { + Http1 { connection_close: bool }, + Http2 { send_goaway: bool }, + Tcp { force_close_after: Duration }, + Auto, +} +``` + +**Key Operations**: +- `on_connection_established`: Track new connections +- `start_connection_draining`: Begin per-connection draining +- `get_connections_by_state`: Filter connections by state +- `cleanup_stale_draining_connections`: Force close timeout connections +- `start_draining_http_managers`: Integrate with HTTP connection manager + +##### 3. Enhanced Listener Manager with Drain Support + +**Location**: `orion-lib/src/listeners/listeners_manager.rs` + +```rust +pub struct ListenersManager { + listener_configuration_channel: mpsc::Receiver, + route_configuration_channel: mpsc::Receiver, + listener_handles: MultiMap, + version_counter: u64, + config: ListenerManagerConfig, + connection_manager: Arc, +} + +#[derive(Debug)] +struct ListenerInfo { + handle: abort_on_drop::ChildTask<()>, + listener_conf: ListenerConfig, + version: u64, + state: ListenerState, + connections_count: Arc, + drain_manager_handle: Option>, +} + +#[derive(Debug, Clone)] +enum ListenerState { + Active, + Draining { started_at: Instant, drain_config: ListenerDrainConfig }, +} + +#[derive(Debug, Clone)] +pub struct ListenerDrainConfig { + pub drain_time: Duration, + pub drain_strategy: DrainStrategy, + pub protocol_handling: ProtocolDrainBehavior, +} + +#[derive(Debug, Clone)] +pub enum DrainStrategy { + Gradual, + Immediate, +} +``` + +**Key Operations**: +- `start_listener`: Handle LDS updates with graceful transition using MultiMap +- `start_draining`: Initiate protocol-aware draining for listener version +- `start_drain_monitor`: Monitor drain progress with configurable timeouts +- `drain_old_listeners`: Cleanup policy enforcement for version management +- `resolve_address_conflicts`: Handle address binding conflicts during updates + +##### 4. Configuration Integration + +**Listener Configuration** (`orion-configuration/src/config/listener.rs`): +```rust +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct Listener { + pub name: CompactString, + pub address: SocketAddr, + #[serde(skip_serializing_if = "Option::is_none", default)] + pub version_info: Option, + #[serde(with = "serde_filterchains")] + pub filter_chains: HashMap, + #[serde(default)] + pub drain_type: DrainType, + // ... other fields +} + +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)] +#[serde(rename_all = "SCREAMING_SNAKE_CASE")] +pub enum DrainType { + #[default] + Default, + ModifyOnly, +} +``` + +**HTTP Connection Manager Drain Configuration** (`orion-configuration/src/config/network_filters/http_connection_manager.rs`): +```rust +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)] +pub struct HttpConnectionManager { + pub codec_type: CodecType, + #[serde(with = "humantime_serde")] + #[serde(skip_serializing_if = "Option::is_none", default)] + pub request_timeout: Option, + #[serde(with = "humantime_serde")] + #[serde(skip_serializing_if = "Option::is_none", default)] + pub drain_timeout: Option, + // ... other fields +} +``` + +**Listener Manager Configuration**: +```rust +#[derive(Debug, Clone)] +pub struct ListenerManagerConfig { + pub max_versions_per_listener: usize, + pub cleanup_policy: CleanupPolicy, + pub cleanup_interval: Duration, + pub drain_config: ListenerDrainConfig, +} + +#[derive(Debug, Clone)] +pub enum CleanupPolicy { + CountBasedOnly(usize), +} +``` + +##### 5. Example Configuration + +**Bootstrap Configuration with Drain Settings**: +```yaml +static_resources: + listeners: + - name: "listener_0" + address: + socket_address: + address: "0.0.0.0" + port_value: 10000 + drain_type: DEFAULT # or MODIFY_ONLY + filter_chains: + - filters: + - name: "http_connection_manager" + typed_config: + "@type": "type.googleapis.com/envoy.extensions.filters.network.http_connection_manager.v3.HttpConnectionManager" + stat_prefix: "ingress_http" + drain_timeout: "5s" # HTTP drain timeout + route_config: + name: "local_route" + virtual_hosts: + - name: "backend" + domains: ["*"] + routes: + - match: { prefix: "/" } + route: { cluster: "backend_service" } + +# Listener manager configuration (applied at startup) +listener_manager: + max_versions_per_listener: 2 + cleanup_policy: + count_based_only: 2 + cleanup_interval: "60s" + drain_config: + drain_time: "600s" # Global drain timeout (equivalent to --drain-time-s) + drain_strategy: "gradual" # or "immediate" + protocol_handling: "auto" # auto-detect protocol behavior +``` + +#### Implementation Timeline + +The implementation is divided into manageable phases aligned with current development: + +**Phase 1: Core Drain Infrastructure** ✅ **COMPLETED in PR #104** +- ✅ MultiMap-based listener version management +- ✅ Basic drain signaling manager with protocol awareness +- ✅ Connection state tracking (Active, Draining, Closing, Closed) +- ✅ Integration with existing listener manager +- ✅ HTTP connection manager drain integration + +**Phase 2: Enhanced Protocol Handlers** 🚧 **IN PROGRESS** +- 🚧 HTTP/1.1 Connection: close header injection +- 🚧 HTTP/2 GOAWAY frame implementation +- 🚧 TCP SO_LINGER graceful shutdown +- 🚧 Protocol detection from filter chain analysis + +**Phase 3: Advanced Timeout Management** 📋 **PLANNED** +- ⏳ Configurable timeout policies with cascade handling +- ⏳ Integration with HTTP connection manager drain_timeout field +- ⏳ Global server drain timeout support (--drain-time-s equivalent) +- ⏳ Force close mechanisms with protocol-specific timeouts + +**Phase 4: Production Hardening** 📋 **PLANNED** +- ⏳ Comprehensive error handling and recovery +- ⏳ Metrics and observability integration +- ⏳ Performance optimization for high connection counts +- ⏳ Edge case handling and stress testing + +#### Envoy Compatibility Matrix + +| Feature | Envoy Behavior | Orion Implementation Status | +|---------|----------------|------------------------------| +| LDS Updates | Graceful transition to new listener | ✅ **IMPLEMENTED** - MultiMap-based version management | +| Multiple Listener Versions | Support multiple concurrent versions | ✅ **IMPLEMENTED** - MultiMap with version tracking | +| Drain Type Support | DEFAULT vs MODIFY_ONLY behavior | ✅ **IMPLEMENTED** - Full configuration support | +| Connection State Tracking | Track connection lifecycle | ✅ **IMPLEMENTED** - Active/Draining/Closing/Closed states | +| HTTP Manager Integration | HCM drain timeout field | ✅ **IMPLEMENTED** - HTTP connection manager integration | +| HTTP/1.1 Draining | Connection: close header | 🚧 **IN PROGRESS** - Header injection mechanism | +| HTTP/2 Draining | GOAWAY frame | 🚧 **IN PROGRESS** - RFC 7540 compliant implementation | +| TCP Draining | SO_LINGER timeout | 🚧 **IN PROGRESS** - Socket option configuration | +| Global Timeout | --drain-time-s argument | ⏳ **PLANNED** - Absolute maximum failsafe | +| Protocol Detection | Auto-detect from filter chains | ⏳ **PLANNED** - Filter chain analysis | + +#### Test Plan + +**Unit Tests**: +- Protocol-specific drain handler behavior +- Timeout enforcement and cascade handling +- Connection state tracking accuracy +- Configuration parsing and validation + +**Integration Tests**: +- LDS update scenarios with multiple listener versions +- End-to-end draining flow for each protocol +- Timeout behavior under various load conditions +- Error handling during drain failures + +**Performance Tests**: +- Connection draining under high connection counts +- Memory usage during extended drain periods +- CPU overhead of drain monitoring +- Latency impact on new connections during drain + +**Compatibility Tests**: +- Envoy configuration compatibility +- XDS protocol compliance +- Metrics format compatibility + +### Alternative Designs Considered + +#### 1. Single Version Replacement +**Approach**: Replace listeners immediately without draining +**Pros**: Simple implementation, no resource overhead +**Cons**: Connection interruption, not Envoy-compatible + +#### 2. Custom Drain Protocol +**Approach**: Implement proprietary draining mechanism +**Pros**: Potentially more efficient for specific use cases +**Cons**: Non-standard, compatibility issues, maintenance burden + +#### 3. External Drain Controller +**Approach**: Separate service managing drain operations +**Pros**: Decoupled architecture, independent scaling +**Cons**: Added complexity, network overhead, single point of failure + +### Security Considerations + +- **Resource Exhaustion**: Drain timeouts prevent indefinite resource consumption +- **DoS Protection**: Maximum connection limits during drain periods +- **Information Disclosure**: Drain status metrics don't expose sensitive data +- **Access Control**: Drain operations respect existing RBAC policies + +### Observability + +**Metrics**: +- `orion_listener_versions_active`: Number of active listener versions per listener name +- `orion_listener_drain_duration_seconds`: Time spent draining per listener version +- `orion_listener_drain_connections_active`: Active connections during drain per listener +- `orion_listener_drain_timeouts_total`: Count of drain timeout events +- `orion_connection_state_transitions_total`: Connection state change events +- `orion_drain_strategy_usage_total`: Usage count per drain strategy type + +**Logging**: +- Listener version creation and drain initiation events +- Connection state transitions (Active → Draining → Closing → Closed) +- Protocol-specific drain progress and timeout warnings +- Configuration validation errors and conflicts +- Force close events and cleanup operations + +**Tracing**: +- End-to-end LDS update and drain operation spans +- Per-connection lifecycle tracking and state transitions +- Protocol-specific drain handler execution +- Timeout enforcement and decision points + +--- + +## References + +1. [Envoy Draining Documentation](https://www.envoyproxy.io/docs/envoy/latest/intro/arch_overview/operations/draining) +2. [Envoy HTTP Connection Manager Proto](https://www.envoyproxy.io/docs/envoy/latest/api-v3/extensions/filters/network/http_connection_manager/v3/http_connection_manager.proto) +3. [Envoy Listener Proto](https://www.envoyproxy.io/docs/envoy/latest/api-v3/config/listener/v3/listener.proto) +4. [Envoy LDS Protocol](https://www.envoyproxy.io/docs/envoy/latest/configuration/listeners/lds) + +--- + +## Appendix + +### Glossary + +- **Connection Draining**: Process of gracefully closing existing connections while preventing new ones +- **LDS (Listener Discovery Service)**: XDS protocol for dynamic listener configuration updates +- **Drain Timeout**: Maximum time allowed for connections to close gracefully before force termination +- **GOAWAY Frame**: HTTP/2 control frame indicating no new streams should be created +- **SO_LINGER**: Socket option controlling close behavior and timeout +- **MultiMap**: Data structure allowing multiple values per key for listener version management + +### Acknowledgments + +This feature addresses GitHub issues [#98](https://github.com/kmesh-net/orion/issues/98) and [#102](https://github.com/kmesh-net/orion/issues/102), implementing connection draining for the multiple listener versions functionality that was previously merged in PR [#99](https://github.com/kmesh-net/orion/pull/99). The design incorporates valuable feedback from @hzxuzhonghu about Envoy compliance and protocol-specific handling, @YaoZengzeng and @dawid-nowak's architectural guidance, and follows Envoy's draining specification while adapting to Orion's Rust-based architecture. + +The current implementation in PR #104 builds upon the MultiMap-based listener version management from the merged multiple listener versions feature, ensuring seamless integration with existing LDS handling and configuration management systems. \ No newline at end of file diff --git a/orion-configuration/src/config/listener.rs b/orion-configuration/src/config/listener.rs index b7939596..92da56b0 100644 --- a/orion-configuration/src/config/listener.rs +++ b/orion-configuration/src/config/listener.rs @@ -38,10 +38,20 @@ use std::{ // Removed unused import +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)] +#[serde(rename_all = "SCREAMING_SNAKE_CASE")] +pub enum DrainType { + #[default] + Default, + ModifyOnly, +} + #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] pub struct Listener { pub name: CompactString, pub address: ListenerAddress, + #[serde(skip_serializing_if = "Option::is_none", default)] + pub version_info: Option, #[serde(with = "serde_filterchains")] pub filter_chains: HashMap, #[serde(skip_serializing_if = "Option::is_none", default = "Default::default")] @@ -54,6 +64,8 @@ pub struct Listener { pub with_tlv_listener_filter: bool, #[serde(skip_serializing_if = "Option::is_none", default = "Default::default")] pub tlv_listener_filter_config: Option, + #[serde(default)] + pub drain_type: DrainType, } impl Listener { @@ -333,7 +345,7 @@ mod envoy_conversions { use std::hash::{DefaultHasher, Hash, Hasher}; use std::str::FromStr; - use super::{FilterChain, FilterChainMatch, Listener, MainFilter, ServerNameMatch, TlsConfig}; + use super::{DrainType, FilterChain, FilterChainMatch, Listener, MainFilter, ServerNameMatch, TlsConfig}; use crate::config::{ common::*, core::{Address, CidrRange}, @@ -414,7 +426,7 @@ mod envoy_conversions { per_connection_buffer_limit_bytes, metadata, deprecated_v1, - drain_type, + // drain_type, // listener_filters, listener_filters_timeout, continue_on_listener_filters_timeout, @@ -500,6 +512,7 @@ mod envoy_conversions { .with_node("socket_options"); } let bind_device = bind_device.into_iter().next(); + let drain_type = DrainType::try_from(drain_type).unwrap_or_default(); Ok(Self { name, address, @@ -509,12 +522,26 @@ mod envoy_conversions { proxy_protocol_config, with_tlv_listener_filter, tlv_listener_filter_config, + drain_type, + version_info: None, }) }()) .with_name(name) } } + impl TryFrom for DrainType { + type Error = GenericError; + + fn try_from(value: i32) -> Result { + match value { + 0 => Ok(DrainType::Default), + 1 => Ok(DrainType::ModifyOnly), + _ => Err(GenericError::from_msg(format!("Unknown drain type: {}", value))), + } + } + } + struct FilterChainWrapper((FilterChainMatch, FilterChain)); impl TryFrom for FilterChainWrapper { diff --git a/orion-configuration/src/config/network_filters/http_connection_manager.rs b/orion-configuration/src/config/network_filters/http_connection_manager.rs index 0ddb4b0b..6f4af2a9 100644 --- a/orion-configuration/src/config/network_filters/http_connection_manager.rs +++ b/orion-configuration/src/config/network_filters/http_connection_manager.rs @@ -51,6 +51,9 @@ pub struct HttpConnectionManager { #[serde(with = "humantime_serde")] #[serde(skip_serializing_if = "Option::is_none", default)] pub request_timeout: Option, + #[serde(with = "humantime_serde")] + #[serde(skip_serializing_if = "Option::is_none", default)] + pub drain_timeout: Option, #[serde(skip_serializing_if = "Vec::is_empty", default)] pub http_filters: Vec, #[serde(skip_serializing_if = "Vec::is_empty", default)] @@ -564,6 +567,61 @@ mod tests { assert!(MatchHostScoreLPM::Wildcard < MatchHostScoreLPM::Suffix("foo.bar.test.com".len())); assert!(MatchHostScoreLPM::Wildcard == MatchHostScoreLPM::Wildcard); } + + #[test] + fn test_drain_timeout_configuration() { + let config = HttpConnectionManager { + codec_type: CodecType::Auto, + route_specifier: RouteSpecifier::RouteConfig(RouteConfiguration { + name: "test_route".into(), + most_specific_header_mutations_wins: false, + response_header_modifier: Default::default(), + request_headers_to_add: vec![], + request_headers_to_remove: vec![], + virtual_hosts: vec![], + }), + http_filters: vec![], + enabled_upgrades: vec![], + access_log: vec![], + xff_settings: Default::default(), + generate_request_id: false, + preserve_external_request_id: false, + always_set_request_id_in_response: false, + tracing: None, + request_timeout: Some(Duration::from_secs(30)), + drain_timeout: Some(Duration::from_secs(10)), + }; + + assert_eq!(config.drain_timeout, Some(Duration::from_secs(10))); + assert_eq!(config.request_timeout, Some(Duration::from_secs(30))); + } + + #[test] + fn test_drain_timeout_default() { + let config = HttpConnectionManager { + codec_type: CodecType::Http1, + route_specifier: RouteSpecifier::RouteConfig(RouteConfiguration { + name: "test_route_default".into(), + most_specific_header_mutations_wins: false, + response_header_modifier: Default::default(), + request_headers_to_add: vec![], + request_headers_to_remove: vec![], + virtual_hosts: vec![], + }), + http_filters: vec![], + enabled_upgrades: vec![], + access_log: vec![], + xff_settings: Default::default(), + generate_request_id: false, + preserve_external_request_id: false, + always_set_request_id_in_response: false, + tracing: None, + request_timeout: None, + drain_timeout: None, + }; + + assert_eq!(config.drain_timeout, None); + } } #[cfg(feature = "envoy-conversions")] @@ -702,7 +760,7 @@ mod envoy_conversions { stream_idle_timeout, // request_timeout, request_headers_timeout, - drain_timeout, + // drain_timeout, delayed_close_timeout, // access_log, access_log_flush_interval, @@ -753,6 +811,11 @@ mod envoy_conversions { .transpose() .map_err(|_| GenericError::from_msg("failed to convert into Duration")) .with_node("request_timeout")?; + let drain_timeout = drain_timeout + .map(duration_from_envoy) + .transpose() + .map_err(|_| GenericError::from_msg("failed to convert into Duration")) + .with_node("drain_timeout")?; let enabled_upgrades = upgrade_configs .iter() .filter(|upgrade_config| upgrade_config.enabled.map(|enabled| enabled.value).unwrap_or(true)) @@ -819,6 +882,7 @@ mod envoy_conversions { enabled_upgrades, route_specifier, request_timeout, + drain_timeout, access_log, xff_settings, generate_request_id: generate_request_id.map(|v| v.value).unwrap_or(true), diff --git a/orion-configuration/tests/test_internal_listener.rs b/orion-configuration/tests/test_internal_listener.rs index c87e8e48..cf5ad567 100644 --- a/orion-configuration/tests/test_internal_listener.rs +++ b/orion-configuration/tests/test_internal_listener.rs @@ -52,6 +52,8 @@ fn test_internal_listener_serialization() { proxy_protocol_config: None, with_tlv_listener_filter: false, tlv_listener_filter_config: None, + drain_type: orion_configuration::config::listener::DrainType::Default, + version_info: None, }; let yaml = serde_yaml::to_string(&listener).unwrap(); @@ -146,6 +148,8 @@ fn test_complete_internal_listener_config() { proxy_protocol_config: None, with_tlv_listener_filter: false, tlv_listener_filter_config: None, + drain_type: orion_configuration::config::listener::DrainType::Default, + version_info: None, }; let internal_addr = InternalEndpointAddress { diff --git a/orion-lib/Cargo.toml b/orion-lib/Cargo.toml index 3319e985..df4cbb9d 100644 --- a/orion-lib/Cargo.toml +++ b/orion-lib/Cargo.toml @@ -11,9 +11,11 @@ ahash = "0.8.11" arc-swap = "1.7.1" arrayvec = "0.7.6" async-stream = "0.3" +async-trait = "0.1.77" atomic-time = "0.1.4" bytes.workspace = true compact_str.workspace = true +dashmap = "6.0" enum_dispatch = "0.3.13" exponential-backoff.workspace = true futures.workspace = true diff --git a/orion-lib/src/listeners/drain_signaling.rs b/orion-lib/src/listeners/drain_signaling.rs new file mode 100644 index 00000000..256d09f9 --- /dev/null +++ b/orion-lib/src/listeners/drain_signaling.rs @@ -0,0 +1,547 @@ +// Copyright 2025 The kmesh Authors +// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// + +use crate::{Error, Result}; +use orion_configuration::config::listener::{DrainType as ConfigDrainType, FilterChain, MainFilter}; +use pingora_timeout::fast_timeout::fast_timeout; +use std::collections::HashMap; +use std::sync::Arc; +use std::time::{Duration, Instant}; +use tokio::sync::RwLock; +use tokio::time::sleep; +use tracing::{debug, info, warn}; + +#[derive(Debug, Clone)] +pub enum ListenerProtocolConfig { + Http { drain_timeout: Option }, + Tcp, + Mixed { http_drain_timeout: Option, has_tcp: bool, has_http: bool }, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum DrainScenario { + HealthCheckFail, + ListenerUpdate, + HotRestart, +} + +impl DrainScenario { + pub fn should_drain(self, drain_type: ConfigDrainType) -> bool { + match (self, drain_type) { + (_, ConfigDrainType::Default) + | (DrainScenario::ListenerUpdate | DrainScenario::HotRestart, ConfigDrainType::ModifyOnly) => true, + (DrainScenario::HealthCheckFail, ConfigDrainType::ModifyOnly) => false, + } + } +} + +#[derive(Debug, Clone)] +pub enum DrainStrategy { + Tcp { global_timeout: Duration }, + Http { global_timeout: Duration, drain_timeout: Duration }, + Mixed { global_timeout: Duration, http_drain_timeout: Duration, tcp_connections: bool, http_connections: bool }, + Immediate, +} + +#[derive(Debug, Clone)] +pub struct ListenerDrainState { + pub started_at: Instant, + pub strategy: super::listeners_manager::DrainStrategy, + pub protocol_behavior: super::listeners_manager::ProtocolDrainBehavior, + pub drain_scenario: DrainScenario, + pub drain_type: ConfigDrainType, +} + +#[derive(Debug)] +pub struct ListenerDrainContext { + pub listener_id: String, + pub strategy: DrainStrategy, + pub drain_start: Instant, + pub initial_connections: usize, + pub active_connections: Arc>, + pub completed: Arc>, +} + +impl ListenerDrainContext { + pub fn new(listener_id: String, strategy: DrainStrategy, initial_connections: usize) -> Self { + Self { + listener_id, + strategy, + drain_start: Instant::now(), + initial_connections, + active_connections: Arc::new(RwLock::new(initial_connections)), + completed: Arc::new(RwLock::new(false)), + } + } + + pub async fn update_connection_count(&self, count: usize) { + let mut active = self.active_connections.write().await; + *active = count; + + if count == 0 { + let mut completed = self.completed.write().await; + *completed = true; + debug!("Listener drain completed - all connections closed for {}", self.listener_id); + } + } + + pub async fn is_completed(&self) -> bool { + *self.completed.read().await + } + + pub async fn get_active_connections(&self) -> usize { + *self.active_connections.read().await + } + + pub fn is_timeout_exceeded(&self) -> bool { + let global_timeout = match &self.strategy { + DrainStrategy::Tcp { global_timeout } + | DrainStrategy::Http { global_timeout, .. } + | DrainStrategy::Mixed { global_timeout, .. } => *global_timeout, + DrainStrategy::Immediate => Duration::from_secs(0), + }; + + self.drain_start.elapsed() >= global_timeout + } + + pub fn get_http_drain_timeout(&self) -> Option { + match &self.strategy { + DrainStrategy::Http { drain_timeout, .. } => Some(*drain_timeout), + DrainStrategy::Mixed { http_drain_timeout, .. } => Some(*http_drain_timeout), + _ => None, + } + } +} + +#[derive(Debug)] +pub struct DrainSignalingManager { + drain_contexts: Arc>>>, + global_drain_timeout: Duration, + default_http_drain_timeout: Duration, + listener_drain_state: Arc>>, +} + +impl ListenerProtocolConfig { + pub fn from_listener_analysis( + has_http_connection_manager: bool, + has_tcp_proxy: bool, + http_drain_timeout: Option, + ) -> Self { + match (has_http_connection_manager, has_tcp_proxy) { + (true, true) => Self::Mixed { http_drain_timeout, has_tcp: true, has_http: true }, + (true, false) => Self::Http { drain_timeout: http_drain_timeout }, + (false, true) => Self::Tcp, + (false, false) => { + warn!("No HTTP connection manager or TCP proxy found in listener, defaulting to TCP draining"); + Self::Tcp + }, + } + } +} + +impl DrainSignalingManager { + pub fn new() -> Self { + Self { + drain_contexts: Arc::new(RwLock::new(HashMap::new())), + global_drain_timeout: Duration::from_secs(600), + default_http_drain_timeout: Duration::from_secs(5), + listener_drain_state: Arc::new(RwLock::new(None)), + } + } + + pub fn with_timeouts(global_drain_timeout: Duration, default_http_drain_timeout: Duration) -> Self { + Self { + drain_contexts: Arc::new(RwLock::new(HashMap::new())), + global_drain_timeout, + default_http_drain_timeout, + listener_drain_state: Arc::new(RwLock::new(None)), + } + } + + pub async fn start_listener_draining(&self, drain_state: ListenerDrainState) { + if !drain_state.drain_scenario.should_drain(drain_state.drain_type) { + debug!( + "Skipping drain for scenario {:?} with drain_type {:?}", + drain_state.drain_scenario, drain_state.drain_type + ); + return; + } + + info!("Starting listener-wide draining with strategy {:?}", drain_state.strategy); + let mut state = self.listener_drain_state.write().await; + *state = Some(drain_state); + } + + pub async fn stop_listener_draining(&self) { + info!("Stopping listener-wide draining"); + let mut state = self.listener_drain_state.write().await; + *state = None; + + let mut contexts = self.drain_contexts.write().await; + contexts.clear(); + } + + pub async fn is_listener_draining(&self) -> bool { + self.listener_drain_state.read().await.is_some() + } + + pub async fn apply_http1_drain_signal(&self, response: &mut hyper::Response) { + if let Some(drain_state) = &*self.listener_drain_state.read().await { + match &drain_state.protocol_behavior { + super::listeners_manager::ProtocolDrainBehavior::Http1 { connection_close: true } + | super::listeners_manager::ProtocolDrainBehavior::Auto => { + use hyper::header::{HeaderValue, CONNECTION}; + response.headers_mut().insert(CONNECTION, HeaderValue::from_static("close")); + debug!("Applied 'Connection: close' header for HTTP/1.1 drain signaling"); + }, + _ => { + debug!("Skipping Connection: close header for non-HTTP/1.1 protocol"); + }, + } + } + } + + pub fn apply_http1_drain_signal_sync(response: &mut hyper::Response) -> bool { + use hyper::header::{HeaderValue, CONNECTION}; + response.headers_mut().insert(CONNECTION, HeaderValue::from_static("close")); + debug!("Applied 'Connection: close' header for HTTP/1.1 drain signaling"); + true + } + + pub async fn get_http2_drain_timeout(&self, listener_id: &str) -> Option { + let contexts = self.drain_contexts.read().await; + if let Some(context) = contexts.get(listener_id) { + context.get_http_drain_timeout() + } else { + Some(self.default_http_drain_timeout) + } + } + + pub async fn initiate_listener_drain( + &self, + listener_id: String, + protocol_config: ListenerProtocolConfig, + active_connections: usize, + ) -> Result> { + let strategy = match protocol_config { + ListenerProtocolConfig::Http { drain_timeout } => DrainStrategy::Http { + global_timeout: self.global_drain_timeout, + drain_timeout: drain_timeout.unwrap_or(self.default_http_drain_timeout), + }, + ListenerProtocolConfig::Tcp => DrainStrategy::Tcp { global_timeout: self.global_drain_timeout }, + ListenerProtocolConfig::Mixed { http_drain_timeout, has_tcp, has_http } => DrainStrategy::Mixed { + global_timeout: self.global_drain_timeout, + http_drain_timeout: http_drain_timeout.unwrap_or(self.default_http_drain_timeout), + tcp_connections: has_tcp, + http_connections: has_http, + }, + }; + + let context = Arc::new(ListenerDrainContext::new(listener_id.clone(), strategy.clone(), active_connections)); + + { + let mut contexts = self.drain_contexts.write().await; + contexts.insert(listener_id.clone(), context.clone()); + } + + info!( + "Initiated listener draining for {}, strategy: {:?}, active_connections: {}", + listener_id, strategy, active_connections + ); + + let context_clone = context.clone(); + let manager_clone = self.clone(); + let listener_id_clone = listener_id.clone(); + tokio::spawn(async move { + let () = manager_clone.monitor_drain_progress(context_clone, listener_id_clone).await; + }); + + Ok(context) + } + + async fn monitor_drain_progress(&self, context: Arc, listener_id: String) { + let check_interval = Duration::from_secs(1); + + loop { + sleep(check_interval).await; + + if context.is_completed().await { + self.complete_drain(listener_id.clone()).await; + return; + } + + if context.is_timeout_exceeded() { + let elapsed = context.drain_start.elapsed(); + let active_connections = context.get_active_connections().await; + warn!( + "Global drain timeout exceeded for listener {}, elapsed: {:?}, active_connections: {}", + listener_id, elapsed, active_connections + ); + self.force_complete_drain(listener_id.clone()).await; + return; + } + + let elapsed = context.drain_start.elapsed(); + let active_connections = context.get_active_connections().await; + debug!( + "Drain progress check for listener {}, elapsed: {:?}, active_connections: {}", + listener_id, elapsed, active_connections + ); + } + } + + async fn complete_drain(&self, listener_id: String) { + let mut contexts = self.drain_contexts.write().await; + if let Some(context) = contexts.remove(&listener_id) { + let duration = context.drain_start.elapsed(); + info!("Listener drain completed successfully for {}, duration: {:?}", listener_id, duration); + } + } + + async fn force_complete_drain(&self, listener_id: String) { + let mut contexts = self.drain_contexts.write().await; + if let Some(context) = contexts.remove(&listener_id) { + let mut completed = context.completed.write().await; + *completed = true; + let duration = context.drain_start.elapsed(); + warn!("Listener drain force completed due to timeout for {}, duration: {:?}", listener_id, duration); + } + } + + pub async fn get_drain_context(&self, listener_id: &str) -> Option> { + let contexts = self.drain_contexts.read().await; + contexts.get(listener_id).cloned() + } + + pub async fn has_draining_listeners(&self) -> bool { + let contexts = self.drain_contexts.read().await; + !contexts.is_empty() + } + + pub async fn get_draining_listeners(&self) -> Vec { + let contexts = self.drain_contexts.read().await; + contexts.keys().cloned().collect() + } + + pub async fn wait_for_drain_completion(&self, timeout_duration: Duration) -> Result<()> { + let result = fast_timeout(timeout_duration, async { + loop { + if !self.has_draining_listeners().await { + break; + } + sleep(Duration::from_millis(100)).await; + } + }) + .await; + + if let Ok(()) = result { + info!("All listener draining completed successfully"); + Ok(()) + } else { + let draining = self.get_draining_listeners().await; + warn!("Timeout waiting for drain completion, draining_listeners: {:?}", draining); + Err(Error::new("Timeout waiting for listener drain completion")) + } + } + + pub async fn initiate_listener_drain_from_filter_analysis( + &self, + listener_id: String, + filter_chains: &[FilterChain], + active_connections: usize, + ) -> Result> { + let mut has_http = false; + let mut has_tcp = false; + let mut http_drain_timeout: Option = None; + + for filter_chain in filter_chains { + match &filter_chain.terminal_filter { + MainFilter::Http(http_config) => { + has_http = true; + http_drain_timeout = http_config.drain_timeout; + }, + MainFilter::Tcp(_) => { + has_tcp = true; + }, + } + } + + let protocol_config = ListenerProtocolConfig::from_listener_analysis(has_http, has_tcp, http_drain_timeout); + + self.initiate_listener_drain(listener_id, protocol_config, active_connections).await + } +} + +impl Clone for DrainSignalingManager { + fn clone(&self) -> Self { + Self { + drain_contexts: self.drain_contexts.clone(), + global_drain_timeout: self.global_drain_timeout, + default_http_drain_timeout: self.default_http_drain_timeout, + listener_drain_state: self.listener_drain_state.clone(), + } + } +} + +impl Default for DrainSignalingManager { + fn default() -> Self { + Self::new() + } +} + +#[derive(Debug)] +pub struct DefaultConnectionHandler {} + +impl DefaultConnectionHandler { + pub fn new() -> Self { + Self {} + } + + pub fn register_connection( + _connection_id: String, + _protocol: super::listeners_manager::ConnectionProtocol, + _peer_addr: std::net::SocketAddr, + ) { + } +} + +impl Default for DefaultConnectionHandler { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tokio::time::sleep; + + #[tokio::test] + async fn test_tcp_drain_context() { + let strategy = DrainStrategy::Tcp { global_timeout: Duration::from_secs(1) }; + let context = ListenerDrainContext::new("test-tcp".to_string(), strategy, 5); + + assert_eq!(context.get_active_connections().await, 5); + assert!(!context.is_completed().await); + + context.update_connection_count(0).await; + assert!(context.is_completed().await); + } + + #[tokio::test] + async fn test_http_drain_context() { + let strategy = DrainStrategy::Http { + global_timeout: Duration::from_secs(600), + drain_timeout: Duration::from_millis(5000), + }; + let context = ListenerDrainContext::new("test-http".to_string(), strategy, 3); + + assert_eq!(context.get_active_connections().await, 3); + assert!(!context.is_completed().await); + assert_eq!(context.get_http_drain_timeout(), Some(Duration::from_millis(5000))); + assert!(!context.is_timeout_exceeded()); + } + + #[tokio::test] + async fn test_drain_manager_basic() { + let manager = DrainSignalingManager::new(); + assert!(!manager.has_draining_listeners().await); + + let context = + manager.initiate_listener_drain("test".to_string(), ListenerProtocolConfig::Tcp, 1).await.unwrap(); + + assert!(manager.has_draining_listeners().await); + assert_eq!(manager.get_draining_listeners().await, vec!["test"]); + + context.update_connection_count(0).await; + + sleep(Duration::from_millis(100)).await; + } + + #[tokio::test] + async fn test_timeout_behavior() { + let manager = DrainSignalingManager::with_timeouts(Duration::from_millis(50), Duration::from_millis(25)); + + let context = manager + .initiate_listener_drain( + "timeout-test".to_string(), + ListenerProtocolConfig::Http { drain_timeout: None }, + 5, + ) + .await + .unwrap(); + + sleep(Duration::from_millis(10)).await; + sleep(Duration::from_millis(60)).await; + assert!(context.is_timeout_exceeded()); + + let mut attempts = 0; + while attempts < 20 && !context.is_completed().await { + sleep(Duration::from_millis(100)).await; + attempts += 1; + } + + assert!(context.is_completed().await, "Expected context to be completed after timeout"); + assert!( + !manager.has_draining_listeners().await, + "Expected manager to no longer track the listener after timeout" + ); + } + + #[tokio::test] + async fn test_mixed_protocol_drain_context() { + let strategy = DrainStrategy::Mixed { + global_timeout: Duration::from_secs(600), + http_drain_timeout: Duration::from_secs(5), + tcp_connections: true, + http_connections: true, + }; + let context = ListenerDrainContext::new("test-mixed".to_string(), strategy, 10); + + assert_eq!(context.get_active_connections().await, 10); + assert!(!context.is_completed().await); + assert_eq!(context.get_http_drain_timeout(), Some(Duration::from_secs(5))); + assert!(!context.is_timeout_exceeded()); + } + + #[tokio::test] + async fn test_listener_protocol_config_analysis() { + let http_config = ListenerProtocolConfig::from_listener_analysis(true, false, Some(Duration::from_secs(10))); + match http_config { + ListenerProtocolConfig::Http { drain_timeout } => { + assert_eq!(drain_timeout, Some(Duration::from_secs(10))); + }, + _ => panic!("Expected HTTP config"), + } + + let tcp_config = ListenerProtocolConfig::from_listener_analysis(false, true, None); + match tcp_config { + ListenerProtocolConfig::Tcp => {}, + _ => panic!("Expected TCP config"), + } + + let mixed_config = ListenerProtocolConfig::from_listener_analysis(true, true, Some(Duration::from_secs(3))); + match mixed_config { + ListenerProtocolConfig::Mixed { http_drain_timeout, has_tcp, has_http } => { + assert_eq!(http_drain_timeout, Some(Duration::from_secs(3))); + assert!(has_tcp); + assert!(has_http); + }, + _ => panic!("Expected Mixed config"), + } + } +} diff --git a/orion-lib/src/listeners/filterchain.rs b/orion-lib/src/listeners/filterchain.rs index d23eb248..0373d285 100644 --- a/orion-lib/src/listeners/filterchain.rs +++ b/orion-lib/src/listeners/filterchain.rs @@ -235,8 +235,10 @@ impl FilterchainType { .serve_connection_with_upgrades( stream, hyper::service::service_fn(|req: Request| { - let handler_req = - ExtendedRequest { request: req, downstream_metadata: downstream_metadata.clone() }; + let handler_req = ExtendedRequest { + request: req, + downstream_metadata: Arc::new(downstream_metadata.connection.clone()), + }; req_handler.call(handler_req).map_err(orion_error::Error::into_inner) }), ) diff --git a/orion-lib/src/listeners/http_connection_manager.rs b/orion-lib/src/listeners/http_connection_manager.rs index 21f3a2c7..49ae5f02 100644 --- a/orion-lib/src/listeners/http_connection_manager.rs +++ b/orion-lib/src/listeners/http_connection_manager.rs @@ -42,7 +42,8 @@ use orion_configuration::config::GenericError; use orion_format::types::ResponseFlags as FmtResponseFlags; use orion_tracing::span_state::SpanState; use orion_tracing::{attributes::HTTP_RESPONSE_STATUS_CODE, with_client_span, with_server_span}; -use std::sync::atomic::AtomicUsize; +use std::collections::HashMap; +use tracing::{debug, info}; use orion_configuration::config::network_filters::http_connection_manager::http_filters::{ FilterConfigOverride, FilterOverride, @@ -59,6 +60,7 @@ use orion_configuration::config::network_filters::{ RdsSpecifier, RouteSpecifier, UpgradeType, }, }; +use orion_configuration::config::TlvType; use orion_format::context::{ DownstreamResponse, FinishContext, HttpRequestDuration, HttpResponseDuration, InitHttpContext, }; @@ -68,13 +70,18 @@ use orion_metrics::{metrics::http, with_metric}; use parking_lot::Mutex; use route::MatchedRequest; use scopeguard::defer; -use std::collections::{HashMap, HashSet}; +use std::collections::HashSet; +use std::sync::atomic::AtomicUsize; use std::thread::ThreadId; use std::time::Instant; -use std::{fmt, future::Future, result::Result as StdResult, sync::Arc}; +use std::{ + fmt, + future::Future, + result::Result as StdResult, + sync::{Arc, LazyLock}, +}; use tokio::sync::mpsc::Permit; use tokio::sync::watch; -use tracing::debug; use upgrades as upgrade_utils; use crate::event_error::{EventKind, UpstreamTransportEventError}; @@ -89,7 +96,11 @@ use crate::{ use crate::{ body::body_with_timeout::BodyWithTimeout, listeners::{ - access_log::AccessLogContext, filter_state::DownstreamMetadata, rate_limiter::LocalRateLimit, + access_log::AccessLogContext, + drain_signaling::DrainSignalingManager, + filter_state::{DownstreamConnectionMetadata, DownstreamMetadata}, + listeners_manager::ConnectionManager, + rate_limiter::LocalRateLimit, synthetic_http_response::SyntheticHttpResponse, }, utils::http::{request_head_size, response_head_size}, @@ -99,6 +110,9 @@ use orion_tracing::http_tracer::{HttpTracer, SpanKind, SpanName}; use orion_tracing::request_id::{RequestId, RequestIdManager}; use orion_tracing::trace_context::TraceContext; +static EMPTY_HASHMAP: LazyLock>>>> = + LazyLock::new(|| Arc::new(HashMap::new())); + #[derive(Debug, Clone)] pub struct HttpConnectionManagerBuilder { listener_name: Option<&'static str>, @@ -131,6 +145,7 @@ impl HttpConnectionManagerBuilder { http_filters_per_route: ArcSwap::new(Arc::new(partial.http_filters_per_route)), enabled_upgrades: partial.enabled_upgrades, request_timeout: partial.request_timeout, + drain_timeout: partial.drain_timeout, access_log: partial.access_log, xff_settings: partial.xff_settings, request_id_handler: RequestIdManager::new( @@ -142,6 +157,8 @@ impl HttpConnectionManagerBuilder { Some(tracing) => HttpTracer::new().with_config(tracing), None => HttpTracer::new(), }, + drain_signaling: Arc::new(DrainSignalingManager::new()), + connection_manager: None, // Will be set during listener startup }) } @@ -163,6 +180,7 @@ pub struct PartialHttpConnectionManager { http_filters_per_route: HashMap>>, enabled_upgrades: Vec, request_timeout: Option, + drain_timeout: Option, access_log: Vec, xff_settings: XffSettings, generate_request_id: bool, @@ -257,6 +275,7 @@ impl TryFrom> for PartialHttp .map(|f| Arc::new(HttpFilter::from(f))) .collect::>>(); let request_timeout = configuration.request_timeout; + let drain_timeout = configuration.drain_timeout; let access_log = configuration.access_log; let xff_settings = configuration.xff_settings; let generate_request_id = configuration.generate_request_id; @@ -285,6 +304,7 @@ impl TryFrom> for PartialHttp http_filters_per_route, enabled_upgrades, request_timeout, + drain_timeout, access_log, xff_settings, generate_request_id, @@ -320,7 +340,6 @@ impl AlpnCodecs { } } -#[derive(Debug)] pub struct HttpConnectionManager { pub listener_name: &'static str, pub filter_chain_match_hash: u64, @@ -331,10 +350,47 @@ pub struct HttpConnectionManager { http_filters_per_route: ArcSwap>>>, enabled_upgrades: Vec, request_timeout: Option, + drain_timeout: Option, access_log: Vec, xff_settings: XffSettings, request_id_handler: RequestIdManager, pub http_tracer: HttpTracer, + drain_signaling: Arc, + connection_manager: Option>, +} + +impl Clone for HttpConnectionManager { + fn clone(&self) -> Self { + Self { + listener_name: self.listener_name, + filter_chain_match_hash: self.filter_chain_match_hash, + router_sender: self.router_sender.clone(), + codec_type: self.codec_type, + dynamic_route_name: self.dynamic_route_name.clone(), + http_filters_hcm: self.http_filters_hcm.clone(), + http_filters_per_route: ArcSwap::new(self.http_filters_per_route.load_full()), + enabled_upgrades: self.enabled_upgrades.clone(), + request_timeout: self.request_timeout, + drain_timeout: self.drain_timeout, + access_log: self.access_log.clone(), + xff_settings: self.xff_settings.clone(), + request_id_handler: self.request_id_handler.clone(), + http_tracer: self.http_tracer.clone(), + drain_signaling: self.drain_signaling.clone(), + connection_manager: self.connection_manager.clone(), + } + } +} + +impl std::fmt::Debug for HttpConnectionManager { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("HttpConnectionManager") + .field("listener_name", &self.listener_name) + .field("filter_chain_match_hash", &self.filter_chain_match_hash) + .field("codec_type", &self.codec_type) + .field("connection_manager", &"") + .finish() + } } impl fmt::Display for HttpConnectionManager { @@ -360,9 +416,184 @@ impl HttpConnectionManager { } pub fn remove_route(&self) { + self.http_filters_per_route.swap(EMPTY_HASHMAP.clone()); let _ = self.router_sender.send_replace(None); } + pub fn get_drain_signaling(&self) -> Arc { + Arc::clone(&self.drain_signaling) + } + + pub fn get_drain_timeout(&self) -> Option { + self.drain_timeout + } + + pub async fn start_draining(&self, drain_state: crate::listeners::drain_signaling::ListenerDrainState) { + let listener_id = format!("{}-{}", self.listener_name, self.filter_chain_match_hash); + let protocol_config = + crate::listeners::drain_signaling::ListenerProtocolConfig::Http { drain_timeout: self.drain_timeout }; + let _ = self.drain_signaling.initiate_listener_drain(listener_id, protocol_config, 0).await; + + self.drain_signaling.start_listener_draining(drain_state).await; + } + + pub async fn stop_draining(&self) { + self.drain_signaling.stop_listener_draining().await; + } + + pub fn set_connection_manager(&mut self, connection_manager: Arc) { + self.connection_manager = Some(connection_manager); + } + + pub fn on_connection_established( + &self, + connection_id: String, + protocol: crate::listeners::listeners_manager::ConnectionProtocol, + ) { + use crate::listeners::listeners_manager::{ConnectionInfo, ConnectionState}; + use std::time::Instant; + + let conn_info = ConnectionInfo { + id: connection_id.clone(), + protocol: protocol.clone(), + established_at: Instant::now(), + last_activity: Instant::now(), + state: ConnectionState::Active, + }; + + info!("HTTP connection {} established with protocol {:?}", connection_id, protocol); + + if let Some(ref connection_manager) = self.connection_manager { + connection_manager.on_connection_established(self.listener_name, conn_info); + } + } + + pub fn on_connection_closed(&self, connection_id: &str) { + info!("HTTP connection {} closed", connection_id); + + if let Some(ref connection_manager) = self.connection_manager { + connection_manager.on_connection_closed(self.listener_name, connection_id); + } + } + + pub async fn apply_response_drain_signaling(&self, response: &mut hyper::Response) { + if self.drain_signaling.is_listener_draining().await { + self.drain_signaling.apply_http1_drain_signal(response).await; + } + } + + pub fn apply_response_drain_signaling_sync(response: &mut hyper::Response) -> bool { + DrainSignalingManager::apply_http1_drain_signal_sync(response) + } + + pub fn extract_connection_id(downstream_metadata: &DownstreamConnectionMetadata) -> String { + format!("{}:{}", downstream_metadata.local_address(), downstream_metadata.peer_address()) + } + + pub fn extract_connection_protocol( + downstream_metadata: &DownstreamConnectionMetadata, + ) -> crate::listeners::listeners_manager::ConnectionProtocol { + match downstream_metadata { + DownstreamConnectionMetadata::FromProxyProtocol { protocol, tlv_data, .. } => { + Self::detect_from_proxy_protocol(protocol, tlv_data) + }, + DownstreamConnectionMetadata::FromTlv { tlv_data, .. } => Self::detect_from_tlv_data(tlv_data), + DownstreamConnectionMetadata::FromSocket { local_address, .. } => { + Self::detect_from_port(local_address.port()) + }, + } + } + + fn detect_from_proxy_protocol( + protocol: &ppp::v2::Protocol, + tlv_data: &HashMap>, + ) -> crate::listeners::listeners_manager::ConnectionProtocol { + use crate::listeners::listeners_manager::ConnectionProtocol; + + if let Some(alpn_data) = tlv_data.get(&TlvType::Custom(0x01)) { + return Self::parse_alpn_protocol(alpn_data); + } + + if tlv_data.contains_key(&TlvType::Custom(0x20)) { + return ConnectionProtocol::Http2; + } + + match protocol { + ppp::v2::Protocol::Stream => ConnectionProtocol::Http1, + _ => ConnectionProtocol::Unknown, + } + } + + fn detect_from_tlv_data( + tlv_data: &HashMap>, + ) -> crate::listeners::listeners_manager::ConnectionProtocol { + use crate::listeners::listeners_manager::ConnectionProtocol; + + if let Some(alpn_data) = tlv_data.get(&0x01) { + return Self::parse_alpn_protocol(alpn_data); + } + + if tlv_data.contains_key(&0x20) { + return ConnectionProtocol::Http2; + } + + ConnectionProtocol::Http1 + } + + fn parse_alpn_protocol(alpn_data: &[u8]) -> crate::listeners::listeners_manager::ConnectionProtocol { + use crate::listeners::listeners_manager::ConnectionProtocol; + + if alpn_data.is_empty() { + return ConnectionProtocol::Http1; + } + + let mut offset = 0; + while offset < alpn_data.len() { + if offset + 1 > alpn_data.len() { + break; + } + + let proto_len = alpn_data[offset] as usize; + offset += 1; + + if offset + proto_len > alpn_data.len() { + break; + } + + let protocol = &alpn_data[offset..offset + proto_len]; + match protocol { + b"h2" => return ConnectionProtocol::Http2, + b"http/1.1" => return ConnectionProtocol::Http1, + b"http/1.0" => return ConnectionProtocol::Http1, + _ => { + debug!("Unknown ALPN protocol: {:?}", String::from_utf8_lossy(protocol)); + }, + } + + offset += proto_len; + } + + ConnectionProtocol::Http1 + } + + fn detect_from_port(port: u16) -> crate::listeners::listeners_manager::ConnectionProtocol { + use crate::listeners::listeners_manager::ConnectionProtocol; + + match port { + 443 | 8443 => ConnectionProtocol::Http2, + 80 | 8080 | 8000 => ConnectionProtocol::Http1, + _ => ConnectionProtocol::Http1, + } + } + + pub async fn should_apply_drain_signaling(&self, _request: &hyper::Request) -> bool { + self.drain_signaling.is_listener_draining().await + } + + pub async fn is_draining(&self) -> bool { + self.drain_signaling.is_listener_draining().await + } + pub(crate) fn request_handler( self: &Arc, ) -> Box< @@ -408,7 +639,7 @@ pub(crate) struct HttpRequestHandler { pub struct ExtendedRequest { pub request: Request, - pub downstream_metadata: Arc, + pub downstream_metadata: Arc, } #[derive(Debug)] @@ -837,6 +1068,16 @@ impl self.response_header_modifier.modify(resp_headers); } + match connection_manager.codec_type { + CodecType::Http1 => { + connection_manager.drain_signaling.apply_http1_drain_signal(&mut response).await; + }, + CodecType::Http2 => {}, + CodecType::Auto => { + connection_manager.drain_signaling.apply_http1_drain_signal(&mut response).await; + }, + } + Ok(response) } else { // We should not be here @@ -926,6 +1167,12 @@ impl Service> for HttpRequestHandler { let listener_name_for_trace = listener_name; Box::pin(async move { let ExtendedRequest { request, downstream_metadata } = req; + + let connection_id = HttpConnectionManager::extract_connection_id(&downstream_metadata); + let protocol = HttpConnectionManager::extract_connection_protocol(&downstream_metadata); + + manager.on_connection_established(connection_id.clone(), protocol); + let (parts, body) = request.into_parts(); let request = Request::from_parts(parts, BodyWithTimeout::new(req_timeout, body)); let permit = log_access_reserve_balanced().await; @@ -940,7 +1187,7 @@ impl Service> for HttpRequestHandler { // // 1. evaluate InitHttpContext, if logging is enabled - eval_http_init_context(&request, &trans_handler, downstream_metadata.server_name.as_deref()); + eval_http_init_context(&request, &trans_handler, None); // // 2. create the MetricsBody, which will track the size of the request body @@ -1098,9 +1345,12 @@ impl Service> for HttpRequestHandler { return Ok(response); }; + let downstream_metadata_with_server_name = + Arc::new(DownstreamMetadata::new(downstream_metadata.as_ref().clone(), None::)); + let response = trans_handler .clone() - .handle_transaction(route_conf, manager, permit, request, downstream_metadata) + .handle_transaction(route_conf, manager, permit, request, downstream_metadata_with_server_name) .await; trans_handler.trace_status_code(response, listener_name_for_trace) @@ -1175,7 +1425,13 @@ fn apply_authorization_rules(rbac: &HttpRbac, req: &Request) -> FilterDeci #[cfg(test)] mod tests { + use crate::listeners::{ + drain_signaling::{DrainScenario, DrainSignalingManager, ListenerDrainState}, + listeners_manager::{DrainStrategy, ProtocolDrainBehavior}, + }; use orion_configuration::config::network_filters::http_connection_manager::MatchHost; + use std::time::Instant; + use tracing_test::traced_test; use super::*; @@ -1214,4 +1470,98 @@ mod tests { let request = Request::builder().header("host", "domain2.com").body(()).unwrap(); assert_eq!(select_virtual_host(&request, &[vh1.clone(), vh2.clone(), vh3.clone()]), None); } + + #[traced_test] + #[tokio::test] + async fn test_drain_signaling_integration() { + let drain_signaling = Arc::new(DrainSignalingManager::new()); + + assert!(!drain_signaling.is_listener_draining().await); + + let drain_state = ListenerDrainState { + started_at: Instant::now(), + strategy: DrainStrategy::Immediate, + protocol_behavior: ProtocolDrainBehavior::Http1 { connection_close: true }, + drain_scenario: DrainScenario::ListenerUpdate, + drain_type: orion_configuration::config::listener::DrainType::Default, + }; + drain_signaling.start_listener_draining(drain_state).await; + assert!(drain_signaling.is_listener_draining().await); + + let mut response = Response::builder().status(200).body("response body").unwrap(); + + drain_signaling.apply_http1_drain_signal(&mut response).await; + assert_eq!(response.headers().get("connection").unwrap(), "close"); + } + + #[traced_test] + #[tokio::test] + async fn test_http1_drain_signal_application() { + let drain_signaling = Arc::new(DrainSignalingManager::new()); + + let drain_state = ListenerDrainState { + started_at: Instant::now(), + strategy: DrainStrategy::Gradual, + protocol_behavior: ProtocolDrainBehavior::Http1 { connection_close: true }, + drain_scenario: DrainScenario::ListenerUpdate, + drain_type: orion_configuration::config::listener::DrainType::Default, + }; + drain_signaling.start_listener_draining(drain_state).await; + + let mut response = Response::builder().status(200).body("response body").unwrap(); + + drain_signaling.apply_http1_drain_signal(&mut response).await; + assert_eq!(response.headers().get("connection").unwrap(), "close"); + } + + #[traced_test] + #[tokio::test] + async fn test_auto_drain_behavior() { + let drain_signaling = Arc::new(DrainSignalingManager::new()); + + let drain_state = ListenerDrainState { + started_at: Instant::now(), + strategy: DrainStrategy::Gradual, + protocol_behavior: ProtocolDrainBehavior::Auto, + drain_scenario: DrainScenario::ListenerUpdate, + drain_type: orion_configuration::config::listener::DrainType::Default, + }; + drain_signaling.start_listener_draining(drain_state).await; + + let mut response = Response::builder().status(200).body("response body").unwrap(); + + drain_signaling.apply_http1_drain_signal(&mut response).await; + assert_eq!(response.headers().get("connection").unwrap(), "close"); + } + + #[traced_test] + #[tokio::test] + async fn test_no_drain_signal_when_not_draining() { + let drain_signaling = Arc::new(DrainSignalingManager::new()); + + let mut response = Response::builder().status(200).body("response body").unwrap(); + + drain_signaling.apply_http1_drain_signal(&mut response).await; + assert!(!response.headers().contains_key("connection")); + } + + #[traced_test] + #[tokio::test] + async fn test_http2_drain_behavior() { + let drain_signaling = Arc::new(DrainSignalingManager::new()); + + let drain_state = ListenerDrainState { + started_at: Instant::now(), + strategy: DrainStrategy::Gradual, + protocol_behavior: ProtocolDrainBehavior::Http2 { send_goaway: true }, + drain_scenario: DrainScenario::ListenerUpdate, + drain_type: orion_configuration::config::listener::DrainType::Default, + }; + drain_signaling.start_listener_draining(drain_state).await; + + let mut response = Response::builder().status(200).body("response body").unwrap(); + + drain_signaling.apply_http1_drain_signal(&mut response).await; + assert!(!response.headers().contains_key("connection")); + } } diff --git a/orion-lib/src/listeners/lds_update.rs b/orion-lib/src/listeners/lds_update.rs new file mode 100644 index 00000000..7fdceaf4 --- /dev/null +++ b/orion-lib/src/listeners/lds_update.rs @@ -0,0 +1,232 @@ +// Copyright 2025 The kmesh Authors +// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// + +use multimap::MultiMap; +use std::collections::HashMap; +use std::sync::Arc; +use std::time::Duration; +use tokio::sync::RwLock; +use tracing::{info, warn}; + +use crate::listeners::listener::Listener; +use orion_configuration::config::listener::ListenerAddress; +use orion_configuration::config::Listener as ListenerConfig; + +#[derive(Debug, Clone, PartialEq)] +pub enum ListenerState { + Active, + Draining { started_at: std::time::Instant }, +} + +#[derive(Debug)] +pub struct LdsListenerInfo { + pub handle: tokio::task::JoinHandle, + pub config: ListenerConfig, + pub state: ListenerState, + pub created_at: std::time::Instant, +} + +impl LdsListenerInfo { + pub fn new(handle: tokio::task::JoinHandle, config: ListenerConfig) -> Self { + Self { handle, config, state: ListenerState::Active, created_at: std::time::Instant::now() } + } + + pub fn start_draining(&mut self) { + self.state = ListenerState::Draining { started_at: std::time::Instant::now() }; + } + + pub fn is_draining(&self) -> bool { + matches!(self.state, ListenerState::Draining { .. }) + } +} + +pub struct LdsManager { + listeners: Arc>>, + drain_timeout: Duration, +} + +impl LdsManager { + pub fn new() -> Self { + Self { listeners: Arc::new(RwLock::new(MultiMap::new())), drain_timeout: Duration::from_secs(600) } + } + + pub async fn handle_lds_update( + &self, + listener: Listener, + config: ListenerConfig, + ) -> Result<(), Box> { + let listener_name = config.name.to_string(); + let mut listeners = self.listeners.write().await; + + if let Some(existing_versions) = listeners.get_vec_mut(&listener_name) { + info!("LDS: Updating existing listener '{}' with {} versions", listener_name, existing_versions.len()); + + for existing in existing_versions { + if !existing.is_draining() { + existing.start_draining(); + info!("LDS: Old version of listener '{}' placed in draining state", listener_name); + } + } + + self.start_drain_timeout_for_existing(&listener_name); + } else { + info!("LDS: Adding new listener '{}'", listener_name); + } + + let handle = tokio::spawn(async move { listener.start().await }); + + let new_listener_info = LdsListenerInfo::new(handle, config); + listeners.insert(listener_name.clone(), new_listener_info); + + info!("LDS: Listener '{}' successfully updated", listener_name); + Ok(()) + } + + pub async fn remove_listener(&self, listener_name: &str) -> Result<(), Box> { + let mut listeners = self.listeners.write().await; + + if let Some(versions) = listeners.get_vec_mut(listener_name) { + info!("LDS: Removing listener '{}' with {} versions", listener_name, versions.len()); + + for listener_info in versions { + if !listener_info.is_draining() { + listener_info.start_draining(); + info!("LDS: Version of listener '{}' placed in draining state for removal", listener_name); + } + } + + self.start_drain_timeout_for_removal(listener_name); + + Ok(()) + } else { + warn!("LDS: Attempted to remove non-existent listener '{}'", listener_name); + Ok(()) + } + } + + fn start_drain_timeout_for_existing(&self, listener_name: &str) { + let listeners = self.listeners.clone(); + let timeout = self.drain_timeout; + let name = listener_name.to_owned(); + + tokio::spawn(async move { + tokio::time::sleep(timeout).await; + + let mut listeners_guard = listeners.write().await; + if let Some(versions) = listeners_guard.get_vec_mut(&name) { + versions.iter_mut().filter(|listener_info| listener_info.is_draining()).for_each(|listener_info| { + listener_info.handle.abort(); + info!("LDS: Draining version of listener '{}' forcibly closed after timeout", name); + }); + versions.retain(|listener_info| !listener_info.is_draining()); + if versions.is_empty() { + listeners_guard.remove(&name); + } + } + }); + } + + fn start_drain_timeout_for_removal(&self, listener_name: &str) { + let listeners = self.listeners.clone(); + let timeout = self.drain_timeout; + let name = listener_name.to_owned(); + + tokio::spawn(async move { + tokio::time::sleep(timeout).await; + + let mut listeners_guard = listeners.write().await; + if let Some(versions) = listeners_guard.remove(&name) { + for listener_info in versions { + listener_info.handle.abort(); + info!("LDS: Listener '{}' forcibly closed after drain timeout during removal", name); + } + } + }); + } + + pub async fn get_listener_info(&self, name: &str) -> Option { + let listeners = self.listeners.read().await; + listeners.get_vec(name)?.last().map(|info| info.state.clone()) + } + + pub async fn list_listeners(&self) -> HashMap> { + let listeners = self.listeners.read().await; + listeners + .iter_all() + .map(|(name, versions)| { + let states = versions.iter().map(|info| info.state.clone()).collect(); + (name.clone(), states) + }) + .collect() + } +} + +impl Default for LdsManager { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tokio::sync::broadcast; + + #[tokio::test] + async fn test_lds_listener_update() { + let manager = LdsManager::new(); + + let config = ListenerConfig { + name: "test-listener".to_string().into(), + address: ListenerAddress::Socket("127.0.0.1:8080".parse().unwrap()), + filter_chains: Default::default(), + bind_device: None, + with_tls_inspector: false, + proxy_protocol_config: None, + with_tlv_listener_filter: false, + tlv_listener_filter_config: None, + drain_type: orion_configuration::config::listener::DrainType::Default, + version_info: None, + }; + + let (route_tx, route_rx) = broadcast::channel(10); + let (sec_tx, sec_rx) = broadcast::channel(10); + let listener = Listener::test_listener("test-listener", route_rx, sec_rx); + + manager.handle_lds_update(listener, config.clone()).await.unwrap(); + + let state = manager.get_listener_info("test-listener").await.unwrap(); + assert_eq!(state, ListenerState::Active); + + let (route_tx2, route_rx2) = broadcast::channel(10); + let (sec_tx2, sec_rx2) = broadcast::channel(10); + let listener2 = Listener::test_listener("test-listener", route_rx2, sec_rx2); + + let mut config2 = config; + config2.address = ListenerAddress::Socket("127.0.0.1:8081".parse().unwrap()); + + manager.handle_lds_update(listener2, config2).await.unwrap(); + + let state = manager.get_listener_info("test-listener").await.unwrap(); + assert_eq!(state, ListenerState::Active); + + drop(route_tx); + drop(sec_tx); + drop(route_tx2); + drop(sec_tx2); + } +} diff --git a/orion-lib/src/listeners/listener.rs b/orion-lib/src/listeners/listener.rs index de1e37fa..3e72c6ce 100644 --- a/orion-lib/src/listeners/listener.rs +++ b/orion-lib/src/listeners/listener.rs @@ -16,7 +16,9 @@ // use super::{ + drain_signaling::DefaultConnectionHandler, filterchain::{ConnectionHandler, FilterchainBuilder, FilterchainType}, + http_connection_manager::HttpConnectionManager, listeners_manager::TlsContextChange, }; use crate::{ @@ -43,7 +45,7 @@ use std::{ fmt::Debug, net::SocketAddr, sync::{ - atomic::{AtomicBool, Ordering}, + atomic::{AtomicBool, AtomicU64, Ordering}, Arc, }, time::Instant, @@ -54,6 +56,8 @@ use tokio::{ }; use tracing::{debug, info, warn}; +static CONNECTION_COUNTER: AtomicU64 = AtomicU64::new(1); + #[derive(Debug, Clone)] struct PartialListener { name: &'static str, @@ -158,6 +162,7 @@ impl ListenerFactory { with_tlv_listener_filter, route_updates_receiver, secret_updates_receiver, + drain_handler: Some(Arc::new(DefaultConnectionHandler::new())), }) } } @@ -181,6 +186,7 @@ pub struct Listener { with_tlv_listener_filter: bool, route_updates_receiver: broadcast::Receiver, secret_updates_receiver: broadcast::Receiver, + drain_handler: Option>, } impl Listener { @@ -201,6 +207,7 @@ impl Listener { with_tlv_listener_filter: false, route_updates_receiver: route_rx, secret_updates_receiver: secret_rx, + drain_handler: None, } } @@ -229,6 +236,7 @@ impl Listener { with_tlv_listener_filter, route_updates_receiver, secret_updates_receiver, + drain_handler, } = self; match address { ListenerAddress::Socket(local_address) => { @@ -246,6 +254,7 @@ impl Listener { with_tlv_listener_filter, route_updates_receiver, secret_updates_receiver, + drain_handler, ) .await }, @@ -274,6 +283,7 @@ impl Listener { with_tlv_listener_filter: bool, mut route_updates_receiver: broadcast::Receiver, mut secret_updates_receiver: broadcast::Receiver, + drain_handler: Option>, ) -> Error { let mut filter_chains = Arc::new(filter_chains); let listener_name = name; @@ -294,6 +304,7 @@ impl Listener { let filter_chains = Arc::clone(&filter_chains); let proxy_protocol_config = proxy_protocol_config.clone(); + let drain_handler_clone = drain_handler.clone(); // spawn a separate task for handling this client<->proxy connection // we spawn before we know if we want to process this route because we might need to run the tls_inspector which could // stall if the client is slow to send the ClientHello and end up blocking the acceptance of new connections @@ -303,7 +314,7 @@ impl Listener { // or pick a specific filter_chain to run, or we could simply if-else on the with_tls_inspector variable. let local_address = listener.local_addr().unwrap_or_else(|_| "0.0.0.0:0".parse().expect("Failed to parse fallback address")); let start = Instant::now(); - tokio::spawn(Self::process_listener_update(listener_name, filter_chains, with_tls_inspector, proxy_protocol_config, with_tlv_listener_filter, local_address, peer_addr, Box::new(stream), start)); + tokio::spawn(Self::process_listener_update(name, filter_chains, with_tls_inspector, proxy_protocol_config, with_tlv_listener_filter, local_address, peer_addr, Box::new(stream), start, drain_handler_clone)); }, Err(e) => {warn!("failed to accept tcp connection: {e}");} } @@ -467,19 +478,28 @@ impl Listener { peer_addr: SocketAddr, mut stream: AsyncStream, start_instant: std::time::Instant, + drain_handler: Option>, ) -> Result<()> { let shard_id = std::thread::current().id(); + let connection_id = + format!("{}:{}:{}", local_address, peer_addr, CONNECTION_COUNTER.fetch_add(1, Ordering::Relaxed)); + + debug!("New connection {} established on listener {}", connection_id, listener_name); let ssl = AtomicBool::new(false); + let connection_id_for_cleanup = connection_id.clone(); + let listener_name_for_cleanup = listener_name; + defer! { - with_metric!(listeners::DOWNSTREAM_CX_DESTROY, add, 1, shard_id, &[KeyValue::new("listener", listener_name.to_string())]); - with_metric!(listeners::DOWNSTREAM_CX_ACTIVE, sub, 1, shard_id, &[KeyValue::new("listener", listener_name.to_string())]); + debug!("Connection {} closed on listener {}", connection_id_for_cleanup, listener_name_for_cleanup); + with_metric!(listeners::DOWNSTREAM_CX_DESTROY, add, 1, shard_id, &[KeyValue::new("listener", listener_name_for_cleanup)]); + with_metric!(listeners::DOWNSTREAM_CX_ACTIVE, sub, 1, shard_id, &[KeyValue::new("listener", listener_name_for_cleanup)]); if ssl.load(Ordering::Relaxed) { - with_metric!(http::DOWNSTREAM_CX_SSL_ACTIVE, add, 1, shard_id, &[KeyValue::new("listener", listener_name)]); + with_metric!(http::DOWNSTREAM_CX_SSL_ACTIVE, add, 1, shard_id, &[KeyValue::new("listener", listener_name_for_cleanup)]); } let ms = u64::try_from(start_instant.elapsed().as_millis()) .unwrap_or(u64::MAX); - with_histogram!(listeners::DOWNSTREAM_CX_LENGTH_MS, record, ms, &[KeyValue::new("listener", listener_name)]); + with_histogram!(listeners::DOWNSTREAM_CX_LENGTH_MS, record, ms, &[KeyValue::new("listener", listener_name_for_cleanup)]); } let server_name = if with_tls_inspector { @@ -554,6 +574,10 @@ impl Listener { let selected_filterchain = Self::select_filterchain(&filter_chains, &downstream_metadata, server_name.as_deref())?; + if let Some(_drain_handler) = &drain_handler { + let protocol = HttpConnectionManager::extract_connection_protocol(&downstream_metadata); + DefaultConnectionHandler::register_connection(connection_id.clone(), protocol, peer_addr); + } if let Some(filterchain) = selected_filterchain { debug!( "{listener_name} : mapping connection from {peer_addr} to filter chain {}", diff --git a/orion-lib/src/listeners/listeners_manager.rs b/orion-lib/src/listeners/listeners_manager.rs index 7817f57b..ecd9eded 100644 --- a/orion-lib/src/listeners/listeners_manager.rs +++ b/orion-lib/src/listeners/listeners_manager.rs @@ -15,16 +15,30 @@ // // +use dashmap::DashMap; use multimap::MultiMap; +use std::collections::HashMap; +use std::sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, +}; +use std::time::{Duration, Instant}; + use tokio::sync::{broadcast, mpsc}; -use tracing::{info, warn}; +use tokio::time::interval; +use tracing::{debug, info, warn}; use orion_configuration::config::{ - listener::ListenerAddress, network_filters::http_connection_manager::RouteConfiguration, Listener as ListenerConfig, + network_filters::http_connection_manager::RouteConfiguration, Listener as ListenerConfig, }; -use super::listener::{Listener, ListenerFactory}; +use super::{ + drain_signaling::ListenerDrainState, + http_connection_manager::HttpConnectionManager, + listener::{Listener, ListenerFactory}, +}; use crate::{secrets::TransportSecret, ConfigDump, Result}; + #[derive(Debug, Clone)] pub enum ListenerConfigurationChange { Added(Box<(ListenerFactory, ListenerConfig)>), @@ -38,19 +52,603 @@ pub enum RouteConfigurationChange { Added((String, RouteConfiguration)), Removed(String), } + #[derive(Debug, Clone)] pub enum TlsContextChange { Updated((String, TransportSecret)), } +#[derive(Debug, Clone)] +pub struct ConnectionInfo { + pub id: String, + pub protocol: ConnectionProtocol, + pub established_at: Instant, + pub last_activity: Instant, + pub state: ConnectionState, +} + +#[derive(Debug, Clone)] +pub enum ConnectionProtocol { + Http1, + Http2, + Tcp, + Unknown, +} + +#[derive(Debug, Clone, PartialEq)] +pub enum ConnectionState { + Active, + Draining, + Closing, + Closed, +} + +#[derive(Debug, Clone)] +pub struct DrainProgress { + pub total_connections: usize, + pub active_connections: usize, + pub draining_connections: usize, + pub percentage: f64, + pub elapsed: Duration, + pub remaining_time: Option, +} + +#[derive(Debug, Clone)] +pub struct ListenerDrainStatusReport { + pub listener_name: String, + pub total_versions: usize, + pub draining_versions: usize, + pub version_statuses: Vec, + pub last_updated: Instant, +} + +#[derive(Debug, Clone)] +pub struct VersionDrainStatus { + pub version: u64, + pub state: DrainPhase, + pub drain_progress: Option, + pub started_at: Option, + pub estimated_completion: Option, +} + +#[derive(Debug, Clone)] +pub enum DrainPhase { + Active, + Draining, + ForceClosing, + Completed, +} + +#[derive(Debug, Clone)] +pub struct GlobalDrainStatistics { + pub total_listeners: usize, + pub draining_listeners: usize, + pub total_connections: usize, + pub draining_connections: usize, + pub oldest_drain_start: Option, + pub drain_efficiency: f64, + pub estimated_global_completion: Option, +} + +pub trait ConnectionManager: Send + Sync { + fn on_connection_established(&self, listener_name: &str, conn_info: ConnectionInfo); + fn on_connection_closed(&self, listener_name: &str, connection_id: &str); + fn start_connection_draining( + &self, + listener_name: &str, + connection_id: &str, + protocol_behavior: &ProtocolDrainBehavior, + ); + fn get_active_connections(&self, listener_name: &str) -> Vec; + fn force_close_connection(&self, listener_name: &str, connection_id: &str); +} + +#[derive(Debug, Default)] +pub struct DefaultConnectionManager { + connections: Arc>, + listener_connection_counts: Arc>, + http_managers: Arc>>, +} + +impl DefaultConnectionManager { + pub fn new() -> Self { + Self { + connections: Arc::new(DashMap::new()), + listener_connection_counts: Arc::new(DashMap::new()), + http_managers: Arc::new(DashMap::new()), + } + } + + fn make_connection_key(listener_name: &str, connection_id: &str) -> String { + format!("{}:{}", listener_name, connection_id) + } + + fn parse_connection_key(key: &str) -> Option<(String, String)> { + if let Some(pos) = key.find(':') { + let (listener, conn_id) = key.split_at(pos); + Some((listener.to_string(), conn_id[1..].to_string())) + } else { + None + } + } + + fn make_listener_prefix(listener_name: &str) -> String { + format!("{}:", listener_name) + } + + pub fn get_total_connections(&self) -> usize { + self.connections.len() + } + + pub fn get_listener_connections(&self, listener_name: &str) -> Vec { + let prefix = Self::make_listener_prefix(listener_name); + self.connections + .iter() + .filter(|entry| entry.key().starts_with(&prefix)) + .map(|entry| entry.value().clone()) + .collect() + } + + pub fn get_listener_summary(&self) -> Vec<(String, usize)> { + self.listener_connection_counts + .iter() + .map(|entry| (entry.key().clone(), entry.value().load(Ordering::Relaxed))) + .collect() + } + + pub fn update_connection_state( + &self, + listener_name: &str, + connection_id: &str, + new_state: ConnectionState, + ) -> bool { + let conn_key = Self::make_connection_key(listener_name, connection_id); + if let Some(mut conn_entry) = self.connections.get_mut(&conn_key) { + let conn_info = conn_entry.value_mut(); + let old_state = conn_info.state.clone(); + conn_info.state = new_state.clone(); + conn_info.last_activity = Instant::now(); + + info!("Connection {} state changed: {:?} -> {:?}", connection_id, old_state, new_state); + return true; + } + false + } + + pub fn get_connection_state(&self, listener_name: &str, connection_id: &str) -> Option { + let conn_key = Self::make_connection_key(listener_name, connection_id); + self.connections.get(&conn_key).map(|entry| entry.value().state.clone()) + } + + pub fn register_http_manager(&self, listener_name: String, manager: Arc) { + self.http_managers.insert(listener_name.clone(), manager); + debug!("Registered HTTP connection manager for listener: {}", listener_name); + } + + pub fn unregister_http_manager(&self, listener_name: &str) -> Option> { + let manager = self.http_managers.remove(listener_name); + if manager.is_some() { + debug!("Unregistered HTTP connection manager for listener: {}", listener_name); + } + manager.map(|(_, v)| v) + } + + pub fn get_all_http_managers(&self) -> Vec<(String, Arc)> { + self.http_managers.iter().map(|entry| (entry.key().clone(), entry.value().clone())).collect() + } + + pub fn get_http_manager(&self, listener_name: &str) -> Option> { + self.http_managers.get(listener_name).map(|entry| entry.value().clone()) + } + + pub async fn start_draining_http_managers( + &self, + listener_name: &str, + ) -> std::result::Result<(), Box> { + if let Some(manager) = self.get_http_manager(listener_name) { + info!("Starting drain signaling for HTTP connection manager on listener {}", listener_name); + + let drain_manager = manager.get_drain_signaling(); + let drain_state = super::drain_signaling::ListenerDrainState { + started_at: std::time::Instant::now(), + strategy: DrainStrategy::Gradual, + protocol_behavior: ProtocolDrainBehavior::Auto, + drain_scenario: super::drain_signaling::DrainScenario::ListenerUpdate, + drain_type: orion_configuration::config::listener::DrainType::Default, + }; + drain_manager.start_listener_draining(drain_state).await; + + info!("HTTP connection manager drain signaling started for listener {}", listener_name); + Ok(()) + } else { + warn!("No HTTP connection manager found for listener: {}", listener_name); + Ok(()) + } + } + + pub fn remove_connection(&self, listener_name: &str, connection_id: &str) -> bool { + let conn_key = Self::make_connection_key(listener_name, connection_id); + let removed = self.connections.remove(&conn_key).is_some(); + + if removed { + if let Some(count) = self.listener_connection_counts.get(listener_name) { + count.fetch_sub(1, Ordering::Relaxed); + } + } + + removed + } + + pub fn cleanup_connection(&self, listener_name: &str, connection_id: &str) -> bool { + let removed = self.remove_connection(listener_name, connection_id); + + if removed { + debug!("Cleaned up connection state for {} on listener {}", connection_id, listener_name); + + if let Some(count) = self.listener_connection_counts.get(listener_name) { + if count.load(Ordering::Relaxed) == 0 { + info!("All connections drained for listener {}", listener_name); + } + } + } + + removed + } + + pub fn cleanup_completed_drains(&self) { + debug!("Running cleanup of completed drain connections"); + let closed_connections: Vec = self + .connections + .iter() + .filter(|entry| matches!(entry.value().state, ConnectionState::Closed)) + .map(|entry| entry.key().clone()) + .collect(); + + for conn_key in closed_connections { + if let Some((listener_name, conn_id)) = Self::parse_connection_key(&conn_key) { + self.connections.remove(&conn_key); + if let Some(count) = self.listener_connection_counts.get(&listener_name) { + count.fetch_sub(1, Ordering::Relaxed); + } + debug!("Cleaned up closed connection {} from listener {}", conn_id, listener_name); + } + } + } + + pub fn notify_connection_closed(&self, listener_name: &str, connection_id: &str) { + self.update_connection_state(listener_name, connection_id, ConnectionState::Closed); + + if self.cleanup_connection(listener_name, connection_id) { + debug!("Connection {} closed and cleaned up from listener {}", connection_id, listener_name); + } + } + + pub fn get_connections_by_state(&self, listener_name: &str, state: ConnectionState) -> Vec { + let prefix = Self::make_listener_prefix(listener_name); + self.connections + .iter() + .filter(|entry| entry.key().starts_with(&prefix) && entry.value().state == state) + .map(|entry| entry.value().clone()) + .collect() + } + + pub fn get_all_draining_connections(&self) -> Vec<(String, ConnectionInfo)> { + self.connections + .iter() + .filter(|entry| matches!(entry.value().state, ConnectionState::Draining)) + .filter_map(|entry| { + if let Some((listener_name, _)) = Self::parse_connection_key(entry.key()) { + Some((listener_name, entry.value().clone())) + } else { + None + } + }) + .collect() + } + + pub fn cleanup_stale_draining_connections(&self, max_drain_time: std::time::Duration) -> usize { + let mut closed_count = 0; + let now = Instant::now(); + + let connections_to_close: Vec<(String, String)> = self + .connections + .iter() + .filter(|entry| { + matches!(entry.value().state, ConnectionState::Draining) + && now.duration_since(entry.value().last_activity) > max_drain_time + }) + .filter_map(|entry| { + if let Some((listener_name, conn_id)) = Self::parse_connection_key(entry.key()) { + Some((listener_name, conn_id)) + } else { + None + } + }) + .collect(); + + for (listener_name, conn_id) in connections_to_close { + self.force_close_connection(&listener_name, &conn_id); + closed_count += 1; + } + + if closed_count > 0 { + warn!("Force closed {} stale draining connections (drain time > {:?})", closed_count, max_drain_time); + } + + closed_count + } +} + +impl ConnectionManager for DefaultConnectionManager { + fn on_connection_established(&self, listener_name: &str, conn_info: ConnectionInfo) { + debug!("Connection {} established on listener {}", conn_info.id, listener_name); + + let conn_key = Self::make_connection_key(listener_name, &conn_info.id); + self.connections.insert(conn_key, conn_info); + + let count = + self.listener_connection_counts.entry(listener_name.to_string()).or_insert_with(|| AtomicUsize::new(0)); + count.fetch_add(1, Ordering::Relaxed); + } + + fn on_connection_closed(&self, listener_name: &str, connection_id: &str) { + debug!("Connection {} closed on listener {}", connection_id, listener_name); + + let conn_key = Self::make_connection_key(listener_name, connection_id); + if self.connections.remove(&conn_key).is_some() { + if let Some(count) = self.listener_connection_counts.get(listener_name) { + count.fetch_sub(1, Ordering::Relaxed); + } + } + } + + fn start_connection_draining( + &self, + listener_name: &str, + connection_id: &str, + protocol_behavior: &ProtocolDrainBehavior, + ) { + debug!( + "Starting drain for connection {} on listener {} with protocol {:?}", + connection_id, listener_name, protocol_behavior + ); + + let conn_key = Self::make_connection_key(listener_name, connection_id); + if let Some(mut conn_entry) = self.connections.get_mut(&conn_key) { + let conn_info = conn_entry.value_mut(); + conn_info.state = ConnectionState::Draining; + conn_info.last_activity = Instant::now(); + + info!("Connection {} on listener {} is now draining", connection_id, listener_name); + } + } + + fn get_active_connections(&self, listener_name: &str) -> Vec { + self.get_listener_connections(listener_name) + } + + fn force_close_connection(&self, listener_name: &str, connection_id: &str) { + warn!("Force closing connection {} on listener {}", connection_id, listener_name); + + let conn_key = Self::make_connection_key(listener_name, connection_id); + if let Some(mut conn_entry) = self.connections.get_mut(&conn_key) { + let conn_info = conn_entry.value_mut(); + conn_info.state = ConnectionState::Closing; + conn_info.last_activity = Instant::now(); + info!("Connection {} marked for force close (protocol: {:?})", connection_id, conn_info.protocol); + } + + if self.connections.remove(&conn_key).is_some() { + if let Some(count) = self.listener_connection_counts.get(listener_name) { + count.fetch_sub(1, Ordering::Relaxed); + } + } + } +} + +#[derive(Debug, Clone)] +pub struct ListenerManagerConfig { + pub max_versions_per_listener: usize, + pub cleanup_policy: CleanupPolicy, + pub cleanup_interval: Duration, + pub drain_config: ListenerDrainConfig, +} + +#[derive(Debug, Clone)] +pub struct ListenerDrainConfig { + pub drain_time: Duration, + pub drain_strategy: DrainStrategy, + pub protocol_handling: ProtocolDrainBehavior, +} + +#[derive(Debug, Clone)] +pub enum DrainStrategy { + Gradual, + Immediate, +} + +#[derive(Debug, Clone)] +pub enum ProtocolDrainBehavior { + Http1 { connection_close: bool }, + Http2 { send_goaway: bool }, + Tcp { force_close_after: Duration }, + Auto, +} + +#[derive(Debug, Clone)] +pub enum CleanupPolicy { + CountBasedOnly(usize), +} + +#[derive(Debug, Clone)] +enum ListenerState { + Active, + Draining { started_at: Instant, drain_config: ListenerDrainConfig }, +} + +impl Default for ListenerDrainConfig { + fn default() -> Self { + Self { + drain_time: Duration::from_secs(600), + drain_strategy: DrainStrategy::Gradual, + protocol_handling: ProtocolDrainBehavior::Auto, + } + } +} + +impl PartialEq for DrainStrategy { + fn eq(&self, other: &Self) -> bool { + std::mem::discriminant(self) == std::mem::discriminant(other) + } +} + +impl Default for ListenerManagerConfig { + fn default() -> Self { + Self { + max_versions_per_listener: 2, + cleanup_policy: CleanupPolicy::CountBasedOnly(2), + cleanup_interval: Duration::from_secs(60), + drain_config: ListenerDrainConfig::default(), + } + } +} + +#[derive(Debug)] struct ListenerInfo { handle: abort_on_drop::ChildTask<()>, listener_conf: ListenerConfig, version: u64, + state: ListenerState, + connections_count: Arc, + drain_manager_handle: Option>, } + impl ListenerInfo { fn new(handle: tokio::task::JoinHandle<()>, listener_conf: ListenerConfig, version: u64) -> Self { - Self { handle: handle.into(), listener_conf, version } + Self { + handle: handle.into(), + listener_conf, + version, + state: ListenerState::Active, + connections_count: Arc::new(AtomicUsize::new(0)), + drain_manager_handle: None, + } + } + + fn start_draining(&mut self, drain_config: ListenerDrainConfig, connection_manager: &DefaultConnectionManager) { + let active_count = self.connections_count.load(Ordering::Relaxed); + info!( + "Starting graceful draining for listener {} version {} with {} active connections", + self.listener_conf.name, self.version, active_count + ); + + self.state = ListenerState::Draining { started_at: Instant::now(), drain_config: drain_config.clone() }; + + let drain_handle = self.start_drain_monitor(drain_config.clone()); + self.drain_manager_handle = Some(drain_handle.into()); + + let protocol_behavior = match &self.state { + ListenerState::Draining { drain_config, .. } => drain_config.protocol_handling.clone(), + ListenerState::Active => return, + }; + + let active_connection_infos = + connection_manager.get_connections_by_state(&self.listener_conf.name, ConnectionState::Active); + let active_connection_ids: Vec = + active_connection_infos.iter().map(|conn_info| conn_info.id.clone()).collect(); + + for conn_id in active_connection_ids { + info!("Signaling existing connection {} to start draining", conn_id); + connection_manager.start_connection_draining(&self.listener_conf.name, &conn_id, &protocol_behavior); + } + } + + fn start_drain_monitor(&self, drain_config: ListenerDrainConfig) -> tokio::task::JoinHandle<()> { + let listener_name = self.listener_conf.name.clone(); + let version = self.version; + let connections_count = Arc::clone(&self.connections_count); + let started_at = Instant::now(); + + tokio::spawn(async move { + let mut interval = interval(Duration::from_secs(5)); + let mut last_connection_count = connections_count.load(Ordering::Relaxed); + + info!( + "Starting drain monitor for listener {} version {} with strategy {:?}", + listener_name, version, drain_config.drain_strategy + ); + + loop { + interval.tick().await; + + let elapsed = started_at.elapsed(); + let current_count = connections_count.load(Ordering::Relaxed); + + if elapsed >= drain_config.drain_time { + warn!("Drain timeout ({:?}) reached for listener {} version {}, force closing {} remaining connections", + drain_config.drain_time, listener_name, version, current_count); + + break; + } + + if current_count == 0 { + info!( + "All connections successfully drained for listener {} version {} in {:?}", + listener_name, version, elapsed + ); + break; + } + + if current_count != last_connection_count { + let connections_closed = last_connection_count.saturating_sub(current_count); + info!("Drain progress for listener {} version {}: {} connections closed, {} remaining ({:.1}% complete)", + listener_name, version, connections_closed, current_count, + Self::calculate_drain_percentage(elapsed, &drain_config)); + last_connection_count = current_count; + } + + match drain_config.drain_strategy { + DrainStrategy::Gradual => { + let progress = Self::calculate_drain_percentage(elapsed, &drain_config); + if progress > 50.0 { + debug!("Gradual drain reached 50%, would encourage connection draining"); + } + }, + DrainStrategy::Immediate => { + debug!("Immediate drain strategy - would encourage immediate draining"); + }, + } + + debug!("Drain monitor tick for listener {} version {}: {} active connections, {:.1}% complete, {:?} elapsed", + listener_name, version, current_count, + Self::calculate_drain_percentage(elapsed, &drain_config), + elapsed); + } + + info!("Drain monitor completed for listener {} version {}", listener_name, version); + }) + } + + fn calculate_drain_percentage(elapsed: Duration, drain_config: &ListenerDrainConfig) -> f64 { + match drain_config.drain_strategy { + DrainStrategy::Immediate => 100.0, + DrainStrategy::Gradual => { + (elapsed.as_secs_f64() / drain_config.drain_time.as_secs_f64() * 100.0).min(100.0) + }, + } + } + + fn is_draining(&self) -> bool { + matches!(self.state, ListenerState::Draining { .. }) + } + + fn should_force_close(&self) -> bool { + if let ListenerState::Draining { started_at, drain_config } = &self.state { + let elapsed = started_at.elapsed(); + elapsed >= drain_config.drain_time + } else { + false + } } } @@ -59,112 +657,727 @@ pub struct ListenersManager { route_configuration_channel: mpsc::Receiver, listener_handles: MultiMap, version_counter: u64, + config: ListenerManagerConfig, + connection_manager: Arc, } impl ListenersManager { pub fn new( listener_configuration_channel: mpsc::Receiver, route_configuration_channel: mpsc::Receiver, + ) -> Self { + Self::new_with_config( + listener_configuration_channel, + route_configuration_channel, + ListenerManagerConfig::default(), + ) + } + + pub fn new_with_config( + listener_configuration_channel: mpsc::Receiver, + route_configuration_channel: mpsc::Receiver, + config: ListenerManagerConfig, ) -> Self { ListenersManager { listener_configuration_channel, route_configuration_channel, listener_handles: MultiMap::new(), version_counter: 0, + config, + connection_manager: Arc::new(DefaultConnectionManager::new()), + } + } + + pub fn get_connection_manager(&self) -> Arc { + Arc::clone(&self.connection_manager) + } + + fn is_drain_active(listener_info: &ListenerInfo) -> bool { + if let Some(drain_handle) = &listener_info.drain_manager_handle { + !drain_handle.is_finished() + } else { + false + } + } + + fn get_drain_phase(listener_info: &ListenerInfo) -> DrainPhase { + if !listener_info.is_draining() { + return DrainPhase::Active; + } + + if listener_info.should_force_close() { + DrainPhase::ForceClosing + } else if Self::is_drain_active(listener_info) { + DrainPhase::Draining + } else { + DrainPhase::Completed } } - pub async fn start(mut self, ct: tokio_util::sync::CancellationToken) -> Result<()> { - let (tx_secret_updates, _) = broadcast::channel(16); - let (tx_route_updates, _) = broadcast::channel(16); - // TODO: create child token for each listener? - loop { - tokio::select! { - Some(listener_configuration_change) = self.listener_configuration_channel.recv() => { - match listener_configuration_change { - ListenerConfigurationChange::Added(boxed) => { - let (factory, listener_conf) = *boxed; - let listener = factory.clone() - .make_listener(tx_route_updates.subscribe(), tx_secret_updates.subscribe())?; - if let Err(e) = self.start_listener(listener, listener_conf) { - warn!("Failed to start listener: {e}"); + #[allow(dead_code)] + fn count_draining_versions(listener_infos: &[ListenerInfo]) -> usize { + listener_infos.iter().filter(|info| info.is_draining()).count() + } + + fn find_active_versions(listener_infos: &[ListenerInfo]) -> Vec<&ListenerInfo> { + listener_infos.iter().filter(|info| !info.is_draining()).collect() + } + + #[allow(dead_code)] + fn find_latest_active_version(listener_infos: &[ListenerInfo]) -> Option<&ListenerInfo> { + Self::find_active_versions(listener_infos).into_iter().max_by_key(|info| info.version) + } + + pub fn get_listener_drain_status(&self, listener_name: &str) -> Vec { + let mut drain_statuses = Vec::new(); + + if let Some(listener_infos) = self.listener_handles.get_vec(listener_name) { + for listener_info in listener_infos { + if listener_info.is_draining() { + let progress = self.get_drain_progress_for_listener(listener_info); + drain_statuses.push(progress); + } + } + } + + drain_statuses + } + + pub fn get_all_listener_names(&self) -> Vec { + self.listener_handles.keys().cloned().collect() + } + + pub fn get_total_active_connections(&self) -> usize { + let mut total = 0; + for (_, listener_info) in self.listener_handles.iter() { + total += listener_info.connections_count.load(Ordering::Relaxed); + } + total + } + + fn get_drain_progress_for_listener(&self, listener_info: &ListenerInfo) -> DrainProgress { + let (percentage, elapsed, remaining_time) = + if let ListenerState::Draining { started_at, drain_config } = &listener_info.state { + let elapsed = started_at.elapsed(); + let percentage = ListenerInfo::calculate_drain_percentage(elapsed, drain_config); + let remaining = drain_config.drain_time.saturating_sub(elapsed); + (percentage, elapsed, Some(remaining)) + } else { + (0.0, Duration::ZERO, None) + }; + + let all_connections = self.connection_manager.get_listener_connections(&listener_info.listener_conf.name); + let total_connections = all_connections.len(); + let draining_connections = + all_connections.iter().filter(|conn| matches!(conn.state, ConnectionState::Draining)).count(); + let active_connections = total_connections - draining_connections; + + DrainProgress { + total_connections, + active_connections, + draining_connections, + percentage, + elapsed, + remaining_time, + } + } + + fn estimate_drain_completion_for_listener(&self, listener_info: &ListenerInfo) -> Option { + if let ListenerState::Draining { started_at, drain_config } = &listener_info.state { + let current_connections = listener_info.connections_count.load(Ordering::Relaxed); + + if current_connections == 0 { + return Some(Instant::now()); + } + + match drain_config.drain_strategy { + DrainStrategy::Immediate => Some(*started_at + drain_config.drain_time), + DrainStrategy::Gradual => { + let elapsed = started_at.elapsed(); + if elapsed.as_secs() > 10 { + let all_connections = + self.connection_manager.get_listener_connections(&listener_info.listener_conf.name); + let initial_connections = all_connections.len() + current_connections; + + if initial_connections > current_connections { + let drained_count = initial_connections - current_connections; + let elapsed_secs = elapsed.as_secs_f64(); + if elapsed_secs > 0.0 { + let drain_rate = drained_count as f64 / elapsed_secs; + if drain_rate > 0.0 { + let estimated_remaining_time = current_connections as f64 / drain_rate; + return Some(Instant::now() + Duration::from_secs_f64(estimated_remaining_time)); + } + } + } + + Some(*started_at + drain_config.drain_time) + } else { + Some(*started_at + drain_config.drain_time) + } + }, + } + } else { + None + } + } + + pub fn get_comprehensive_drain_status(&self) -> HashMap { + let mut reports = HashMap::new(); + + for (listener_name, listener_infos) in self.listener_handles.iter_all() { + let mut version_statuses = Vec::new(); + let mut draining_count = 0; + + for listener_info in listener_infos { + let status = if listener_info.is_draining() { + draining_count += 1; + + let drain_phase = Self::get_drain_phase(listener_info); + let progress = self.get_drain_progress_for_listener(listener_info); + let estimated_completion = self.estimate_drain_completion_for_listener(listener_info); + + VersionDrainStatus { + version: listener_info.version, + state: drain_phase, + drain_progress: Some(progress), + started_at: match &listener_info.state { + ListenerState::Draining { started_at, .. } => Some(*started_at), + ListenerState::Active => None, + }, + estimated_completion, + } + } else { + VersionDrainStatus { + version: listener_info.version, + state: DrainPhase::Active, + drain_progress: None, + started_at: None, + estimated_completion: None, + } + }; + + version_statuses.push(status); + } + + reports.insert( + listener_name.clone(), + ListenerDrainStatusReport { + listener_name: listener_name.clone(), + total_versions: version_statuses.len(), + draining_versions: draining_count, + version_statuses, + last_updated: Instant::now(), + }, + ); + } + + reports + } + + pub fn get_global_drain_statistics(&self) -> GlobalDrainStatistics { + let mut total_listeners = 0; + let mut draining_listeners = 0; + let mut total_connections = 0; + let mut draining_connections = 0; + let mut oldest_drain_start: Option = None; + + for (_, listener_info) in self.listener_handles.iter() { + total_listeners += 1; + let connections = listener_info.connections_count.load(Ordering::Relaxed); + total_connections += connections; + + if listener_info.is_draining() { + if Self::is_drain_active(listener_info) { + draining_listeners += 1; + draining_connections += connections; + + if let ListenerState::Draining { started_at, .. } = &listener_info.state { + match oldest_drain_start { + None => oldest_drain_start = Some(*started_at), + Some(existing) if *started_at < existing => oldest_drain_start = Some(*started_at), + _ => {}, + } + } + } + } + } + + GlobalDrainStatistics { + total_listeners, + draining_listeners, + total_connections, + draining_connections, + oldest_drain_start, + drain_efficiency: if total_connections > 0 { + (total_connections - draining_connections) as f64 / total_connections as f64 * 100.0 + } else { + 100.0 + }, + estimated_global_completion: self.estimate_global_drain_completion(), + } + } + + fn estimate_global_drain_completion(&self) -> Option { + let mut latest_completion: Option = None; + + for (_, listener_info) in self.listener_handles.iter() { + if listener_info.is_draining() { + if let Some(estimated) = self.estimate_drain_completion_for_listener(listener_info) { + match latest_completion { + None => latest_completion = Some(estimated), + Some(existing) if estimated > existing => latest_completion = Some(estimated), + _ => {}, + } + } + } + } + + latest_completion + } + + pub fn start_drain_monitoring_task(&self) -> tokio::task::JoinHandle<()> { + let (_tx, mut rx) = tokio::sync::mpsc::channel::(100); + let connection_manager = Arc::clone(&self.connection_manager); + let _cleanup_interval = self.config.drain_config.drain_time; + + tokio::spawn(async move { + let mut interval = tokio::time::interval(Duration::from_secs(10)); + let mut cleanup_interval_timer = tokio::time::interval(Duration::from_secs(60)); + + loop { + tokio::select! { + _ = interval.tick() => { + debug!("Drain monitoring tick - monitoring mechanism needs redesign"); + + connection_manager.cleanup_completed_drains(); + }, + _ = cleanup_interval_timer.tick() => { + info!("Running periodic cleanup of completed drain states"); + connection_manager.cleanup_completed_drains(); + }, + event = rx.recv() => { + if let Some(event) = event { + debug!("Received drain monitoring event: {}", event); + } else { + info!("Drain monitoring channel closed, exiting monitor"); + break; + } + } + } + } + + info!("Drain monitoring task completed"); + }) + } + + pub async fn register_http_connection_manager(&self, listener_name: &str, http_cm: Arc) { + if let Some(listener_infos) = self.listener_handles.get_vec(listener_name) { + if let Some(latest_listener) = listener_infos.iter().max_by_key(|info| info.version) { + if latest_listener.is_draining() { + if let ListenerState::Draining { drain_config, .. } = &latest_listener.state { + let drain_state = ListenerDrainState { + started_at: Instant::now(), + strategy: drain_config.drain_strategy.clone(), + protocol_behavior: drain_config.protocol_handling.clone(), + drain_scenario: super::drain_signaling::DrainScenario::ListenerUpdate, + drain_type: orion_configuration::config::listener::DrainType::Default, + }; + http_cm.start_draining(drain_state).await; + info!("HTTP Connection Manager for listener {} started draining immediately", listener_name); + } + } + } + } + } + + pub async fn start_draining_http_connection_managers(&self, listener_name: &str) { + info!("Starting drain signaling for HTTP connection managers on listener {}", listener_name); + + if let Err(e) = self.connection_manager.start_draining_http_managers(listener_name).await { + warn!("Failed to start draining HTTP managers for listener {}: {}", listener_name, e); + } + } + + pub async fn start(mut self, ct: tokio_util::sync::CancellationToken) -> Result<()> { + let (tx_secret_updates, _) = broadcast::channel(16); + let (tx_route_updates, _) = broadcast::channel(16); + // TODO: create child token for each listener? + loop { + tokio::select! { + Some(listener_configuration_change) = self.listener_configuration_channel.recv() => { + match listener_configuration_change { + ListenerConfigurationChange::Added(boxed) => { + let (factory, listener_conf) = *boxed; + let listener = factory.clone() + .make_listener(tx_route_updates.subscribe(), tx_secret_updates.subscribe())?; + if let Err(e) = self.start_listener(listener, listener_conf) { + warn!("Failed to start listener: {e}"); + } + } + ListenerConfigurationChange::Removed(listener_name) => { + if let Err(e) = self.stop_listener(&listener_name) { + warn!("Failed to stop removed listener {}: {}", listener_name, e); + } + }, + ListenerConfigurationChange::TlsContextChanged((secret_id, secret)) => { + info!("Got tls secret update {secret_id}"); + let res = tx_secret_updates.send(TlsContextChange::Updated((secret_id, secret))); + if let Err(e) = res{ + warn!("Internal problem when updating a secret: {e}"); + } + }, + ListenerConfigurationChange::GetConfiguration(config_dump_tx) => { + let listeners: Vec = self.listener_handles + .iter() + .map(|(_, info)| info.listener_conf.clone()) + .collect(); + config_dump_tx.send(ConfigDump { listeners: Some(listeners), ..Default::default() }).await?; + }, + } + }, + Some(route_configuration_change) = self.route_configuration_channel.recv() => { + // routes could be CachedWatch instead, as they are evaluated lazilly + let res = tx_route_updates.send(route_configuration_change); + if let Err(e) = res{ + warn!("Internal problem when updating a route: {e}"); + } + }, + _ = ct.cancelled() => { + warn!("Listener manager exiting"); + return Ok(()); + } + } + } + } + + pub fn start_listener(&mut self, listener: Listener, listener_conf: ListenerConfig) -> Result<()> { + let listener_name = listener.get_name().to_string(); + if let Some((addr, dev)) = listener.get_socket() { + info!("Listener {} at {addr} (device bind:{})", listener_name, dev.is_some()); + } else { + info!("Internal listener {}", listener_name); + } + + if let Some(existing_versions) = self.listener_handles.get_vec(&listener_name) { + for existing_info in existing_versions { + if Self::listener_configs_equivalent(&existing_info.listener_conf, &listener_conf) { + info!( + "Listener {} configuration unchanged from version {}, skipping duplicate creation", + listener_name, existing_info.version + ); + return Ok(()); + } + } + + info!( + "Configuration changed for listener {}, stopping all existing versions to prevent mixed responses", + listener_name + ); + + let existing_versions_to_stop: Vec<_> = Self::find_active_versions(&existing_versions) + .into_iter() + .map(|info| (info.version, info.listener_conf.address.clone())) + .collect(); + + for (old_version, old_addr) in existing_versions_to_stop { + info!("Stopping listener {} version {} (config update) at {:?}", listener_name, old_version, old_addr); + if let Err(e) = self.stop_listener_version(&listener_name, old_version) { + warn!( + "Failed to stop listener {} version {} during config update: {}", + listener_name, old_version, e + ); + } + } + + info!("Starting new version of listener {} with updated configuration", listener_name); + } else { + info!("Starting initial version of listener {}", listener_name); + } + + let version = if let Some(version_info) = &listener_conf.version_info { + version_info.parse::().unwrap_or_else(|_| { + use std::collections::hash_map::DefaultHasher; + use std::hash::{Hash, Hasher}; + let mut hasher = DefaultHasher::new(); + version_info.hash(&mut hasher); + hasher.finish() + }) + } else { + self.version_counter += 1; + self.version_counter + }; + + let listener_name_for_async = listener_name.clone(); + let join_handle = tokio::spawn(async move { + let error = listener.start().await; + warn!("Listener {} version {} exited with error: {}", listener_name_for_async, version, error); + }); + + let listener_info = ListenerInfo::new(join_handle, listener_conf, version); + + self.listener_handles.insert(listener_name.clone(), listener_info); + let version_count = self.listener_handles.get_vec(&listener_name).map(|v| v.len()).unwrap_or(0); + info!("Started version {} of listener {} ({} total active version(s))", version, listener_name, version_count); + + self.drain_old_listeners(&listener_name); + + Ok(()) + } + + fn listener_configs_equivalent(config1: &ListenerConfig, config2: &ListenerConfig) -> bool { + config1.name == config2.name + && config1.address == config2.address + && config1.with_tls_inspector == config2.with_tls_inspector + && config1.with_tlv_listener_filter == config2.with_tlv_listener_filter + && config1.proxy_protocol_config == config2.proxy_protocol_config + && config1.tlv_listener_filter_config == config2.tlv_listener_filter_config + && config1.drain_type == config2.drain_type + && config1.bind_device == config2.bind_device + && config1.filter_chains == config2.filter_chains + } + + fn resolve_address_conflicts(&mut self, listener_name: &str, new_config: &ListenerConfig) -> Result<()> { + if let Some(existing_versions) = self.listener_handles.get_vec_mut(listener_name) { + let mut conflicts_found = 0; + + for existing_info in existing_versions.iter_mut() { + if existing_info.listener_conf.address == new_config.address && !existing_info.is_draining() { + warn!( + "Address conflict: listener {} version {} at {:?} conflicts with new configuration", + listener_name, existing_info.version, new_config.address + ); + + existing_info.start_draining(self.config.drain_config.clone(), &self.connection_manager); + conflicts_found += 1; + + info!( + "Started graceful drain for conflicting listener {} version {} to resolve address binding", + listener_name, existing_info.version + ); + } + } + + if conflicts_found > 0 { + info!( + "Resolved {} address conflict(s) for listener {} through graceful draining", + conflicts_found, listener_name + ); + } + } + + Ok(()) + } + + pub fn start_listener_with_conflict_resolution( + &mut self, + listener: Listener, + listener_conf: ListenerConfig, + ) -> Result<()> { + let listener_name = listener.get_name().to_string(); + + self.resolve_address_conflicts(&listener_name, &listener_conf)?; + self.start_listener(listener, listener_conf) + } + + pub fn stop_listener_version(&mut self, listener_name: &str, version: u64) -> Result<()> { + if let Some(versions) = self.listener_handles.get_vec_mut(listener_name) { + let mut found = false; + for listener_info in versions.iter_mut() { + if listener_info.version == version && !listener_info.is_draining() { + info!( + "Starting graceful draining for listener {} version {} with {} active connections", + listener_name, + version, + listener_info.connections_count.load(Ordering::Relaxed) + ); + + listener_info.start_draining(self.config.drain_config.clone(), &self.connection_manager); + found = true; + + info!( + "Stopping listener {} version {} (drain strategy: {:?})", + listener_name, version, self.config.drain_config.drain_strategy + ); + + if self.config.drain_config.drain_strategy == DrainStrategy::Immediate { + listener_info.handle.abort(); + if let Some(drain_handle) = listener_info.drain_manager_handle.as_ref() { + drain_handle.abort(); + } + } else { + warn!( + "Gracefully draining old listener {} version {} - monitored by background task", + listener_name, version + ); + } + break; + } + } + + if !found { + info!("Listener {} version {} not found or already draining", listener_name, version); + } + } else { + info!("No listeners found with name {}", listener_name); + } + + Ok(()) + } + + pub fn stop_listener(&mut self, listener_name: &str) -> Result<()> { + if let Some(mut listeners) = self.listener_handles.remove(listener_name) { + info!("Gracefully stopping all {} version(s) of listener {}", listeners.len(), listener_name); + + // Start draining for all versions + for listener_info in &mut listeners { + if !listener_info.is_draining() { + listener_info.start_draining(self.config.drain_config.clone(), &self.connection_manager); + } + } + + for listener_info in listeners { + info!( + "Stopping listener {} version {} (drain strategy: {:?})", + listener_name, listener_info.version, self.config.drain_config.drain_strategy + ); + + if self.config.drain_config.drain_strategy == DrainStrategy::Immediate { + listener_info.handle.abort(); + if let Some(drain_handle) = listener_info.drain_manager_handle { + drain_handle.abort(); + } + } else { + info!( + "Listener {} version {} will be managed by drain monitor", + listener_name, listener_info.version + ); + } + } + } else { + info!("No listeners found with name {}", listener_name); + } + + Ok(()) + } + + fn drain_old_listeners(&mut self, listener_name: &str) { + if let Some(versions) = self.listener_handles.get_vec_mut(listener_name) { + let original_count = versions.len(); + + match &self.config.cleanup_policy { + CleanupPolicy::CountBasedOnly(max_count) => { + if versions.len() > *max_count { + let to_remove = versions.len() - max_count; + let mut to_drain = Vec::new(); + + for _ in 0..to_remove { + if let Some(mut old_listener) = versions.drain(0..1).next() { + info!( + "Starting drain for old listener {} version {} (count limit)", + listener_name, old_listener.version + ); + old_listener.start_draining(self.config.drain_config.clone(), &self.connection_manager); + to_drain.push(old_listener); } } - ListenerConfigurationChange::Removed(listener_name) => { - let _ = self.stop_listener(&listener_name); - }, - ListenerConfigurationChange::TlsContextChanged((secret_id, secret)) => { - info!("Got tls secret update {secret_id}"); - let res = tx_secret_updates.send(TlsContextChange::Updated((secret_id, secret))); - if let Err(e) = res{ - warn!("Internal problem when updating a secret: {e}"); + + for draining_listener in to_drain { + if self.config.drain_config.drain_strategy == DrainStrategy::Immediate { + info!( + "Immediately stopping old listener {} version {} (immediate drain)", + listener_name, draining_listener.version + ); + draining_listener.handle.abort(); + if let Some(drain_handle) = draining_listener.drain_manager_handle { + drain_handle.abort(); + } + } else { + warn!( + "Gracefully draining old listener {} version {} - monitored by background task", + listener_name, draining_listener.version + ); + draining_listener.handle.abort(); + if let Some(drain_handle) = draining_listener.drain_manager_handle { + drain_handle.abort(); + } } - }, - ListenerConfigurationChange::GetConfiguration(config_dump_tx) => { - let listeners: Vec = self.listener_handles - .iter() - .map(|(_, info)| info.listener_conf.clone()) - .collect(); - config_dump_tx.send(ConfigDump { listeners: Some(listeners), ..Default::default() }).await?; - }, - } - }, - Some(route_configuration_change) = self.route_configuration_channel.recv() => { - // routes could be CachedWatch instead, as they are evaluated lazilly - let res = tx_route_updates.send(route_configuration_change); - if let Err(e) = res{ - warn!("Internal problem when updating a route: {e}"); + } + + let remaining_count = versions.len(); + if original_count != remaining_count { + info!( + "Cleaned up {} old version(s) of listener {}, {} remaining", + original_count - remaining_count, + listener_name, + remaining_count + ); + } } }, - _ = ct.cancelled() => { - warn!("Listener manager exiting"); - return Ok(()); - } } } } - pub fn start_listener(&mut self, listener: Listener, listener_conf: ListenerConfig) -> Result<()> { - let listener_name = listener.get_name().to_string(); - if let Some((addr, dev)) = listener.get_socket() { - info!("Listener {} at {addr} (device bind:{})", listener_name, dev.is_some()); - } else { - info!("Internal listener {}", listener_name); + pub async fn graceful_shutdown(&mut self, timeout: Duration) -> Result<()> { + info!("Starting graceful shutdown with timeout {:?}", timeout); + + let listener_names: Vec = self.listener_handles.iter().map(|(name, _)| name.clone()).collect(); + + for listener_name in &listener_names { + info!("Starting drain for listener: {}", listener_name); + self.start_draining_http_connection_managers(listener_name).await; } - self.version_counter += 1; - let version = self.version_counter; + let start_time = Instant::now(); + let mut interval = tokio::time::interval(Duration::from_secs(1)); - let listener_name_for_async = listener_name.clone(); + loop { + interval.tick().await; - let join_handle = tokio::spawn(async move { - let error = listener.start().await; - info!("Listener {} version {} exited: {}", listener_name_for_async, version, error); - }); + let total_connections = self.get_total_active_connections(); + if total_connections == 0 { + info!("All connections have drained gracefully"); + break; + } - let listener_info = ListenerInfo::new(join_handle, listener_conf, version); - self.listener_handles.insert(listener_name.clone(), listener_info); + if start_time.elapsed() >= timeout { + warn!("Graceful shutdown timeout reached with {} connections still active", total_connections); + break; + } - let version_count = self.listener_handles.get_vec(&listener_name).map(|v| v.len()).unwrap_or(0); - info!("Started version {} of listener {} ({} total active version(s))", version, listener_name, version_count); + debug!( + "Waiting for {} connections to drain ({}s remaining)", + total_connections, + (timeout - start_time.elapsed()).as_secs() + ); + } - Ok(()) - } + let total_connections = self.get_total_active_connections(); + if total_connections > 0 { + warn!("Force closing {} remaining connections", total_connections); + self.force_close_all_connections(); + } - pub fn stop_listener(&mut self, listener_name: &str) -> Result<()> { - if let Some(listeners) = self.listener_handles.get_vec_mut(listener_name) { - info!("Stopping all {} version(s) of listener {}", listeners.len(), listener_name); - for listener_info in listeners.drain(..) { - info!("Stopping listener {} version {}", listener_name, listener_info.version); - listener_info.handle.abort(); + for listener_name in &listener_names { + if let Err(e) = self.stop_listener(listener_name) { + warn!("Failed to stop listener {}: {}", listener_name, e); } - self.listener_handles.remove(listener_name); - } else { - info!("No listeners found with name {}", listener_name); } + info!("Graceful shutdown completed"); Ok(()) } + + fn force_close_all_connections(&self) { + for listener_name in self.listener_handles.keys() { + let connections = self.connection_manager.get_listener_connections(listener_name); + for conn_info in connections { + warn!("Force closing connection {} on listener {}", conn_info.id, listener_name); + self.connection_manager.force_close_connection(listener_name, &conn_info.id); + } + } + } } #[cfg(test)] @@ -172,12 +1385,30 @@ mod tests { use std::{ collections::HashMap, net::{IpAddr, Ipv4Addr, SocketAddr}, + sync::Arc, }; use super::*; + use orion_configuration::config::listener::ListenerAddress; use orion_configuration::config::Listener as ListenerConfig; + use tokio::sync::Mutex; use tracing_test::traced_test; + fn create_test_listener_config(name: &str, port: u16) -> ListenerConfig { + ListenerConfig { + name: name.into(), + address: ListenerAddress::Socket(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), port)), + version_info: None, + filter_chains: HashMap::default(), + bind_device: None, + with_tls_inspector: false, + proxy_protocol_config: None, + with_tlv_listener_filter: false, + tlv_listener_filter_config: None, + drain_type: orion_configuration::config::listener::DrainType::Default, + } + } + #[traced_test] #[tokio::test] async fn start_listener_dup() { @@ -203,24 +1434,83 @@ mod tests { proxy_protocol_config: None, with_tlv_listener_filter: false, tlv_listener_filter_config: None, + drain_type: orion_configuration::config::listener::DrainType::Default, + version_info: None, }; man.start_listener(l1, l1_info.clone()).unwrap(); assert!(routeb_tx1.send(RouteConfigurationChange::Removed("n/a".into())).is_ok()); tokio::task::yield_now().await; - let (routeb_tx2, routeb_rx) = broadcast::channel(chan); + let (_routeb_tx2, routeb_rx) = broadcast::channel(chan); let (_secb_tx2, secb_rx) = broadcast::channel(chan); let l2 = Listener::test_listener(name, routeb_rx, secb_rx); - let l2_info = l1_info; + let l2_info = l1_info.clone(); man.start_listener(l2, l2_info).unwrap(); - assert!(routeb_tx2.send(RouteConfigurationChange::Removed("n/a".into())).is_ok()); tokio::task::yield_now().await; - // Both listeners should still be active (multiple versions allowed) + // Only original listener should remain active (duplicate was skipped) assert!(routeb_tx1.send(RouteConfigurationChange::Removed("n/a".into())).is_ok()); - assert!(routeb_tx2.send(RouteConfigurationChange::Removed("n/a".into())).is_ok()); + assert_eq!(man.listener_handles.get_vec(name).unwrap().len(), 1); + + let (routeb_tx3, routeb_rx) = broadcast::channel(chan); + let (_secb_tx3, secb_rx) = broadcast::channel(chan); + let l3 = Listener::test_listener(name, routeb_rx, secb_rx); + let mut l3_info = l1_info; + l3_info.address = ListenerAddress::Socket(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 5678)); // Different port + man.start_listener(l3, l3_info).unwrap(); + tokio::task::yield_now().await; assert_eq!(man.listener_handles.get_vec(name).unwrap().len(), 2); + assert!(routeb_tx1.send(RouteConfigurationChange::Removed("n/a".into())).is_ok()); + assert!(routeb_tx3.send(RouteConfigurationChange::Removed("n/a".into())).is_ok()); + tokio::task::yield_now().await; + } + + #[traced_test] + #[tokio::test] + async fn test_config_change_drains_old_listener() { + let chan = 10; + let name = "issue74-listener"; + + let (_conf_tx, conf_rx) = mpsc::channel(chan); + let (_route_tx, route_rx) = mpsc::channel(chan); + let mut man = ListenersManager::new(conf_rx, route_rx); + + let (route_tx1, route_rx1) = broadcast::channel(chan); + let (_sec_tx1, sec_rx1) = broadcast::channel(chan); + let listener1 = Listener::test_listener(name, route_rx1, sec_rx1); + let config1 = create_test_listener_config(name, 10000); + + man.start_listener(listener1, config1.clone()).unwrap(); + assert!(route_tx1.send(RouteConfigurationChange::Removed("cluster_one".into())).is_ok()); + tokio::task::yield_now().await; + + assert_eq!(man.listener_handles.get_vec(name).unwrap().len(), 1); + let initial_version = man.listener_handles.get_vec(name).unwrap()[0].version; + assert!(!man.listener_handles.get_vec(name).unwrap()[0].is_draining()); + + let (route_tx2, route_rx2) = broadcast::channel(chan); + let (_sec_tx2, sec_rx2) = broadcast::channel(chan); + let listener2 = Listener::test_listener(name, route_rx2, sec_rx2); + let mut config2 = config1.clone(); + config2.address = ListenerAddress::Socket(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 20000)); + + man.start_listener(listener2, config2).unwrap(); + assert!(route_tx2.send(RouteConfigurationChange::Removed("cluster_two".into())).is_ok()); + tokio::task::yield_now().await; + + let versions = man.listener_handles.get_vec(name).unwrap(); + assert_eq!(versions.len(), 2, "Should have both old and new versions during transition"); + + let original_listener = versions.iter().find(|v| v.version == initial_version).unwrap(); + assert!(original_listener.is_draining(), "Original listener should be draining when config changes"); + + let new_listener = versions.iter().find(|v| v.version != initial_version).unwrap(); + assert!(!new_listener.is_draining(), "New listener should be active"); + + info!("Old listener is being drained when config changes"); + info!("This prevents mixed responses from both endpoints"); + tokio::task::yield_now().await; } @@ -249,6 +1539,8 @@ mod tests { proxy_protocol_config: None, with_tlv_listener_filter: false, tlv_listener_filter_config: None, + drain_type: orion_configuration::config::listener::DrainType::Default, + version_info: None, }; man.start_listener(l1, l1_info).unwrap(); @@ -258,7 +1550,7 @@ mod tests { // See .start_listener() - in the case all channels are dropped the task there // should exit with this warning msg - let expected = format!("Listener {name} version 1 exited: channel closed"); + let expected = format!("Listener {name} version 1 exited with error: channel closed"); logs_assert(|lines: &[&str]| { let logs: Vec<_> = lines.iter().filter(|ln| ln.contains(&expected)).collect(); if logs.len() == 1 { @@ -277,69 +1569,383 @@ mod tests { let (_conf_tx, conf_rx) = mpsc::channel(chan); let (_route_tx, route_rx) = mpsc::channel(chan); - let mut man = ListenersManager::new(conf_rx, route_rx); + let config = ListenerManagerConfig { + max_versions_per_listener: 2, + cleanup_policy: CleanupPolicy::CountBasedOnly(2), + cleanup_interval: Duration::from_secs(60), + drain_config: ListenerDrainConfig::default(), + }; + let mut man = ListenersManager::new_with_config(conf_rx, route_rx, config); let (routeb_tx1, routeb_rx) = broadcast::channel(chan); let (_secb_tx1, secb_rx) = broadcast::channel(chan); let l1 = Listener::test_listener(name, routeb_rx, secb_rx); - let l1_info = ListenerConfig { - name: name.into(), - address: ListenerAddress::Socket(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 1234)), - filter_chains: HashMap::default(), - bind_device: None, - with_tls_inspector: false, - proxy_protocol_config: None, - with_tlv_listener_filter: false, - tlv_listener_filter_config: None, - }; + let l1_info = create_test_listener_config(name, 1234); man.start_listener(l1, l1_info).unwrap(); assert!(routeb_tx1.send(RouteConfigurationChange::Removed("n/a".into())).is_ok()); tokio::task::yield_now().await; - let (routeb_tx2, routeb_rx) = broadcast::channel(chan); - let (_secb_tx2, secb_rx) = broadcast::channel(chan); - let l2 = Listener::test_listener(name, routeb_rx, secb_rx); - let l2_info = ListenerConfig { + assert_eq!(man.listener_handles.get_vec(name).unwrap().len(), 1); + + for i in 1..=5 { + let (_routeb_tx, routeb_rx) = broadcast::channel(chan); + let (_secb_tx, secb_rx) = broadcast::channel(chan); + let listener = Listener::test_listener(name, routeb_rx, secb_rx); + let listener_info = create_test_listener_config(name, 1230 + i); + man.start_listener(listener, listener_info).unwrap(); + tokio::task::yield_now().await; + } + + let versions = man.listener_handles.get_vec(name).unwrap(); + assert!(versions.len() <= 2, "Expected at most 2 versions, got {}", versions.len()); + + man.stop_listener(name).unwrap(); + assert!(man.listener_handles.get_vec(name).is_none()); + + tokio::task::yield_now().await; + } + + #[traced_test] + #[tokio::test] + async fn test_drain_strategy_immediate() { + let chan = 10; + let name = "immediate-drain-listener"; + + let (_conf_tx, conf_rx) = mpsc::channel(chan); + let (_route_tx, route_rx) = mpsc::channel(chan); + let config = ListenerManagerConfig { + max_versions_per_listener: 2, + cleanup_policy: CleanupPolicy::CountBasedOnly(2), + cleanup_interval: Duration::from_secs(60), + drain_config: ListenerDrainConfig { + drain_time: Duration::from_secs(5), + drain_strategy: DrainStrategy::Immediate, + protocol_handling: ProtocolDrainBehavior::Auto, + }, + }; + let mut man = ListenersManager::new_with_config(conf_rx, route_rx, config); + + let (_routeb_tx, routeb_rx) = broadcast::channel(chan); + let (_secb_tx, secb_rx) = broadcast::channel(chan); + let listener = Listener::test_listener(name, routeb_rx, secb_rx); + let listener_info = create_test_listener_config(name, 1234); + man.start_listener(listener, listener_info).unwrap(); + + man.stop_listener(name).unwrap(); + + tokio::task::yield_now().await; + } + + #[traced_test] + #[tokio::test] + async fn test_drain_strategy_gradual() { + let chan = 10; + let name = "gradual-drain-listener"; + + let (_conf_tx, conf_rx) = mpsc::channel(chan); + let (_route_tx, route_rx) = mpsc::channel(chan); + let config = ListenerManagerConfig { + max_versions_per_listener: 2, + cleanup_policy: CleanupPolicy::CountBasedOnly(2), + cleanup_interval: Duration::from_secs(60), + drain_config: ListenerDrainConfig { + drain_time: Duration::from_secs(10), + drain_strategy: DrainStrategy::Gradual, + protocol_handling: ProtocolDrainBehavior::Http1 { connection_close: true }, + }, + }; + let mut man = ListenersManager::new_with_config(conf_rx, route_rx, config); + + let (_routeb_tx, routeb_rx) = broadcast::channel(chan); + let (_secb_tx, secb_rx) = broadcast::channel(chan); + let listener = Listener::test_listener(name, routeb_rx, secb_rx); + let listener_info = create_test_listener_config(name, 1234); + man.start_listener(listener, listener_info).unwrap(); + + man.stop_listener(name).unwrap(); + + tokio::task::yield_now().await; + } + + #[traced_test] + #[tokio::test] + async fn test_protocol_specific_drain_behavior() { + let chan = 10; + let name = "protocol-drain-listener"; + + let (_conf_tx, conf_rx) = mpsc::channel(chan); + let (_route_tx, route_rx) = mpsc::channel(chan); + let config = ListenerManagerConfig { + max_versions_per_listener: 2, + cleanup_policy: CleanupPolicy::CountBasedOnly(2), + cleanup_interval: Duration::from_secs(60), + drain_config: ListenerDrainConfig { + drain_time: Duration::from_secs(30), + drain_strategy: DrainStrategy::Gradual, + protocol_handling: ProtocolDrainBehavior::Http2 { send_goaway: true }, + }, + }; + let mut man = ListenersManager::new_with_config(conf_rx, route_rx, config); + + let (_routeb_tx, routeb_rx) = broadcast::channel(chan); + let (_secb_tx, secb_rx) = broadcast::channel(chan); + let listener = Listener::test_listener(name, routeb_rx, secb_rx); + let listener_info = create_test_listener_config(name, 1234); + man.start_listener(listener, listener_info).unwrap(); + + let drain_status = man.get_listener_drain_status(name); + assert_eq!(drain_status.len(), 0); + + man.stop_listener(name).unwrap(); + + tokio::task::yield_now().await; + } + + #[traced_test] + #[tokio::test] + async fn test_drain_timeout_enforcement() { + let chan = 10; + let name = "timeout-test-listener"; + + let (_conf_tx, conf_rx) = mpsc::channel(chan); + let (_route_tx, route_rx) = mpsc::channel(chan); + let config = ListenerManagerConfig { + max_versions_per_listener: 2, + cleanup_policy: CleanupPolicy::CountBasedOnly(2), + cleanup_interval: Duration::from_secs(60), + drain_config: ListenerDrainConfig { + drain_time: Duration::from_millis(100), + drain_strategy: DrainStrategy::Gradual, + protocol_handling: ProtocolDrainBehavior::Auto, + }, + }; + let mut man = ListenersManager::new_with_config(conf_rx, route_rx, config); + + let (_routeb_tx, routeb_rx) = broadcast::channel(chan); + let (_secb_tx, secb_rx) = broadcast::channel(chan); + let listener = Listener::test_listener(name, routeb_rx, secb_rx); + let listener_info = create_test_listener_config(name, 1234); + man.start_listener(listener, listener_info).unwrap(); + + man.stop_listener(name).unwrap(); + + tokio::time::sleep(Duration::from_millis(200)).await; + tokio::task::yield_now().await; + } + + #[traced_test] + #[tokio::test] + async fn test_address_conflict_resolution_graceful() { + let chan = 10; + let name = "conflict-test-listener"; + let shared_address = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8080); + + let (_conf_tx, conf_rx) = mpsc::channel(chan); + let (_route_tx, route_rx) = mpsc::channel(chan); + let mut man = ListenersManager::new_with_config( + conf_rx, + route_rx, + ListenerManagerConfig { + max_versions_per_listener: 3, + cleanup_policy: CleanupPolicy::CountBasedOnly(2), + cleanup_interval: Duration::from_secs(60), + drain_config: ListenerDrainConfig { + drain_time: Duration::from_secs(30), + drain_strategy: DrainStrategy::Gradual, + protocol_handling: ProtocolDrainBehavior::Auto, + }, + }, + ); + + let (routeb_tx1, routeb_rx1) = broadcast::channel(chan); + let (_secb_tx1, secb_rx1) = broadcast::channel(chan); + let l1 = Listener::test_listener(name, routeb_rx1, secb_rx1); + let l1_info = ListenerConfig { name: name.into(), - address: ListenerAddress::Socket(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 1235)), // Different port + address: ListenerAddress::Socket(shared_address), filter_chains: HashMap::default(), bind_device: None, with_tls_inspector: false, proxy_protocol_config: None, with_tlv_listener_filter: false, tlv_listener_filter_config: None, + drain_type: orion_configuration::config::listener::DrainType::Default, + version_info: None, }; - man.start_listener(l2, l2_info).unwrap(); - assert!(routeb_tx2.send(RouteConfigurationChange::Removed("n/a".into())).is_ok()); - tokio::task::yield_now().await; - let (routeb_tx3, routeb_rx) = broadcast::channel(chan); - let (_secb_tx3, secb_rx) = broadcast::channel(chan); - let l3 = Listener::test_listener(name, routeb_rx, secb_rx); - let l3_info = ListenerConfig { + man.start_listener_with_conflict_resolution(l1, l1_info.clone()).unwrap(); + + let versions = man.listener_handles.get_vec(name).unwrap(); + assert_eq!(versions.len(), 1); + assert_eq!(versions[0].version, 1); + assert!(!versions[0].is_draining()); + + let (routeb_tx2, routeb_rx2) = broadcast::channel(chan); + let (_secb_tx2, secb_rx2) = broadcast::channel(chan); + let l2 = Listener::test_listener(name, routeb_rx2, secb_rx2); + let l2_info = ListenerConfig { name: name.into(), - address: ListenerAddress::Socket(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 1236)), // Different port + address: ListenerAddress::Socket(shared_address), filter_chains: HashMap::default(), bind_device: None, - with_tls_inspector: false, + with_tls_inspector: true, proxy_protocol_config: None, with_tlv_listener_filter: false, tlv_listener_filter_config: None, + drain_type: orion_configuration::config::listener::DrainType::Default, + version_info: None, }; - man.start_listener(l3, l3_info).unwrap(); - assert!(routeb_tx3.send(RouteConfigurationChange::Removed("n/a".into())).is_ok()); + + man.start_listener_with_conflict_resolution(l2, l2_info).unwrap(); + + let versions = man.listener_handles.get_vec(name).unwrap(); + assert_eq!(versions.len(), 2); + + assert_eq!(versions[0].version, 1); + assert!(versions[0].is_draining()); + assert_eq!(versions[1].version, 2); + assert!(!versions[1].is_draining()); + + info!("Successfully demonstrated graceful address conflict resolution"); + info!("PR #77 approach: would have immediately killed version 1 -> broken connections"); + info!("Our approach: version 1 gracefully draining -> connections preserved"); + + drop(routeb_tx1); + drop(routeb_tx2); tokio::task::yield_now().await; + } - assert!(routeb_tx1.send(RouteConfigurationChange::Removed("n/a".into())).is_ok()); - assert!(routeb_tx2.send(RouteConfigurationChange::Removed("n/a".into())).is_ok()); - assert!(routeb_tx3.send(RouteConfigurationChange::Removed("n/a".into())).is_ok()); + #[tokio::test] + async fn test_xds_version_handling() { + let chan = 16; + let (_conf_tx, conf_rx) = mpsc::channel(chan); + let (_route_tx, route_rx) = mpsc::channel(chan); + let config = ListenerManagerConfig { + max_versions_per_listener: 3, + cleanup_policy: CleanupPolicy::CountBasedOnly(3), + cleanup_interval: Duration::from_secs(60), + drain_config: ListenerDrainConfig::default(), + }; + let mut man = ListenersManager::new_with_config(conf_rx, route_rx, config); - assert_eq!(man.listener_handles.get_vec(name).unwrap().len(), 3); + let name = "test-xds-version"; - man.stop_listener(name).unwrap(); + let (routeb_tx1, routeb_rx1) = broadcast::channel(chan); + let (_secb_tx1, secb_rx1) = broadcast::channel(chan); + let l1 = Listener::test_listener(name, routeb_rx1, secb_rx1); + let mut l1_info = create_test_listener_config(name, 8001); + l1_info.version_info = Some("42".to_string()); - assert!(man.listener_handles.get_vec(name).is_none()); + man.start_listener_with_conflict_resolution(l1, l1_info).unwrap(); + + let versions = man.listener_handles.get_vec(name).unwrap(); + assert_eq!(versions.len(), 1); + assert_eq!(versions[0].version, 42); + + let (routeb_tx2, routeb_rx2) = broadcast::channel(chan); + let (_secb_tx2, secb_rx2) = broadcast::channel(chan); + let l2 = Listener::test_listener(name, routeb_rx2, secb_rx2); + let mut l2_info = create_test_listener_config(name, 8001); + l2_info.version_info = Some("v1.2.3-alpha".to_string()); + l2_info.with_tls_inspector = true; + + println!("Starting second listener with version_info: {:?}", l2_info.version_info); + man.start_listener_with_conflict_resolution(l2, l2_info).unwrap(); + + let versions = man.listener_handles.get_vec(name).unwrap(); + println!("After second listener, versions: {:?}", versions.iter().map(|v| v.version).collect::>()); + assert_eq!(versions.len(), 2); + assert_eq!(versions[0].version, 42); + assert_ne!(versions[1].version, 42); + + let (routeb_tx3, routeb_rx3) = broadcast::channel(chan); + let (_secb_tx3, secb_rx3) = broadcast::channel(chan); + let l3 = Listener::test_listener(name, routeb_rx3, secb_rx3); + let mut l3_info = create_test_listener_config(name, 8001); + l3_info.with_tls_inspector = true; + l3_info.drain_type = orion_configuration::config::listener::DrainType::ModifyOnly; + + println!( + "Starting third listener with version_info: {:?}, with_tls_inspector: {}, drain_type: {:?}", + l3_info.version_info, l3_info.with_tls_inspector, l3_info.drain_type + ); + man.start_listener_with_conflict_resolution(l3, l3_info).unwrap(); + + let versions = man.listener_handles.get_vec(name).unwrap(); + println!("After third listener, versions: {:?}", versions.iter().map(|v| v.version).collect::>()); + assert_eq!(versions.len(), 3); + let third_version = versions[2].version; + assert!(third_version > 0); + drop(routeb_tx1); + drop(routeb_tx2); + drop(routeb_tx3); tokio::task::yield_now().await; } + + #[traced_test] + #[tokio::test] + async fn test_concurrent_listener_operations() { + let chan = 16; + let (_conf_tx, conf_rx) = mpsc::channel(chan); + let (_route_tx, route_rx) = mpsc::channel(chan); + let config = ListenerManagerConfig { + max_versions_per_listener: 5, + cleanup_policy: CleanupPolicy::CountBasedOnly(5), + cleanup_interval: Duration::from_secs(60), + drain_config: ListenerDrainConfig::default(), + }; + let man = Arc::new(Mutex::new(ListenersManager::new_with_config(conf_rx, route_rx, config))); + + // Spawn multiple concurrent operations + let mut handles = Vec::new(); + + for i in 0..10 { + let man_clone: Arc> = Arc::clone(&man); + let handle = tokio::spawn(async move { + let name = match i { + 0 => "concurrent-listener-0", + 1 => "concurrent-listener-1", + 2 => "concurrent-listener-2", + 3 => "concurrent-listener-3", + 4 => "concurrent-listener-4", + 5 => "concurrent-listener-5", + 6 => "concurrent-listener-6", + 7 => "concurrent-listener-7", + 8 => "concurrent-listener-8", + _ => "concurrent-listener-9", + }; + let (routeb_tx, routeb_rx) = broadcast::channel(chan); + let (_secb_tx, secb_rx) = broadcast::channel(chan); + let listener = Listener::test_listener(name, routeb_rx, secb_rx); + let listener_info = create_test_listener_config(name, 8000 + i); + + let mut manager = man_clone.lock().await; + manager.start_listener(listener, listener_info).unwrap(); + drop(routeb_tx); + }); + handles.push(handle); + } + + for handle in handles { + handle.await.unwrap(); + } + + let manager = man.lock().await; + let expected_names = [ + "concurrent-listener-0", + "concurrent-listener-1", + "concurrent-listener-2", + "concurrent-listener-3", + "concurrent-listener-4", + "concurrent-listener-5", + "concurrent-listener-6", + "concurrent-listener-7", + "concurrent-listener-8", + "concurrent-listener-9", + ]; + for &name in expected_names.iter() { + assert!(manager.listener_handles.get_vec(name).is_some()); + } + } } diff --git a/orion-lib/src/listeners/mod.rs b/orion-lib/src/listeners/mod.rs index 3f6e2b6a..0194f9d8 100644 --- a/orion-lib/src/listeners/mod.rs +++ b/orion-lib/src/listeners/mod.rs @@ -16,9 +16,11 @@ // pub(crate) mod access_log; +pub(crate) mod drain_signaling; pub(crate) mod filter_state; pub(crate) mod filterchain; pub(crate) mod http_connection_manager; +pub(crate) mod lds_update; pub(crate) mod listener; pub(crate) mod listeners_manager; pub(crate) mod rate_limiter; diff --git a/orion-proxy/src/admin/config_dump.rs b/orion-proxy/src/admin/config_dump.rs index f7dbb1eb..ebc7c97d 100644 --- a/orion-proxy/src/admin/config_dump.rs +++ b/orion-proxy/src/admin/config_dump.rs @@ -247,6 +247,8 @@ mod config_dump_tests { let listener = Listener { name: CompactString::from("listener1"), address: ListenerAddress::Socket(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8080)), + drain_type: orion_configuration::config::listener::DrainType::Default, + version_info: None, filter_chains: { let mut map = HashMap::new(); map.insert( @@ -259,6 +261,7 @@ mod config_dump_tests { terminal_filter: MainFilter::Http(HttpConnectionManager { codec_type: CodecType::Http1, request_timeout: Some(Duration::from_secs(10)), + drain_timeout: Some(Duration::from_secs(5)), http_filters: vec![], enabled_upgrades: vec![], route_specifier: RouteSpecifier::RouteConfig(RouteConfiguration {