Skip to content

Commit a96e03f

Browse files
authored
fix: update tests for cuDF 25.08.00a (#2700)
* remove xfail * added is_polars_exception * catch cudf exception * coverage happy? * coverage happier? * xfail nw.Date
1 parent ac74d8a commit a96e03f

File tree

4 files changed

+28
-22
lines changed

4 files changed

+28
-22
lines changed

narwhals/_polars/utils.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,19 @@ def narwhals_to_native_dtype( # noqa: C901, PLR0912
217217
return pl.Unknown() # pragma: no cover
218218

219219

220+
def _is_polars_exception(exception: Exception, backend_version: tuple[int, ...]) -> bool:
221+
if backend_version >= (1,):
222+
# Old versions of Polars didn't have PolarsError.
223+
return isinstance(exception, pl.exceptions.PolarsError)
224+
# Last attempt, for old Polars versions.
225+
return "polars.exceptions" in str(type(exception)) # pragma: no cover
226+
227+
228+
def _is_cudf_exception(exception: Exception) -> bool:
229+
# These exceptions are raised when running polars on GPUs via cuDF
230+
return str(exception).startswith("CUDF failure")
231+
232+
220233
def catch_polars_exception(
221234
exception: Exception, backend_version: tuple[int, ...]
222235
) -> NarwhalsError | Exception:
@@ -230,13 +243,7 @@ def catch_polars_exception(
230243
return DuplicateError(str(exception))
231244
elif isinstance(exception, pl.exceptions.ComputeError):
232245
return ComputeError(str(exception))
233-
if backend_version >= (1,) and isinstance(exception, pl.exceptions.PolarsError):
234-
# Old versions of Polars didn't have PolarsError.
246+
if _is_polars_exception(exception, backend_version) or _is_cudf_exception(exception):
235247
return NarwhalsError(str(exception)) # pragma: no cover
236-
elif backend_version < (1,) and "polars.exceptions" in str(
237-
type(exception)
238-
): # pragma: no cover
239-
# Last attempt, for old Polars versions.
240-
return NarwhalsError(str(exception))
241248
# Just return exception as-is.
242249
return exception

tests/expr_and_series/is_in_test.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
from __future__ import annotations
22

3-
import os
4-
53
import pytest
64

75
import narwhals as nw
@@ -18,16 +16,7 @@ def test_expr_is_in(constructor: Constructor) -> None:
1816
assert_equal_data(result, expected)
1917

2018

21-
def test_expr_is_in_empty_list(
22-
constructor: Constructor, request: pytest.FixtureRequest
23-
) -> None:
24-
if "polars_lazy" in str(constructor) and os.environ.get("NARWHALS_POLARS_GPU"):
25-
# Traceback:
26-
# narwhals.exceptions.ComputeError: RuntimeError:
27-
# CUDF failure at:/__w/cudf/cudf/cpp/include/cudf/utilities/type_dispatcher.hpp:567:
28-
# Invalid type_id.
29-
request.applymarker(pytest.mark.xfail)
30-
19+
def test_expr_is_in_empty_list(constructor: Constructor) -> None:
3120
df = nw.from_native(constructor(data))
3221
result = df.select(nw.col("a").is_in([]))
3322
expected = {"a": [False, False, False, False]}

tests/expr_and_series/lit_test.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,13 @@
88
import pytest
99

1010
import narwhals as nw
11-
from tests.utils import DASK_VERSION, PANDAS_VERSION, Constructor, assert_equal_data
11+
from tests.utils import (
12+
CUDF_VERSION,
13+
DASK_VERSION,
14+
PANDAS_VERSION,
15+
Constructor,
16+
assert_equal_data,
17+
)
1218

1319
if TYPE_CHECKING:
1420
from narwhals.dtypes import DType
@@ -109,8 +115,11 @@ def test_lit_operation_in_with_columns(
109115

110116
@pytest.mark.skipif(PANDAS_VERSION < (1, 5), reason="too old for pyarrow")
111117
def test_date_lit(constructor: Constructor, request: pytest.FixtureRequest) -> None:
112-
if "dask" in str(constructor):
113-
# https://github.com/dask/dask/issues/11637
118+
# https://github.com/dask/dask/issues/11637
119+
if "dask" in str(constructor) or (
120+
# https://github.com/rapidsai/cudf/pull/18832
121+
"cudf" in str(constructor) and CUDF_VERSION >= (25, 8, 0)
122+
):
114123
request.applymarker(pytest.mark.xfail)
115124
df = nw.from_native(constructor({"a": [1]}))
116125
result = df.with_columns(nw.lit(date(2020, 1, 1), dtype=nw.Date)).collect_schema()

tests/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def get_module_version_as_tuple(module_name: str) -> tuple[int, ...]:
3838
DASK_VERSION: tuple[int, ...] = get_module_version_as_tuple("dask")
3939
PYARROW_VERSION: tuple[int, ...] = get_module_version_as_tuple("pyarrow")
4040
PYSPARK_VERSION: tuple[int, ...] = get_module_version_as_tuple("pyspark")
41+
CUDF_VERSION: tuple[int, ...] = get_module_version_as_tuple("cudf")
4142

4243
Constructor: TypeAlias = Callable[[Any], "NativeLazyFrame | NativeFrame | DataFrameLike"]
4344
ConstructorEager: TypeAlias = Callable[[Any], "NativeFrame | DataFrameLike"]

0 commit comments

Comments
 (0)