Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 33 additions & 166 deletions python/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
WindowFrame,
column,
literal,
udf,
)
from datafusion import (
col as df_col,
Expand Down Expand Up @@ -3190,179 +3191,42 @@ def test_fill_null_all_null_column(ctx):
assert result.column(1).to_pylist() == ["filled", "filled", "filled"]


def test_collect_interrupted():
"""Test that a long-running query can be interrupted with Ctrl-C.
@udf([pa.int64()], pa.int64(), "immutable")
def slow_udf(x: pa.Array) -> pa.Array:
# This must be longer than the check interval in wait_for_future
time.sleep(2.0)
return x

This test simulates a Ctrl-C keyboard interrupt by raising a KeyboardInterrupt
exception in the main thread during a long-running query execution.
"""
# Create a context and a DataFrame with a query that will run for a while
ctx = SessionContext()

# Create a recursive computation that will run for some time
batches = []
for i in range(10):
batch = pa.RecordBatch.from_arrays(
[
pa.array(list(range(i * 1000, (i + 1) * 1000))),
pa.array([f"value_{j}" for j in range(i * 1000, (i + 1) * 1000)]),
],
names=["a", "b"],
)
batches.append(batch)

# Register tables
ctx.register_record_batches("t1", [batches])
ctx.register_record_batches("t2", [batches])

# Create a large join operation that will take time to process
df = ctx.sql("""
WITH t1_expanded AS (
SELECT
a,
b,
CAST(a AS DOUBLE) / 1.5 AS c,
CAST(a AS DOUBLE) * CAST(a AS DOUBLE) AS d
FROM t1
CROSS JOIN (SELECT 1 AS dummy FROM t1 LIMIT 5)
),
t2_expanded AS (
SELECT
a,
b,
CAST(a AS DOUBLE) * 2.5 AS e,
CAST(a AS DOUBLE) * CAST(a AS DOUBLE) * CAST(a AS DOUBLE) AS f
FROM t2
CROSS JOIN (SELECT 1 AS dummy FROM t2 LIMIT 5)
)
SELECT
t1.a, t1.b, t1.c, t1.d,
t2.a AS a2, t2.b AS b2, t2.e, t2.f
FROM t1_expanded t1
JOIN t2_expanded t2 ON t1.a % 100 = t2.a % 100
WHERE t1.a > 100 AND t2.a > 100
""")

# Flag to track if the query was interrupted
interrupted = False
interrupt_error = None
main_thread = threading.main_thread()

# Shared flag to indicate query execution has started
query_started = threading.Event()
max_wait_time = 5.0 # Maximum wait time in seconds

# This function will be run in a separate thread and will raise
# KeyboardInterrupt in the main thread
def trigger_interrupt():
"""Poll for query start, then raise KeyboardInterrupt in the main thread"""
# Poll for query to start with small sleep intervals
start_time = time.time()
while not query_started.is_set():
time.sleep(0.1) # Small sleep between checks
if time.time() - start_time > max_wait_time:
msg = f"Query did not start within {max_wait_time} seconds"
raise RuntimeError(msg)

# Check if thread ID is available
thread_id = main_thread.ident
if thread_id is None:
msg = "Cannot get main thread ID"
raise RuntimeError(msg)

# Use ctypes to raise exception in main thread
exception = ctypes.py_object(KeyboardInterrupt)
res = ctypes.pythonapi.PyThreadState_SetAsyncExc(
ctypes.c_long(thread_id), exception
)
if res != 1:
# If res is 0, the thread ID was invalid
# If res > 1, we modified multiple threads
ctypes.pythonapi.PyThreadState_SetAsyncExc(
ctypes.c_long(thread_id), ctypes.py_object(0)
)
msg = "Failed to raise KeyboardInterrupt in main thread"
raise RuntimeError(msg)

# Start a thread to trigger the interrupt
interrupt_thread = threading.Thread(target=trigger_interrupt)
# we mark as daemon so the test process can exit even if this thread doesn't finish
interrupt_thread.daemon = True
interrupt_thread.start()

# Execute the query and expect it to be interrupted
try:
# Signal that we're about to start the query
query_started.set()
df.collect()
except KeyboardInterrupt:
interrupted = True
except Exception as e:
interrupt_error = e

# Assert that the query was interrupted properly
if not interrupted:
pytest.fail(f"Query was not interrupted; got error: {interrupt_error}")

# Make sure the interrupt thread has finished
interrupt_thread.join(timeout=1.0)

@pytest.mark.parametrize(
("slow_query", "as_c_stream"),
[
(True, True),
(True, False),
(False, True),
(False, False),
],
)
def test_collect_or_stream_interrupted(slow_query, as_c_stream): # noqa: C901 PLR0915
"""Ensure collection responds to ``KeyboardInterrupt`` signals.

def test_arrow_c_stream_interrupted(): # noqa: C901 PLR0915
"""__arrow_c_stream__ responds to ``KeyboardInterrupt`` signals.
This test issues a long-running query, and consumes the results via
either a collect() call or ``__arrow_c_stream__``. It raises
``KeyboardInterrupt`` in the main thread and verifies that the
process has interrupted.

Similar to ``test_collect_interrupted`` this test issues a long running
query, but consumes the results via ``__arrow_c_stream__``. It then raises
``KeyboardInterrupt`` in the main thread and verifies that the stream
iteration stops promptly with the appropriate exception.
The `slow_query` determines if the query itself is slow via a
UDF with a timeout or if it is a fast query that generates many
results so it takes a long time to iterate through them all.
"""

ctx = SessionContext()
df = ctx.sql("select * from generate_series(1, 1000000000000000000)")
if slow_query:
df = ctx.from_pydict({"a": [1, 2, 3]}).select(slow_udf(column("a")))

batches = []
for i in range(10):
batch = pa.RecordBatch.from_arrays(
[
pa.array(list(range(i * 1000, (i + 1) * 1000))),
pa.array([f"value_{j}" for j in range(i * 1000, (i + 1) * 1000)]),
],
names=["a", "b"],
)
batches.append(batch)

ctx.register_record_batches("t1", [batches])
ctx.register_record_batches("t2", [batches])

df = ctx.sql(
"""
WITH t1_expanded AS (
SELECT
a,
b,
CAST(a AS DOUBLE) / 1.5 AS c,
CAST(a AS DOUBLE) * CAST(a AS DOUBLE) AS d
FROM t1
CROSS JOIN (SELECT 1 AS dummy FROM t1 LIMIT 5)
),
t2_expanded AS (
SELECT
a,
b,
CAST(a AS DOUBLE) * 2.5 AS e,
CAST(a AS DOUBLE) * CAST(a AS DOUBLE) * CAST(a AS DOUBLE) AS f
FROM t2
CROSS JOIN (SELECT 1 AS dummy FROM t2 LIMIT 5)
)
SELECT
t1.a, t1.b, t1.c, t1.d,
t2.a AS a2, t2.b AS b2, t2.e, t2.f
FROM t1_expanded t1
JOIN t2_expanded t2 ON t1.a % 100 = t2.a % 100
WHERE t1.a > 100 AND t2.a > 100
"""
)

reader = pa.RecordBatchReader.from_stream(df)
if as_c_stream:
reader = pa.RecordBatchReader.from_stream(df)

read_started = threading.Event()
read_exception = []
Expand Down Expand Up @@ -3396,7 +3260,10 @@ def read_stream():
read_thread_id = threading.get_ident()
try:
read_started.set()
reader.read_all()
if as_c_stream:
reader.read_all()
else:
df.collect()
# If we get here, the read completed without interruption
read_exception.append(RuntimeError("Read completed without interruption"))
except KeyboardInterrupt:
Expand Down
5 changes: 5 additions & 0 deletions src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@ where
let runtime: &Runtime = &get_tokio_runtime().0;
const INTERVAL_CHECK_SIGNALS: Duration = Duration::from_millis(1_000);

// Some fast running processes that generate many `wait_for_future` calls like
// PartitionedDataFrameStreamReader::next require checking for interrupts early
py.run(cr"pass", None, None)?;
py.check_signals()?;

py.detach(|| {
runtime.block_on(async {
tokio::pin!(fut);
Expand Down