Skip to content

Commit 2a822f2

Browse files
committed
Preemptive fixes for upcoming Beam PR 10717, which will enable Python 3 annotations support by default.
PiperOrigin-RevId: 299228970
1 parent 26a198b commit 2a822f2

File tree

4 files changed

+11
-12
lines changed

4 files changed

+11
-12
lines changed

tensorflow_data_validation/api/stats_api.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@
4848
import apache_beam as beam
4949
import pyarrow as pa
5050
from tensorflow_data_validation import constants
51-
from tensorflow_data_validation import types
5251
from tensorflow_data_validation.statistics import stats_impl
5352
from tensorflow_data_validation.statistics import stats_options
5453
from typing import Generator
@@ -122,8 +121,8 @@ def expand(self, dataset: beam.pvalue.PCollection) -> beam.pvalue.PCollection:
122121
stats_impl.GenerateStatisticsImpl(self._options))
123122

124123

125-
def _sample_at_rate(example: types.Example, sample_rate: float
126-
) -> Generator[types.Example, None, None]:
124+
def _sample_at_rate(example: pa.Table, sample_rate: float
125+
) -> Generator[pa.Table, None, None]:
127126
"""Sample examples at input sampling rate."""
128127
# TODO(pachristopher): Revisit this to decide if we need to fix a seed
129128
# or add an optional seed argument.

tensorflow_data_validation/statistics/generators/lift_stats_generator.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def _get_example_value_presence(
185185
def _to_partial_copresence_counts(
186186
sliced_table: types.SlicedTable, y_path: types.FeaturePath,
187187
x_paths: Iterable[types.FeaturePath],
188-
y_boundaries: Optional[Iterable[float]], weight_column_name: Optional[Text]
188+
y_boundaries: Optional[np.ndarray], weight_column_name: Optional[Text]
189189
) -> Iterator[Tuple[_SlicedXYKey, _CountType]]:
190190
"""Yields per-(slice, path_x, x, y) counts of examples with x and y.
191191
@@ -234,7 +234,7 @@ def _to_partial_copresence_counts(
234234

235235
def _to_partial_counts(
236236
sliced_table: types.SlicedTable, path: types.FeaturePath,
237-
boundaries: Optional[Iterable[float]], weight_column_name: Optional[Text]
237+
boundaries: Optional[np.ndarray], weight_column_name: Optional[Text]
238238
) -> Iterator[Tuple[Tuple[types.SliceKey, Union[_XType, _YType]], _CountType]]:
239239
"""Yields per-(slice, value) counts of the examples with value in path."""
240240
slice_key, table = sliced_table
@@ -274,9 +274,9 @@ def _get_unicode_value(value: Union[Text, bytes], path: types.FeaturePath):
274274

275275

276276
def _make_dataset_feature_stats_proto(
277-
lifts: Tuple[_SlicedFeatureKey, _LiftSeries], y_path: types.FeaturePath,
278-
y_boundaries: Optional[np.ndarray], weighted_examples: bool,
279-
output_custom_stats: bool
277+
lifts: Tuple[_SlicedFeatureKey, Iterable[_LiftSeries]],
278+
y_path: types.FeaturePath, y_boundaries: Optional[np.ndarray],
279+
weighted_examples: bool, output_custom_stats: bool
280280
) -> Tuple[types.SliceKey, statistics_pb2.DatasetFeatureStatistics]:
281281
"""Generates DatasetFeatureStatistics proto for a given x_path, y_path pair.
282282

tensorflow_data_validation/statistics/generators/top_k_uniques_stats_generator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def _make_dataset_feature_stats_proto_with_topk_for_single_feature(
155155
FeaturePathTuple],
156156
List[FeatureValueCount]],
157157
categorical_features: Set[types.FeaturePath], is_weighted_stats: bool,
158-
num_top_values: int, frequency_threshold: float,
158+
num_top_values: int, frequency_threshold: Union[int, float],
159159
num_rank_histogram_buckets: int
160160
) -> Tuple[types.SliceKey, statistics_pb2.DatasetFeatureStatistics]:
161161
"""Makes a DatasetFeatureStatistics proto with top-k stats for a feature."""
@@ -198,12 +198,12 @@ def _weighted_unique(values: np.ndarray, weights: np.ndarray
198198

199199

200200
def _to_topk_tuples(
201-
sliced_table: Tuple[Text, pa.Table],
201+
sliced_table: Tuple[types.SliceKey, pa.Table],
202202
bytes_features: FrozenSet[types.FeaturePath],
203203
categorical_features: FrozenSet[types.FeaturePath],
204204
weight_feature: Optional[Text]
205205
) -> Iterable[
206-
Tuple[Tuple[Text, FeaturePathTuple, Any],
206+
Tuple[Tuple[types.SliceKey, FeaturePathTuple, Any],
207207
Union[int, Tuple[int, Union[int, float]]]]]:
208208
"""Generates tuples for computing top-k and uniques from input tables."""
209209
slice_key, table = sliced_table

tensorflow_data_validation/statistics/stats_impl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def __init__(
118118
"""
119119
self._options = options
120120
self._is_slicing_enabled = (
121-
is_slicing_enabled or self._options.slice_functions)
121+
is_slicing_enabled or bool(self._options.slice_functions))
122122

123123
def expand(self, dataset: beam.pvalue.PCollection) -> beam.pvalue.PCollection:
124124
# Handles generators by their type:

0 commit comments

Comments
 (0)