diff --git a/.github/workflows/ci-docs.yaml b/.github/workflows/ci-docs.yaml index 4b13d9ef9..a3a1d6e9e 100644 --- a/.github/workflows/ci-docs.yaml +++ b/.github/workflows/ci-docs.yaml @@ -21,6 +21,8 @@ jobs: run: | python -m pip install ".[all]" - uses: quarto-dev/quarto-actions/setup@v2 + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} with: tinytex: true - name: Build docs diff --git a/docs/_quarto.yml b/docs/_quarto.yml index 98bfd44d3..4307f9a5b 100644 --- a/docs/_quarto.yml +++ b/docs/_quarto.yml @@ -168,6 +168,13 @@ quartodoc: - GT.cols_move_to_end - GT.cols_hide - GT.cols_unhide + - title: Adding rows + desc: > + The [`grand_summary_rows()`](`great_tables.GT.grand_summary_rows`) function will add rows to + summarize data in your table, such as totals or averages. + contents: + - GT.grand_summary_rows + - title: Location Targeting and Styling Classes desc: > Location targeting is a powerful feature of **Great Tables**. It allows for the precise @@ -182,8 +189,10 @@ quartodoc: - loc.column_header - loc.spanner_labels - loc.column_labels + - loc.grand_summary_stub - loc.stub - loc.row_groups + - loc.grand_summary - loc.body - loc.footer - loc.source_notes diff --git a/docs/get-started/loc-selection.qmd b/docs/get-started/loc-selection.qmd index 6222b54db..e1ead79c2 100644 --- a/docs/get-started/loc-selection.qmd +++ b/docs/get-started/loc-selection.qmd @@ -23,7 +23,11 @@ data = [ ["", "loc.column_labels()", "columns"], ["row stub", "loc.stub()", "rows"], ["", "loc.row_groups()", "rows"], + # ["", "loc.summary_stub()", "rows"], + ["", "loc.grand_summary_stub()", "rows"], ["table body", "loc.body()", "columns and rows"], + # ["", "loc.summary_rows()", "columns and rows"], + ["", "loc.grand_summary_rows()", "columns and rows"], ["footer", "loc.footer()", "composite"], ["", "loc.source_notes()", ""], ] diff --git a/docs/get-started/table-theme-options.qmd b/docs/get-started/table-theme-options.qmd index ab7474444..e377546a2 100644 --- a/docs/get-started/table-theme-options.qmd +++ b/docs/get-started/table-theme-options.qmd @@ -25,6 +25,7 @@ gt_ex = ( .tab_header("THE HEADING", "(a subtitle)") .tab_stubhead("THE STUBHEAD") .tab_source_note("THE SOURCE NOTE") + .grand_summary_rows(fns={"GRAND SUMMARY ROW": lambda df: df.sum(numeric_only=True)}) ) gt_ex @@ -48,6 +49,7 @@ The code below illustrates the table parts `~~GT.tab_options()` can target, by s row_group_background_color="lightyellow", stub_background_color="lightgreen", source_notes_background_color="#f1e2af", + grand_summary_row_background_color="lightpink", ) ) ``` diff --git a/docs/get-started/targeted-styles.qmd b/docs/get-started/targeted-styles.qmd index eb3cf6239..a594beba0 100644 --- a/docs/get-started/targeted-styles.qmd +++ b/docs/get-started/targeted-styles.qmd @@ -18,7 +18,7 @@ Below is a big example that shows all possible `loc` specifiers being used. ```{python} from great_tables import GT, exibble, loc, style -# https://colorbrewer2.org/#type=qualitative&scheme=Paired&n=12 +# https://colorbrewer2.org/#type=qualitative&scheme=Paired&n=12 and grey brewer_colors = [ "#a6cee3", "#1f78b4", @@ -32,6 +32,7 @@ brewer_colors = [ "#6a3d9a", "#ffff99", "#b15928", + "#808080", ] c = iter(brewer_colors) @@ -43,6 +44,7 @@ gt = ( .tab_source_note("yo") .tab_spanner("spanner", ["char", "fctr"]) .tab_stubhead("stubhead") + .grand_summary_rows(fns={"Total": lambda x: x.sum(numeric_only=True)}) ) ( @@ -64,6 +66,9 @@ gt = ( .tab_style(style.borders(weight="3px"), loc.stub(rows=1)) .tab_style(style.fill(next(c)), loc.stub()) .tab_style(style.fill(next(c)), loc.stubhead()) + # Summary Rows -------------- + .tab_style(style.fill(next(c)), loc.grand_summary()) + .tab_style(style.fill(next(c)), loc.grand_summary_stub()) ) ``` @@ -129,3 +134,17 @@ gt.tab_style(style.fill("yellow"), loc.body()) ```{python} gt.tab_style(style.fill("yellow"), loc.stubhead()) ``` + +## Grand Summary Rows + +```{python} +( + gt.tab_style( + style.fill("yellow"), + loc.grand_summary_stub(), + ).tab_style( + style.fill("lightblue"), + loc.grand_summary(), + ) +) +``` diff --git a/docs/get-started/technical-notes.qmd b/docs/get-started/technical-notes.qmd new file mode 100644 index 000000000..163c16fac --- /dev/null +++ b/docs/get-started/technical-notes.qmd @@ -0,0 +1,71 @@ +--- +title: Technical notes +jupyter: python3 +--- + +This document holds technical notes on how the dataclass behind GT interact. + +```{python} +import polars as pl +from great_tables import GT, exibble + +``` + +## Boxhead, Stub, and Summary Rows + +```{python} + +lil_ex = pl.from_pandas(exibble).select("num", "row", "group").head(6) + +gt = GT(lil_ex).grand_summary_rows(fns={"Total": pl.col("num").sum()}) +``` + +:::{.grid} + +:::{.g-col-4} + +### Placeholder stub + +```{python} +(gt + + +) +``` + +::: + +:::{.g-col-4} + +### Rowname stub + +```{python} +(gt + .tab_stub(rowname_col="row") + +) +``` + +::: + +:::{.g-col-4} + +### Group stub + +```{python} +(gt + .tab_stub(groupname_col="group") + .tab_options(row_group_as_column=True) +) +``` + +::: + +::: + +One more not shown above is if both rowname and groupnames are used in the stub, then the summary +spans both those columns: + +```{python} +(gt.tab_stub(rowname_col="row", groupname_col="group").tab_options(row_group_as_column=True)) +``` diff --git a/great_tables/_gt_data.py b/great_tables/_gt_data.py index 1456d9454..5e55dbe19 100644 --- a/great_tables/_gt_data.py +++ b/great_tables/_gt_data.py @@ -2,7 +2,7 @@ import copy import re -from collections.abc import Sequence +from collections.abc import Mapping, Sequence from dataclasses import dataclass, field, replace from enum import Enum, auto from itertools import chain, product @@ -75,6 +75,8 @@ class GTData: _spanners: Spanners _heading: Heading _stubhead: Stubhead + _summary_rows: SummaryRows + _summary_rows_grand: SummaryRows _source_notes: SourceNotes _footnotes: Footnotes _styles: Styles @@ -122,6 +124,8 @@ def from_data( _spanners=Spanners([]), _heading=Heading(), _stubhead=None, + _summary_rows=SummaryRows(), + _summary_rows_grand=SummaryRows(_is_grand_summary=True), _source_notes=[], _footnotes=[], _styles=[], @@ -220,6 +224,11 @@ class ColInfoTypeEnum(Enum): row_group = auto() hidden = auto() + # A placeholder column created when there is no table data in the stub, + # but summary rows need to create one for their labels. e.g. "mean" to indicate + # the row is mean summaries + summary_placeholder = auto() + @dataclass(frozen=True) class ColInfo: @@ -249,7 +258,11 @@ def visible(self) -> bool: @property def is_stub(self) -> bool: - return self.type in (ColInfoTypeEnum.stub, ColInfoTypeEnum.row_group) + return self.type in ( + ColInfoTypeEnum.stub, + ColInfoTypeEnum.row_group, + ColInfoTypeEnum.summary_placeholder, + ) @property def defaulted_align(self) -> str: @@ -510,10 +523,12 @@ def _get_number_of_visible_data_columns(self) -> int: # Obtain the number of visible columns in the built table; this should # account for the size of the stub in the final, built table - def _get_effective_number_of_columns(self, stub: Stub, options: Options) -> int: + def _get_effective_number_of_columns( + self, stub: Stub, has_summary_rows: bool, options: Options + ) -> int: n_data_cols = self._get_number_of_visible_data_columns() - stub_layout = stub._get_stub_layout(options=options) + stub_layout = stub._get_stub_layout(has_summary_rows=has_summary_rows, options=options) # Once the stub is defined in the package, we need to account # for the width of the stub at build time to fully obtain the number # of visible columns in the built table @@ -674,7 +689,7 @@ def _stub_group_names_has_column(self, options: Options) -> bool: return row_group_as_column - def _get_stub_layout(self, options: Options) -> list[str]: + def _get_stub_layout(self, has_summary_rows: bool, options: Options) -> list[str]: # Determine which stub components are potentially present as columns stub_rownames_is_column = "row_id" in self._get_stub_components() stub_groupnames_is_column = self._stub_group_names_has_column(options=options) @@ -684,14 +699,12 @@ def _get_stub_layout(self, options: Options) -> list[str]: # Resolve the layout of the stub (i.e., the roles of columns if present) if n_stub_cols == 0: - # TODO: If summary rows are present, we will use the `rowname` column - # # for the summary row labels - # if _summary_exists(data=data): - # stub_layout = ["rowname"] - # else: - # stub_layout = [] - - stub_layout = [] + # If summary rows are present, we will use the `rowname` column + # for the summary row labels + if has_summary_rows: + stub_layout = ["rowname"] + else: + stub_layout = [] else: stub_layout = [ @@ -719,7 +732,7 @@ class GroupRowInfo: indices: list[int] = field(default_factory=list) # row_start: int | None = None # row_end: int | None = None - has_summary_rows: bool = False + # has_summary_rows: bool = False # TODO: remove summary_row_side: str | None = None def defaulted_label(self) -> str: @@ -972,6 +985,162 @@ def __init__(self, func: FormatFns, cols: list[str], rows: list[int]): Formats = list +# Summary Rows --- + +# This can't conflict with actual group ids since we have a +# seperate data structure for grand summary row infos + + +@dataclass(frozen=True) +class SummaryRowInfo: + """Information about a single summary row""" + + id: str + label: str # For now, label and id are identical + # The motivation for values as a dict is to ensure cols_* functions don't have to consider + # the implications on existing SummaryRowInfo objects + values: dict[str, Any] # TODO: consider datatype, series? + side: Literal["top", "bottom"] # TODO: consider enum + + +# TODO: refactor into a collection/dataclass wrapping the list part +# put most of the methods for filtering, adding, replacing there. +# Make immutable to avoid potential bugs. +class SummaryRows(Mapping[str, list[SummaryRowInfo]]): + """A sequence of summary rows + + The following strctures should always be true about summary rows: + - The id is also the label (often the same as the function name) + - There is at most 1 row for each group and id pairing + - If a summary row is added and no row exists for that group and id, add it + - If a summary row is added and a row exists for that group and id pairing, + then replace all cells (in values) that are numeric in the new version + """ + + _d: dict[str, list[SummaryRowInfo]] + _is_grand_summary: bool + + LIST_CLS = list + GRAND_SUMMARY_KEY = "grand" + + def __init__( + self, + entries: dict[str, list[SummaryRowInfo]] | None = None, + _is_grand_summary: bool = False, + ): + if entries is None: + self._d = {} + else: + self._d = entries + self._is_grand_summary = _is_grand_summary + + def __bool__(self) -> bool: + """Return True if there are any summary rows, False otherwise.""" + return len(self._d) > 0 + + def __getitem__(self, key: str | None) -> list[SummaryRowInfo]: + if self._is_grand_summary: + key = SummaryRows.GRAND_SUMMARY_KEY + + if not key: + raise KeyError("Summary row group key must not be None for group summary rows.") + + if key not in self._d: + raise KeyError(f"Group '{key}' not found in summary rows.") + + return self.LIST_CLS(self._d[key]) + + def define(self, **kwargs: list[SummaryRowInfo]) -> Self: + """Define multiple summary row groups at once, replacing any existing groups.""" + + new_d = dict(self._d) + for group_id, summary_rows in kwargs.items(): + new_d[group_id] = summary_rows + + return self.__class__(new_d, _is_grand_summary=self._is_grand_summary) + + def add_summary_row(self, summary_row: SummaryRowInfo, group_id: str | None = None) -> Self: + """Add a summary row following the merging rules in the class docstring.""" + + # TODO: group_id can be None for grand summary configuration, but can't be none + # for regular summary configuration (a bit double barrelled). + if self._is_grand_summary and group_id is None: + group_id = SummaryRows.GRAND_SUMMARY_KEY + elif group_id is None: + raise TypeError("group_id must be provided for group summary rows.") + + existing_group = self.get(group_id) + + if not existing_group: + return self.define(**{group_id: [summary_row]}) + + else: + existing_index = next( + (ii for ii, crnt_row in enumerate(existing_group) if crnt_row.id == summary_row.id), + None, + ) + + new_rows = self.LIST_CLS(existing_group) + + if existing_index is None: + # No existing row for this group and id, add it + new_rows.append(summary_row) + else: + # Replace existing row, but merge numeric values from new version + existing_row = new_rows[existing_index] + + # Start with existing values + merged_values = existing_row.values.copy() + + # Replace with numeric values from new row + for key, new_value in summary_row.values.items(): + merged_values[key] = new_value + + # Create merged row with new row's properties but merged values + merged_row = SummaryRowInfo( + id=summary_row.id, + label=summary_row.label, + values=merged_values, + # Setting this to existing row instead of summary_row means original + # side is fixed by whatever side is first assigned to this row + side=existing_row.side, + ) + + new_rows[existing_index] = merged_row + + return self.define(**{group_id: new_rows}) + + def get_summary_rows( + self, group_id: str | None = None, side: str | None = None + ) -> list[SummaryRowInfo]: + """Get list of summary rows for that group. If side is None, do not filter by side. + Sorts result with 'top' side first, then 'bottom'.""" + + result: list[SummaryRowInfo] = [] + + if self._is_grand_summary: + group_id = SummaryRows.GRAND_SUMMARY_KEY + elif group_id is None: + raise TypeError("group_id must be provided for group summary rows.") + + summary_row_group = self.get(group_id) + + if summary_row_group: + for summary_row in summary_row_group: + if side is None or summary_row.side == side: + result.append(summary_row) + + # Sort: 'top' first, then 'bottom' + result.sort(key=lambda r: 0 if r.side == "top" else 1) # TODO: modify if enum for side + return result + + def __iter__(self): + raise NotImplementedError + + def __len__(self): + raise NotImplementedError + + # Options ---- default_fonts_list = [ @@ -1130,25 +1299,25 @@ class Options: # summary_row_border_style: OptionsInfo = OptionsInfo(True, "summary_row", "value", "solid") # summary_row_border_width: OptionsInfo = OptionsInfo(True, "summary_row", "px", "2px") # summary_row_border_color: OptionsInfo = OptionsInfo(True, "summary_row", "value", "#D3D3D3") - # grand_summary_row_padding: OptionsInfo = OptionsInfo(True, "grand_summary_row", "px", "8px") - # grand_summary_row_padding_horizontal: OptionsInfo = OptionsInfo( - # True, "grand_summary_row", "px", "5px" - # ) - # grand_summary_row_background_color: OptionsInfo = OptionsInfo( - # True, "grand_summary_row", "value", None - # ) - # grand_summary_row_text_transform: OptionsInfo = OptionsInfo( - # True, "grand_summary_row", "value", "inherit" - # ) - # grand_summary_row_border_style: OptionsInfo = OptionsInfo( - # True, "grand_summary_row", "value", "double" - # ) - # grand_summary_row_border_width: OptionsInfo = OptionsInfo( - # True, "grand_summary_row", "px", "6px" - # ) - # grand_summary_row_border_color: OptionsInfo = OptionsInfo( - # True, "grand_summary_row", "value", "#D3D3D3" - # ) + grand_summary_row_padding: OptionsInfo = OptionsInfo(True, "grand_summary_row", "px", "8px") + grand_summary_row_padding_horizontal: OptionsInfo = OptionsInfo( + True, "grand_summary_row", "px", "5px" + ) + grand_summary_row_background_color: OptionsInfo = OptionsInfo( + True, "grand_summary_row", "value", None + ) + grand_summary_row_text_transform: OptionsInfo = OptionsInfo( + True, "grand_summary_row", "value", "inherit" + ) + grand_summary_row_border_style: OptionsInfo = OptionsInfo( + True, "grand_summary_row", "value", "double" + ) + grand_summary_row_border_width: OptionsInfo = OptionsInfo( + True, "grand_summary_row", "px", "6px" + ) + grand_summary_row_border_color: OptionsInfo = OptionsInfo( + True, "grand_summary_row", "value", "#D3D3D3" + ) # footnotes_font_size: OptionsInfo = OptionsInfo(True, "footnotes", "px", "90%") # footnotes_padding: OptionsInfo = OptionsInfo(True, "footnotes", "px", "4px") # footnotes_padding_horizontal: OptionsInfo = OptionsInfo(True, "footnotes", "px", "5px") diff --git a/great_tables/_locations.py b/great_tables/_locations.py index 246966303..8e91ef449 100644 --- a/great_tables/_locations.py +++ b/great_tables/_locations.py @@ -480,8 +480,66 @@ class LocRowGroups(Loc): rows: RowSelectExpr = None +# @dataclass +# class LocSummaryStub(Loc): +# rows: RowSelectExpr = None + + @dataclass -class LocSummaryLabel(Loc): +class LocGrandSummaryStub(Loc): + """Target the grand summary stub. + + With `loc.grand_summary_stub()` we can target the cells containing the grand summary row labels, + which reside in the table stub. This is useful for applying custom styling with the + [`tab_style()`](`great_tables.GT.tab_style`) method. That method has a `locations=` argument and + this class should be used there to perform the targeting. + + Parameters + ---------- + rows + The rows to target within the grand summary stub. Can either be a single row name or a + series of row names provided in a list. If no rows are specified, all grand summary rows + are targeted. Note that if rows are targeted by index, top and bottom grand summary rows + are indexed as one combined list starting with the top rows. + + Returns + ------- + LocGrandSummaryStub + A LocGrandSummaryStub object, which is used for a `locations=` argument if specifying the + table's grand summary rows' labels. + + Examples + -------- + Let's use a subset of the `gtcars` dataset in a new table. We will style the entire table grand + summary stub (the row labels) by using `locations=loc.grand_summary_stub()` within + [`tab_style()`](`great_tables.GT.tab_style`). + + ```{python} + from great_tables import GT, style, loc, vals + from great_tables.data import gtcars + + ( + GT( + gtcars[["mfr", "model", "hp", "trq", "mpg_c"]].head(6), + rowname_col="model", + ) + .fmt_integer(columns=["hp", "trq", "mpg_c"]) + .grand_summary_rows( + fns={ + "Min": lambda df: df.min(numeric_only=True), + "Max": lambda x: x.max(numeric_only=True), + }, + side="top", + fmt=vals.fmt_integer, + ) + .tab_style( + style=[style.text(color="crimson", weight="bold"), style.fill(color="lightgray")], + locations=loc.grand_summary_stub(), + ) + ) + ``` + """ + rows: RowSelectExpr = None @@ -550,11 +608,75 @@ class LocBody(Loc): mask: PlExpr | None = None +# @dataclass +# class LocSummary(Loc): +# # TODO: these can be tidyselectors +# columns: SelectExpr = None +# rows: RowSelectExpr = None +# mask: PlExpr | None = None + + @dataclass -class LocSummary(Loc): +class LocGrandSummary(Loc): + """Target the data cells in grand summary rows. + + With `loc.grand_summary()` we can target the cells containing the grand summary data. + This is useful for applying custom styling with the [`tab_style()`](`great_tables.GT.tab_style`) + method. That method has a `locations=` argument and this class should be used there to perform + the targeting. + + Parameters + ---------- + columns + The columns to target. Can either be a single column name or a series of column names + provided in a list. + rows + The rows to target. Can either be a single row name or a series of row names provided in a + list. Note that if rows are targeted by index, top and bottom grand summary rows are indexed + as one combined list starting with the top rows. + + Returns + ------- + LocGrandSummary + A LocGrandSummary object, which is used for a `locations=` argument if specifying the + table's grand summary rows. + + Examples + -------- + Let's use a subset of the `gtcars` dataset in a new table. We will style all of the grand + summary cells by using `locations=loc.grand_summary()` within + [`tab_style()`](`great_tables.GT.tab_style`). + + ```{python} + from great_tables import GT, style, loc, vals + from great_tables.data import gtcars + + ( + GT( + gtcars[["mfr", "model", "hp", "trq", "mpg_c"]].head(6), + rowname_col="model", + ) + .fmt_integer(columns=["hp", "trq", "mpg_c"]) + .grand_summary_rows( + fns={ + "Min": lambda df: df.min(numeric_only=True), + "Max": lambda x: x.max(numeric_only=True), + }, + side="top", + fmt=vals.fmt_integer, + ) + .tab_style( + style=[style.text(color="crimson", weight="bold"), style.fill(color="lightgray")], + locations=loc.grand_summary(), + ) + ) + ``` + """ + # TODO: these can be tidyselectors columns: SelectExpr = None rows: RowSelectExpr = None + mask: PlExpr | None = None @dataclass @@ -910,6 +1032,22 @@ def _(loc: LocRowGroups, data: GTData) -> set[str]: return group_pos +@resolve.register +def _(loc: LocGrandSummaryStub, data: GTData) -> set[int]: + # Select just grand summary rows + grand_summary_rows = data._summary_rows_grand.get_summary_rows() + grand_summary_rows_ids = [row.id for row in grand_summary_rows] + + rows = resolve_rows_i(data=grand_summary_rows_ids, expr=loc.rows) + + cell_pos = set(row[1] for row in rows) + return cell_pos + + +# @resolve.register(LocSummaryStub) +# Also target by groupname in styleinfo + + @resolve.register def _(loc: LocStub, data: GTData) -> set[int]: # TODO: what are the rules for matching row groups? @@ -918,6 +1056,35 @@ def _(loc: LocStub, data: GTData) -> set[int]: return cell_pos +@resolve.register +def _(loc: LocGrandSummary, data: GTData) -> list[CellPos]: + if (loc.columns is not None or loc.rows is not None) and loc.mask is not None: + raise ValueError( + "Cannot specify the `mask` argument along with `columns` or `rows` in `loc.body()`." + ) + + grand_summary_rows = data._summary_rows_grand.get_summary_rows() + grand_summary_rows_ids = [row.id for row in grand_summary_rows] + + if loc.mask is None: + rows = resolve_rows_i(data=grand_summary_rows_ids, expr=loc.rows) + cols = resolve_cols_i(data=data, expr=loc.columns) + # TODO: dplyr arranges by `Var1`, and does distinct (since you can tidyselect the same + # thing multiple times + cell_pos = [ + CellPos(col[1], row[1], colname=col[0]) for col, row in itertools.product(cols, rows) + ] + else: + # I am not sure how to approach this, since GTData._summary_rows is not a frame + # We could convert to a frame, but I don't think that's a simple step + raise NotImplementedError("Masked selection is not yet implemented for Grand Summary Rows") + return cell_pos + + +# @resolve.register(LocSummary) +# Also target by groupname in styleinfo + + @resolve.register def _(loc: LocBody, data: GTData) -> list[CellPos]: if (loc.columns is not None or loc.rows is not None) and loc.mask is not None: @@ -953,9 +1120,11 @@ def _(loc: LocBody, data: GTData) -> list[CellPos]: # LocStub # LocRowGroupLabel # LocRowLabel -# LocSummaryLabel +# LocSummaryStub +# LocGrandSummaryStub # LocBody # LocSummary +# LocGrandSummary # LocFooter # LocFootnotes # LocSourceNotes @@ -1039,8 +1208,10 @@ def _(loc: LocRowGroups, data: GTData, style: list[CellStyle]) -> GTData: ) -@set_style.register -def _(loc: LocStub, data: GTData, style: list[CellStyle]) -> GTData: +# @set_style.register(LocSummaryStub) +@set_style.register(LocStub) +@set_style.register(LocGrandSummaryStub) +def _(loc: (LocStub | LocGrandSummaryStub), data: GTData, style: list[CellStyle]) -> GTData: # validate ---- for entry in style: entry._raise_if_requires_data(loc) @@ -1051,8 +1222,10 @@ def _(loc: LocStub, data: GTData, style: list[CellStyle]) -> GTData: return data._replace(_styles=data._styles + new_styles) -@set_style.register -def _(loc: LocBody, data: GTData, style: list[CellStyle]) -> GTData: +# @set_style.register(LocSummary) +@set_style.register(LocBody) +@set_style.register(LocGrandSummary) +def _(loc: (LocBody | LocGrandSummary), data: GTData, style: list[CellStyle]) -> GTData: positions: list[CellPos] = resolve(loc, data) # evaluate any column expressions in styles diff --git a/great_tables/_modify_rows.py b/great_tables/_modify_rows.py index 4ee30ae74..f9fd39069 100644 --- a/great_tables/_modify_rows.py +++ b/great_tables/_modify_rows.py @@ -1,8 +1,21 @@ from __future__ import annotations -from typing import TYPE_CHECKING - -from ._gt_data import Locale, RowGroups, Styles +from typing import TYPE_CHECKING, Any, Callable, Literal + +from ._gt_data import ( + FormatFn, + GTData, + Locale, + RowGroups, + Styles, + SummaryRowInfo, +) +from ._tbl_data import ( + PlExpr, + SelectExpr, + TblData, + eval_aggregate, +) if TYPE_CHECKING: from ._types import GTSelf @@ -16,8 +29,8 @@ def row_group_order(self: GTSelf, groups: RowGroups) -> GTSelf: def _remove_from_body_styles(styles: Styles, column: str) -> Styles: # TODO: refactor - from ._utils_render_html import _is_loc from ._locations import LocBody + from ._utils_render_html import _is_loc new_styles = [ info for info in styles if not (_is_loc(info.locname, LocBody) and info.colname == column) @@ -178,3 +191,179 @@ def with_id(self: GTSelf, id: str | None = None) -> GTSelf: ``` """ return self._replace(_options=self._options._set_option_value("table_id", id)) + + +def grand_summary_rows( + self: GTSelf, + *, + fns: dict[str, PlExpr] | dict[str, Callable[[TblData], Any]], + fmt: FormatFn | None = None, + columns: SelectExpr = None, + side: Literal["bottom", "top"] = "bottom", + missing_text: str = "---", +) -> GTSelf: + """Add grand summary rows to the table. + + Add grand summary rows by using the table data and any suitable aggregation functions. With + grand summary rows, all of the available data in the gt table is incorporated (regardless of + whether some of the data are part of row groups). Multiple grand summary rows can be added via + expressions given to fns. You can selectively format the values in the resulting grand summary + cells by use of formatting expressions from the `vals.fmt_*` class of functions. + + Note that currently all arguments are keyword-only, since the final positions may change. + + Parameters + ---------- + fns + A dictionary mapping row labels to aggregation expressions. Can be either Polars + expressions or callable functions that take the entire DataFrame and return aggregated + results. Each key becomes the label for a grand summary row. + fmt + A formatting function from the `vals.fmt_*` family (e.g., `vals.fmt_number`, + `vals.fmt_currency`) to apply to the summary row values. If `None`, no formatting + is applied. + columns + Currently, this function does not support selection by columns. If you would like to choose + which columns to summarize, you can select columns within the functions given to `fns=`. + See examples below for more explicit cases. + side + Should the grand summary rows be placed at the `"bottom"` (the default) or the `"top"` of + the table? + missing_text + The text to be used in summary cells with no data outputs. + + Returns + ------- + GT + The GT object is returned. This is the same object that the method is called on so that we + can facilitate method chaining. + + Examples + -------- + Let's use a subset of the `sp500` dataset to create a table with grand summary rows. We'll + calculate min, max, and mean values for the numeric columns. Notice the different + approaches to selecting columns to apply the aggregations to: we can use polars selectors + or select the columns directly. + + ```{python} + import polars as pl + import polars.selectors as cs + from great_tables import GT, vals, style, loc + from great_tables.data import sp500 + + sp500_mini = ( + pl.from_pandas(sp500) + .slice(0, 7) + .drop(["volume", "adj_close"]) + ) + + ( + GT(sp500_mini, rowname_col="date") + .grand_summary_rows( + fns={ + "Minimum": pl.min("open", "high", "low", "close"), + "Maximum": pl.col("open", "high", "low", "close").max(), + "Average": cs.numeric().mean(), + }, + fmt=vals.fmt_currency, + ) + .tab_style( + style=[ + style.text(color="crimson"), + style.fill(color="lightgray"), + ], + locations=loc.grand_summary(), + ) + ) + ``` + + We can also use custom callable functions to create more complex summary calculations. + Notice here that grand summary rows can be placed at the top of the table and formatted + with currency notation, by passing a formatter from the `vals.fmt_*` class of functions. + + ```{python} + from great_tables import GT, style, loc, vals + from great_tables.data import gtcars + + def pd_median(df): + return df.median(numeric_only=True) + + + ( + GT( + gtcars[["mfr", "model", "hp", "trq", "mpg_c"]].head(6), + rowname_col="model", + ) + .fmt_integer(columns=["hp", "trq", "mpg_c"]) + .grand_summary_rows( + fns={ + "Min": lambda df: df.min(numeric_only=True), + "Max": lambda df: df.max(numeric_only=True), + "Median": pd_median, + }, + side="top", + fmt=vals.fmt_integer, + ) + .tab_style( + style=[style.text(color="crimson", weight="bold"), style.fill(color="lightgray")], + locations=loc.grand_summary_stub(), + ) + ) + ``` + + """ + if columns is not None: + raise NotImplementedError( + "Currently, grand_summary_rows() does not support column selection." + ) + + # summary_col_names = resolve_cols_c(data=self, expr=columns) + + new_summary = self._summary_rows_grand + for label, fn in fns.items(): + row_values_dict = _calculate_summary_row(self, fn, fmt, missing_text) + + summary_row_info = SummaryRowInfo( + id=label, + label=label, + values=row_values_dict, + side=side, + ) + + new_summary = new_summary.add_summary_row(summary_row_info) + + return self._replace(_summary_rows_grand=new_summary) + + +def _calculate_summary_row( + data: GTData, + fn: PlExpr | Callable[[TblData], Any], + fmt: FormatFn | None, + # summary_col_names: list[str], + missing_text: str, +) -> dict[str, Any]: + """Calculate a summary row using eval_transform.""" + original_columns = data._boxhead._get_columns() + summary_row = {} + + # Use eval_aggregate to apply the function/expression to the data + result_df = eval_aggregate(data._tbl_data, fn) + + # Extract results for each column + for col in original_columns: + if col in result_df: + res = result_df[col] + + if fmt is not None: + formatted = fmt([res]) + res = formatted[0] + + summary_row[col] = res + else: + summary_row[col] = missing_text + + return summary_row + + +# TODO: delegate to group by agg instead (group_by for summary row case) +# TODO: validates after diff --git a/great_tables/_options.py b/great_tables/_options.py index f4ab0fc8a..7aab82624 100644 --- a/great_tables/_options.py +++ b/great_tables/_options.py @@ -128,13 +128,13 @@ def tab_options( # summary_row_border_style: str | None = None, # summary_row_border_width: str | None = None, # summary_row_border_color: str | None = None, - # grand_summary_row_background_color: str | None = None, - # grand_summary_row_text_transform: str | None = None, - # grand_summary_row_padding: str | None = None, - # grand_summary_row_padding_horizontal: str | None = None, - # grand_summary_row_border_style: str | None = None, - # grand_summary_row_border_width: str | None = None, - # grand_summary_row_border_color: str | None = None, + grand_summary_row_background_color: str | None = None, + grand_summary_row_text_transform: str | None = None, + grand_summary_row_padding: str | None = None, + grand_summary_row_padding_horizontal: str | None = None, + grand_summary_row_border_style: str | None = None, + grand_summary_row_border_width: str | None = None, + grand_summary_row_border_color: str | None = None, # footnotes_background_color: str | None = None, # footnotes_font_size: str | None = None, # footnotes_padding: str | None = None, @@ -1386,10 +1386,8 @@ def opt_stylize( # Omit keys that are not needed for the `tab_options()` method # TODO: the omitted keys are for future use when: # (1) summary rows are implemented - # (2) grand summary rows are implemented omit_keys = { "summary_row_background_color", - "grand_summary_row_background_color", } def dict_omit_keys(dict: dict[str, str], omit_keys: set[str]) -> dict[str, str]: @@ -1440,6 +1438,8 @@ class StyleMapper: data_vlines_style: str data_vlines_color: str row_striping_background_color: str + grand_summary_row_background_color: str + # summary_row_background_color: str mappings: ClassVar[dict[str, list[str]]] = { "table_hlines_color": ["table_border_top_color", "table_border_bottom_color"], @@ -1461,6 +1461,8 @@ class StyleMapper: "data_vlines_style": ["table_body_vlines_style"], "data_vlines_color": ["table_body_vlines_color"], "row_striping_background_color": ["row_striping_background_color"], + "grand_summary_row_background_color": ["grand_summary_row_background_color"], + # "summary_row_background_color": ["summary_row_background_color"], } def map_entry(self, name: str) -> dict[str, list[str]]: diff --git a/great_tables/_tbl_data.py b/great_tables/_tbl_data.py index 43798e364..0d3418209 100644 --- a/great_tables/_tbl_data.py +++ b/great_tables/_tbl_data.py @@ -874,3 +874,65 @@ def _(ser: PyArrowChunkedArray, name: Optional[str] = None) -> PyArrowTable: import pyarrow as pa return pa.table({name: ser}) + + +# eval_aggregate ---- + + +@singledispatch +def eval_aggregate(df, expr) -> dict[str, Any]: + """Evaluate an expression against data and return a single row as a dictionary. + + This is designed for aggregation operations that produce summary statistics. + The result should be a single row with values for each column. + + Parameters + ---------- + data + The input data (DataFrame) + expr + The expression to evaluate (Polars expression or callable) + + Returns + ------- + dict[str, Any] + A dictionary mapping column names to their aggregated values + """ + raise NotImplementedError(f"eval_aggregate not implemented for type: {type(df)}") + + +@eval_aggregate.register +def _(df: PdDataFrame, expr: Callable[[PdDataFrame], PdSeries]) -> dict[str, Any]: + res = expr(df) + + if not isinstance(res, PdSeries): + raise ValueError(f"Result must be a pandas Series. Received {type(res)}") + + return res.to_dict() + + +@eval_aggregate.register +def _(df: PlDataFrame, expr: PlExpr) -> dict[str, Any]: + res = df.select(expr) + + if len(res) != 1: + raise ValueError( + f"Expression must produce exactly 1 row (aggregation). Got {len(res)} rows." + ) + + return res.to_dicts()[0] + + +@eval_aggregate.register +def _(df: PyArrowTable, expr: Callable[[PyArrowTable], PyArrowTable]) -> dict[str, Any]: + res = expr(df) + + if not isinstance(res, PyArrowTable): + raise ValueError(f"Result must be a PyArrow Table. Received {type(res)}") + + if res.num_rows != 1: + raise ValueError( + f"Expression must produce exactly 1 row (aggregation). Got {res.num_rows} rows." + ) + + return {col: res.column(col)[0].as_py() for col in res.column_names} diff --git a/great_tables/_utils_render_html.py b/great_tables/_utils_render_html.py index 3053004f4..e3bb7094d 100644 --- a/great_tables/_utils_render_html.py +++ b/great_tables/_utils_render_html.py @@ -1,17 +1,28 @@ from __future__ import annotations from itertools import chain -from typing import Any, cast +from typing import TYPE_CHECKING, Any, cast from htmltools import HTML, TagList, css, tags from . import _locations as loc -from ._gt_data import GroupRowInfo, GTData, Styles +from ._gt_data import ( + ColInfo, + ColInfoTypeEnum, + GroupRowInfo, + GTData, + StyleInfo, + Styles, + SummaryRowInfo, +) from ._spanners import spanners_print_matrix from ._tbl_data import _get_cell, cast_frame_to_string, replace_null_frame from ._text import BaseText, _process_text, _process_text_id from ._utils import heading_has_subtitle, heading_has_title, seq_groups +if TYPE_CHECKING: + from ._tbl_data import TblData + def _is_loc(loc: str | loc.Loc, cls: type[loc.Loc]): if isinstance(loc, str): @@ -74,10 +85,14 @@ def create_heading_component_h(data: GTData) -> str: title_style = _flatten_styles(styles_header + styles_title, wrap=True) subtitle_style = _flatten_styles(styles_header + styles_subtitle, wrap=True) + has_summary_rows = bool(data._summary_rows or data._summary_rows_grand) + # Get the effective number of columns, which is number of columns # that will finally be rendered accounting for the stub layout n_cols_total = data._boxhead._get_effective_number_of_columns( - stub=data._stub, options=data._options + stub=data._stub, + has_summary_rows=has_summary_rows, + options=data._options, ) if has_subtitle: @@ -117,7 +132,10 @@ def create_columns_component_h(data: GTData) -> str: # body = data._body # Get vector representation of stub layout - stub_layout = data._stub._get_stub_layout(options=data._options) + has_summary_rows = bool(data._summary_rows or data._summary_rows_grand) + stub_layout = data._stub._get_stub_layout( + has_summary_rows=has_summary_rows, options=data._options + ) # Determine the finalized number of spanner rows spanner_row_count = _get_spanners_matrix_height(data=data, omit_columns_row=True) @@ -426,27 +444,43 @@ def create_body_component_h(data: GTData) -> str: # Filter list of StyleInfo to only those that apply to the stub styles_row_group_label = [x for x in data._styles if _is_loc(x.locname, loc.LocRowGroups)] styles_row_label = [x for x in data._styles if _is_loc(x.locname, loc.LocStub)] - # styles_summary_label = [x for x in data._styles if _is_loc(x.locname, loc.LocSummaryLabel)] + # styles_summary_label = [x for x in data._styles if _is_loc(x.locname, loc.LocSummaryStub)] + styles_grand_summary_label = [ + x for x in data._styles if _is_loc(x.locname, loc.LocGrandSummaryStub) + ] # Filter list of StyleInfo to only those that apply to the body styles_cells = [x for x in data._styles if _is_loc(x.locname, loc.LocBody)] # styles_body = [x for x in data._styles if _is_loc(x.locname, loc.LocBody2)] # styles_summary = [x for x in data._styles if _is_loc(x.locname, loc.LocSummary)] + styles_grand_summary = [x for x in data._styles if _is_loc(x.locname, loc.LocGrandSummary)] # Get the default column vars column_vars = data._boxhead._get_default_columns() row_stub_var = data._boxhead._get_stub_column() - stub_layout = data._stub._get_stub_layout(options=data._options) + has_summary_rows = bool(data._summary_rows or data._summary_rows_grand) + stub_layout = data._stub._get_stub_layout( + has_summary_rows=has_summary_rows, options=data._options + ) has_row_stub_column = "rowname" in stub_layout has_group_stub_column = "group_label" in stub_layout has_groups = data._stub.group_ids is not None and len(data._stub.group_ids) > 0 # If there is a stub, then prepend that to the `column_vars` list - if row_stub_var is not None: - column_vars = [row_stub_var] + column_vars + if has_row_stub_column: + # There is already a column assigned to the rownames + if row_stub_var: + column_vars = [row_stub_var] + column_vars + # Else we have summary rows but no stub yet + else: + # TODO: this naming is not ideal + summary_row_stub_var = ColInfo( + "__do_not_use__", ColInfoTypeEnum.summary_placeholder, column_align="left" + ) + column_vars = [summary_row_stub_var] + column_vars # Is the stub to be striped? table_stub_striped = data._options.row_striping_include_stub.value @@ -456,6 +490,24 @@ def create_body_component_h(data: GTData) -> str: body_rows: list[str] = [] + # Add grand summary rows at top + top_g_summary_rows = data._summary_rows_grand.get_summary_rows(side="top") + for i, summary_row in enumerate(top_g_summary_rows): + row_html = _create_row_component_h( + column_vars=column_vars, + row_stub_var=row_stub_var, # Should probably include group stub? + has_row_stub_column=has_row_stub_column, # Should probably include group stub? + has_group_stub_column=has_group_stub_column, # Add this parameter + apply_stub_striping=False, # No striping for summary rows + apply_body_striping=False, # No striping for summary rows + styles_cells=styles_grand_summary, + styles_labels=styles_grand_summary_label, + row_index=i, + summary_row=summary_row, + css_class="gt_last_grand_summary_row_top" if i == len(top_g_summary_rows) - 1 else None, + ) + body_rows.append(row_html) + # iterate over rows (ordered by groupings) prev_group_info = None @@ -468,7 +520,7 @@ def create_body_component_h(data: GTData) -> str: odd_j_row = j % 2 == 1 - body_cells: list[str] = [] + leading_cell = None # Create table row or label in the stub specifically for group (if applicable) if has_groups: @@ -487,15 +539,13 @@ def create_body_component_h(data: GTData) -> str: if has_group_stub_column: rowspan_value = len(group_info.indices) - body_cells.append( - f"""