diff --git a/quixstreams/app.py b/quixstreams/app.py
index 58c69e284..3fdbf1c49 100644
--- a/quixstreams/app.py
+++ b/quixstreams/app.py
@@ -5,6 +5,7 @@
import time
import warnings
from collections import defaultdict
+from datetime import datetime
from pathlib import Path
from typing import Callable, List, Literal, Optional, Protocol, Tuple, Type, Union
@@ -28,6 +29,8 @@
from .logging import LogLevel, configure_logging
from .models import (
DeserializerType,
+ MessageContext,
+ Row,
SerializerType,
TimestampExtractor,
Topic,
@@ -151,6 +154,7 @@ def __init__(
topic_create_timeout: float = 60,
processing_guarantee: ProcessingGuarantee = "at-least-once",
max_partition_buffer_size: int = 10000,
+ heartbeat_interval: float = 0.0,
):
"""
:param broker_address: Connection settings for Kafka.
@@ -220,6 +224,11 @@ def __init__(
It is a soft limit, and the actual number of buffered messages can be up to x2 higher.
Lower value decreases the memory use, but increases the latency.
Default - `10000`.
+ :param heartbeat_interval: the interval (seconds) at which to send heartbeat messages.
+ The heartbeat timing starts counting from application start.
+ The heartbeat is sent for every partition on every topic.
+ If the value is 0, no heartbeat messages will be sent.
+ Default - `0.0`.
***Error Handlers***
To handle errors, `Application` accepts callbacks triggered when
@@ -363,6 +372,10 @@ def __init__(
recovery_manager=recovery_manager,
)
+ self._heartbeat_active = heartbeat_interval > 0
+ self._heartbeat_interval = heartbeat_interval
+ self._heartbeat_last_sent = datetime.now().timestamp()
+
self._source_manager = SourceManager()
self._sink_manager = SinkManager()
self._dataframe_registry = DataFrameRegistry()
@@ -879,6 +892,7 @@ def _run_dataframe(self, sink: Optional[VoidExecutor] = None):
processing_context = self._processing_context
source_manager = self._source_manager
process_message = self._process_message
+ process_heartbeat = self._process_heartbeat
printer = self._processing_context.printer
run_tracker = self._run_tracker
consumer = self._consumer
@@ -902,6 +916,7 @@ def _run_dataframe(self, sink: Optional[VoidExecutor] = None):
run_tracker.timeout_refresh()
else:
process_message(dataframes_composed)
+ process_heartbeat(dataframes_composed)
processing_context.commit_checkpoint()
consumer.resume_backpressured()
source_manager.raise_for_error()
@@ -985,6 +1000,49 @@ def _process_message(self, dataframe_composed):
if self._on_message_processed is not None:
self._on_message_processed(topic_name, partition, offset)
+ def _process_heartbeat(self, dataframe_composed):
+ if not self._heartbeat_active:
+ return
+
+ now = datetime.now().timestamp()
+ if self._heartbeat_last_sent > now - self._heartbeat_interval:
+ return
+
+ value, key, timestamp, headers = None, None, int(now * 1000), {}
+ non_changelog_topics = self._topic_manager.non_changelog_topics
+
+ for tp in self._consumer.assignment():
+ if (topic := tp.topic) in non_changelog_topics:
+ row = Row(
+ value=value,
+ key=key,
+ timestamp=timestamp,
+ context=MessageContext(
+ topic=topic,
+ partition=tp.partition,
+ offset=-1,
+ size=-1,
+ heartbeat=True,
+ ),
+ headers=headers,
+ )
+ context = copy_context()
+ context.run(set_message_context, row.context)
+ try:
+ context.run(
+ dataframe_composed[topic],
+ value,
+ key,
+ timestamp,
+ headers,
+ )
+ except Exception as exc:
+ to_suppress = self._on_processing_error(exc, row, logger)
+ if not to_suppress:
+ raise
+
+ self._heartbeat_last_sent = now
+
def _on_assign(self, _, topic_partitions: List[TopicPartition]):
"""
Assign new topic partitions to consumer and state.
diff --git a/quixstreams/core/stream/functions/__init__.py b/quixstreams/core/stream/functions/__init__.py
index 1561500e3..1be9fd6a0 100644
--- a/quixstreams/core/stream/functions/__init__.py
+++ b/quixstreams/core/stream/functions/__init__.py
@@ -2,6 +2,7 @@
from .apply import *
from .base import *
from .filter import *
+from .heartbeat import *
from .transform import *
from .types import *
from .update import *
diff --git a/quixstreams/core/stream/functions/apply.py b/quixstreams/core/stream/functions/apply.py
index bdf493953..24f6fd92d 100644
--- a/quixstreams/core/stream/functions/apply.py
+++ b/quixstreams/core/stream/functions/apply.py
@@ -1,6 +1,7 @@
from typing import Any, Literal, Union, overload
from .base import StreamFunction
+from .heartbeat import is_heartbeat_message
from .types import (
ApplyCallback,
ApplyExpandedCallback,
@@ -48,6 +49,10 @@ def wrapper(
timestamp: int,
headers: Any,
) -> None:
+ # Pass heartbeat messages downstream
+ if is_heartbeat_message(key, value):
+ child_executor(value, key, timestamp, headers)
+
# Execute a function on a single value and wrap results into a list
# to expand them downstream
result = func(value)
@@ -62,9 +67,11 @@ def wrapper(
timestamp: int,
headers: Any,
) -> None:
- # Execute a function on a single value and return its result
- result = func(value)
- child_executor(result, key, timestamp, headers)
+ # Pass heartbeat messages downstream or execute
+ # a function on a single value and return its result
+ if not is_heartbeat_message(key, value):
+ value = func(value)
+ child_executor(value, key, timestamp, headers)
return wrapper
@@ -110,6 +117,10 @@ def wrapper(
timestamp: int,
headers: Any,
):
+ # Pass heartbeat messages downstream
+ if is_heartbeat_message(key, value):
+ child_executor(value, key, timestamp, headers)
+
# Execute a function on a single value and wrap results into a list
# to expand them downstream
result = func(value, key, timestamp, headers)
@@ -124,8 +135,10 @@ def wrapper(
timestamp: int,
headers: Any,
):
- # Execute a function on a single value and return its result
- result = func(value, key, timestamp, headers)
- child_executor(result, key, timestamp, headers)
+ # Pass heartbeat messages downstream or execute
+ # a function on a single value and return its result
+ if not is_heartbeat_message(key, value):
+ value = func(value, key, timestamp, headers)
+ child_executor(value, key, timestamp, headers)
return wrapper
diff --git a/quixstreams/core/stream/functions/filter.py b/quixstreams/core/stream/functions/filter.py
index e291880c7..ed1bf6038 100644
--- a/quixstreams/core/stream/functions/filter.py
+++ b/quixstreams/core/stream/functions/filter.py
@@ -1,5 +1,7 @@
from typing import Any
+from quixstreams.core.stream.functions.heartbeat import is_heartbeat_message
+
from .base import StreamFunction
from .types import FilterCallback, FilterWithMetadataCallback, VoidExecutor
@@ -30,7 +32,7 @@ def wrapper(
headers: Any,
):
# Filter a single value
- if func(value):
+ if is_heartbeat_message(key, value) or func(value):
child_executor(value, key, timestamp, headers)
return wrapper
@@ -62,7 +64,7 @@ def wrapper(
headers: Any,
):
# Filter a single value
- if func(value, key, timestamp, headers):
+ if is_heartbeat_message(key, value) or func(value, key, timestamp, headers):
child_executor(value, key, timestamp, headers)
return wrapper
diff --git a/quixstreams/core/stream/functions/heartbeat.py b/quixstreams/core/stream/functions/heartbeat.py
new file mode 100644
index 000000000..589745aa2
--- /dev/null
+++ b/quixstreams/core/stream/functions/heartbeat.py
@@ -0,0 +1,36 @@
+from typing import Any
+
+from quixstreams.context import message_context
+
+from .base import StreamFunction
+from .types import HeartbeatCallback, VoidExecutor
+
+__all__ = ("HeartbeatFunction", "is_heartbeat_message")
+
+
+class HeartbeatFunction(StreamFunction):
+ def __init__(self, func: HeartbeatCallback) -> None:
+ super().__init__(func)
+ self.func: HeartbeatCallback
+
+ def get_executor(self, *child_executors: VoidExecutor) -> VoidExecutor:
+ child_executor = self._resolve_branching(*child_executors)
+
+ func = self.func
+
+ def wrapper(
+ value: Any,
+ key: Any,
+ timestamp: int,
+ headers: Any,
+ ):
+ if is_heartbeat_message(key, value) and (result := func(timestamp)):
+ for new_value, new_key, new_timestamp, new_headers in result:
+ child_executor(new_value, new_key, new_timestamp, new_headers)
+ child_executor(value, key, timestamp, headers)
+
+ return wrapper
+
+
+def is_heartbeat_message(key: Any, value: Any) -> bool:
+ return message_context().heartbeat and key is None and value is None
diff --git a/quixstreams/core/stream/functions/transform.py b/quixstreams/core/stream/functions/transform.py
index 219662b6b..2ca23718a 100644
--- a/quixstreams/core/stream/functions/transform.py
+++ b/quixstreams/core/stream/functions/transform.py
@@ -1,5 +1,7 @@
from typing import Any, Literal, Union, cast, overload
+from quixstreams.core.stream.functions.heartbeat import is_heartbeat_message
+
from .base import StreamFunction
from .types import TransformCallback, TransformExpandedCallback, VoidExecutor
@@ -53,9 +55,13 @@ def wrapper(
timestamp: int,
headers: Any,
):
- result = expanded_func(value, key, timestamp, headers)
- for new_value, new_key, new_timestamp, new_headers in result:
- child_executor(new_value, new_key, new_timestamp, new_headers)
+ # Pass heartbeat messages downstream
+ if is_heartbeat_message(key, value):
+ child_executor(value, key, timestamp, headers)
+ else:
+ result = expanded_func(value, key, timestamp, headers)
+ for new_value, new_key, new_timestamp, new_headers in result:
+ child_executor(new_value, new_key, new_timestamp, new_headers)
else:
func = cast(TransformCallback, self.func)
@@ -66,10 +72,13 @@ def wrapper(
timestamp: int,
headers: Any,
):
- # Execute a function on a single value and return its result
- new_value, new_key, new_timestamp, new_headers = func(
- value, key, timestamp, headers
- )
- child_executor(new_value, new_key, new_timestamp, new_headers)
+ # Pass heartbeat messages downstream or execute
+ # a function on a single value and return its result
+ if not is_heartbeat_message(key, value):
+ value, key, timestamp, headers = func(
+ value, key, timestamp, headers
+ )
+
+ child_executor(value, key, timestamp, headers)
return wrapper
diff --git a/quixstreams/core/stream/functions/types.py b/quixstreams/core/stream/functions/types.py
index 504299b53..eceb8fdb3 100644
--- a/quixstreams/core/stream/functions/types.py
+++ b/quixstreams/core/stream/functions/types.py
@@ -1,4 +1,4 @@
-from typing import Any, Callable, Iterable, Protocol, Tuple, Union
+from typing import Any, Callable, Iterable, Optional, Protocol, Tuple, Union
__all__ = (
"StreamCallback",
@@ -36,6 +36,8 @@ def __bool__(self) -> bool: ...
[Any, Any, int, Any], Iterable[Tuple[Any, Any, int, Any]]
]
+HeartbeatCallback = Callable[[int], Optional[Iterable[Tuple[Any, Any, int, Any]]]]
+
StreamCallback = Union[
ApplyCallback,
ApplyExpandedCallback,
diff --git a/quixstreams/core/stream/functions/update.py b/quixstreams/core/stream/functions/update.py
index b2d9a19bc..18b7928aa 100644
--- a/quixstreams/core/stream/functions/update.py
+++ b/quixstreams/core/stream/functions/update.py
@@ -1,5 +1,7 @@
from typing import Any
+from quixstreams.core.stream.functions.heartbeat import is_heartbeat_message
+
from .base import StreamFunction
from .types import UpdateCallback, UpdateWithMetadataCallback, VoidExecutor
@@ -28,7 +30,8 @@ def get_executor(self, *child_executors: VoidExecutor) -> VoidExecutor:
def wrapper(value: Any, key: Any, timestamp: int, headers: Any):
# Update a single value and forward it
- func(value)
+ if not is_heartbeat_message(key, value):
+ func(value)
child_executor(value, key, timestamp, headers)
return wrapper
@@ -56,7 +59,8 @@ def get_executor(self, *child_executors: VoidExecutor) -> VoidExecutor:
def wrapper(value: Any, key: Any, timestamp: int, headers: Any):
# Update a single value and forward it
- func(value, key, timestamp, headers)
+ if not is_heartbeat_message(key, value):
+ func(value, key, timestamp, headers)
child_executor(value, key, timestamp, headers)
return wrapper
diff --git a/quixstreams/core/stream/stream.py b/quixstreams/core/stream/stream.py
index f538f5307..115d66235 100644
--- a/quixstreams/core/stream/stream.py
+++ b/quixstreams/core/stream/stream.py
@@ -25,6 +25,7 @@
FilterFunction,
FilterWithMetadataCallback,
FilterWithMetadataFunction,
+ HeartbeatFunction,
ReturningExecutor,
StreamFunction,
TransformCallback,
@@ -280,6 +281,10 @@ def add_transform(
"""
return self._add(TransformFunction(func, expand=expand)) # type: ignore[call-overload]
+ def add_heartbeat(self, func) -> "Stream":
+ heartbeat_func = HeartbeatFunction(func)
+ return self._add(heartbeat_func)
+
def merge(self, other: "Stream") -> "Stream":
"""
Merge two Streams together and return a new Stream with two parents
diff --git a/quixstreams/dataframe/dataframe.py b/quixstreams/dataframe/dataframe.py
index 7e9b09844..fe85f11ab 100644
--- a/quixstreams/dataframe/dataframe.py
+++ b/quixstreams/dataframe/dataframe.py
@@ -191,6 +191,10 @@ def stream_id(self) -> str:
def topics(self) -> tuple[Topic, ...]:
return self._topics
+ def heartbeat(self, func) -> "StreamingDataFrame":
+ stream = self.stream.add_heartbeat(func)
+ return self.__dataframe_clone__(stream=stream)
+
@overload
def apply(
self,
diff --git a/quixstreams/dataframe/windows/base.py b/quixstreams/dataframe/windows/base.py
index 8040b2774..4f326306d 100644
--- a/quixstreams/dataframe/windows/base.py
+++ b/quixstreams/dataframe/windows/base.py
@@ -69,6 +69,14 @@ def process_window(
) -> tuple[Iterable[WindowKeyResult], Iterable[WindowKeyResult]]:
pass
+ @abstractmethod
+ def process_heartbeat(
+ self,
+ timestamp_ms: int,
+ transaction: WindowedPartitionTransaction,
+ ) -> Iterable[WindowKeyResult]:
+ pass
+
def register_store(self) -> None:
TopicManager.ensure_topics_copartitioned(*self._dataframe.topics)
# Create a config for the changelog topic based on the underlying SDF topics
@@ -83,6 +91,7 @@ def _apply_window(
self,
func: TransformRecordCallbackExpandedWindowed,
name: str,
+ heartbeat_func,
) -> "StreamingDataFrame":
self.register_store()
@@ -92,11 +101,18 @@ def _apply_window(
processing_context=self._dataframe.processing_context,
store_name=name,
)
+ heartbeat_func = _as_heartbeat(
+ func=heartbeat_func,
+ stream_id=self._dataframe.stream_id,
+ processing_context=self._dataframe.processing_context,
+ store_name=name,
+ )
# Manually modify the Stream and clone the source StreamingDataFrame
# to avoid adding "transform" API to it.
# Transform callbacks can modify record key and timestamp,
# and it's prone to misuse.
stream = self._dataframe.stream.add_transform(func=windowed_func, expand=True)
+ stream = stream.add_heartbeat(func=heartbeat_func)
return self._dataframe.__dataframe_clone__(stream=stream)
def final(self) -> "StreamingDataFrame":
@@ -140,9 +156,17 @@ def window_callback(
for key, window in expired_windows:
yield (window, key, window["start"], None)
+ def heartbeat_callback(
+ timestamp: int, transaction: WindowedPartitionTransaction
+ ) -> Iterable[Message]:
+ # TODO: Check if this will work for sliding windows
+ for key, window in self.process_heartbeat(timestamp, transaction):
+ yield (window, key, window["start"], None)
+
return self._apply_window(
func=window_callback,
name=self._name,
+ heartbeat_func=heartbeat_callback,
)
def current(self) -> "StreamingDataFrame":
@@ -188,7 +212,17 @@ def window_callback(
for key, window in updated_windows:
yield (window, key, window["start"], None)
- return self._apply_window(func=window_callback, name=self._name)
+ def heartbeat_callback(
+ timestamp: int, transaction: WindowedPartitionTransaction
+ ) -> Iterable[Message]:
+ # TODO: Implement heartbeat callback
+ return []
+
+ return self._apply_window(
+ func=window_callback,
+ name=self._name,
+ heartbeat_func=heartbeat_callback,
+ )
# Implemented by SingleAggregationWindowMixin and MultiAggregationWindowMixin
# Single aggregation and multi aggregation windows store aggregations and collections
@@ -424,6 +458,26 @@ def wrapper(
return wrapper
+def _as_heartbeat(
+ func, # TODO: typing?
+ processing_context: "ProcessingContext",
+ store_name: str,
+ stream_id: str,
+): # TODO: typing?
+ @functools.wraps(func)
+ def wrapper(timestamp: int) -> Iterable[Message]:
+ ctx = message_context()
+ transaction = cast(
+ WindowedPartitionTransaction,
+ processing_context.checkpoint.get_store_transaction(
+ stream_id=stream_id, partition=ctx.partition, store_name=store_name
+ ),
+ )
+ return func(timestamp, transaction)
+
+ return wrapper
+
+
class WindowOnLateCallback(Protocol):
def __call__(
self,
diff --git a/quixstreams/dataframe/windows/count_based.py b/quixstreams/dataframe/windows/count_based.py
index 57c6b36e5..649ee0ee0 100644
--- a/quixstreams/dataframe/windows/count_based.py
+++ b/quixstreams/dataframe/windows/count_based.py
@@ -189,6 +189,14 @@ def process_window(
state.set(key=self.STATE_KEY, value=data)
return updated_windows, expired_windows
+ def process_heartbeat(
+ self,
+ timestamp_ms: int,
+ transaction: WindowedPartitionTransaction,
+ ) -> Iterable[WindowKeyResult]:
+ # Count based windows cannot be expired by heartbeat
+ return []
+
def _get_collection_start_id(self, window: CountWindowData) -> int:
start_id = window.get("collection_start_id", _MISSING)
if start_id is _MISSING:
diff --git a/quixstreams/dataframe/windows/time_based.py b/quixstreams/dataframe/windows/time_based.py
index c403cfdfa..171dae6d6 100644
--- a/quixstreams/dataframe/windows/time_based.py
+++ b/quixstreams/dataframe/windows/time_based.py
@@ -200,11 +200,27 @@ def process_window(
return updated_windows, expired_windows
+ def process_heartbeat(
+ self,
+ timestamp_ms: int,
+ transaction: WindowedPartitionTransaction,
+ ) -> Iterable[WindowKeyResult]:
+ latest_expired_window_end = transaction.get_latest_expired(prefix=b"")
+ latest_timestamp = max(timestamp_ms, latest_expired_window_end)
+ max_expired_window_end = latest_timestamp - self._grace_ms
+ return self.expire_by_partition(
+ transaction,
+ max_expired_window_end,
+ self.collect,
+ advance_last_expired_timestamp=False,
+ )
+
def expire_by_partition(
self,
transaction: WindowedPartitionTransaction,
max_expired_end: int,
collect: bool,
+ advance_last_expired_timestamp: bool = True,
) -> Iterable[WindowKeyResult]:
for (
window_start,
@@ -214,6 +230,7 @@ def expire_by_partition(
step_ms=self._step_ms if self._step_ms else self._duration_ms,
collect=collect,
delete=True,
+ advance_last_expired_timestamp=advance_last_expired_timestamp,
):
yield key, self._results(aggregated, collected, window_start, window_end)
diff --git a/quixstreams/models/messagecontext.py b/quixstreams/models/messagecontext.py
index 351fe9157..a41c011bf 100644
--- a/quixstreams/models/messagecontext.py
+++ b/quixstreams/models/messagecontext.py
@@ -16,6 +16,7 @@ class MessageContext:
"_size",
"_headers",
"_leader_epoch",
+ "_heartbeat",
)
def __init__(
@@ -25,12 +26,14 @@ def __init__(
offset: int,
size: int,
leader_epoch: Optional[int] = None,
+ heartbeat: bool = False,
):
self._topic = topic
self._partition = partition
self._offset = offset
self._size = size
self._leader_epoch = leader_epoch
+ self._heartbeat = heartbeat
@property
def topic(self) -> str:
@@ -51,3 +54,7 @@ def size(self) -> int:
@property
def leader_epoch(self) -> Optional[int]:
return self._leader_epoch
+
+ @property
+ def heartbeat(self) -> bool:
+ return self._heartbeat
diff --git a/quixstreams/state/rocksdb/windowed/transaction.py b/quixstreams/state/rocksdb/windowed/transaction.py
index 3779b3e29..dfa3fa12a 100644
--- a/quixstreams/state/rocksdb/windowed/transaction.py
+++ b/quixstreams/state/rocksdb/windowed/transaction.py
@@ -298,6 +298,7 @@ def expire_all_windows(
step_ms: int,
delete: bool = True,
collect: bool = False,
+ advance_last_expired_timestamp: bool = True,
) -> Iterable[ExpiredWindowDetail]:
"""
Get all expired windows for all prefix from RocksDB up to the specified `max_end_time` timestamp.
@@ -360,9 +361,12 @@ def expire_all_windows(
if collect:
self.delete_from_collection(end=start, prefix=prefix)
- self._set_timestamp(
- prefix=b"", cache=self._last_expired_timestamps, timestamp_ms=last_expired
- )
+ if advance_last_expired_timestamp:
+ self._set_timestamp(
+ prefix=b"",
+ cache=self._last_expired_timestamps,
+ timestamp_ms=last_expired,
+ )
def delete_windows(
self, max_start_time: int, delete_values: bool, prefix: bytes
diff --git a/quixstreams/state/types.py b/quixstreams/state/types.py
index c80c9e2ad..9f07b9499 100644
--- a/quixstreams/state/types.py
+++ b/quixstreams/state/types.py
@@ -378,6 +378,7 @@ def expire_all_windows(
step_ms: int,
delete: bool = True,
collect: bool = False,
+ advance_last_expired_timestamp: bool = True,
) -> Iterable[ExpiredWindowDetail[V]]:
"""
Get all expired windows for all prefix from RocksDB up to the specified `max_start_time` timestamp.
@@ -388,6 +389,7 @@ def expire_all_windows(
:param max_end_time: The timestamp up to which windows are considered expired, inclusive.
:param delete: If True, expired windows will be deleted.
:param collect: If True, values will be collected into windows.
+ :param advance_last_expired_timestamp: If True, the last expired timestamp will be persisted.
"""
...