19
19
from quixstreams .context import message_context
20
20
from quixstreams .core .stream import TransformExpandedCallback
21
21
from quixstreams .core .stream .exceptions import InvalidOperation
22
- from quixstreams .core .stream .functions .heartbeat import is_heartbeat_message
23
22
from quixstreams .models .topics .manager import TopicManager
24
23
from quixstreams .state import WindowedPartitionTransaction
25
24
@@ -93,13 +92,19 @@ def _apply_window(
93
92
stream_id = self ._dataframe .stream_id ,
94
93
processing_context = self ._dataframe .processing_context ,
95
94
store_name = name ,
96
- heartbeat_func = heartbeat_func ,
95
+ )
96
+ heartbeat_func = _as_heartbeat (
97
+ func = heartbeat_func ,
98
+ stream_id = self ._dataframe .stream_id ,
99
+ processing_context = self ._dataframe .processing_context ,
100
+ store_name = name ,
97
101
)
98
102
# Manually modify the Stream and clone the source StreamingDataFrame
99
103
# to avoid adding "transform" API to it.
100
104
# Transform callbacks can modify record key and timestamp,
101
105
# and it's prone to misuse.
102
106
stream = self ._dataframe .stream .add_transform (func = windowed_func , expand = True )
107
+ stream = stream .add_heartbeat (func = heartbeat_func )
103
108
return self ._dataframe .__dataframe_clone__ (stream = stream )
104
109
105
110
def final (self ) -> "StreamingDataFrame" :
@@ -417,7 +422,6 @@ def _as_windowed(
417
422
processing_context : "ProcessingContext" ,
418
423
store_name : str ,
419
424
stream_id : str ,
420
- heartbeat_func ,
421
425
) -> TransformExpandedCallback :
422
426
@functools .wraps (func )
423
427
def wrapper (
@@ -430,9 +434,6 @@ def wrapper(
430
434
stream_id = stream_id , partition = ctx .partition , store_name = store_name
431
435
),
432
436
)
433
- if is_heartbeat_message (key , value ):
434
- return heartbeat_func (timestamp , transaction )
435
-
436
437
if key is None :
437
438
logger .warning (
438
439
f"Skipping window processing for a message because the key is None, "
@@ -444,6 +445,26 @@ def wrapper(
444
445
return wrapper
445
446
446
447
448
+ def _as_heartbeat (
449
+ func , # TODO: typing?
450
+ processing_context : "ProcessingContext" ,
451
+ store_name : str ,
452
+ stream_id : str ,
453
+ ): # TODO: typing?
454
+ @functools .wraps (func )
455
+ def wrapper (timestamp : int ) -> Iterable [Message ]:
456
+ ctx = message_context ()
457
+ transaction = cast (
458
+ WindowedPartitionTransaction ,
459
+ processing_context .checkpoint .get_store_transaction (
460
+ stream_id = stream_id , partition = ctx .partition , store_name = store_name
461
+ ),
462
+ )
463
+ return func (timestamp , transaction )
464
+
465
+ return wrapper
466
+
467
+
447
468
class WindowOnLateCallback (Protocol ):
448
469
def __call__ (
449
470
self ,
0 commit comments