diff --git a/quixstreams/app.py b/quixstreams/app.py index 58c69e284..3fdbf1c49 100644 --- a/quixstreams/app.py +++ b/quixstreams/app.py @@ -5,6 +5,7 @@ import time import warnings from collections import defaultdict +from datetime import datetime from pathlib import Path from typing import Callable, List, Literal, Optional, Protocol, Tuple, Type, Union @@ -28,6 +29,8 @@ from .logging import LogLevel, configure_logging from .models import ( DeserializerType, + MessageContext, + Row, SerializerType, TimestampExtractor, Topic, @@ -151,6 +154,7 @@ def __init__( topic_create_timeout: float = 60, processing_guarantee: ProcessingGuarantee = "at-least-once", max_partition_buffer_size: int = 10000, + heartbeat_interval: float = 0.0, ): """ :param broker_address: Connection settings for Kafka. @@ -220,6 +224,11 @@ def __init__( It is a soft limit, and the actual number of buffered messages can be up to x2 higher. Lower value decreases the memory use, but increases the latency. Default - `10000`. + :param heartbeat_interval: the interval (seconds) at which to send heartbeat messages. + The heartbeat timing starts counting from application start. + The heartbeat is sent for every partition on every topic. + If the value is 0, no heartbeat messages will be sent. + Default - `0.0`.

***Error Handlers***
To handle errors, `Application` accepts callbacks triggered when @@ -363,6 +372,10 @@ def __init__( recovery_manager=recovery_manager, ) + self._heartbeat_active = heartbeat_interval > 0 + self._heartbeat_interval = heartbeat_interval + self._heartbeat_last_sent = datetime.now().timestamp() + self._source_manager = SourceManager() self._sink_manager = SinkManager() self._dataframe_registry = DataFrameRegistry() @@ -879,6 +892,7 @@ def _run_dataframe(self, sink: Optional[VoidExecutor] = None): processing_context = self._processing_context source_manager = self._source_manager process_message = self._process_message + process_heartbeat = self._process_heartbeat printer = self._processing_context.printer run_tracker = self._run_tracker consumer = self._consumer @@ -902,6 +916,7 @@ def _run_dataframe(self, sink: Optional[VoidExecutor] = None): run_tracker.timeout_refresh() else: process_message(dataframes_composed) + process_heartbeat(dataframes_composed) processing_context.commit_checkpoint() consumer.resume_backpressured() source_manager.raise_for_error() @@ -985,6 +1000,49 @@ def _process_message(self, dataframe_composed): if self._on_message_processed is not None: self._on_message_processed(topic_name, partition, offset) + def _process_heartbeat(self, dataframe_composed): + if not self._heartbeat_active: + return + + now = datetime.now().timestamp() + if self._heartbeat_last_sent > now - self._heartbeat_interval: + return + + value, key, timestamp, headers = None, None, int(now * 1000), {} + non_changelog_topics = self._topic_manager.non_changelog_topics + + for tp in self._consumer.assignment(): + if (topic := tp.topic) in non_changelog_topics: + row = Row( + value=value, + key=key, + timestamp=timestamp, + context=MessageContext( + topic=topic, + partition=tp.partition, + offset=-1, + size=-1, + heartbeat=True, + ), + headers=headers, + ) + context = copy_context() + context.run(set_message_context, row.context) + try: + context.run( + dataframe_composed[topic], + value, + key, + timestamp, + headers, + ) + except Exception as exc: + to_suppress = self._on_processing_error(exc, row, logger) + if not to_suppress: + raise + + self._heartbeat_last_sent = now + def _on_assign(self, _, topic_partitions: List[TopicPartition]): """ Assign new topic partitions to consumer and state. diff --git a/quixstreams/core/stream/functions/__init__.py b/quixstreams/core/stream/functions/__init__.py index 1561500e3..1be9fd6a0 100644 --- a/quixstreams/core/stream/functions/__init__.py +++ b/quixstreams/core/stream/functions/__init__.py @@ -2,6 +2,7 @@ from .apply import * from .base import * from .filter import * +from .heartbeat import * from .transform import * from .types import * from .update import * diff --git a/quixstreams/core/stream/functions/apply.py b/quixstreams/core/stream/functions/apply.py index bdf493953..24f6fd92d 100644 --- a/quixstreams/core/stream/functions/apply.py +++ b/quixstreams/core/stream/functions/apply.py @@ -1,6 +1,7 @@ from typing import Any, Literal, Union, overload from .base import StreamFunction +from .heartbeat import is_heartbeat_message from .types import ( ApplyCallback, ApplyExpandedCallback, @@ -48,6 +49,10 @@ def wrapper( timestamp: int, headers: Any, ) -> None: + # Pass heartbeat messages downstream + if is_heartbeat_message(key, value): + child_executor(value, key, timestamp, headers) + # Execute a function on a single value and wrap results into a list # to expand them downstream result = func(value) @@ -62,9 +67,11 @@ def wrapper( timestamp: int, headers: Any, ) -> None: - # Execute a function on a single value and return its result - result = func(value) - child_executor(result, key, timestamp, headers) + # Pass heartbeat messages downstream or execute + # a function on a single value and return its result + if not is_heartbeat_message(key, value): + value = func(value) + child_executor(value, key, timestamp, headers) return wrapper @@ -110,6 +117,10 @@ def wrapper( timestamp: int, headers: Any, ): + # Pass heartbeat messages downstream + if is_heartbeat_message(key, value): + child_executor(value, key, timestamp, headers) + # Execute a function on a single value and wrap results into a list # to expand them downstream result = func(value, key, timestamp, headers) @@ -124,8 +135,10 @@ def wrapper( timestamp: int, headers: Any, ): - # Execute a function on a single value and return its result - result = func(value, key, timestamp, headers) - child_executor(result, key, timestamp, headers) + # Pass heartbeat messages downstream or execute + # a function on a single value and return its result + if not is_heartbeat_message(key, value): + value = func(value, key, timestamp, headers) + child_executor(value, key, timestamp, headers) return wrapper diff --git a/quixstreams/core/stream/functions/filter.py b/quixstreams/core/stream/functions/filter.py index e291880c7..ed1bf6038 100644 --- a/quixstreams/core/stream/functions/filter.py +++ b/quixstreams/core/stream/functions/filter.py @@ -1,5 +1,7 @@ from typing import Any +from quixstreams.core.stream.functions.heartbeat import is_heartbeat_message + from .base import StreamFunction from .types import FilterCallback, FilterWithMetadataCallback, VoidExecutor @@ -30,7 +32,7 @@ def wrapper( headers: Any, ): # Filter a single value - if func(value): + if is_heartbeat_message(key, value) or func(value): child_executor(value, key, timestamp, headers) return wrapper @@ -62,7 +64,7 @@ def wrapper( headers: Any, ): # Filter a single value - if func(value, key, timestamp, headers): + if is_heartbeat_message(key, value) or func(value, key, timestamp, headers): child_executor(value, key, timestamp, headers) return wrapper diff --git a/quixstreams/core/stream/functions/heartbeat.py b/quixstreams/core/stream/functions/heartbeat.py new file mode 100644 index 000000000..589745aa2 --- /dev/null +++ b/quixstreams/core/stream/functions/heartbeat.py @@ -0,0 +1,36 @@ +from typing import Any + +from quixstreams.context import message_context + +from .base import StreamFunction +from .types import HeartbeatCallback, VoidExecutor + +__all__ = ("HeartbeatFunction", "is_heartbeat_message") + + +class HeartbeatFunction(StreamFunction): + def __init__(self, func: HeartbeatCallback) -> None: + super().__init__(func) + self.func: HeartbeatCallback + + def get_executor(self, *child_executors: VoidExecutor) -> VoidExecutor: + child_executor = self._resolve_branching(*child_executors) + + func = self.func + + def wrapper( + value: Any, + key: Any, + timestamp: int, + headers: Any, + ): + if is_heartbeat_message(key, value) and (result := func(timestamp)): + for new_value, new_key, new_timestamp, new_headers in result: + child_executor(new_value, new_key, new_timestamp, new_headers) + child_executor(value, key, timestamp, headers) + + return wrapper + + +def is_heartbeat_message(key: Any, value: Any) -> bool: + return message_context().heartbeat and key is None and value is None diff --git a/quixstreams/core/stream/functions/transform.py b/quixstreams/core/stream/functions/transform.py index 219662b6b..2ca23718a 100644 --- a/quixstreams/core/stream/functions/transform.py +++ b/quixstreams/core/stream/functions/transform.py @@ -1,5 +1,7 @@ from typing import Any, Literal, Union, cast, overload +from quixstreams.core.stream.functions.heartbeat import is_heartbeat_message + from .base import StreamFunction from .types import TransformCallback, TransformExpandedCallback, VoidExecutor @@ -53,9 +55,13 @@ def wrapper( timestamp: int, headers: Any, ): - result = expanded_func(value, key, timestamp, headers) - for new_value, new_key, new_timestamp, new_headers in result: - child_executor(new_value, new_key, new_timestamp, new_headers) + # Pass heartbeat messages downstream + if is_heartbeat_message(key, value): + child_executor(value, key, timestamp, headers) + else: + result = expanded_func(value, key, timestamp, headers) + for new_value, new_key, new_timestamp, new_headers in result: + child_executor(new_value, new_key, new_timestamp, new_headers) else: func = cast(TransformCallback, self.func) @@ -66,10 +72,13 @@ def wrapper( timestamp: int, headers: Any, ): - # Execute a function on a single value and return its result - new_value, new_key, new_timestamp, new_headers = func( - value, key, timestamp, headers - ) - child_executor(new_value, new_key, new_timestamp, new_headers) + # Pass heartbeat messages downstream or execute + # a function on a single value and return its result + if not is_heartbeat_message(key, value): + value, key, timestamp, headers = func( + value, key, timestamp, headers + ) + + child_executor(value, key, timestamp, headers) return wrapper diff --git a/quixstreams/core/stream/functions/types.py b/quixstreams/core/stream/functions/types.py index 504299b53..eceb8fdb3 100644 --- a/quixstreams/core/stream/functions/types.py +++ b/quixstreams/core/stream/functions/types.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Iterable, Protocol, Tuple, Union +from typing import Any, Callable, Iterable, Optional, Protocol, Tuple, Union __all__ = ( "StreamCallback", @@ -36,6 +36,8 @@ def __bool__(self) -> bool: ... [Any, Any, int, Any], Iterable[Tuple[Any, Any, int, Any]] ] +HeartbeatCallback = Callable[[int], Optional[Iterable[Tuple[Any, Any, int, Any]]]] + StreamCallback = Union[ ApplyCallback, ApplyExpandedCallback, diff --git a/quixstreams/core/stream/functions/update.py b/quixstreams/core/stream/functions/update.py index b2d9a19bc..18b7928aa 100644 --- a/quixstreams/core/stream/functions/update.py +++ b/quixstreams/core/stream/functions/update.py @@ -1,5 +1,7 @@ from typing import Any +from quixstreams.core.stream.functions.heartbeat import is_heartbeat_message + from .base import StreamFunction from .types import UpdateCallback, UpdateWithMetadataCallback, VoidExecutor @@ -28,7 +30,8 @@ def get_executor(self, *child_executors: VoidExecutor) -> VoidExecutor: def wrapper(value: Any, key: Any, timestamp: int, headers: Any): # Update a single value and forward it - func(value) + if not is_heartbeat_message(key, value): + func(value) child_executor(value, key, timestamp, headers) return wrapper @@ -56,7 +59,8 @@ def get_executor(self, *child_executors: VoidExecutor) -> VoidExecutor: def wrapper(value: Any, key: Any, timestamp: int, headers: Any): # Update a single value and forward it - func(value, key, timestamp, headers) + if not is_heartbeat_message(key, value): + func(value, key, timestamp, headers) child_executor(value, key, timestamp, headers) return wrapper diff --git a/quixstreams/core/stream/stream.py b/quixstreams/core/stream/stream.py index f538f5307..115d66235 100644 --- a/quixstreams/core/stream/stream.py +++ b/quixstreams/core/stream/stream.py @@ -25,6 +25,7 @@ FilterFunction, FilterWithMetadataCallback, FilterWithMetadataFunction, + HeartbeatFunction, ReturningExecutor, StreamFunction, TransformCallback, @@ -280,6 +281,10 @@ def add_transform( """ return self._add(TransformFunction(func, expand=expand)) # type: ignore[call-overload] + def add_heartbeat(self, func) -> "Stream": + heartbeat_func = HeartbeatFunction(func) + return self._add(heartbeat_func) + def merge(self, other: "Stream") -> "Stream": """ Merge two Streams together and return a new Stream with two parents diff --git a/quixstreams/dataframe/dataframe.py b/quixstreams/dataframe/dataframe.py index 7e9b09844..fe85f11ab 100644 --- a/quixstreams/dataframe/dataframe.py +++ b/quixstreams/dataframe/dataframe.py @@ -191,6 +191,10 @@ def stream_id(self) -> str: def topics(self) -> tuple[Topic, ...]: return self._topics + def heartbeat(self, func) -> "StreamingDataFrame": + stream = self.stream.add_heartbeat(func) + return self.__dataframe_clone__(stream=stream) + @overload def apply( self, diff --git a/quixstreams/dataframe/windows/base.py b/quixstreams/dataframe/windows/base.py index 8040b2774..4f326306d 100644 --- a/quixstreams/dataframe/windows/base.py +++ b/quixstreams/dataframe/windows/base.py @@ -69,6 +69,14 @@ def process_window( ) -> tuple[Iterable[WindowKeyResult], Iterable[WindowKeyResult]]: pass + @abstractmethod + def process_heartbeat( + self, + timestamp_ms: int, + transaction: WindowedPartitionTransaction, + ) -> Iterable[WindowKeyResult]: + pass + def register_store(self) -> None: TopicManager.ensure_topics_copartitioned(*self._dataframe.topics) # Create a config for the changelog topic based on the underlying SDF topics @@ -83,6 +91,7 @@ def _apply_window( self, func: TransformRecordCallbackExpandedWindowed, name: str, + heartbeat_func, ) -> "StreamingDataFrame": self.register_store() @@ -92,11 +101,18 @@ def _apply_window( processing_context=self._dataframe.processing_context, store_name=name, ) + heartbeat_func = _as_heartbeat( + func=heartbeat_func, + stream_id=self._dataframe.stream_id, + processing_context=self._dataframe.processing_context, + store_name=name, + ) # Manually modify the Stream and clone the source StreamingDataFrame # to avoid adding "transform" API to it. # Transform callbacks can modify record key and timestamp, # and it's prone to misuse. stream = self._dataframe.stream.add_transform(func=windowed_func, expand=True) + stream = stream.add_heartbeat(func=heartbeat_func) return self._dataframe.__dataframe_clone__(stream=stream) def final(self) -> "StreamingDataFrame": @@ -140,9 +156,17 @@ def window_callback( for key, window in expired_windows: yield (window, key, window["start"], None) + def heartbeat_callback( + timestamp: int, transaction: WindowedPartitionTransaction + ) -> Iterable[Message]: + # TODO: Check if this will work for sliding windows + for key, window in self.process_heartbeat(timestamp, transaction): + yield (window, key, window["start"], None) + return self._apply_window( func=window_callback, name=self._name, + heartbeat_func=heartbeat_callback, ) def current(self) -> "StreamingDataFrame": @@ -188,7 +212,17 @@ def window_callback( for key, window in updated_windows: yield (window, key, window["start"], None) - return self._apply_window(func=window_callback, name=self._name) + def heartbeat_callback( + timestamp: int, transaction: WindowedPartitionTransaction + ) -> Iterable[Message]: + # TODO: Implement heartbeat callback + return [] + + return self._apply_window( + func=window_callback, + name=self._name, + heartbeat_func=heartbeat_callback, + ) # Implemented by SingleAggregationWindowMixin and MultiAggregationWindowMixin # Single aggregation and multi aggregation windows store aggregations and collections @@ -424,6 +458,26 @@ def wrapper( return wrapper +def _as_heartbeat( + func, # TODO: typing? + processing_context: "ProcessingContext", + store_name: str, + stream_id: str, +): # TODO: typing? + @functools.wraps(func) + def wrapper(timestamp: int) -> Iterable[Message]: + ctx = message_context() + transaction = cast( + WindowedPartitionTransaction, + processing_context.checkpoint.get_store_transaction( + stream_id=stream_id, partition=ctx.partition, store_name=store_name + ), + ) + return func(timestamp, transaction) + + return wrapper + + class WindowOnLateCallback(Protocol): def __call__( self, diff --git a/quixstreams/dataframe/windows/count_based.py b/quixstreams/dataframe/windows/count_based.py index 57c6b36e5..649ee0ee0 100644 --- a/quixstreams/dataframe/windows/count_based.py +++ b/quixstreams/dataframe/windows/count_based.py @@ -189,6 +189,14 @@ def process_window( state.set(key=self.STATE_KEY, value=data) return updated_windows, expired_windows + def process_heartbeat( + self, + timestamp_ms: int, + transaction: WindowedPartitionTransaction, + ) -> Iterable[WindowKeyResult]: + # Count based windows cannot be expired by heartbeat + return [] + def _get_collection_start_id(self, window: CountWindowData) -> int: start_id = window.get("collection_start_id", _MISSING) if start_id is _MISSING: diff --git a/quixstreams/dataframe/windows/time_based.py b/quixstreams/dataframe/windows/time_based.py index c403cfdfa..171dae6d6 100644 --- a/quixstreams/dataframe/windows/time_based.py +++ b/quixstreams/dataframe/windows/time_based.py @@ -200,11 +200,27 @@ def process_window( return updated_windows, expired_windows + def process_heartbeat( + self, + timestamp_ms: int, + transaction: WindowedPartitionTransaction, + ) -> Iterable[WindowKeyResult]: + latest_expired_window_end = transaction.get_latest_expired(prefix=b"") + latest_timestamp = max(timestamp_ms, latest_expired_window_end) + max_expired_window_end = latest_timestamp - self._grace_ms + return self.expire_by_partition( + transaction, + max_expired_window_end, + self.collect, + advance_last_expired_timestamp=False, + ) + def expire_by_partition( self, transaction: WindowedPartitionTransaction, max_expired_end: int, collect: bool, + advance_last_expired_timestamp: bool = True, ) -> Iterable[WindowKeyResult]: for ( window_start, @@ -214,6 +230,7 @@ def expire_by_partition( step_ms=self._step_ms if self._step_ms else self._duration_ms, collect=collect, delete=True, + advance_last_expired_timestamp=advance_last_expired_timestamp, ): yield key, self._results(aggregated, collected, window_start, window_end) diff --git a/quixstreams/models/messagecontext.py b/quixstreams/models/messagecontext.py index 351fe9157..a41c011bf 100644 --- a/quixstreams/models/messagecontext.py +++ b/quixstreams/models/messagecontext.py @@ -16,6 +16,7 @@ class MessageContext: "_size", "_headers", "_leader_epoch", + "_heartbeat", ) def __init__( @@ -25,12 +26,14 @@ def __init__( offset: int, size: int, leader_epoch: Optional[int] = None, + heartbeat: bool = False, ): self._topic = topic self._partition = partition self._offset = offset self._size = size self._leader_epoch = leader_epoch + self._heartbeat = heartbeat @property def topic(self) -> str: @@ -51,3 +54,7 @@ def size(self) -> int: @property def leader_epoch(self) -> Optional[int]: return self._leader_epoch + + @property + def heartbeat(self) -> bool: + return self._heartbeat diff --git a/quixstreams/state/rocksdb/windowed/transaction.py b/quixstreams/state/rocksdb/windowed/transaction.py index 3779b3e29..dfa3fa12a 100644 --- a/quixstreams/state/rocksdb/windowed/transaction.py +++ b/quixstreams/state/rocksdb/windowed/transaction.py @@ -298,6 +298,7 @@ def expire_all_windows( step_ms: int, delete: bool = True, collect: bool = False, + advance_last_expired_timestamp: bool = True, ) -> Iterable[ExpiredWindowDetail]: """ Get all expired windows for all prefix from RocksDB up to the specified `max_end_time` timestamp. @@ -360,9 +361,12 @@ def expire_all_windows( if collect: self.delete_from_collection(end=start, prefix=prefix) - self._set_timestamp( - prefix=b"", cache=self._last_expired_timestamps, timestamp_ms=last_expired - ) + if advance_last_expired_timestamp: + self._set_timestamp( + prefix=b"", + cache=self._last_expired_timestamps, + timestamp_ms=last_expired, + ) def delete_windows( self, max_start_time: int, delete_values: bool, prefix: bytes diff --git a/quixstreams/state/types.py b/quixstreams/state/types.py index c80c9e2ad..9f07b9499 100644 --- a/quixstreams/state/types.py +++ b/quixstreams/state/types.py @@ -378,6 +378,7 @@ def expire_all_windows( step_ms: int, delete: bool = True, collect: bool = False, + advance_last_expired_timestamp: bool = True, ) -> Iterable[ExpiredWindowDetail[V]]: """ Get all expired windows for all prefix from RocksDB up to the specified `max_start_time` timestamp. @@ -388,6 +389,7 @@ def expire_all_windows( :param max_end_time: The timestamp up to which windows are considered expired, inclusive. :param delete: If True, expired windows will be deleted. :param collect: If True, values will be collected into windows. + :param advance_last_expired_timestamp: If True, the last expired timestamp will be persisted. """ ...