diff --git a/.github/workflows/ci-docs.yaml b/.github/workflows/ci-docs.yaml index 4b13d9ef9..a3a1d6e9e 100644 --- a/.github/workflows/ci-docs.yaml +++ b/.github/workflows/ci-docs.yaml @@ -21,6 +21,8 @@ jobs: run: | python -m pip install ".[all]" - uses: quarto-dev/quarto-actions/setup@v2 + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} with: tinytex: true - name: Build docs diff --git a/docs/_quarto.yml b/docs/_quarto.yml index 98bfd44d3..4307f9a5b 100644 --- a/docs/_quarto.yml +++ b/docs/_quarto.yml @@ -168,6 +168,13 @@ quartodoc: - GT.cols_move_to_end - GT.cols_hide - GT.cols_unhide + - title: Adding rows + desc: > + The [`grand_summary_rows()`](`great_tables.GT.grand_summary_rows`) function will add rows to + summarize data in your table, such as totals or averages. + contents: + - GT.grand_summary_rows + - title: Location Targeting and Styling Classes desc: > Location targeting is a powerful feature of **Great Tables**. It allows for the precise @@ -182,8 +189,10 @@ quartodoc: - loc.column_header - loc.spanner_labels - loc.column_labels + - loc.grand_summary_stub - loc.stub - loc.row_groups + - loc.grand_summary - loc.body - loc.footer - loc.source_notes diff --git a/docs/get-started/loc-selection.qmd b/docs/get-started/loc-selection.qmd index 6222b54db..e1ead79c2 100644 --- a/docs/get-started/loc-selection.qmd +++ b/docs/get-started/loc-selection.qmd @@ -23,7 +23,11 @@ data = [ ["", "loc.column_labels()", "columns"], ["row stub", "loc.stub()", "rows"], ["", "loc.row_groups()", "rows"], + # ["", "loc.summary_stub()", "rows"], + ["", "loc.grand_summary_stub()", "rows"], ["table body", "loc.body()", "columns and rows"], + # ["", "loc.summary_rows()", "columns and rows"], + ["", "loc.grand_summary_rows()", "columns and rows"], ["footer", "loc.footer()", "composite"], ["", "loc.source_notes()", ""], ] diff --git a/docs/get-started/table-theme-options.qmd b/docs/get-started/table-theme-options.qmd index ab7474444..e377546a2 100644 --- a/docs/get-started/table-theme-options.qmd +++ b/docs/get-started/table-theme-options.qmd @@ -25,6 +25,7 @@ gt_ex = ( .tab_header("THE HEADING", "(a subtitle)") .tab_stubhead("THE STUBHEAD") .tab_source_note("THE SOURCE NOTE") + .grand_summary_rows(fns={"GRAND SUMMARY ROW": lambda df: df.sum(numeric_only=True)}) ) gt_ex @@ -48,6 +49,7 @@ The code below illustrates the table parts `~~GT.tab_options()` can target, by s row_group_background_color="lightyellow", stub_background_color="lightgreen", source_notes_background_color="#f1e2af", + grand_summary_row_background_color="lightpink", ) ) ``` diff --git a/docs/get-started/targeted-styles.qmd b/docs/get-started/targeted-styles.qmd index eb3cf6239..a594beba0 100644 --- a/docs/get-started/targeted-styles.qmd +++ b/docs/get-started/targeted-styles.qmd @@ -18,7 +18,7 @@ Below is a big example that shows all possible `loc` specifiers being used. ```{python} from great_tables import GT, exibble, loc, style -# https://colorbrewer2.org/#type=qualitative&scheme=Paired&n=12 +# https://colorbrewer2.org/#type=qualitative&scheme=Paired&n=12 and grey brewer_colors = [ "#a6cee3", "#1f78b4", @@ -32,6 +32,7 @@ brewer_colors = [ "#6a3d9a", "#ffff99", "#b15928", + "#808080", ] c = iter(brewer_colors) @@ -43,6 +44,7 @@ gt = ( .tab_source_note("yo") .tab_spanner("spanner", ["char", "fctr"]) .tab_stubhead("stubhead") + .grand_summary_rows(fns={"Total": lambda x: x.sum(numeric_only=True)}) ) ( @@ -64,6 +66,9 @@ gt = ( .tab_style(style.borders(weight="3px"), loc.stub(rows=1)) .tab_style(style.fill(next(c)), loc.stub()) .tab_style(style.fill(next(c)), loc.stubhead()) + # Summary Rows -------------- + .tab_style(style.fill(next(c)), loc.grand_summary()) + .tab_style(style.fill(next(c)), loc.grand_summary_stub()) ) ``` @@ -129,3 +134,17 @@ gt.tab_style(style.fill("yellow"), loc.body()) ```{python} gt.tab_style(style.fill("yellow"), loc.stubhead()) ``` + +## Grand Summary Rows + +```{python} +( + gt.tab_style( + style.fill("yellow"), + loc.grand_summary_stub(), + ).tab_style( + style.fill("lightblue"), + loc.grand_summary(), + ) +) +``` diff --git a/docs/get-started/technical-notes.qmd b/docs/get-started/technical-notes.qmd new file mode 100644 index 000000000..163c16fac --- /dev/null +++ b/docs/get-started/technical-notes.qmd @@ -0,0 +1,71 @@ +--- +title: Technical notes +jupyter: python3 +--- + +This document holds technical notes on how the dataclass behind GT interact. + +```{python} +import polars as pl +from great_tables import GT, exibble + +``` + +## Boxhead, Stub, and Summary Rows + +```{python} + +lil_ex = pl.from_pandas(exibble).select("num", "row", "group").head(6) + +gt = GT(lil_ex).grand_summary_rows(fns={"Total": pl.col("num").sum()}) +``` + +:::{.grid} + +:::{.g-col-4} + +### Placeholder stub + +```{python} +(gt + + +) +``` + +::: + +:::{.g-col-4} + +### Rowname stub + +```{python} +(gt + .tab_stub(rowname_col="row") + +) +``` + +::: + +:::{.g-col-4} + +### Group stub + +```{python} +(gt + .tab_stub(groupname_col="group") + .tab_options(row_group_as_column=True) +) +``` + +::: + +::: + +One more not shown above is if both rowname and groupnames are used in the stub, then the summary +spans both those columns: + +```{python} +(gt.tab_stub(rowname_col="row", groupname_col="group").tab_options(row_group_as_column=True)) +``` diff --git a/great_tables/_gt_data.py b/great_tables/_gt_data.py index 1456d9454..5e55dbe19 100644 --- a/great_tables/_gt_data.py +++ b/great_tables/_gt_data.py @@ -2,7 +2,7 @@ import copy import re -from collections.abc import Sequence +from collections.abc import Mapping, Sequence from dataclasses import dataclass, field, replace from enum import Enum, auto from itertools import chain, product @@ -75,6 +75,8 @@ class GTData: _spanners: Spanners _heading: Heading _stubhead: Stubhead + _summary_rows: SummaryRows + _summary_rows_grand: SummaryRows _source_notes: SourceNotes _footnotes: Footnotes _styles: Styles @@ -122,6 +124,8 @@ def from_data( _spanners=Spanners([]), _heading=Heading(), _stubhead=None, + _summary_rows=SummaryRows(), + _summary_rows_grand=SummaryRows(_is_grand_summary=True), _source_notes=[], _footnotes=[], _styles=[], @@ -220,6 +224,11 @@ class ColInfoTypeEnum(Enum): row_group = auto() hidden = auto() + # A placeholder column created when there is no table data in the stub, + # but summary rows need to create one for their labels. e.g. "mean" to indicate + # the row is mean summaries + summary_placeholder = auto() + @dataclass(frozen=True) class ColInfo: @@ -249,7 +258,11 @@ def visible(self) -> bool: @property def is_stub(self) -> bool: - return self.type in (ColInfoTypeEnum.stub, ColInfoTypeEnum.row_group) + return self.type in ( + ColInfoTypeEnum.stub, + ColInfoTypeEnum.row_group, + ColInfoTypeEnum.summary_placeholder, + ) @property def defaulted_align(self) -> str: @@ -510,10 +523,12 @@ def _get_number_of_visible_data_columns(self) -> int: # Obtain the number of visible columns in the built table; this should # account for the size of the stub in the final, built table - def _get_effective_number_of_columns(self, stub: Stub, options: Options) -> int: + def _get_effective_number_of_columns( + self, stub: Stub, has_summary_rows: bool, options: Options + ) -> int: n_data_cols = self._get_number_of_visible_data_columns() - stub_layout = stub._get_stub_layout(options=options) + stub_layout = stub._get_stub_layout(has_summary_rows=has_summary_rows, options=options) # Once the stub is defined in the package, we need to account # for the width of the stub at build time to fully obtain the number # of visible columns in the built table @@ -674,7 +689,7 @@ def _stub_group_names_has_column(self, options: Options) -> bool: return row_group_as_column - def _get_stub_layout(self, options: Options) -> list[str]: + def _get_stub_layout(self, has_summary_rows: bool, options: Options) -> list[str]: # Determine which stub components are potentially present as columns stub_rownames_is_column = "row_id" in self._get_stub_components() stub_groupnames_is_column = self._stub_group_names_has_column(options=options) @@ -684,14 +699,12 @@ def _get_stub_layout(self, options: Options) -> list[str]: # Resolve the layout of the stub (i.e., the roles of columns if present) if n_stub_cols == 0: - # TODO: If summary rows are present, we will use the `rowname` column - # # for the summary row labels - # if _summary_exists(data=data): - # stub_layout = ["rowname"] - # else: - # stub_layout = [] - - stub_layout = [] + # If summary rows are present, we will use the `rowname` column + # for the summary row labels + if has_summary_rows: + stub_layout = ["rowname"] + else: + stub_layout = [] else: stub_layout = [ @@ -719,7 +732,7 @@ class GroupRowInfo: indices: list[int] = field(default_factory=list) # row_start: int | None = None # row_end: int | None = None - has_summary_rows: bool = False + # has_summary_rows: bool = False # TODO: remove summary_row_side: str | None = None def defaulted_label(self) -> str: @@ -972,6 +985,162 @@ def __init__(self, func: FormatFns, cols: list[str], rows: list[int]): Formats = list +# Summary Rows --- + +# This can't conflict with actual group ids since we have a +# seperate data structure for grand summary row infos + + +@dataclass(frozen=True) +class SummaryRowInfo: + """Information about a single summary row""" + + id: str + label: str # For now, label and id are identical + # The motivation for values as a dict is to ensure cols_* functions don't have to consider + # the implications on existing SummaryRowInfo objects + values: dict[str, Any] # TODO: consider datatype, series? + side: Literal["top", "bottom"] # TODO: consider enum + + +# TODO: refactor into a collection/dataclass wrapping the list part +# put most of the methods for filtering, adding, replacing there. +# Make immutable to avoid potential bugs. +class SummaryRows(Mapping[str, list[SummaryRowInfo]]): + """A sequence of summary rows + + The following strctures should always be true about summary rows: + - The id is also the label (often the same as the function name) + - There is at most 1 row for each group and id pairing + - If a summary row is added and no row exists for that group and id, add it + - If a summary row is added and a row exists for that group and id pairing, + then replace all cells (in values) that are numeric in the new version + """ + + _d: dict[str, list[SummaryRowInfo]] + _is_grand_summary: bool + + LIST_CLS = list + GRAND_SUMMARY_KEY = "grand" + + def __init__( + self, + entries: dict[str, list[SummaryRowInfo]] | None = None, + _is_grand_summary: bool = False, + ): + if entries is None: + self._d = {} + else: + self._d = entries + self._is_grand_summary = _is_grand_summary + + def __bool__(self) -> bool: + """Return True if there are any summary rows, False otherwise.""" + return len(self._d) > 0 + + def __getitem__(self, key: str | None) -> list[SummaryRowInfo]: + if self._is_grand_summary: + key = SummaryRows.GRAND_SUMMARY_KEY + + if not key: + raise KeyError("Summary row group key must not be None for group summary rows.") + + if key not in self._d: + raise KeyError(f"Group '{key}' not found in summary rows.") + + return self.LIST_CLS(self._d[key]) + + def define(self, **kwargs: list[SummaryRowInfo]) -> Self: + """Define multiple summary row groups at once, replacing any existing groups.""" + + new_d = dict(self._d) + for group_id, summary_rows in kwargs.items(): + new_d[group_id] = summary_rows + + return self.__class__(new_d, _is_grand_summary=self._is_grand_summary) + + def add_summary_row(self, summary_row: SummaryRowInfo, group_id: str | None = None) -> Self: + """Add a summary row following the merging rules in the class docstring.""" + + # TODO: group_id can be None for grand summary configuration, but can't be none + # for regular summary configuration (a bit double barrelled). + if self._is_grand_summary and group_id is None: + group_id = SummaryRows.GRAND_SUMMARY_KEY + elif group_id is None: + raise TypeError("group_id must be provided for group summary rows.") + + existing_group = self.get(group_id) + + if not existing_group: + return self.define(**{group_id: [summary_row]}) + + else: + existing_index = next( + (ii for ii, crnt_row in enumerate(existing_group) if crnt_row.id == summary_row.id), + None, + ) + + new_rows = self.LIST_CLS(existing_group) + + if existing_index is None: + # No existing row for this group and id, add it + new_rows.append(summary_row) + else: + # Replace existing row, but merge numeric values from new version + existing_row = new_rows[existing_index] + + # Start with existing values + merged_values = existing_row.values.copy() + + # Replace with numeric values from new row + for key, new_value in summary_row.values.items(): + merged_values[key] = new_value + + # Create merged row with new row's properties but merged values + merged_row = SummaryRowInfo( + id=summary_row.id, + label=summary_row.label, + values=merged_values, + # Setting this to existing row instead of summary_row means original + # side is fixed by whatever side is first assigned to this row + side=existing_row.side, + ) + + new_rows[existing_index] = merged_row + + return self.define(**{group_id: new_rows}) + + def get_summary_rows( + self, group_id: str | None = None, side: str | None = None + ) -> list[SummaryRowInfo]: + """Get list of summary rows for that group. If side is None, do not filter by side. + Sorts result with 'top' side first, then 'bottom'.""" + + result: list[SummaryRowInfo] = [] + + if self._is_grand_summary: + group_id = SummaryRows.GRAND_SUMMARY_KEY + elif group_id is None: + raise TypeError("group_id must be provided for group summary rows.") + + summary_row_group = self.get(group_id) + + if summary_row_group: + for summary_row in summary_row_group: + if side is None or summary_row.side == side: + result.append(summary_row) + + # Sort: 'top' first, then 'bottom' + result.sort(key=lambda r: 0 if r.side == "top" else 1) # TODO: modify if enum for side + return result + + def __iter__(self): + raise NotImplementedError + + def __len__(self): + raise NotImplementedError + + # Options ---- default_fonts_list = [ @@ -1130,25 +1299,25 @@ class Options: # summary_row_border_style: OptionsInfo = OptionsInfo(True, "summary_row", "value", "solid") # summary_row_border_width: OptionsInfo = OptionsInfo(True, "summary_row", "px", "2px") # summary_row_border_color: OptionsInfo = OptionsInfo(True, "summary_row", "value", "#D3D3D3") - # grand_summary_row_padding: OptionsInfo = OptionsInfo(True, "grand_summary_row", "px", "8px") - # grand_summary_row_padding_horizontal: OptionsInfo = OptionsInfo( - # True, "grand_summary_row", "px", "5px" - # ) - # grand_summary_row_background_color: OptionsInfo = OptionsInfo( - # True, "grand_summary_row", "value", None - # ) - # grand_summary_row_text_transform: OptionsInfo = OptionsInfo( - # True, "grand_summary_row", "value", "inherit" - # ) - # grand_summary_row_border_style: OptionsInfo = OptionsInfo( - # True, "grand_summary_row", "value", "double" - # ) - # grand_summary_row_border_width: OptionsInfo = OptionsInfo( - # True, "grand_summary_row", "px", "6px" - # ) - # grand_summary_row_border_color: OptionsInfo = OptionsInfo( - # True, "grand_summary_row", "value", "#D3D3D3" - # ) + grand_summary_row_padding: OptionsInfo = OptionsInfo(True, "grand_summary_row", "px", "8px") + grand_summary_row_padding_horizontal: OptionsInfo = OptionsInfo( + True, "grand_summary_row", "px", "5px" + ) + grand_summary_row_background_color: OptionsInfo = OptionsInfo( + True, "grand_summary_row", "value", None + ) + grand_summary_row_text_transform: OptionsInfo = OptionsInfo( + True, "grand_summary_row", "value", "inherit" + ) + grand_summary_row_border_style: OptionsInfo = OptionsInfo( + True, "grand_summary_row", "value", "double" + ) + grand_summary_row_border_width: OptionsInfo = OptionsInfo( + True, "grand_summary_row", "px", "6px" + ) + grand_summary_row_border_color: OptionsInfo = OptionsInfo( + True, "grand_summary_row", "value", "#D3D3D3" + ) # footnotes_font_size: OptionsInfo = OptionsInfo(True, "footnotes", "px", "90%") # footnotes_padding: OptionsInfo = OptionsInfo(True, "footnotes", "px", "4px") # footnotes_padding_horizontal: OptionsInfo = OptionsInfo(True, "footnotes", "px", "5px") diff --git a/great_tables/_locations.py b/great_tables/_locations.py index 246966303..8e91ef449 100644 --- a/great_tables/_locations.py +++ b/great_tables/_locations.py @@ -480,8 +480,66 @@ class LocRowGroups(Loc): rows: RowSelectExpr = None +# @dataclass +# class LocSummaryStub(Loc): +# rows: RowSelectExpr = None + + @dataclass -class LocSummaryLabel(Loc): +class LocGrandSummaryStub(Loc): + """Target the grand summary stub. + + With `loc.grand_summary_stub()` we can target the cells containing the grand summary row labels, + which reside in the table stub. This is useful for applying custom styling with the + [`tab_style()`](`great_tables.GT.tab_style`) method. That method has a `locations=` argument and + this class should be used there to perform the targeting. + + Parameters + ---------- + rows + The rows to target within the grand summary stub. Can either be a single row name or a + series of row names provided in a list. If no rows are specified, all grand summary rows + are targeted. Note that if rows are targeted by index, top and bottom grand summary rows + are indexed as one combined list starting with the top rows. + + Returns + ------- + LocGrandSummaryStub + A LocGrandSummaryStub object, which is used for a `locations=` argument if specifying the + table's grand summary rows' labels. + + Examples + -------- + Let's use a subset of the `gtcars` dataset in a new table. We will style the entire table grand + summary stub (the row labels) by using `locations=loc.grand_summary_stub()` within + [`tab_style()`](`great_tables.GT.tab_style`). + + ```{python} + from great_tables import GT, style, loc, vals + from great_tables.data import gtcars + + ( + GT( + gtcars[["mfr", "model", "hp", "trq", "mpg_c"]].head(6), + rowname_col="model", + ) + .fmt_integer(columns=["hp", "trq", "mpg_c"]) + .grand_summary_rows( + fns={ + "Min": lambda df: df.min(numeric_only=True), + "Max": lambda x: x.max(numeric_only=True), + }, + side="top", + fmt=vals.fmt_integer, + ) + .tab_style( + style=[style.text(color="crimson", weight="bold"), style.fill(color="lightgray")], + locations=loc.grand_summary_stub(), + ) + ) + ``` + """ + rows: RowSelectExpr = None @@ -550,11 +608,75 @@ class LocBody(Loc): mask: PlExpr | None = None +# @dataclass +# class LocSummary(Loc): +# # TODO: these can be tidyselectors +# columns: SelectExpr = None +# rows: RowSelectExpr = None +# mask: PlExpr | None = None + + @dataclass -class LocSummary(Loc): +class LocGrandSummary(Loc): + """Target the data cells in grand summary rows. + + With `loc.grand_summary()` we can target the cells containing the grand summary data. + This is useful for applying custom styling with the [`tab_style()`](`great_tables.GT.tab_style`) + method. That method has a `locations=` argument and this class should be used there to perform + the targeting. + + Parameters + ---------- + columns + The columns to target. Can either be a single column name or a series of column names + provided in a list. + rows + The rows to target. Can either be a single row name or a series of row names provided in a + list. Note that if rows are targeted by index, top and bottom grand summary rows are indexed + as one combined list starting with the top rows. + + Returns + ------- + LocGrandSummary + A LocGrandSummary object, which is used for a `locations=` argument if specifying the + table's grand summary rows. + + Examples + -------- + Let's use a subset of the `gtcars` dataset in a new table. We will style all of the grand + summary cells by using `locations=loc.grand_summary()` within + [`tab_style()`](`great_tables.GT.tab_style`). + + ```{python} + from great_tables import GT, style, loc, vals + from great_tables.data import gtcars + + ( + GT( + gtcars[["mfr", "model", "hp", "trq", "mpg_c"]].head(6), + rowname_col="model", + ) + .fmt_integer(columns=["hp", "trq", "mpg_c"]) + .grand_summary_rows( + fns={ + "Min": lambda df: df.min(numeric_only=True), + "Max": lambda x: x.max(numeric_only=True), + }, + side="top", + fmt=vals.fmt_integer, + ) + .tab_style( + style=[style.text(color="crimson", weight="bold"), style.fill(color="lightgray")], + locations=loc.grand_summary(), + ) + ) + ``` + """ + # TODO: these can be tidyselectors columns: SelectExpr = None rows: RowSelectExpr = None + mask: PlExpr | None = None @dataclass @@ -910,6 +1032,22 @@ def _(loc: LocRowGroups, data: GTData) -> set[str]: return group_pos +@resolve.register +def _(loc: LocGrandSummaryStub, data: GTData) -> set[int]: + # Select just grand summary rows + grand_summary_rows = data._summary_rows_grand.get_summary_rows() + grand_summary_rows_ids = [row.id for row in grand_summary_rows] + + rows = resolve_rows_i(data=grand_summary_rows_ids, expr=loc.rows) + + cell_pos = set(row[1] for row in rows) + return cell_pos + + +# @resolve.register(LocSummaryStub) +# Also target by groupname in styleinfo + + @resolve.register def _(loc: LocStub, data: GTData) -> set[int]: # TODO: what are the rules for matching row groups? @@ -918,6 +1056,35 @@ def _(loc: LocStub, data: GTData) -> set[int]: return cell_pos +@resolve.register +def _(loc: LocGrandSummary, data: GTData) -> list[CellPos]: + if (loc.columns is not None or loc.rows is not None) and loc.mask is not None: + raise ValueError( + "Cannot specify the `mask` argument along with `columns` or `rows` in `loc.body()`." + ) + + grand_summary_rows = data._summary_rows_grand.get_summary_rows() + grand_summary_rows_ids = [row.id for row in grand_summary_rows] + + if loc.mask is None: + rows = resolve_rows_i(data=grand_summary_rows_ids, expr=loc.rows) + cols = resolve_cols_i(data=data, expr=loc.columns) + # TODO: dplyr arranges by `Var1`, and does distinct (since you can tidyselect the same + # thing multiple times + cell_pos = [ + CellPos(col[1], row[1], colname=col[0]) for col, row in itertools.product(cols, rows) + ] + else: + # I am not sure how to approach this, since GTData._summary_rows is not a frame + # We could convert to a frame, but I don't think that's a simple step + raise NotImplementedError("Masked selection is not yet implemented for Grand Summary Rows") + return cell_pos + + +# @resolve.register(LocSummary) +# Also target by groupname in styleinfo + + @resolve.register def _(loc: LocBody, data: GTData) -> list[CellPos]: if (loc.columns is not None or loc.rows is not None) and loc.mask is not None: @@ -953,9 +1120,11 @@ def _(loc: LocBody, data: GTData) -> list[CellPos]: # LocStub # LocRowGroupLabel # LocRowLabel -# LocSummaryLabel +# LocSummaryStub +# LocGrandSummaryStub # LocBody # LocSummary +# LocGrandSummary # LocFooter # LocFootnotes # LocSourceNotes @@ -1039,8 +1208,10 @@ def _(loc: LocRowGroups, data: GTData, style: list[CellStyle]) -> GTData: ) -@set_style.register -def _(loc: LocStub, data: GTData, style: list[CellStyle]) -> GTData: +# @set_style.register(LocSummaryStub) +@set_style.register(LocStub) +@set_style.register(LocGrandSummaryStub) +def _(loc: (LocStub | LocGrandSummaryStub), data: GTData, style: list[CellStyle]) -> GTData: # validate ---- for entry in style: entry._raise_if_requires_data(loc) @@ -1051,8 +1222,10 @@ def _(loc: LocStub, data: GTData, style: list[CellStyle]) -> GTData: return data._replace(_styles=data._styles + new_styles) -@set_style.register -def _(loc: LocBody, data: GTData, style: list[CellStyle]) -> GTData: +# @set_style.register(LocSummary) +@set_style.register(LocBody) +@set_style.register(LocGrandSummary) +def _(loc: (LocBody | LocGrandSummary), data: GTData, style: list[CellStyle]) -> GTData: positions: list[CellPos] = resolve(loc, data) # evaluate any column expressions in styles diff --git a/great_tables/_modify_rows.py b/great_tables/_modify_rows.py index 4ee30ae74..f9fd39069 100644 --- a/great_tables/_modify_rows.py +++ b/great_tables/_modify_rows.py @@ -1,8 +1,21 @@ from __future__ import annotations -from typing import TYPE_CHECKING - -from ._gt_data import Locale, RowGroups, Styles +from typing import TYPE_CHECKING, Any, Callable, Literal + +from ._gt_data import ( + FormatFn, + GTData, + Locale, + RowGroups, + Styles, + SummaryRowInfo, +) +from ._tbl_data import ( + PlExpr, + SelectExpr, + TblData, + eval_aggregate, +) if TYPE_CHECKING: from ._types import GTSelf @@ -16,8 +29,8 @@ def row_group_order(self: GTSelf, groups: RowGroups) -> GTSelf: def _remove_from_body_styles(styles: Styles, column: str) -> Styles: # TODO: refactor - from ._utils_render_html import _is_loc from ._locations import LocBody + from ._utils_render_html import _is_loc new_styles = [ info for info in styles if not (_is_loc(info.locname, LocBody) and info.colname == column) @@ -178,3 +191,179 @@ def with_id(self: GTSelf, id: str | None = None) -> GTSelf: ``` """ return self._replace(_options=self._options._set_option_value("table_id", id)) + + +def grand_summary_rows( + self: GTSelf, + *, + fns: dict[str, PlExpr] | dict[str, Callable[[TblData], Any]], + fmt: FormatFn | None = None, + columns: SelectExpr = None, + side: Literal["bottom", "top"] = "bottom", + missing_text: str = "---", +) -> GTSelf: + """Add grand summary rows to the table. + + Add grand summary rows by using the table data and any suitable aggregation functions. With + grand summary rows, all of the available data in the gt table is incorporated (regardless of + whether some of the data are part of row groups). Multiple grand summary rows can be added via + expressions given to fns. You can selectively format the values in the resulting grand summary + cells by use of formatting expressions from the `vals.fmt_*` class of functions. + + Note that currently all arguments are keyword-only, since the final positions may change. + + Parameters + ---------- + fns + A dictionary mapping row labels to aggregation expressions. Can be either Polars + expressions or callable functions that take the entire DataFrame and return aggregated + results. Each key becomes the label for a grand summary row. + fmt + A formatting function from the `vals.fmt_*` family (e.g., `vals.fmt_number`, + `vals.fmt_currency`) to apply to the summary row values. If `None`, no formatting + is applied. + columns + Currently, this function does not support selection by columns. If you would like to choose + which columns to summarize, you can select columns within the functions given to `fns=`. + See examples below for more explicit cases. + side + Should the grand summary rows be placed at the `"bottom"` (the default) or the `"top"` of + the table? + missing_text + The text to be used in summary cells with no data outputs. + + Returns + ------- + GT + The GT object is returned. This is the same object that the method is called on so that we + can facilitate method chaining. + + Examples + -------- + Let's use a subset of the `sp500` dataset to create a table with grand summary rows. We'll + calculate min, max, and mean values for the numeric columns. Notice the different + approaches to selecting columns to apply the aggregations to: we can use polars selectors + or select the columns directly. + + ```{python} + import polars as pl + import polars.selectors as cs + from great_tables import GT, vals, style, loc + from great_tables.data import sp500 + + sp500_mini = ( + pl.from_pandas(sp500) + .slice(0, 7) + .drop(["volume", "adj_close"]) + ) + + ( + GT(sp500_mini, rowname_col="date") + .grand_summary_rows( + fns={ + "Minimum": pl.min("open", "high", "low", "close"), + "Maximum": pl.col("open", "high", "low", "close").max(), + "Average": cs.numeric().mean(), + }, + fmt=vals.fmt_currency, + ) + .tab_style( + style=[ + style.text(color="crimson"), + style.fill(color="lightgray"), + ], + locations=loc.grand_summary(), + ) + ) + ``` + + We can also use custom callable functions to create more complex summary calculations. + Notice here that grand summary rows can be placed at the top of the table and formatted + with currency notation, by passing a formatter from the `vals.fmt_*` class of functions. + + ```{python} + from great_tables import GT, style, loc, vals + from great_tables.data import gtcars + + def pd_median(df): + return df.median(numeric_only=True) + + + ( + GT( + gtcars[["mfr", "model", "hp", "trq", "mpg_c"]].head(6), + rowname_col="model", + ) + .fmt_integer(columns=["hp", "trq", "mpg_c"]) + .grand_summary_rows( + fns={ + "Min": lambda df: df.min(numeric_only=True), + "Max": lambda df: df.max(numeric_only=True), + "Median": pd_median, + }, + side="top", + fmt=vals.fmt_integer, + ) + .tab_style( + style=[style.text(color="crimson", weight="bold"), style.fill(color="lightgray")], + locations=loc.grand_summary_stub(), + ) + ) + ``` + + """ + if columns is not None: + raise NotImplementedError( + "Currently, grand_summary_rows() does not support column selection." + ) + + # summary_col_names = resolve_cols_c(data=self, expr=columns) + + new_summary = self._summary_rows_grand + for label, fn in fns.items(): + row_values_dict = _calculate_summary_row(self, fn, fmt, missing_text) + + summary_row_info = SummaryRowInfo( + id=label, + label=label, + values=row_values_dict, + side=side, + ) + + new_summary = new_summary.add_summary_row(summary_row_info) + + return self._replace(_summary_rows_grand=new_summary) + + +def _calculate_summary_row( + data: GTData, + fn: PlExpr | Callable[[TblData], Any], + fmt: FormatFn | None, + # summary_col_names: list[str], + missing_text: str, +) -> dict[str, Any]: + """Calculate a summary row using eval_transform.""" + original_columns = data._boxhead._get_columns() + summary_row = {} + + # Use eval_aggregate to apply the function/expression to the data + result_df = eval_aggregate(data._tbl_data, fn) + + # Extract results for each column + for col in original_columns: + if col in result_df: + res = result_df[col] + + if fmt is not None: + formatted = fmt([res]) + res = formatted[0] + + summary_row[col] = res + else: + summary_row[col] = missing_text + + return summary_row + + +# TODO: delegate to group by agg instead (group_by for summary row case) +# TODO: validates after diff --git a/great_tables/_options.py b/great_tables/_options.py index f4ab0fc8a..7aab82624 100644 --- a/great_tables/_options.py +++ b/great_tables/_options.py @@ -128,13 +128,13 @@ def tab_options( # summary_row_border_style: str | None = None, # summary_row_border_width: str | None = None, # summary_row_border_color: str | None = None, - # grand_summary_row_background_color: str | None = None, - # grand_summary_row_text_transform: str | None = None, - # grand_summary_row_padding: str | None = None, - # grand_summary_row_padding_horizontal: str | None = None, - # grand_summary_row_border_style: str | None = None, - # grand_summary_row_border_width: str | None = None, - # grand_summary_row_border_color: str | None = None, + grand_summary_row_background_color: str | None = None, + grand_summary_row_text_transform: str | None = None, + grand_summary_row_padding: str | None = None, + grand_summary_row_padding_horizontal: str | None = None, + grand_summary_row_border_style: str | None = None, + grand_summary_row_border_width: str | None = None, + grand_summary_row_border_color: str | None = None, # footnotes_background_color: str | None = None, # footnotes_font_size: str | None = None, # footnotes_padding: str | None = None, @@ -1386,10 +1386,8 @@ def opt_stylize( # Omit keys that are not needed for the `tab_options()` method # TODO: the omitted keys are for future use when: # (1) summary rows are implemented - # (2) grand summary rows are implemented omit_keys = { "summary_row_background_color", - "grand_summary_row_background_color", } def dict_omit_keys(dict: dict[str, str], omit_keys: set[str]) -> dict[str, str]: @@ -1440,6 +1438,8 @@ class StyleMapper: data_vlines_style: str data_vlines_color: str row_striping_background_color: str + grand_summary_row_background_color: str + # summary_row_background_color: str mappings: ClassVar[dict[str, list[str]]] = { "table_hlines_color": ["table_border_top_color", "table_border_bottom_color"], @@ -1461,6 +1461,8 @@ class StyleMapper: "data_vlines_style": ["table_body_vlines_style"], "data_vlines_color": ["table_body_vlines_color"], "row_striping_background_color": ["row_striping_background_color"], + "grand_summary_row_background_color": ["grand_summary_row_background_color"], + # "summary_row_background_color": ["summary_row_background_color"], } def map_entry(self, name: str) -> dict[str, list[str]]: diff --git a/great_tables/_tbl_data.py b/great_tables/_tbl_data.py index 43798e364..0d3418209 100644 --- a/great_tables/_tbl_data.py +++ b/great_tables/_tbl_data.py @@ -874,3 +874,65 @@ def _(ser: PyArrowChunkedArray, name: Optional[str] = None) -> PyArrowTable: import pyarrow as pa return pa.table({name: ser}) + + +# eval_aggregate ---- + + +@singledispatch +def eval_aggregate(df, expr) -> dict[str, Any]: + """Evaluate an expression against data and return a single row as a dictionary. + + This is designed for aggregation operations that produce summary statistics. + The result should be a single row with values for each column. + + Parameters + ---------- + data + The input data (DataFrame) + expr + The expression to evaluate (Polars expression or callable) + + Returns + ------- + dict[str, Any] + A dictionary mapping column names to their aggregated values + """ + raise NotImplementedError(f"eval_aggregate not implemented for type: {type(df)}") + + +@eval_aggregate.register +def _(df: PdDataFrame, expr: Callable[[PdDataFrame], PdSeries]) -> dict[str, Any]: + res = expr(df) + + if not isinstance(res, PdSeries): + raise ValueError(f"Result must be a pandas Series. Received {type(res)}") + + return res.to_dict() + + +@eval_aggregate.register +def _(df: PlDataFrame, expr: PlExpr) -> dict[str, Any]: + res = df.select(expr) + + if len(res) != 1: + raise ValueError( + f"Expression must produce exactly 1 row (aggregation). Got {len(res)} rows." + ) + + return res.to_dicts()[0] + + +@eval_aggregate.register +def _(df: PyArrowTable, expr: Callable[[PyArrowTable], PyArrowTable]) -> dict[str, Any]: + res = expr(df) + + if not isinstance(res, PyArrowTable): + raise ValueError(f"Result must be a PyArrow Table. Received {type(res)}") + + if res.num_rows != 1: + raise ValueError( + f"Expression must produce exactly 1 row (aggregation). Got {res.num_rows} rows." + ) + + return {col: res.column(col)[0].as_py() for col in res.column_names} diff --git a/great_tables/_utils_render_html.py b/great_tables/_utils_render_html.py index 3053004f4..e3bb7094d 100644 --- a/great_tables/_utils_render_html.py +++ b/great_tables/_utils_render_html.py @@ -1,17 +1,28 @@ from __future__ import annotations from itertools import chain -from typing import Any, cast +from typing import TYPE_CHECKING, Any, cast from htmltools import HTML, TagList, css, tags from . import _locations as loc -from ._gt_data import GroupRowInfo, GTData, Styles +from ._gt_data import ( + ColInfo, + ColInfoTypeEnum, + GroupRowInfo, + GTData, + StyleInfo, + Styles, + SummaryRowInfo, +) from ._spanners import spanners_print_matrix from ._tbl_data import _get_cell, cast_frame_to_string, replace_null_frame from ._text import BaseText, _process_text, _process_text_id from ._utils import heading_has_subtitle, heading_has_title, seq_groups +if TYPE_CHECKING: + from ._tbl_data import TblData + def _is_loc(loc: str | loc.Loc, cls: type[loc.Loc]): if isinstance(loc, str): @@ -74,10 +85,14 @@ def create_heading_component_h(data: GTData) -> str: title_style = _flatten_styles(styles_header + styles_title, wrap=True) subtitle_style = _flatten_styles(styles_header + styles_subtitle, wrap=True) + has_summary_rows = bool(data._summary_rows or data._summary_rows_grand) + # Get the effective number of columns, which is number of columns # that will finally be rendered accounting for the stub layout n_cols_total = data._boxhead._get_effective_number_of_columns( - stub=data._stub, options=data._options + stub=data._stub, + has_summary_rows=has_summary_rows, + options=data._options, ) if has_subtitle: @@ -117,7 +132,10 @@ def create_columns_component_h(data: GTData) -> str: # body = data._body # Get vector representation of stub layout - stub_layout = data._stub._get_stub_layout(options=data._options) + has_summary_rows = bool(data._summary_rows or data._summary_rows_grand) + stub_layout = data._stub._get_stub_layout( + has_summary_rows=has_summary_rows, options=data._options + ) # Determine the finalized number of spanner rows spanner_row_count = _get_spanners_matrix_height(data=data, omit_columns_row=True) @@ -426,27 +444,43 @@ def create_body_component_h(data: GTData) -> str: # Filter list of StyleInfo to only those that apply to the stub styles_row_group_label = [x for x in data._styles if _is_loc(x.locname, loc.LocRowGroups)] styles_row_label = [x for x in data._styles if _is_loc(x.locname, loc.LocStub)] - # styles_summary_label = [x for x in data._styles if _is_loc(x.locname, loc.LocSummaryLabel)] + # styles_summary_label = [x for x in data._styles if _is_loc(x.locname, loc.LocSummaryStub)] + styles_grand_summary_label = [ + x for x in data._styles if _is_loc(x.locname, loc.LocGrandSummaryStub) + ] # Filter list of StyleInfo to only those that apply to the body styles_cells = [x for x in data._styles if _is_loc(x.locname, loc.LocBody)] # styles_body = [x for x in data._styles if _is_loc(x.locname, loc.LocBody2)] # styles_summary = [x for x in data._styles if _is_loc(x.locname, loc.LocSummary)] + styles_grand_summary = [x for x in data._styles if _is_loc(x.locname, loc.LocGrandSummary)] # Get the default column vars column_vars = data._boxhead._get_default_columns() row_stub_var = data._boxhead._get_stub_column() - stub_layout = data._stub._get_stub_layout(options=data._options) + has_summary_rows = bool(data._summary_rows or data._summary_rows_grand) + stub_layout = data._stub._get_stub_layout( + has_summary_rows=has_summary_rows, options=data._options + ) has_row_stub_column = "rowname" in stub_layout has_group_stub_column = "group_label" in stub_layout has_groups = data._stub.group_ids is not None and len(data._stub.group_ids) > 0 # If there is a stub, then prepend that to the `column_vars` list - if row_stub_var is not None: - column_vars = [row_stub_var] + column_vars + if has_row_stub_column: + # There is already a column assigned to the rownames + if row_stub_var: + column_vars = [row_stub_var] + column_vars + # Else we have summary rows but no stub yet + else: + # TODO: this naming is not ideal + summary_row_stub_var = ColInfo( + "__do_not_use__", ColInfoTypeEnum.summary_placeholder, column_align="left" + ) + column_vars = [summary_row_stub_var] + column_vars # Is the stub to be striped? table_stub_striped = data._options.row_striping_include_stub.value @@ -456,6 +490,24 @@ def create_body_component_h(data: GTData) -> str: body_rows: list[str] = [] + # Add grand summary rows at top + top_g_summary_rows = data._summary_rows_grand.get_summary_rows(side="top") + for i, summary_row in enumerate(top_g_summary_rows): + row_html = _create_row_component_h( + column_vars=column_vars, + row_stub_var=row_stub_var, # Should probably include group stub? + has_row_stub_column=has_row_stub_column, # Should probably include group stub? + has_group_stub_column=has_group_stub_column, # Add this parameter + apply_stub_striping=False, # No striping for summary rows + apply_body_striping=False, # No striping for summary rows + styles_cells=styles_grand_summary, + styles_labels=styles_grand_summary_label, + row_index=i, + summary_row=summary_row, + css_class="gt_last_grand_summary_row_top" if i == len(top_g_summary_rows) - 1 else None, + ) + body_rows.append(row_html) + # iterate over rows (ordered by groupings) prev_group_info = None @@ -468,7 +520,7 @@ def create_body_component_h(data: GTData) -> str: odd_j_row = j % 2 == 1 - body_cells: list[str] = [] + leading_cell = None # Create table row or label in the stub specifically for group (if applicable) if has_groups: @@ -487,15 +539,13 @@ def create_body_component_h(data: GTData) -> str: if has_group_stub_column: rowspan_value = len(group_info.indices) - body_cells.append( - f""" {group_label}""" - ) # Append a table row for the group heading else: colspan_value = data._boxhead._get_effective_number_of_columns( - stub=data._stub, options=data._options + stub=data._stub, has_summary_rows=has_summary_rows, options=data._options ) group_class = ( @@ -508,65 +558,158 @@ def create_body_component_h(data: GTData) -> str: body_rows.append(group_row) - # Create row cells - for colinfo in column_vars: - cell_content: Any = _get_cell(tbl_data, i, colinfo.var) - cell_str: str = str(cell_content) - - # Determine whether the current cell is the stub cell - if has_row_stub_column: - is_stub_cell = colinfo.var == row_stub_var.var - else: - is_stub_cell = False + # Create data row + row_html = _create_row_component_h( + column_vars=column_vars, + row_stub_var=row_stub_var, + has_row_stub_column=has_row_stub_column, + has_group_stub_column=has_group_stub_column, + leading_cell=leading_cell, + apply_stub_striping=table_stub_striped and odd_j_row, + apply_body_striping=table_body_striped and odd_j_row, + styles_cells=styles_cells, + styles_labels=styles_row_label, + row_index=i, + tbl_data=tbl_data, + ) + body_rows.append(row_html) - # Get alignment for the current column from the `col_alignment` list - # by using the `name` value to obtain the index of the alignment value - cell_alignment = colinfo.defaulted_align + prev_group_info = group_info - # Get the style attributes for the current cell by filtering the - # `styles_cells` list for the current row and column - _body_styles = [x for x in styles_cells if x.rownum == i and x.colname == colinfo.var] + ## after the last row in the group, we need to append the summary rows for the group + ## if this table has summary rows + + # Add grand summary rows at bottom + bottom_g_summary_rows = data._summary_rows_grand.get_summary_rows(side="bottom") + for i, summary_row in enumerate(bottom_g_summary_rows): + row_html = _create_row_component_h( + column_vars=column_vars, + row_stub_var=row_stub_var, + has_row_stub_column=has_row_stub_column, + has_group_stub_column=has_group_stub_column, # Add this parameter + apply_stub_striping=False, + apply_body_striping=False, + styles_cells=styles_grand_summary, + styles_labels=styles_grand_summary_label, + row_index=i + len(top_g_summary_rows), + summary_row=summary_row, + css_class="gt_first_grand_summary_row_bottom" if i == 0 else None, + ) + body_rows.append(row_html) - if is_stub_cell: - el_name = "th" + all_body_rows = "\n".join(body_rows) - classes = ["gt_row", "gt_left", "gt_stub"] + return f""" +{all_body_rows} +""" - _rowname_styles = [x for x in styles_row_label if x.rownum == i] - if table_stub_striped and odd_j_row: - classes.append("gt_striped") +def _create_row_component_h( + column_vars: list[ColInfo], + row_stub_var: ColInfo | None, + has_row_stub_column: bool, + has_group_stub_column: bool, + apply_stub_striping: bool, + apply_body_striping: bool, + styles_cells: list[StyleInfo], # Either styles_cells OR styles_grand_summary + styles_labels: list[StyleInfo], # Either styles_row_label OR styles_grand_summary_label + leading_cell: str | None = None, # For group label when row_group_as_column = True + row_index: int | None = None, + summary_row: SummaryRowInfo | None = None, # For summary rows + tbl_data: TblData | None = None, + css_class: str | None = None, +) -> str: + """Create a single table row (either data row or summary row)""" + + is_summary_row = summary_row is not None + body_cells: list[str] = [] + + if leading_cell: + body_cells.append(leading_cell) + + # Handle special cases for summary rows with group stub columns + if is_summary_row and has_group_stub_column: + if has_row_stub_column: + # Case 1: Both row_stub_column and group_stub_column + # Create a single cell that spans both columns for summary row label (id) + colspan = 2 + else: + # Case 2: Only group_stub_column, no row_stub_column + colspan = 1 - else: - el_name = "td" + cell_styles = _flatten_styles( + [x for x in styles_labels if x.rownum == row_index], wrap=True + ) - classes = ["gt_row", f"gt_{cell_alignment}"] + classes = ["gt_row", "gt_left", "gt_stub", "gt_grand_summary_row"] + if css_class: + classes.append(css_class) + classes_str = " ".join(classes) - _rowname_styles = [] + body_cells.append( + f""" {summary_row.id}""" + ) - if table_body_striped and odd_j_row: - classes.append("gt_striped") + # Skip the first column in column_vars since we've already handled the stub + column_vars_to_process = [column for column in column_vars if not column.is_stub] - # Ensure that `classes` becomes a space-separated string - classes = " ".join(classes) - cell_styles = _flatten_styles( - _body_styles + _rowname_styles, - wrap=True, - ) + else: + # Normal case: process all column_vars + column_vars_to_process = column_vars + + for colinfo in column_vars_to_process: + # Get cell content + if is_summary_row: + if colinfo == row_stub_var or colinfo.is_stub: + cell_content = summary_row.id + else: + cell_content = summary_row.values.get(colinfo.var) + elif colinfo.type == ColInfoTypeEnum.summary_placeholder: + # TODO: this row is technically a summary row, but is_summary_row is False here + cell_content = " " + else: + cell_content = _get_cell(tbl_data, row_index, colinfo.var) - body_cells.append( - f""" <{el_name}{cell_styles} class="{classes}">{cell_str}""" - ) + if css_class: + classes = [css_class] + else: + classes = [] - prev_group_info = group_info + cell_str = str(cell_content) + cell_alignment = colinfo.defaulted_align - body_rows.append(" \n" + "\n".join(body_cells) + "\n ") + # Get styles + _body_styles = [ + x for x in styles_cells if x.rownum == row_index and x.colname == colinfo.var + ] + _rowname_styles = ( + [x for x in styles_labels if x.rownum == row_index] if colinfo.is_stub else [] + ) - all_body_rows = "\n".join(body_rows) + # Build classes and element + if colinfo.is_stub: + el_name = "th" + classes += ["gt_row", "gt_left", "gt_stub"] + if is_summary_row: + classes.append("gt_grand_summary_row") + if apply_stub_striping: + classes.append("gt_striped") + else: + el_name = "td" + classes += ["gt_row", f"gt_{cell_alignment}"] + if is_summary_row: + classes.append("gt_grand_summary_row") + if apply_body_striping: + classes.append("gt_striped") + + classes_str = " ".join(classes) + cell_styles = _flatten_styles(_body_styles + _rowname_styles, wrap=True) + + body_cells.append( + f""" <{el_name}{cell_styles} class="{classes_str}">{cell_str}""" + ) - return f""" -{all_body_rows} -""" + return " \n" + "\n".join(body_cells) + "\n " def create_source_notes_component_h(data: GTData) -> str: @@ -586,8 +729,9 @@ def create_source_notes_component_h(data: GTData) -> str: # Get the effective number of columns, which is number of columns # that will finally be rendered accounting for the stub layout + has_summary_rows = bool(data._summary_rows or data._summary_rows_grand) n_cols_total = data._boxhead._get_effective_number_of_columns( - stub=data._stub, options=data._options + stub=data._stub, has_summary_rows=has_summary_rows, options=data._options ) # Handle the multiline source notes case (each note takes up one line) diff --git a/great_tables/_utils_render_latex.py b/great_tables/_utils_render_latex.py index eb8bef7b6..5b7806337 100644 --- a/great_tables/_utils_render_latex.py +++ b/great_tables/_utils_render_latex.py @@ -554,7 +554,10 @@ def _render_as_latex(data: GTData, use_longtable: bool = False, tbl_pos: str | N _not_implemented("Styles are not yet supported in LaTeX output.") # Get list representation of stub layout - stub_layout = data._stub._get_stub_layout(options=data._options) + has_summary_rows = bool(data._summary_rows or data._summary_rows_grand) + stub_layout = data._stub._get_stub_layout( + has_summary_rows=has_summary_rows, options=data._options + ) # Throw exception if a stub is present in the table if "rowname" in stub_layout or "group_label" in stub_layout: diff --git a/great_tables/css/gt_styles_default.scss b/great_tables/css/gt_styles_default.scss index 692d7886f..8a82596c7 100644 --- a/great_tables/css/gt_styles_default.scss +++ b/great_tables/css/gt_styles_default.scss @@ -277,6 +277,28 @@ p { border-bottom-color: $table_body_border_bottom_color; } +.gt_grand_summary_row { + color: $font_color_grand_summary_row_background_color; + background-color: $grand_summary_row_background_color; + text-transform: $grand_summary_row_text_transform; + padding-top: $grand_summary_row_padding; + padding-bottom: $grand_summary_row_padding; + padding-left: $grand_summary_row_padding_horizontal; + padding-right: $grand_summary_row_padding_horizontal; +} + +.gt_first_grand_summary_row_bottom { + border-top-style: $grand_summary_row_border_style; + border-top-width: $grand_summary_row_border_width; + border-top-color: $grand_summary_row_border_color; +} + +.gt_last_grand_summary_row_top { + border-bottom-style: $grand_summary_row_border_style; + border-bottom-width: $grand_summary_row_border_width; + border-bottom-color: $grand_summary_row_border_color; +} + .gt_sourcenotes { color: $font_color_source_notes_background_color; background-color: $source_notes_background_color; diff --git a/great_tables/gt.py b/great_tables/gt.py index e38875950..db227822f 100644 --- a/great_tables/gt.py +++ b/great_tables/gt.py @@ -32,7 +32,7 @@ from ._gt_data import GTData from ._heading import tab_header from ._helpers import random_id -from ._modify_rows import row_group_order, tab_stub, with_id, with_locale +from ._modify_rows import grand_summary_rows, row_group_order, tab_stub, with_id, with_locale from ._options import ( opt_align_table_header, opt_all_caps, @@ -277,6 +277,7 @@ def __init__( tab_stub = tab_stub with_id = with_id with_locale = with_locale + grand_summary_rows = grand_summary_rows save = save show = show diff --git a/great_tables/loc.py b/great_tables/loc.py index e463ab132..86df0ca5e 100644 --- a/great_tables/loc.py +++ b/great_tables/loc.py @@ -17,10 +17,16 @@ # Stub ---- LocStub as stub, LocRowGroups as row_groups, + # LocSummaryStub as summary_stub, + LocGrandSummaryStub as grand_summary_stub, # # Body ---- LocBody as body, # + # Summary ---- + # LocSummary as summary, + LocGrandSummary as grand_summary, + # # Footer ---- LocFooter as footer, LocSourceNotes as source_notes, @@ -36,7 +42,11 @@ "column_labels", "stub", "row_groups", + # "summary_stub", + "grand_summary_stub", "body", + # "summary", + "grand_summary", "footer", "source_notes", ) diff --git a/tests/__snapshots__/test_export.ambr b/tests/__snapshots__/test_export.ambr index e8277b9ef..b658eeeb2 100644 --- a/tests/__snapshots__/test_export.ambr +++ b/tests/__snapshots__/test_export.ambr @@ -36,6 +36,9 @@ #test_table .gt_row_group_first th { border-top-width: 2px; } #test_table .gt_striped { color: #333333; background-color: #F4F4F4; } #test_table .gt_table_body { border-top-style: solid; border-top-width: 2px; border-top-color: #D3D3D3; border-bottom-style: solid; border-bottom-width: 2px; border-bottom-color: #D3D3D3; } + #test_table .gt_grand_summary_row { color: #333333; background-color: #FFFFFF; text-transform: inherit; padding-top: 8px; padding-bottom: 8px; padding-left: 5px; padding-right: 5px; } + #test_table .gt_first_grand_summary_row_bottom { border-top-style: double; border-top-width: 6px; border-top-color: #D3D3D3; } + #test_table .gt_last_grand_summary_row_top { border-bottom-style: double; border-bottom-width: 6px; border-bottom-color: #D3D3D3; } #test_table .gt_sourcenotes { color: #333333; background-color: #FFFFFF; border-bottom-style: none; border-bottom-width: 2px; border-bottom-color: #D3D3D3; border-left-style: none; border-left-width: 2px; border-left-color: #D3D3D3; border-right-style: none; border-right-width: 2px; border-right-color: #D3D3D3; } #test_table .gt_sourcenote { font-size: 90%; padding-top: 4px; padding-bottom: 4px; padding-left: 5px; padding-right: 5px; text-align: left; } #test_table .gt_left { text-align: left; } diff --git a/tests/__snapshots__/test_modify_rows.ambr b/tests/__snapshots__/test_modify_rows.ambr index 2480ef86a..78df6a3ca 100644 --- a/tests/__snapshots__/test_modify_rows.ambr +++ b/tests/__snapshots__/test_modify_rows.ambr @@ -1,4 +1,79 @@ # serializer version: 1 +# name: test_grand_summary_rows_snap[pd_and_pl] + ''' + + +   + 1 + 4 + + +   + 2 + 5 + + +   + 3 + 6 + + + Average + 2.0 + 5.0 + + + Maximum + 3 + 6 + + + ''' +# --- +# name: test_grand_summary_rows_with_group_as_col_snap + ''' + + + x + 1 + 4 + + + y + 2 + 5 + + + Average + 1.5 + 4.5 + + + ''' +# --- +# name: test_grand_summary_rows_with_rowname_snap + ''' + + + x + 1 + 4 + + + y + 2 + 5 + + + Average + 1.5 + 4.5 + + + ''' +# --- # name: test_row_group_order ''' diff --git a/tests/__snapshots__/test_options.ambr b/tests/__snapshots__/test_options.ambr index 128c77049..75caf6010 100644 --- a/tests/__snapshots__/test_options.ambr +++ b/tests/__snapshots__/test_options.ambr @@ -1024,6 +1024,28 @@ border-bottom-color: #0076BA; } + #abc .gt_grand_summary_row { + color: #333333; + background-color: #89D3FE; + text-transform: inherit; + padding-top: 8px; + padding-bottom: 8px; + padding-left: 5px; + padding-right: 5px; + } + + #abc .gt_first_grand_summary_row_bottom { + border-top-style: double; + border-top-width: 6px; + border-top-color: #D3D3D3; + } + + #abc .gt_last_grand_summary_row_top { + border-bottom-style: double; + border-bottom-width: 6px; + border-bottom-color: #D3D3D3; + } + #abc .gt_sourcenotes { color: #333333; background-color: #FFFFFF; @@ -1376,6 +1398,28 @@ border-bottom-color: #0076BA; } + #abc .gt_grand_summary_row { + color: #333333; + background-color: #89D3FE; + text-transform: inherit; + padding-top: 8px; + padding-bottom: 8px; + padding-left: 5px; + padding-right: 5px; + } + + #abc .gt_first_grand_summary_row_bottom { + border-top-style: double; + border-top-width: 6px; + border-top-color: #D3D3D3; + } + + #abc .gt_last_grand_summary_row_top { + border-bottom-style: double; + border-bottom-width: 6px; + border-bottom-color: #D3D3D3; + } + #abc .gt_sourcenotes { color: #333333; background-color: #FFFFFF; @@ -1836,6 +1880,28 @@ border-bottom-color: red; } + #abc .gt_grand_summary_row { + color: #000000; + background-color: red; + text-transform: inherit; + padding-top: 8px; + padding-bottom: 8px; + padding-left: 5px; + padding-right: 5px; + } + + #abc .gt_first_grand_summary_row_bottom { + border-top-style: double; + border-top-width: 6px; + border-top-color: #D3D3D3; + } + + #abc .gt_last_grand_summary_row_top { + border-bottom-style: double; + border-bottom-width: 6px; + border-bottom-color: #D3D3D3; + } + #abc .gt_sourcenotes { color: #000000; background-color: red; @@ -2188,6 +2254,28 @@ border-bottom-color: #D3D3D3; } + #abc .gt_grand_summary_row { + color: #333333; + background-color: #FFFFFF; + text-transform: inherit; + padding-top: 8px; + padding-bottom: 8px; + padding-left: 5px; + padding-right: 5px; + } + + #abc .gt_first_grand_summary_row_bottom { + border-top-style: double; + border-top-width: 6px; + border-top-color: #D3D3D3; + } + + #abc .gt_last_grand_summary_row_top { + border-bottom-style: double; + border-bottom-width: 6px; + border-bottom-color: #D3D3D3; + } + #abc .gt_sourcenotes { color: #333333; background-color: #FFFFFF; diff --git a/tests/__snapshots__/test_repr.ambr b/tests/__snapshots__/test_repr.ambr index d8ded59af..649d5c4a0 100644 --- a/tests/__snapshots__/test_repr.ambr +++ b/tests/__snapshots__/test_repr.ambr @@ -36,6 +36,9 @@ #test .gt_row_group_first th { border-top-width: 2px; } #test .gt_striped { color: #333333; background-color: #F4F4F4; } #test .gt_table_body { border-top-style: solid; border-top-width: 2px; border-top-color: #D3D3D3; border-bottom-style: solid; border-bottom-width: 2px; border-bottom-color: #D3D3D3; } + #test .gt_grand_summary_row { color: #333333; background-color: #FFFFFF; text-transform: inherit; padding-top: 8px; padding-bottom: 8px; padding-left: 5px; padding-right: 5px; } + #test .gt_first_grand_summary_row_bottom { border-top-style: double; border-top-width: 6px; border-top-color: #D3D3D3; } + #test .gt_last_grand_summary_row_top { border-bottom-style: double; border-bottom-width: 6px; border-bottom-color: #D3D3D3; } #test .gt_sourcenotes { color: #333333; background-color: #FFFFFF; border-bottom-style: none; border-bottom-width: 2px; border-bottom-color: #D3D3D3; border-left-style: none; border-left-width: 2px; border-left-color: #D3D3D3; border-right-style: none; border-right-width: 2px; border-right-color: #D3D3D3; } #test .gt_sourcenote { font-size: 90%; padding-top: 4px; padding-bottom: 4px; padding-left: 5px; padding-right: 5px; text-align: left; } #test .gt_left { text-align: left; } @@ -112,6 +115,9 @@ #test .gt_row_group_first th { border-top-width: 2px; } #test .gt_striped { color: #333333; background-color: #F4F4F4; } #test .gt_table_body { border-top-style: solid; border-top-width: 2px; border-top-color: #D3D3D3; border-bottom-style: solid; border-bottom-width: 2px; border-bottom-color: #D3D3D3; } + #test .gt_grand_summary_row { color: #333333; background-color: #FFFFFF; text-transform: inherit; padding-top: 8px; padding-bottom: 8px; padding-left: 5px; padding-right: 5px; } + #test .gt_first_grand_summary_row_bottom { border-top-style: double; border-top-width: 6px; border-top-color: #D3D3D3; } + #test .gt_last_grand_summary_row_top { border-bottom-style: double; border-bottom-width: 6px; border-bottom-color: #D3D3D3; } #test .gt_sourcenotes { color: #333333; background-color: #FFFFFF; border-bottom-style: none; border-bottom-width: 2px; border-bottom-color: #D3D3D3; border-left-style: none; border-left-width: 2px; border-left-color: #D3D3D3; border-right-style: none; border-right-width: 2px; border-right-color: #D3D3D3; } #test .gt_sourcenote { font-size: 90%; padding-top: 4px; padding-bottom: 4px; padding-left: 5px; padding-right: 5px; text-align: left; } #test .gt_left { text-align: left; } @@ -194,6 +200,9 @@ #test .gt_row_group_first th { border-top-width: 2px !important; } #test .gt_striped { color: #333333 !important; background-color: #F4F4F4 !important; } #test .gt_table_body { border-top-style: solid !important; border-top-width: 2px !important; border-top-color: #D3D3D3 !important; border-bottom-style: solid !important; border-bottom-width: 2px !important; border-bottom-color: #D3D3D3 !important; } + #test .gt_grand_summary_row { color: #333333 !important; background-color: #FFFFFF !important; text-transform: inherit !important; padding-top: 8px !important; padding-bottom: 8px !important; padding-left: 5px !important; padding-right: 5px !important; } + #test .gt_first_grand_summary_row_bottom { border-top-style: double !important; border-top-width: 6px !important; border-top-color: #D3D3D3 !important; } + #test .gt_last_grand_summary_row_top { border-bottom-style: double !important; border-bottom-width: 6px !important; border-bottom-color: #D3D3D3 !important; } #test .gt_sourcenotes { color: #333333 !important; background-color: #FFFFFF !important; border-bottom-style: none !important; border-bottom-width: 2px !important; border-bottom-color: #D3D3D3 !important; border-left-style: none !important; border-left-width: 2px !important; border-left-color: #D3D3D3 !important; border-right-style: none !important; border-right-width: 2px !important; border-right-color: #D3D3D3 !important; } #test .gt_sourcenote { font-size: 90% !important; padding-top: 4px !important; padding-bottom: 4px !important; padding-left: 5px !important; padding-right: 5px !important; text-align: left !important; } #test .gt_left { text-align: left !important; } @@ -273,6 +282,9 @@ #test .gt_row_group_first th { border-top-width: 2px; } #test .gt_striped { color: #333333; background-color: #F4F4F4; } #test .gt_table_body { border-top-style: solid; border-top-width: 2px; border-top-color: #D3D3D3; border-bottom-style: solid; border-bottom-width: 2px; border-bottom-color: #D3D3D3; } + #test .gt_grand_summary_row { color: #333333; background-color: #FFFFFF; text-transform: inherit; padding-top: 8px; padding-bottom: 8px; padding-left: 5px; padding-right: 5px; } + #test .gt_first_grand_summary_row_bottom { border-top-style: double; border-top-width: 6px; border-top-color: #D3D3D3; } + #test .gt_last_grand_summary_row_top { border-bottom-style: double; border-bottom-width: 6px; border-bottom-color: #D3D3D3; } #test .gt_sourcenotes { color: #333333; background-color: #FFFFFF; border-bottom-style: none; border-bottom-width: 2px; border-bottom-color: #D3D3D3; border-left-style: none; border-left-width: 2px; border-left-color: #D3D3D3; border-right-style: none; border-right-width: 2px; border-right-color: #D3D3D3; } #test .gt_sourcenote { font-size: 90%; padding-top: 4px; padding-bottom: 4px; padding-left: 5px; padding-right: 5px; text-align: left; } #test .gt_left { text-align: left; } @@ -349,6 +361,9 @@ #test .gt_row_group_first th { border-top-width: 2px !important; } #test .gt_striped { color: #333333 !important; background-color: #F4F4F4 !important; } #test .gt_table_body { border-top-style: solid !important; border-top-width: 2px !important; border-top-color: #D3D3D3 !important; border-bottom-style: solid !important; border-bottom-width: 2px !important; border-bottom-color: #D3D3D3 !important; } + #test .gt_grand_summary_row { color: #333333 !important; background-color: #FFFFFF !important; text-transform: inherit !important; padding-top: 8px !important; padding-bottom: 8px !important; padding-left: 5px !important; padding-right: 5px !important; } + #test .gt_first_grand_summary_row_bottom { border-top-style: double !important; border-top-width: 6px !important; border-top-color: #D3D3D3 !important; } + #test .gt_last_grand_summary_row_top { border-bottom-style: double !important; border-bottom-width: 6px !important; border-bottom-color: #D3D3D3 !important; } #test .gt_sourcenotes { color: #333333 !important; background-color: #FFFFFF !important; border-bottom-style: none !important; border-bottom-width: 2px !important; border-bottom-color: #D3D3D3 !important; border-left-style: none !important; border-left-width: 2px !important; border-left-color: #D3D3D3 !important; border-right-style: none !important; border-right-width: 2px !important; border-right-color: #D3D3D3 !important; } #test .gt_sourcenote { font-size: 90% !important; padding-top: 4px !important; padding-bottom: 4px !important; padding-left: 5px !important; padding-right: 5px !important; text-align: left !important; } #test .gt_left { text-align: left !important; } diff --git a/tests/__snapshots__/test_scss.ambr b/tests/__snapshots__/test_scss.ambr index 59494100e..8b567693d 100644 --- a/tests/__snapshots__/test_scss.ambr +++ b/tests/__snapshots__/test_scss.ambr @@ -286,6 +286,28 @@ border-bottom-color: #D3D3D3; } + #abc .gt_grand_summary_row { + color: #333333; + background-color: #FFFFFF; + text-transform: inherit; + padding-top: 8px; + padding-bottom: 8px; + padding-left: 5px; + padding-right: 5px; + } + + #abc .gt_first_grand_summary_row_bottom { + border-top-style: double; + border-top-width: 6px; + border-top-color: #D3D3D3; + } + + #abc .gt_last_grand_summary_row_top { + border-bottom-style: double; + border-bottom-width: 6px; + border-bottom-color: #D3D3D3; + } + #abc .gt_sourcenotes { color: #333333; background-color: #FFFFFF; diff --git a/tests/test_locations.py b/tests/test_locations.py index 3dbc75d5a..15c0316f3 100644 --- a/tests/test_locations.py +++ b/tests/test_locations.py @@ -13,6 +13,8 @@ LocSpannerLabels, LocStub, LocTitle, + LocGrandSummaryStub, + LocGrandSummary, resolve, resolve_cols_i, resolve_rows_i, @@ -295,3 +297,50 @@ def test_set_style_loc_title_from_column_error(snapshot): set_style(loc, gt_df, [style]) assert snapshot == exc_info.value.args[0] + + +@pytest.mark.parametrize( + "rows, res", + [ + (0, {0}), + ("min", {1}), + (["min"], {1}), + (["min", 0], {0, 1}), + (["min", -1], {1}), + ], +) +def test_resolve_loc_grand_summary_stub(rows, res): + df = pd.DataFrame({"x": [1, 2], "y": [3, 4]}) + gt = ( + GT(df) + .grand_summary_rows(fns={"min": lambda x: x.min()}, side="bottom") + .grand_summary_rows(fns={"max": lambda x: x.max()}, side="top") + ) + + cells = resolve(LocGrandSummaryStub(rows), gt) + + assert cells == res + + +@pytest.mark.parametrize( + "cols, rows, resolved_subset, length", + [ + (["x"], ["max"], CellPos(column=0, row=0, colname="x", rowname=None), 1), + ([1], ["min"], CellPos(column=1, row=1, colname="y", rowname=None), 1), + ([-1], [0, 1], CellPos(column=1, row=0, colname="y", rowname=None), 2), + ([-1, "x"], ["max", 1], CellPos(column=0, row=0, colname="x", rowname=None), 4), + ], +) +def test_resolve_loc_grand_summary(cols, rows, resolved_subset, length): + df = pd.DataFrame({"x": [1, 2], "y": [3, 4]}) + gt = ( + GT(df) + .grand_summary_rows(fns={"min": lambda x: x.min()}, side="bottom") + .grand_summary_rows(fns={"max": lambda x: x.max()}, side="top") + ) + + cells = resolve(LocGrandSummary(columns=cols, rows=rows), gt) + + assert isinstance(cells, list) + assert len(cells) == length + assert resolved_subset in cells diff --git a/tests/test_modify_rows.py b/tests/test_modify_rows.py index a20e8b771..57aa2c964 100644 --- a/tests/test_modify_rows.py +++ b/tests/test_modify_rows.py @@ -1,16 +1,36 @@ import pandas as pd +import polars as pl +import pytest -from great_tables import GT, loc, style +from great_tables import GT, loc, style, vals from great_tables._utils_render_html import create_body_component_h -def assert_rendered_body(snapshot, gt): +def render_only_body(gt) -> str: built = gt._build_data("html") body = create_body_component_h(built) + return body + + +def assert_rendered_body(snapshot, gt): + body = render_only_body(gt) + assert snapshot == body +def mean_expr(df: pd.DataFrame): + return df.mean(numeric_only=True) + + +def min_expr(df: pd.DataFrame): + return df.min(numeric_only=True) + + +def max_expr(df: pd.DataFrame): + return df.max(numeric_only=True) + + def test_row_group_order(snapshot): gt = GT(pd.DataFrame({"g": ["a", "b"], "x": [1, 2], "y": [3, 4]}), groupname_col="g") @@ -167,3 +187,133 @@ def test_with_id_preserves_other_options(): new_gt = gt.with_id("zzz") assert new_gt._options.table_id.value == "zzz" assert new_gt._options.container_width.value == "20px" + + +def test_grand_summary_rows_snap(snapshot): + for Frame in [pd.DataFrame, pl.DataFrame]: + df = Frame({"a": [1, 2, 3], "b": [4, 5, 6]}) + + if isinstance(df, pd.DataFrame): + + def mean_expr(df): + return df.mean() + + def max_expr(df): + return df.max() + + if isinstance(df, pl.DataFrame): + mean_expr = pl.all().mean() + max_expr = pl.all().max() + + res = GT(df).grand_summary_rows(fns={"Average": mean_expr, "Maximum": max_expr}) + + assert_rendered_body(snapshot(name="pd_and_pl"), res) + + +def test_grand_summary_rows_with_rowname_snap(snapshot): + df = pd.DataFrame({"a": [1, 2], "b": [4, 5], "row": ["x", "y"]}) + + res = GT(df, rowname_col="row").grand_summary_rows(fns={"Average": mean_expr}) + + assert_rendered_body(snapshot, res) + + +def test_grand_summary_rows_with_group_as_col_snap(snapshot): + df = pd.DataFrame({"a": [1, 2], "b": [4, 5], "group": ["x", "y"]}) + + res = ( + GT(df, groupname_col="group") + .grand_summary_rows(fns={"Average": mean_expr}) + .tab_options(row_group_as_column=True) + ) + + assert_rendered_body(snapshot, res) + + +def test_grand_summary_rows_with_rowname_and_groupname(): + df = pd.DataFrame({"a": [1, 2], "group": ["x", "x"], "row": ["row1", "row2"]}) + + res = ( + GT(df, rowname_col="row", groupname_col="group") + .grand_summary_rows(fns={"Average": mean_expr}) + .tab_options(row_group_as_column=True) + ) + html = res.as_raw_html() + + assert 'rowspan="2">x' in html + assert ( + 'Average' + in html + ) + + +def test_grand_summary_rows_with_missing(): + df = pd.DataFrame({"a": [1, 2], "non_numeric": ["x", "y"]}) + + res = GT(df).grand_summary_rows( + fns={"Average": mean_expr}, + missing_text="missing_text", + ) + html = res.as_raw_html() + + assert "missing_text" in html + + +def test_grand_summary_rows_bottom_and_top(): + df = pd.DataFrame({"a": [1, 2]}) + + res = ( + GT(df) + .grand_summary_rows(fns={"Top": min_expr}, side="top") + .grand_summary_rows(fns={"Bottom": max_expr}, side="bottom") + ) + + html = render_only_body(res) + + assert ( + 'gt_first_grand_summary_row_bottom gt_row gt_left gt_stub gt_grand_summary_row">Bottom' + in html + ) + assert ( + 'gt_last_grand_summary_row_top gt_row gt_left gt_stub gt_grand_summary_row">Top' + in html + ) + + +def test_grand_summary_rows_overwritten_row_maintains_location(): + df = pd.DataFrame({"a": [1, 2], "row": ["x", "y"]}) + + res = ( + GT(df) + .grand_summary_rows(fns={"Overwritten": min_expr}, side="top") + .grand_summary_rows(fns={"Overwritten": max_expr}, side="bottom") + ) + html = render_only_body(res) + + assert '"gt_last_grand_summary_row_top' in html + assert '"gt_first_grand_summary_row_bottom' not in html + + assert 'gt_grand_summary_row">1' not in html + assert 'gt_grand_summary_row">2' in html + + +def test_grand_summary_rows_with_fmt(): + df = pd.DataFrame({"a": [1, 3], "row": ["x", "y"]}) + + res = GT(df).grand_summary_rows(fns={"Average": mean_expr}, fmt=vals.fmt_integer) + html = render_only_body(res) + + assert 'gt_grand_summary_row">2' in html + assert 'gt_grand_summary_row">2.0' not in html + + +def test_grand_summary_rows_raises_columns_not_implemented(): + df = pd.DataFrame({"a": [1, 2], "row": ["x", "y"]}) + + with pytest.raises(NotImplementedError) as exc_info: + GT(df).grand_summary_rows(fns={"Minimum": min_expr}, columns="b") + + assert ( + "Currently, grand_summary_rows() does not support column selection." + in exc_info.value.args[0] + ) diff --git a/tests/test_tbl_data.py b/tests/test_tbl_data.py index 335c06fb3..d58307240 100644 --- a/tests/test_tbl_data.py +++ b/tests/test_tbl_data.py @@ -15,6 +15,7 @@ _validate_selector_list, cast_frame_to_string, create_empty_frame, + eval_aggregate, eval_select, get_column_names, group_splits, @@ -323,3 +324,111 @@ def test_copy_frame(df: DataFrameLike): copy_df = copy_frame(df) assert id(copy_df) != id(df) assert_frame_equal(copy_df, df) + + +def test_eval_aggregate_pandas(df: DataFrameLike): + def expr(df): + return pd.Series({"col1_sum": sum(df["col1"]), "col3_max": max(df["col3"])}) + + # Only pandas supports callable aggregation expressions + if isinstance(df, pl.DataFrame): + with pytest.raises(TypeError) as exc_info: + eval_aggregate(df, expr) + assert "cannot create expression literal for value of type function" in str( + exc_info.value.args[0] + ) + return + + if isinstance(df, pa.Table): + with pytest.raises(TypeError) as exc_info: + eval_aggregate(df, expr) + assert "unsupported operand type(s)" in str(exc_info.value.args[0]) + return + + result = eval_aggregate(df, expr) + assert result == {"col1_sum": 6, "col3_max": 6.0} + + +@pytest.mark.parametrize( + "expr,expected", + [ + (pl.col("col1").sum(), {"col1": 6}), + (pl.col("col2").first(), {"col2": "a"}), + (pl.col("col3").max(), {"col3": 6.0}), + ], +) +def test_eval_aggregate_polars(df: DataFrameLike, expr, expected): + # Only polars supports polars expression aggregations + if not isinstance(df, pl.DataFrame): + with pytest.raises(TypeError) as exc_info: + eval_aggregate(df, expr) + assert "'Expr' object is not callable" in str(exc_info.value.args[0]) + return + + result = eval_aggregate(df, expr) + assert result == expected + + +@pytest.mark.parametrize("Frame", [pd.DataFrame, pl.DataFrame, pa.table]) +def test_eval_aggregate_with_nulls(Frame): + df = Frame({"a": [1, None, 3]}) + + if isinstance(df, pd.DataFrame): + + def expr(df): + return pd.Series({"a": df["a"].sum()}) + + if isinstance(df, pl.DataFrame): + expr = pl.col("a").sum() + + if isinstance(df, pa.Table): + + def expr(tbl): + s = pa.compute.sum(tbl.column("a")) + return pa.table({"a": [s.as_py()]}) + + result = eval_aggregate(df, expr) + assert result == {"a": 4} + + +def test_eval_aggregate_pandas_raises(): + df = pd.DataFrame({"a": [1, 2, 3]}) + + def expr(df): + return {"a": df["a"].sum()} + + with pytest.raises(ValueError) as exc_info: + eval_aggregate(df, expr) + assert "Result must be a pandas Series" in str(exc_info.value) + + +def test_eval_aggregate_polars_raises(): + df = pl.DataFrame({"a": [1, 2, 3]}) + expr = pl.col("a") + + with pytest.raises(ValueError) as exc_info: + eval_aggregate(df, expr) + assert "Expression must produce exactly 1 row" in str(exc_info.value) + + +def test_eval_aggregate_pyarrow_raises1(): + df = pa.table({"a": [1, 2, 3]}) + + def expr(tbl): + s = pa.compute.sum(tbl.column("a")) + return {"a": [s.as_py()]} + + with pytest.raises(ValueError) as exc_info: + eval_aggregate(df, expr) + assert "Result must be a PyArrow Table" in str(exc_info.value) + + +def test_eval_aggregate_pyarrow_raises2(): + df = pa.table({"a": [1, 2, 3]}) + + def expr(tbl): + return pa.table({"a": tbl.column("a")}) + + with pytest.raises(ValueError) as exc_info: + eval_aggregate(df, expr) + assert "Expression must produce exactly 1 row (aggregation)" in str(exc_info.value)