Skip to content

Commit 8cd4f82

Browse files
committed
Correct heartbeat handling in windowing
1 parent 9b40b99 commit 8cd4f82

File tree

1 file changed

+27
-6
lines changed
  • quixstreams/dataframe/windows

1 file changed

+27
-6
lines changed

quixstreams/dataframe/windows/base.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from quixstreams.context import message_context
2020
from quixstreams.core.stream import TransformExpandedCallback
2121
from quixstreams.core.stream.exceptions import InvalidOperation
22-
from quixstreams.core.stream.functions.heartbeat import is_heartbeat_message
2322
from quixstreams.models.topics.manager import TopicManager
2423
from quixstreams.state import WindowedPartitionTransaction
2524

@@ -93,13 +92,19 @@ def _apply_window(
9392
stream_id=self._dataframe.stream_id,
9493
processing_context=self._dataframe.processing_context,
9594
store_name=name,
96-
heartbeat_func=heartbeat_func,
95+
)
96+
heartbeat_func = _as_heartbeat(
97+
func=heartbeat_func,
98+
stream_id=self._dataframe.stream_id,
99+
processing_context=self._dataframe.processing_context,
100+
store_name=name,
97101
)
98102
# Manually modify the Stream and clone the source StreamingDataFrame
99103
# to avoid adding "transform" API to it.
100104
# Transform callbacks can modify record key and timestamp,
101105
# and it's prone to misuse.
102106
stream = self._dataframe.stream.add_transform(func=windowed_func, expand=True)
107+
stream = stream.add_heartbeat(func=heartbeat_func)
103108
return self._dataframe.__dataframe_clone__(stream=stream)
104109

105110
def final(self) -> "StreamingDataFrame":
@@ -417,7 +422,6 @@ def _as_windowed(
417422
processing_context: "ProcessingContext",
418423
store_name: str,
419424
stream_id: str,
420-
heartbeat_func,
421425
) -> TransformExpandedCallback:
422426
@functools.wraps(func)
423427
def wrapper(
@@ -430,9 +434,6 @@ def wrapper(
430434
stream_id=stream_id, partition=ctx.partition, store_name=store_name
431435
),
432436
)
433-
if is_heartbeat_message(key, value):
434-
return heartbeat_func(timestamp, transaction)
435-
436437
if key is None:
437438
logger.warning(
438439
f"Skipping window processing for a message because the key is None, "
@@ -444,6 +445,26 @@ def wrapper(
444445
return wrapper
445446

446447

448+
def _as_heartbeat(
449+
func, # TODO: typing?
450+
processing_context: "ProcessingContext",
451+
store_name: str,
452+
stream_id: str,
453+
): # TODO: typing?
454+
@functools.wraps(func)
455+
def wrapper(timestamp: int) -> Iterable[Message]:
456+
ctx = message_context()
457+
transaction = cast(
458+
WindowedPartitionTransaction,
459+
processing_context.checkpoint.get_store_transaction(
460+
stream_id=stream_id, partition=ctx.partition, store_name=store_name
461+
),
462+
)
463+
return func(timestamp, transaction)
464+
465+
return wrapper
466+
467+
447468
class WindowOnLateCallback(Protocol):
448469
def __call__(
449470
self,

0 commit comments

Comments
 (0)