From b56627d3e3b210389aff61a8ab7e1bb475b11bdf Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sun, 11 Jan 2026 10:27:20 -0500 Subject: [PATCH 1/3] Use an explicit wait in a dataframe query during testing to check for keyboard interrupts --- python/tests/test_dataframe.py | 104 +++++---------------------------- 1 file changed, 15 insertions(+), 89 deletions(-) diff --git a/python/tests/test_dataframe.py b/python/tests/test_dataframe.py index f481f31f6..882c854d1 100644 --- a/python/tests/test_dataframe.py +++ b/python/tests/test_dataframe.py @@ -37,6 +37,7 @@ WindowFrame, column, literal, + udf, ) from datafusion import ( col as df_col, @@ -3190,6 +3191,13 @@ def test_fill_null_all_null_column(ctx): assert result.column(1).to_pylist() == ["filled", "filled", "filled"] +@udf([pa.int64()], pa.int64(), "immutable") +def slow_udf(x: pa.Array) -> pa.Array: + # This must be longer than the check interval in wait_for_future + time.sleep(2.0) + return x + + def test_collect_interrupted(): """Test that a long-running query can be interrupted with Ctrl-C. @@ -3198,50 +3206,7 @@ def test_collect_interrupted(): """ # Create a context and a DataFrame with a query that will run for a while ctx = SessionContext() - - # Create a recursive computation that will run for some time - batches = [] - for i in range(10): - batch = pa.RecordBatch.from_arrays( - [ - pa.array(list(range(i * 1000, (i + 1) * 1000))), - pa.array([f"value_{j}" for j in range(i * 1000, (i + 1) * 1000)]), - ], - names=["a", "b"], - ) - batches.append(batch) - - # Register tables - ctx.register_record_batches("t1", [batches]) - ctx.register_record_batches("t2", [batches]) - - # Create a large join operation that will take time to process - df = ctx.sql(""" - WITH t1_expanded AS ( - SELECT - a, - b, - CAST(a AS DOUBLE) / 1.5 AS c, - CAST(a AS DOUBLE) * CAST(a AS DOUBLE) AS d - FROM t1 - CROSS JOIN (SELECT 1 AS dummy FROM t1 LIMIT 5) - ), - t2_expanded AS ( - SELECT - a, - b, - CAST(a AS DOUBLE) * 2.5 AS e, - CAST(a AS DOUBLE) * CAST(a AS DOUBLE) * CAST(a AS DOUBLE) AS f - FROM t2 - CROSS JOIN (SELECT 1 AS dummy FROM t2 LIMIT 5) - ) - SELECT - t1.a, t1.b, t1.c, t1.d, - t2.a AS a2, t2.b AS b2, t2.e, t2.f - FROM t1_expanded t1 - JOIN t2_expanded t2 ON t1.a % 100 = t2.a % 100 - WHERE t1.a > 100 AND t2.a > 100 - """) + df = ctx.from_pydict({"a": [1, 2, 3]}).select(slow_udf(column("a"))) # Flag to track if the query was interrupted interrupted = False @@ -3298,7 +3263,10 @@ def trigger_interrupt(): except KeyboardInterrupt: interrupted = True except Exception as e: - interrupt_error = e + if "KeyboardInterrupt" in str(e): + interrupted = True + else: + interrupt_error = e # Assert that the query was interrupted properly if not interrupted: @@ -3308,7 +3276,7 @@ def trigger_interrupt(): interrupt_thread.join(timeout=1.0) -def test_arrow_c_stream_interrupted(): # noqa: C901 PLR0915 +def test_arrow_c_stream_interrupted(): # noqa: C901 """__arrow_c_stream__ responds to ``KeyboardInterrupt`` signals. Similar to ``test_collect_interrupted`` this test issues a long running @@ -3318,49 +3286,7 @@ def test_arrow_c_stream_interrupted(): # noqa: C901 PLR0915 """ ctx = SessionContext() - - batches = [] - for i in range(10): - batch = pa.RecordBatch.from_arrays( - [ - pa.array(list(range(i * 1000, (i + 1) * 1000))), - pa.array([f"value_{j}" for j in range(i * 1000, (i + 1) * 1000)]), - ], - names=["a", "b"], - ) - batches.append(batch) - - ctx.register_record_batches("t1", [batches]) - ctx.register_record_batches("t2", [batches]) - - df = ctx.sql( - """ - WITH t1_expanded AS ( - SELECT - a, - b, - CAST(a AS DOUBLE) / 1.5 AS c, - CAST(a AS DOUBLE) * CAST(a AS DOUBLE) AS d - FROM t1 - CROSS JOIN (SELECT 1 AS dummy FROM t1 LIMIT 5) - ), - t2_expanded AS ( - SELECT - a, - b, - CAST(a AS DOUBLE) * 2.5 AS e, - CAST(a AS DOUBLE) * CAST(a AS DOUBLE) * CAST(a AS DOUBLE) AS f - FROM t2 - CROSS JOIN (SELECT 1 AS dummy FROM t2 LIMIT 5) - ) - SELECT - t1.a, t1.b, t1.c, t1.d, - t2.a AS a2, t2.b AS b2, t2.e, t2.f - FROM t1_expanded t1 - JOIN t2_expanded t2 ON t1.a % 100 = t2.a % 100 - WHERE t1.a > 100 AND t2.a > 100 - """ - ) + df = ctx.from_pydict({"a": [1, 2, 3]}).select(slow_udf(column("a"))) reader = pa.RecordBatchReader.from_stream(df) From 121d4246a6ef21016600ece940e443ac48e6e727 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sun, 11 Jan 2026 14:20:50 -0500 Subject: [PATCH 2/3] Add interrupt check when spawning futures --- src/utils.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/utils.rs b/src/utils.rs index cbc3d6d9b..6038c77b1 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -77,6 +77,11 @@ where let runtime: &Runtime = &get_tokio_runtime().0; const INTERVAL_CHECK_SIGNALS: Duration = Duration::from_millis(1_000); + // Some fast running processes that generate many `wait_for_future` calls like + // PartitionedDataFrameStreamReader::next require checking for interrupts early + py.run(cr"pass", None, None)?; + py.check_signals()?; + py.detach(|| { runtime.block_on(async { tokio::pin!(fut); From 9215d0889c8201d0f43e4bcd5ac6595b9beaaeea Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sun, 11 Jan 2026 14:24:54 -0500 Subject: [PATCH 3/3] Update unit test to do four variantions of fast/slow queries and interrupt either collect or stream --- python/tests/test_dataframe.py | 113 ++++++++------------------------- 1 file changed, 27 insertions(+), 86 deletions(-) diff --git a/python/tests/test_dataframe.py b/python/tests/test_dataframe.py index 882c854d1..30f9ab903 100644 --- a/python/tests/test_dataframe.py +++ b/python/tests/test_dataframe.py @@ -3198,97 +3198,35 @@ def slow_udf(x: pa.Array) -> pa.Array: return x -def test_collect_interrupted(): - """Test that a long-running query can be interrupted with Ctrl-C. - - This test simulates a Ctrl-C keyboard interrupt by raising a KeyboardInterrupt - exception in the main thread during a long-running query execution. - """ - # Create a context and a DataFrame with a query that will run for a while - ctx = SessionContext() - df = ctx.from_pydict({"a": [1, 2, 3]}).select(slow_udf(column("a"))) - - # Flag to track if the query was interrupted - interrupted = False - interrupt_error = None - main_thread = threading.main_thread() - - # Shared flag to indicate query execution has started - query_started = threading.Event() - max_wait_time = 5.0 # Maximum wait time in seconds - - # This function will be run in a separate thread and will raise - # KeyboardInterrupt in the main thread - def trigger_interrupt(): - """Poll for query start, then raise KeyboardInterrupt in the main thread""" - # Poll for query to start with small sleep intervals - start_time = time.time() - while not query_started.is_set(): - time.sleep(0.1) # Small sleep between checks - if time.time() - start_time > max_wait_time: - msg = f"Query did not start within {max_wait_time} seconds" - raise RuntimeError(msg) - - # Check if thread ID is available - thread_id = main_thread.ident - if thread_id is None: - msg = "Cannot get main thread ID" - raise RuntimeError(msg) - - # Use ctypes to raise exception in main thread - exception = ctypes.py_object(KeyboardInterrupt) - res = ctypes.pythonapi.PyThreadState_SetAsyncExc( - ctypes.c_long(thread_id), exception - ) - if res != 1: - # If res is 0, the thread ID was invalid - # If res > 1, we modified multiple threads - ctypes.pythonapi.PyThreadState_SetAsyncExc( - ctypes.c_long(thread_id), ctypes.py_object(0) - ) - msg = "Failed to raise KeyboardInterrupt in main thread" - raise RuntimeError(msg) - - # Start a thread to trigger the interrupt - interrupt_thread = threading.Thread(target=trigger_interrupt) - # we mark as daemon so the test process can exit even if this thread doesn't finish - interrupt_thread.daemon = True - interrupt_thread.start() - - # Execute the query and expect it to be interrupted - try: - # Signal that we're about to start the query - query_started.set() - df.collect() - except KeyboardInterrupt: - interrupted = True - except Exception as e: - if "KeyboardInterrupt" in str(e): - interrupted = True - else: - interrupt_error = e - - # Assert that the query was interrupted properly - if not interrupted: - pytest.fail(f"Query was not interrupted; got error: {interrupt_error}") - - # Make sure the interrupt thread has finished - interrupt_thread.join(timeout=1.0) - +@pytest.mark.parametrize( + ("slow_query", "as_c_stream"), + [ + (True, True), + (True, False), + (False, True), + (False, False), + ], +) +def test_collect_or_stream_interrupted(slow_query, as_c_stream): # noqa: C901 PLR0915 + """Ensure collection responds to ``KeyboardInterrupt`` signals. -def test_arrow_c_stream_interrupted(): # noqa: C901 - """__arrow_c_stream__ responds to ``KeyboardInterrupt`` signals. + This test issues a long-running query, and consumes the results via + either a collect() call or ``__arrow_c_stream__``. It raises + ``KeyboardInterrupt`` in the main thread and verifies that the + process has interrupted. - Similar to ``test_collect_interrupted`` this test issues a long running - query, but consumes the results via ``__arrow_c_stream__``. It then raises - ``KeyboardInterrupt`` in the main thread and verifies that the stream - iteration stops promptly with the appropriate exception. + The `slow_query` determines if the query itself is slow via a + UDF with a timeout or if it is a fast query that generates many + results so it takes a long time to iterate through them all. """ ctx = SessionContext() - df = ctx.from_pydict({"a": [1, 2, 3]}).select(slow_udf(column("a"))) + df = ctx.sql("select * from generate_series(1, 1000000000000000000)") + if slow_query: + df = ctx.from_pydict({"a": [1, 2, 3]}).select(slow_udf(column("a"))) - reader = pa.RecordBatchReader.from_stream(df) + if as_c_stream: + reader = pa.RecordBatchReader.from_stream(df) read_started = threading.Event() read_exception = [] @@ -3322,7 +3260,10 @@ def read_stream(): read_thread_id = threading.get_ident() try: read_started.set() - reader.read_all() + if as_c_stream: + reader.read_all() + else: + df.collect() # If we get here, the read completed without interruption read_exception.append(RuntimeError("Read completed without interruption")) except KeyboardInterrupt: