Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions grain/_src/core/traceback_util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,7 @@ def _assert_exception_with_short_traceback(
f"Expected {expected_error_type} to be raised, but got"
f" {type(e)} instead."
)
print(f"traceback: {tb}")
self.assertLess(len(tb), 15)
self.assertLess(len(tb), 15, f"Traceback is too long: \n{tb}")


@traceback_util.run_with_traceback_filter
Expand Down
3 changes: 3 additions & 0 deletions grain/_src/python/dataset/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ py_library(
"//grain/_src/core:exceptions",
"//grain/_src/core:monitoring",
"//grain/_src/core:sharding",
"//grain/_src/core:traceback_util",
"//grain/_src/core:transforms",
"//grain/_src/core:tree_lib",
"//grain/_src/python:checkpointing",
Expand All @@ -74,6 +75,7 @@ py_test(
":base",
":dataset",
":stats",
"//grain/_src/core:config",
"//grain/_src/core:pytest",
"//grain/_src/core:transforms",
"//grain/_src/python:options",
Expand All @@ -98,6 +100,7 @@ py_library(
"//grain/_src/core:config",
"//grain/_src/core:monitoring",
"//grain/_src/core:profiler",
"//grain/_src/core:traceback_util",
"//grain/_src/core:tree_lib",
"//grain/proto:execution_summary_py_pb2",
"@abseil-py//absl/logging",
Expand Down
18 changes: 18 additions & 0 deletions grain/_src/python/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@

from etils import epath
from grain._src.core import monitoring as grain_monitoring
from grain._src.core import traceback_util
from grain._src.core import transforms
from grain._src.python import checkpointing
from grain._src.python import options as grain_options
Expand All @@ -65,6 +66,7 @@

from grain._src.core import monitoring

traceback_util.register_exclusion(__file__)

_api_usage_counter = monitoring.Counter(
"/grain/python/lazy_dataset/api",
Expand Down Expand Up @@ -1678,6 +1680,11 @@ def _element_spec(self) -> Any:
return get_element_spec(self._parent)


def traceback_filter_mode() -> str:
"""Returns the traceback filter mode."""
return grain_config.config.py_traceback_filtering


def is_thread_prefetch_injection_enabled() -> bool:
"""Returns whether thread prefetch injection experiment is enabled."""
return False
Expand All @@ -1700,6 +1707,17 @@ def __iter__(self) -> DatasetIterator[T]:
):
if not prefetch.is_prefetch_iterator(iterator):
iterator = prefetch.ThreadPrefetchDatasetIterator(iterator, 1)

filter_mode = traceback_filter_mode()
if filter_mode != "off":
# Loaded lazily due to a circular dependency
# (dataset <-> traceback_filter).
# pylint: disable=g-import-not-at-top
from grain._src.python.dataset.transformations import traceback_filter
# pylint: enable=g-import-not-at-top
iterator = traceback_filter.TracebackFilterDatasetIterator(
iterator, traceback_filter_mode=filter_mode
)
return iterator


Expand Down
7 changes: 7 additions & 0 deletions grain/_src/python/dataset/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from grain._src.core import config as grain_config
from grain._src.core import monitoring as grain_monitoring
from grain._src.core import profiler
from grain._src.core import traceback_util
from grain._src.core import tree_lib
from grain._src.python.dataset import base
from grain._src.python.dataset import stats_utils
Expand All @@ -44,6 +45,9 @@
from grain._src.core import monitoring


traceback_util.register_exclusion(__file__)


# Registry of weak references to output dataset iterators for collecting
# execution stats.
_iter_weakref_registry = set()
Expand Down Expand Up @@ -339,6 +343,9 @@ def wrapper(iterator, *args, **kwargs):
IPL_CAT_UNKNOWN = "unknown"
# This stage is for prefetch overheads on main thread.
IPL_CAT_PREFETCH = "prefetch"
# Stage used for meta-pipeline operations unrelated to the data processing
# itself, e.g. traceback filtering.
IPL_CAT_META = "meta"


def trace_input_pipeline(stage_category: str = IPL_CAT_UNKNOWN, **trace_kwargs):
Expand Down
12 changes: 12 additions & 0 deletions grain/_src/python/dataset/transformations/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ filegroup(
"shuffle.py",
"slice.py",
"source.py",
"traceback_filter.py",
],
)

Expand Down Expand Up @@ -382,3 +383,14 @@ py_test(
"@pypi//numpy:pkg",
],
)

py_test(
name = "traceback_filter_test",
srcs = ["traceback_filter_test.py"],
srcs_version = "PY3",
deps = [
"//grain",
"//grain/_src/core:traceback_util",
"@abseil-py//absl/testing:absltest",
],
)
5 changes: 4 additions & 1 deletion grain/_src/python/dataset/transformations/interleave_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,13 +192,16 @@ def test_interleave_stats(self):
"PrefetchDatasetIterator",
"ThreadPrefetchDatasetIterator",
"InterleaveDatasetIterator",
"TracebackFilterDatasetIterator",
]
for expected_node in expected_nodes:
self.assertTrue(any(expected_node in name for name in node_names))
self.assertLen(node_names, len(expected_nodes))
print(summary)

@flagsaver.flagsaver(grain_py_debug_mode=True)
@flagsaver.flagsaver(
grain_py_debug_mode=True, grain_py_traceback_filtering="off"
)
def test_interleave_stats_with_mismatched_dataset_structures(self):
ds1 = dataset.MapDataset.range(10000).map(lambda x: x + 1)
ds1 = ds1.to_iter_dataset()
Expand Down
3 changes: 3 additions & 0 deletions grain/_src/python/dataset/transformations/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@
import threading
from typing import Any, Callable, Protocol, Sequence, TypeVar, runtime_checkable

from grain._src.core import traceback_util
from grain._src.core import transforms
from grain._src.python.dataset import dataset
from grain._src.python.dataset import stats
import numpy as np


traceback_util.register_exclusion(__file__)

T = TypeVar("T") # pylint: disable=invalid-name


Expand Down
2 changes: 2 additions & 0 deletions grain/_src/python/dataset/transformations/prefetch_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from absl import logging
from absl.testing import absltest
from absl.testing import parameterized
from grain._src.core import config as grain_config
from grain._src.core import transforms
import multiprocessing as mp
from grain._src.python import options
Expand Down Expand Up @@ -76,6 +77,7 @@ class PrefetchIterDatasetTest(parameterized.TestCase):

def setUp(self):
super().setUp()
grain_config.config.update('py_traceback_filtering', 'off')
self.range_ds = dataset.MapDataset.range(20)
self.filtered_range_ds = self.range_ds.filter(
FilterKeepingOddElementsOnly()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,7 @@ class MultiprocessingPrefetchTest(parameterized.TestCase):

def setUp(self):
super().setUp()
config.config.update('py_traceback_filtering', 'off')
ds = dataset.MapDataset.range(20)
self.iter_ds = ds.to_iter_dataset().filter(FilterKeepingOddElementsOnly())

Expand Down
59 changes: 59 additions & 0 deletions grain/_src/python/dataset/transformations/traceback_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Iterator that filters the stacktrace of thrown errors."""

from __future__ import annotations

from typing import Any, TypeVar

from grain._src.core import traceback_util
from grain._src.python.dataset import dataset
from grain._src.python.dataset import stats as dataset_stats


traceback_util.register_exclusion(__file__)


T = TypeVar("T")


class TracebackFilterDatasetIterator(dataset.DatasetIterator[T]):
"""Filters internal stack frames from the stacktrace of thrown errors."""

_MUTATES_ELEMENT_SPEC = False

def __init__(
self, parent: dataset.DatasetIterator[T], traceback_filter_mode: str
):
super().__init__(parent)
self._traceback_filter_mode = traceback_filter_mode

@traceback_util.run_with_traceback_filter
@dataset_stats.record_next_duration_if_output
@dataset_stats.trace_input_pipeline_next(
stage_category=dataset_stats.IPL_CAT_META
)
def __next__(self) -> T:
element = next(self._parent)
with self._stats.record_self_time():
return self._stats.record_output_spec(element)

def get_state(self) -> dict[str, Any]:
return self._parent.get_state()

def set_state(self, state: dict[str, Any]):
self._parent.set_state(state)

def __str__(self) -> str:
return f"TracebackFilterDatasetIterator(traceback_filter_mode={self._traceback_filter_mode})"
99 changes: 99 additions & 0 deletions grain/_src/python/dataset/transformations/traceback_filter_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import traceback
from typing import Callable

from grain._src.core import traceback_util
import grain.python as pygrain

from absl.testing import absltest


traceback_util.register_exclusion(__file__)


def _assert_exception_with_short_traceback(
self: absltest.TestCase,
fn: Callable[[], None],
expected_error_type: type[BaseException],
) -> None:
"""Asserts that a function raises a specific exception with a short traceback.

This function executes `fn` and asserts that it raises an exception of type
`expected_error_type`. Additionally, it checks that the length of the
traceback associated with the caught exception is less than 15 frames,
ensuring that the traceback has been shortened.

Args:
self: The absltest.TestCase instance.
fn: The function to execute, expected to raise an exception.
expected_error_type: The expected type of the exception to be raised.
"""
# Assert that an exception is raised and the length of the traceback is
# sufficiently short, i.e. less than 15 frames.
# We cannot use assertRaises because __traceback__ is cleared before we can
# inspect it.
try:
fn()
self.fail(f"Expected {expected_error_type} to be raised.")
except expected_error_type as e:
tb = traceback.extract_tb(e.__traceback__)
except Exception as e: # pylint: disable=broad-except
self.fail(
f"Expected {expected_error_type} to be raised, but got"
f" {type(e)} instead."
)
self.assertLess(len(tb), 15, f"Traceback is too long: \n{tb}")


class AddOneTransform(pygrain.MapTransform):

def map(self, x: int) -> int:
return x + 1


class RaiseErrorTransform(pygrain.MapTransform):

def map(self, x: int) -> int:
raise ValueError("Boom!")


class TracebackFilterTest(absltest.TestCase):

def test_datasource_multiple_transforms_filters_traceback(self):
range_ds = pygrain.RangeDataSource(0, 10, 1)
sampler = pygrain.IndexSampler(num_records=10, seed=42)
ops = [RaiseErrorTransform()]
for _ in range(100):
ops.append(AddOneTransform())
data_loader = pygrain.DataLoader(
data_source=range_ds, sampler=sampler, operations=ops
)
_assert_exception_with_short_traceback(
self, lambda: next(iter(data_loader)), ValueError
)

def test_dataset_multiple_transforms_filters_traceback(self):
range_ds = pygrain.MapDataset.range(0, 10)
range_ds = range_ds.map(RaiseErrorTransform())
for _ in range(100):
range_ds = range_ds.map(AddOneTransform())
_assert_exception_with_short_traceback(
self, lambda: next(iter(range_ds)), ValueError
)


if __name__ == "__main__":
absltest.main()
Loading