Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
286 changes: 264 additions & 22 deletions bigframes/core/compile/sqlglot/expressions/datetime_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,35 +19,14 @@
from bigframes import dtypes
from bigframes import operations as ops
from bigframes.core.compile.constants import UNIT_TO_US_CONVERSION_FACTORS
from bigframes.core.compile.sqlglot import sqlglot_types
from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr
import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler

register_unary_op = scalar_compiler.scalar_op_compiler.register_unary_op
register_binary_op = scalar_compiler.scalar_op_compiler.register_binary_op


def _calculate_resample_first(y: TypedExpr, origin: str) -> sge.Expression:
if origin == "epoch":
return sge.convert(0)
elif origin == "start_day":
return sge.func(
"UNIX_MICROS",
sge.Cast(
this=sge.Cast(
this=y.expr, to=sge.DataType(this=sge.DataType.Type.DATE)
),
to=sge.DataType(this=sge.DataType.Type.TIMESTAMPTZ),
),
)
elif origin == "start":
return sge.func(
"UNIX_MICROS",
sge.Cast(this=y.expr, to=sge.DataType(this=sge.DataType.Type.TIMESTAMPTZ)),
)
else:
raise ValueError(f"Origin {origin} not supported")


@register_binary_op(ops.DatetimeToIntegerLabelOp, pass_op=True)
def datetime_to_integer_label_op(
x: TypedExpr, y: TypedExpr, op: ops.DatetimeToIntegerLabelOp
Expand Down Expand Up @@ -317,6 +296,20 @@ def _(expr: TypedExpr, op: ops.FloorDtOp) -> sge.Expression:
return sge.TimestampTrunc(this=expr.expr, unit=sge.Identifier(this=bq_freq))


def _calculate_resample_first(y: TypedExpr, origin: str) -> sge.Expression:
if origin == "epoch":
return sge.convert(0)
elif origin == "start_day":
return sge.func(
"UNIX_MICROS",
sge.Cast(this=sge.Cast(this=y.expr, to="DATE"), to="TIMESTAMP"),
)
elif origin == "start":
return sge.func("UNIX_MICROS", sge.Cast(this=y.expr, to="TIMESTAMP"))
else:
raise ValueError(f"Origin {origin} not supported")


@register_unary_op(ops.hour_op)
def _(expr: TypedExpr) -> sge.Expression:
return sge.Extract(this=sge.Identifier(this="HOUR"), expression=expr.expr)
Expand Down Expand Up @@ -436,3 +429,252 @@ def _(expr: TypedExpr, op: ops.UnixSeconds) -> sge.Expression:
@register_unary_op(ops.year_op)
def _(expr: TypedExpr) -> sge.Expression:
return sge.Extract(this=sge.Identifier(this="YEAR"), expression=expr.expr)


@register_binary_op(ops.IntegerLabelToDatetimeOp, pass_op=True)
def integer_label_to_datetime_op(
x: TypedExpr, y: TypedExpr, op: ops.IntegerLabelToDatetimeOp
) -> sge.Expression:
# Determine if the frequency is fixed by checking if 'op.freq.nanos' is defined.
try:
return _integer_label_to_datetime_op_fixed_frequency(x, y, op)
except ValueError:
return _integer_label_to_datetime_op_non_fixed_frequency(x, y, op)


def _integer_label_to_datetime_op_fixed_frequency(
x: TypedExpr, y: TypedExpr, op: ops.IntegerLabelToDatetimeOp
) -> sge.Expression:
"""
This function handles fixed frequency conversions where the unit can range
from microseconds (us) to days.
"""
us = op.freq.nanos / 1000
first = _calculate_resample_first(y, op.origin) # type: ignore
x_label = sge.Cast(
this=sge.func(
"TIMESTAMP_MICROS",
sge.Cast(
this=sge.Add(
this=sge.Mul(
this=sge.Cast(this=x.expr, to="BIGNUMERIC"),
expression=sge.convert(int(us)),
),
expression=sge.Cast(this=first, to="BIGNUMERIC"),
),
to="INT64",
),
),
to=sqlglot_types.from_bigframes_dtype(y.dtype),
)
return x_label


def _integer_label_to_datetime_op_non_fixed_frequency(
x: TypedExpr, y: TypedExpr, op: ops.IntegerLabelToDatetimeOp
) -> sge.Expression:
"""
This function handles non-fixed frequency conversions for units ranging
from weeks to years.
"""
rule_code = op.freq.rule_code

if rule_code == "W-SUN":
return _integer_label_to_datetime_op_weekly_freq(x, y, op)

if rule_code in ("ME", "M"):
return _integer_label_to_datetime_op_monthly_freq(x, y, op)

if rule_code in ("QE-DEC", "Q-DEC"):
return _integer_label_to_datetime_op_quarterly_freq(x, y, op)

if rule_code in ("YE-DEC", "A-DEC", "Y-DEC"):
return _integer_label_to_datetime_op_yearly_freq(x, y, op)

raise ValueError(rule_code)


def _integer_label_to_datetime_op_weekly_freq(
x: TypedExpr, y: TypedExpr, op: ops.IntegerLabelToDatetimeOp
) -> sge.Expression:
n = op.freq.n
# Calculate microseconds for the weekly interval.
us = n * 7 * 24 * 60 * 60 * 1000000
first = sge.func(
"UNIX_MICROS",
sge.Add(
this=sge.TimestampTrunc(
this=sge.Cast(this=y.expr, to="TIMESTAMP"),
unit=sge.Var(this="WEEK(MONDAY)"),
),
expression=sge.Interval(
this=sge.convert(6), unit=sge.Identifier(this="DAY")
),
),
)
return sge.Cast(
this=sge.func(
"TIMESTAMP_MICROS",
sge.Cast(
this=sge.Add(
this=sge.Mul(
this=sge.Cast(this=x.expr, to="BIGNUMERIC"),
expression=sge.convert(us),
),
expression=sge.Cast(this=first, to="BIGNUMERIC"),
),
to="INT64",
),
),
to=sqlglot_types.from_bigframes_dtype(y.dtype),
)


def _integer_label_to_datetime_op_monthly_freq(
x: TypedExpr, y: TypedExpr, op: ops.IntegerLabelToDatetimeOp
) -> sge.Expression:
n = op.freq.n
one = sge.convert(1)
twelve = sge.convert(12)
first = sge.Sub( # type: ignore
this=sge.Add(
this=sge.Mul(
this=sge.Extract(this="YEAR", expression=y.expr),
expression=twelve,
),
expression=sge.Extract(this="MONTH", expression=y.expr),
),
expression=one,
)
x_val = sge.Add(
this=sge.Mul(this=x.expr, expression=sge.convert(n)), expression=first
)
year = sge.Cast(
this=sge.Floor(this=sge.func("IEEE_DIVIDE", x_val, twelve)),
to="INT64",
)
month = sge.Add(this=sge.Mod(this=x_val, expression=twelve), expression=one)

next_year = sge.Case(
ifs=[
sge.If(
this=sge.EQ(this=month, expression=twelve),
true=sge.Add(this=year, expression=one),
)
],
default=year,
)
next_month = sge.Case(
ifs=[sge.If(this=sge.EQ(this=month, expression=twelve), true=one)],
default=sge.Add(this=month, expression=one),
)
next_month_date = sge.func(
"TIMESTAMP",
sge.Anonymous(
this="DATETIME",
expressions=[
next_year,
next_month,
one,
sge.convert(0),
sge.convert(0),
sge.convert(0),
],
),
)
x_label = sge.Sub( # type: ignore
this=next_month_date, expression=sge.Interval(this=one, unit="DAY")
)
return sge.Cast(this=x_label, to=sqlglot_types.from_bigframes_dtype(y.dtype))


def _integer_label_to_datetime_op_quarterly_freq(
x: TypedExpr, y: TypedExpr, op: ops.IntegerLabelToDatetimeOp
) -> sge.Expression:
n = op.freq.n
one = sge.convert(1)
three = sge.convert(3)
four = sge.convert(4)
twelve = sge.convert(12)
first = sge.Sub( # type: ignore
this=sge.Add(
this=sge.Mul(
this=sge.Extract(this="YEAR", expression=y.expr),
expression=four,
),
expression=sge.Extract(this="QUARTER", expression=y.expr),
),
expression=one,
)
x_val = sge.Add(
this=sge.Mul(this=x.expr, expression=sge.convert(n)), expression=first
)
year = sge.Cast(
this=sge.Floor(this=sge.func("IEEE_DIVIDE", x_val, four)),
to="INT64",
)
month = sge.Mul( # type: ignore
this=sge.Paren(
this=sge.Add(this=sge.Mod(this=x_val, expression=four), expression=one)
),
expression=three,
)

next_year = sge.Case(
ifs=[
sge.If(
this=sge.EQ(this=month, expression=twelve),
true=sge.Add(this=year, expression=one),
)
],
default=year,
)
next_month = sge.Case(
ifs=[sge.If(this=sge.EQ(this=month, expression=twelve), true=one)],
default=sge.Add(this=month, expression=one),
)
next_month_date = sge.Anonymous(
this="DATETIME",
expressions=[
next_year,
next_month,
one,
sge.convert(0),
sge.convert(0),
sge.convert(0),
],
)
x_label = sge.Sub( # type: ignore
this=next_month_date, expression=sge.Interval(this=one, unit="DAY")
)
return sge.Cast(this=x_label, to=sqlglot_types.from_bigframes_dtype(y.dtype))


def _integer_label_to_datetime_op_yearly_freq(
x: TypedExpr, y: TypedExpr, op: ops.IntegerLabelToDatetimeOp
) -> sge.Expression:
n = op.freq.n
one = sge.convert(1)
first = sge.Extract(this="YEAR", expression=y.expr)
x_val = sge.Add(
this=sge.Mul(this=x.expr, expression=sge.convert(n)), expression=first
)
next_year = sge.Add(this=x_val, expression=one) # type: ignore
next_month_date = sge.func(
"TIMESTAMP",
sge.Anonymous(
this="DATETIME",
expressions=[
next_year,
one,
one,
sge.convert(0),
sge.convert(0),
sge.convert(0),
],
),
)
x_label = sge.Sub( # type: ignore
this=next_month_date, expression=sge.Interval(this=one, unit="DAY")
)
return sge.Cast(this=x_label, to=sqlglot_types.from_bigframes_dtype(y.dtype))
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
WITH `bfcte_0` AS (
SELECT
`rowindex`,
`timestamp_col`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
), `bfcte_1` AS (
SELECT
*,
CAST(TIMESTAMP_MICROS(
CAST(CAST(`rowindex` AS BIGNUMERIC) * 86400000000 + CAST(UNIX_MICROS(CAST(`timestamp_col` AS TIMESTAMP)) AS BIGNUMERIC) AS INT64)
) AS TIMESTAMP) AS `bfcol_2`,
CAST(DATETIME(
CASE
WHEN (
MOD(
`rowindex` * 1 + EXTRACT(YEAR FROM `timestamp_col`) * 4 + EXTRACT(QUARTER FROM `timestamp_col`) - 1,
4
) + 1
) * 3 = 12
THEN CAST(FLOOR(
IEEE_DIVIDE(
`rowindex` * 1 + EXTRACT(YEAR FROM `timestamp_col`) * 4 + EXTRACT(QUARTER FROM `timestamp_col`) - 1,
4
)
) AS INT64) + 1
ELSE CAST(FLOOR(
IEEE_DIVIDE(
`rowindex` * 1 + EXTRACT(YEAR FROM `timestamp_col`) * 4 + EXTRACT(QUARTER FROM `timestamp_col`) - 1,
4
)
) AS INT64)
END,
CASE
WHEN (
MOD(
`rowindex` * 1 + EXTRACT(YEAR FROM `timestamp_col`) * 4 + EXTRACT(QUARTER FROM `timestamp_col`) - 1,
4
) + 1
) * 3 = 12
THEN 1
ELSE (
MOD(
`rowindex` * 1 + EXTRACT(YEAR FROM `timestamp_col`) * 4 + EXTRACT(QUARTER FROM `timestamp_col`) - 1,
4
) + 1
) * 3 + 1
END,
1,
0,
0,
0
) - INTERVAL 1 DAY AS TIMESTAMP) AS `bfcol_3`
FROM `bfcte_0`
)
SELECT
`bfcol_2` AS `fixed_freq`,
`bfcol_3` AS `non_fixed_freq`
FROM `bfcte_1`
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
WITH `bfcte_0` AS (
SELECT
`rowindex`,
`timestamp_col`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
), `bfcte_1` AS (
SELECT
*,
CAST(TIMESTAMP_MICROS(
CAST(CAST(`rowindex` AS BIGNUMERIC) * 86400000000 + CAST(UNIX_MICROS(CAST(`timestamp_col` AS TIMESTAMP)) AS BIGNUMERIC) AS INT64)
) AS TIMESTAMP) AS `bfcol_2`
FROM `bfcte_0`
)
SELECT
`bfcol_2` AS `fixed_freq`
FROM `bfcte_1`
Loading
Loading