diff --git a/src/textual/widgets/_text_area.py b/src/textual/widgets/_text_area.py index 7f01093598..be38de37ee 100644 --- a/src/textual/widgets/_text_area.py +++ b/src/textual/widgets/_text_area.py @@ -81,48 +81,60 @@ class HighlightMap: BLOCK_SIZE = 50 - def __init__(self, text_area_widget: widgets.TextArea): - self.text_area_widget: widgets.TextArea = text_area_widget - self.uncovered_lines: dict[int, range] = {} + def __init__(self, text_area: TextArea): + self.text_area: TextArea = text_area + """The text area associated with this highlight map.""" - # A mapping from line index to a list of Highlight instances. - self._highlights: LineToHighlightsMap = defaultdict(list) - self.reset() + self._highlighted_blocks: set[int] = set() + """The set of blocks that have been highlighted, identified by the start line index of the block. + (0 represents the first block, 50 the second, 100 the third, etc. - assuming a block size of 50) + """ + + self._highlights: dict[int, list[Highlight]] = defaultdict(list) + """A mapping from line index to a list of Highlight instances.""" def reset(self) -> None: """Reset so that future lookups rebuild the highlight map.""" self._highlights.clear() - line_count = self.document.line_count - uncovered_lines = self.uncovered_lines - uncovered_lines.clear() - i = end_range = 0 - for i in range(0, line_count, self.BLOCK_SIZE): - end_range = min(i + self.BLOCK_SIZE, line_count) - line_range = range(i, end_range) - uncovered_lines.update({j: line_range for j in line_range}) - if end_range < line_count: - line_range = range(i, line_count) - uncovered_lines.update({j: line_range for j in line_range}) + self._highlighted_blocks.clear() @property def document(self) -> DocumentBase: """The text document being highlighted.""" - return self.text_area_widget.document + return self.text_area.document + + def __getitem__(self, index: int) -> list[Highlight]: + start, end = self._get_block_boundaries(index, self.BLOCK_SIZE) + if start not in self._highlighted_blocks: + self._highlighted_blocks.add(start) + self._build_part_of_highlight_map(range(start, end)) + return self._highlights[index] - def __getitem__(self, idx: int) -> list[text_area.Highlight]: - if idx in self.uncovered_lines: - self._build_part_of_highlight_map(self.uncovered_lines[idx]) - return self._highlights[idx] + def _get_block_boundaries(self, index: int, block_size: int) -> tuple[int, int]: + """Get the start and end of the block for the given index. + + The start is inclusive and the end is exclusive. + + Args: + index: The line index to get we want to know the block range for.. + block_size: The size of the bucket. + + Returns: + A tuple containing the start and end of the block. + """ + block_index = index // block_size + start = block_index * block_size + end = (block_index + 1) * block_size + return (start, end) def _build_part_of_highlight_map(self, line_range: range) -> None: """Build part of the highlight map.""" highlights = self._highlights - for line_index in line_range: - self.uncovered_lines.pop(line_index) + start_point = (line_range[0], 0) end_point = (line_range[-1] + 1, 0) captures = self.document.query_syntax_tree( - self.text_area_widget._highlight_query, + self.text_area._highlight_query, start_point=start_point, end_point=end_point, ) @@ -140,8 +152,9 @@ def _build_part_of_highlight_map(self, line_range: range) -> None: ) # Add the middle lines - entire row of this node is highlighted + middle_highlight = (0, None, highlight_name) for node_row in range(node_start_row + 1, node_end_row): - highlights[node_row].append((0, None, highlight_name)) + highlights[node_row].append(middle_highlight) # Add the last line of the node range highlights[node_end_row].append( @@ -157,16 +170,16 @@ def _build_part_of_highlight_map(self, line_range: range) -> None: # to be sorted in ascending order of ``a``. When two highlights have the same # value of ``a`` then the one with the larger a--b range comes first, with ``None`` # being considered larger than any number. - def sort_key(hl) -> tuple[int, int, int]: - a, b, _ = hl - max_range_ind = 1 + def sort_key(highlight: Highlight) -> tuple[int, int, int]: + a, b, _ = highlight + max_range_index = 1 if b is None: - max_range_ind = 0 + max_range_index = 0 b = a - return a, max_range_ind, a - b + return a, max_range_index, a - b for line_index in line_range: - line_highlights = highlights.get(line_index, []).sort(key=sort_key) + highlights.get(line_index, []).sort(key=sort_key) @dataclass