Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 75 additions & 8 deletions pyiceberg/expressions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,11 @@
Type,
TypeVar,
Union,
cast,
)
from typing import Literal as TypingLiteral

from pydantic import Field
from pydantic import ConfigDict, Field, field_validator

from pyiceberg.expressions.literals import (
AboveMax,
Expand All @@ -52,8 +53,14 @@
ConfigDict = dict


def _to_unbound_term(term: Union[str, UnboundTerm[Any]]) -> UnboundTerm[Any]:
return Reference(term) if isinstance(term, str) else term
def _to_unbound_term(term: Union[str, UnboundTerm[Any], BoundReference[Any]]) -> UnboundTerm[Any]:
if isinstance(term, str):
return Reference(term)
if isinstance(term, UnboundTerm):
return term
if isinstance(term, BoundReference):
return Reference(term.field.name)
raise ValueError(f"Expected UnboundTerm | BoundReference | str, got {type(term).__name__}")


def _to_literal_set(values: Union[Iterable[L], Iterable[Literal[L]]]) -> Set[Literal[L]]:
Expand Down Expand Up @@ -743,12 +750,52 @@ def as_bound(self) -> Type[BoundNotIn[L]]:
return BoundNotIn[L]


class LiteralPredicate(UnboundPredicate[L], ABC):
literal: Literal[L]
class LiteralPredicate(IcebergBaseModel, UnboundPredicate[L], ABC):
type: TypingLiteral["lt", "lt-eq", "gt", "gt-eq", "eq", "not-eq", "starts-with", "not-starts-with"] = Field(alias="type")
term: UnboundTerm[Any]
value: Literal[L] = Field(alias="literal", serialization_alias="value")

model_config = ConfigDict(populate_by_name=True, frozen=True, arbitrary_types_allowed=True)

def __init__(
self,
term: Union[str, UnboundTerm[Any], BoundReference[Any]],
literal: Union[L, Literal[L], None] = None,
**data: Any,
) -> None: # pylint: disable=W0621
extra = dict(data)

literal_candidates = []
if literal is not None:
literal_candidates.append(literal)
if "literal" in extra:
literal_candidates.append(extra.pop("literal"))
if "value" in extra:
literal_candidates.append(extra.pop("value"))

literal_candidates = [candidate for candidate in literal_candidates if candidate is not None]

if not literal_candidates:
raise TypeError("LiteralPredicate requires a literal or value argument")
if len(literal_candidates) > 1:
raise TypeError("literal/value provided multiple times")

init = cast("Callable[..., None]", IcebergBaseModel.__init__)
init(self, term=_to_unbound_term(term), literal=_to_literal(literal_candidates[0]), **extra)

@field_validator("term", mode="before")
@classmethod
def _convert_term(cls, value: Any) -> UnboundTerm[Any]:
return _to_unbound_term(value)

@field_validator("value", mode="before")
@classmethod
def _convert_value(cls, value: Any) -> Literal[Any]:
return _to_literal(value)

def __init__(self, term: Union[str, UnboundTerm[Any]], literal: Union[L, Literal[L]]): # pylint: disable=W0621
super().__init__(term)
self.literal = _to_literal(literal) # pylint: disable=W0621
@property
def literal(self) -> Literal[L]:
return self.value

def bind(self, schema: Schema, case_sensitive: bool = True) -> BoundLiteralPredicate[L]:
bound_term = self.term.bind(schema, case_sensitive)
Expand All @@ -773,6 +820,10 @@ def __eq__(self, other: Any) -> bool:
return self.term == other.term and self.literal == other.literal
return False

def __str__(self) -> str:
"""Return the string representation of the LiteralPredicate class."""
return f"{str(self.__class__.__name__)}(term={repr(self.term)}, literal={repr(self.literal)})"

def __repr__(self) -> str:
"""Return the string representation of the LiteralPredicate class."""
return f"{str(self.__class__.__name__)}(term={repr(self.term)}, literal={repr(self.literal)})"
Expand Down Expand Up @@ -886,6 +937,8 @@ def as_unbound(self) -> Type[NotStartsWith[L]]:


class EqualTo(LiteralPredicate[L]):
type: TypingLiteral["eq"] = Field(default="eq", alias="type")

def __invert__(self) -> NotEqualTo[L]:
"""Transform the Expression into its negated version."""
return NotEqualTo[L](self.term, self.literal)
Expand All @@ -896,6 +949,8 @@ def as_bound(self) -> Type[BoundEqualTo[L]]:


class NotEqualTo(LiteralPredicate[L]):
type: TypingLiteral["not-eq"] = Field(default="not-eq", alias="type")

def __invert__(self) -> EqualTo[L]:
"""Transform the Expression into its negated version."""
return EqualTo[L](self.term, self.literal)
Expand All @@ -906,6 +961,8 @@ def as_bound(self) -> Type[BoundNotEqualTo[L]]:


class LessThan(LiteralPredicate[L]):
type: TypingLiteral["lt"] = Field(default="lt", alias="type")

def __invert__(self) -> GreaterThanOrEqual[L]:
"""Transform the Expression into its negated version."""
return GreaterThanOrEqual[L](self.term, self.literal)
Expand All @@ -916,6 +973,8 @@ def as_bound(self) -> Type[BoundLessThan[L]]:


class GreaterThanOrEqual(LiteralPredicate[L]):
type: TypingLiteral["gt-eq"] = Field(default="gt-eq", alias="type")

def __invert__(self) -> LessThan[L]:
"""Transform the Expression into its negated version."""
return LessThan[L](self.term, self.literal)
Expand All @@ -926,6 +985,8 @@ def as_bound(self) -> Type[BoundGreaterThanOrEqual[L]]:


class GreaterThan(LiteralPredicate[L]):
type: TypingLiteral["gt"] = Field(default="gt", alias="type")

def __invert__(self) -> LessThanOrEqual[L]:
"""Transform the Expression into its negated version."""
return LessThanOrEqual[L](self.term, self.literal)
Expand All @@ -936,6 +997,8 @@ def as_bound(self) -> Type[BoundGreaterThan[L]]:


class LessThanOrEqual(LiteralPredicate[L]):
type: TypingLiteral["lt-eq"] = Field(default="lt-eq", alias="type")

def __invert__(self) -> GreaterThan[L]:
"""Transform the Expression into its negated version."""
return GreaterThan[L](self.term, self.literal)
Expand All @@ -946,6 +1009,8 @@ def as_bound(self) -> Type[BoundLessThanOrEqual[L]]:


class StartsWith(LiteralPredicate[L]):
type: TypingLiteral["starts-with"] = Field(default="starts-with", alias="type")

def __invert__(self) -> NotStartsWith[L]:
"""Transform the Expression into its negated version."""
return NotStartsWith[L](self.term, self.literal)
Expand All @@ -956,6 +1021,8 @@ def as_bound(self) -> Type[BoundStartsWith[L]]:


class NotStartsWith(LiteralPredicate[L]):
type: TypingLiteral["not-starts-with"] = Field(default="not-starts-with", alias="type")

def __invert__(self) -> StartsWith[L]:
"""Transform the Expression into its negated version."""
return StartsWith[L](self.term, self.literal)
Expand Down
2 changes: 1 addition & 1 deletion pyiceberg/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def _try_import(module_name: str, extras_name: Optional[str] = None) -> types.Mo
raise NotInstalledError(msg) from None


def _transform_literal(func: Callable[[L], L], lit: Literal[L]) -> Literal[L]:
def _transform_literal(func: Callable[[Any], Any], lit: Literal[L]) -> Literal[L]:
"""Small helper to upwrap the value from the literal, and wrap it again."""
return literal(func(lit.value))

Expand Down
46 changes: 25 additions & 21 deletions tests/expressions/test_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from pyiceberg.conversions import to_bytes
from pyiceberg.expressions import (
And,
BooleanExpression,
EqualTo,
GreaterThan,
GreaterThanOrEqual,
Expand All @@ -30,6 +31,7 @@
IsNull,
LessThan,
LessThanOrEqual,
LiteralPredicate,
Not,
NotEqualTo,
NotIn,
Expand Down Expand Up @@ -301,7 +303,7 @@ def test_missing_stats() -> None:
upper_bounds=None,
)

expressions = [
expressions: list[BooleanExpression] = [
LessThan("no_stats", 5),
LessThanOrEqual("no_stats", 30),
EqualTo("no_stats", 70),
Expand All @@ -324,7 +326,7 @@ def test_zero_record_file_stats(schema_data_file: Schema) -> None:
file_path="file_1.parquet", file_format=FileFormat.PARQUET, partition=Record(), record_count=0
)

expressions = [
expressions: list[BooleanExpression] = [
LessThan("no_stats", 5),
LessThanOrEqual("no_stats", 30),
EqualTo("no_stats", 70),
Expand Down Expand Up @@ -683,26 +685,27 @@ def data_file_nan() -> DataFile:


def test_inclusive_metrics_evaluator_less_than_and_less_than_equal(schema_data_file_nan: Schema, data_file_nan: DataFile) -> None:
for operator in [LessThan, LessThanOrEqual]: # type: ignore
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("all_nan", 1)).eval(data_file_nan) # type: ignore[arg-type]
operators: tuple[type[LiteralPredicate[Any]], ...] = (LessThan, LessThanOrEqual)
for operator in operators:
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("all_nan", 1)).eval(data_file_nan)
assert not should_read, "Should not match: all nan column doesn't contain number"

should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("max_nan", 1)).eval(data_file_nan) # type: ignore[arg-type]
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("max_nan", 1)).eval(data_file_nan)
assert not should_read, "Should not match: 1 is smaller than lower bound"

should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("max_nan", 10)).eval(data_file_nan) # type: ignore[arg-type]
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("max_nan", 10)).eval(data_file_nan)
assert should_read, "Should match: 10 is larger than lower bound"

should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("min_max_nan", 1)).eval(data_file_nan) # type: ignore[arg-type]
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("min_max_nan", 1)).eval(data_file_nan)
assert should_read, "Should match: no visibility"

should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("all_nan_null_bounds", 1)).eval(data_file_nan) # type: ignore[arg-type]
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("all_nan_null_bounds", 1)).eval(data_file_nan)
assert not should_read, "Should not match: all nan column doesn't contain number"

should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("some_nan_correct_bounds", 1)).eval(data_file_nan) # type: ignore[arg-type]
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("some_nan_correct_bounds", 1)).eval(data_file_nan)
assert not should_read, "Should not match: 1 is smaller than lower bound"

should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("some_nan_correct_bounds", 10)).eval( # type: ignore[arg-type]
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("some_nan_correct_bounds", 10)).eval(
data_file_nan
)
assert should_read, "Should match: 10 larger than lower bound"
Expand All @@ -711,31 +714,32 @@ def test_inclusive_metrics_evaluator_less_than_and_less_than_equal(schema_data_f
def test_inclusive_metrics_evaluator_greater_than_and_greater_than_equal(
schema_data_file_nan: Schema, data_file_nan: DataFile
) -> None:
for operator in [GreaterThan, GreaterThanOrEqual]: # type: ignore
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("all_nan", 1)).eval(data_file_nan) # type: ignore[arg-type]
operators: tuple[type[LiteralPredicate[Any]], ...] = (GreaterThan, GreaterThanOrEqual)
for operator in operators:
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("all_nan", 1)).eval(data_file_nan)
assert not should_read, "Should not match: all nan column doesn't contain number"

should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("max_nan", 1)).eval(data_file_nan) # type: ignore[arg-type]
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("max_nan", 1)).eval(data_file_nan)
assert should_read, "Should match: upper bound is larger than 1"

should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("max_nan", 10)).eval(data_file_nan) # type: ignore[arg-type]
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("max_nan", 10)).eval(data_file_nan)
assert should_read, "Should match: upper bound is larger than 10"

should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("min_max_nan", 1)).eval(data_file_nan) # type: ignore[arg-type]
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("min_max_nan", 1)).eval(data_file_nan)
assert should_read, "Should match: no visibility"

should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("all_nan_null_bounds", 1)).eval(data_file_nan) # type: ignore[arg-type]
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("all_nan_null_bounds", 1)).eval(data_file_nan)
assert not should_read, "Should not match: all nan column doesn't contain number"

should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("some_nan_correct_bounds", 1)).eval(data_file_nan) # type: ignore[arg-type]
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("some_nan_correct_bounds", 1)).eval(data_file_nan)
assert should_read, "Should match: 1 is smaller than upper bound"

should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("some_nan_correct_bounds", 10)).eval( # type: ignore[arg-type]
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("some_nan_correct_bounds", 10)).eval(
data_file_nan
)
assert should_read, "Should match: 10 is smaller than upper bound"

should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("all_nan", 30)).eval(data_file_nan) # type: ignore[arg-type]
should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("all_nan", 30)).eval(data_file_nan)
assert not should_read, "Should not match: 30 is greater than upper bound"


Expand Down Expand Up @@ -1162,7 +1166,7 @@ def test_strict_missing_stats(strict_data_file_schema: Schema, strict_data_file_
upper_bounds=None,
)

expressions = [
expressions: list[BooleanExpression] = [
LessThan("no_stats", 5),
LessThanOrEqual("no_stats", 30),
EqualTo("no_stats", 70),
Expand All @@ -1185,7 +1189,7 @@ def test_strict_zero_record_file_stats(strict_data_file_schema: Schema) -> None:
file_path="file_1.parquet", file_format=FileFormat.PARQUET, partition=Record(), record_count=0
)

expressions = [
expressions: list[BooleanExpression] = [
LessThan("no_stats", 5),
LessThanOrEqual("no_stats", 30),
EqualTo("no_stats", 70),
Expand Down
Loading