Skip to content

Commit 1b9acba

Browse files
paulgc17tf-data-validation-team
authored andcommitted
Avoid using multiprocessing by default when generating statistics over a
dataframe. PiperOrigin-RevId: 236177889
1 parent 21da233 commit 1b9acba

File tree

3 files changed

+34
-3
lines changed

3 files changed

+34
-3
lines changed

RELEASE.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,12 @@
99

1010
## Bug Fixes and Other Changes
1111
* Expand unit test coverage.
12-
1312
* Modify validation logic to raise `SCHEMA_MISSING_COLUMN` anomaly when
1413
observing a feature with no stats.
1514
* Add utility functions `write_stats_text` and `load_stats_text` to write and
1615
load DatasetFeatureStatisticsList protos.
16+
* Avoid using multiprocessing by default when generating statistics over a
17+
dataframe.
1718
* Depends on `joblib>=0.12,<1`.
1819
* Requires pre-installed `tensorflow>=1.13,<2`.
1920

tensorflow_data_validation/utils/stats_gen_lib.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def generate_statistics_from_csv(
173173
def generate_statistics_from_dataframe(
174174
dataframe,
175175
stats_options = options.StatsOptions(),
176-
n_jobs = multiprocessing.cpu_count()
176+
n_jobs = 1
177177
):
178178
"""Compute data statistics for the input pandas DataFrame.
179179
@@ -183,7 +183,8 @@ def generate_statistics_from_dataframe(
183183
Args:
184184
dataframe: Input pandas DataFrame.
185185
stats_options: Options for generating data statistics.
186-
n_jobs: Number of processes to run.
186+
n_jobs: Number of processes to run (defaults to 1). If -1 is provided,
187+
uses the same number of processes as the number of CPU cores.
187188
188189
Returns:
189190
A DatasetFeatureStatisticsList proto.
@@ -193,7 +194,14 @@ def generate_statistics_from_dataframe(
193194
'pandas DataFrame.'.format(type(dataframe).__name__))
194195

195196
stats_generators = stats_impl.get_generators(stats_options, in_memory=True)
197+
if n_jobs < -1 or n_jobs == 0:
198+
raise ValueError('Invalid n_jobs parameter {}. Should be either '
199+
' -1 or >= 1.'.format(n_jobs))
200+
201+
if n_jobs == -1:
202+
n_jobs = multiprocessing.cpu_count()
196203
n_jobs = max(min(n_jobs, multiprocessing.cpu_count()), 1)
204+
197205
if n_jobs == 1:
198206
merged_partial_stats = _generate_partial_statistics_from_df(
199207
dataframe, stats_options, stats_generators)

tensorflow_data_validation/utils/stats_gen_lib_test.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -565,6 +565,28 @@ def test_stats_gen_with_dataframe(self):
565565
test_util.assert_dataset_feature_stats_proto_equal(
566566
self, result.datasets[0], expected_result.datasets[0])
567567

568+
def test_stats_gen_with_dataframe_invalid_njobs_zero(self):
569+
records, _, _ = self._get_csv_test(delimiter=',', with_header=True)
570+
input_data_path = self._write_records_to_csv(records, self._get_temp_dir(),
571+
'input_data.csv')
572+
dataframe = pd.read_csv(input_data_path)
573+
with self.assertRaisesRegexp(
574+
ValueError, 'Invalid n_jobs parameter.*'):
575+
_ = stats_gen_lib.generate_statistics_from_dataframe(
576+
dataframe=dataframe,
577+
stats_options=self._default_stats_options, n_jobs=0)
578+
579+
def test_stats_gen_with_dataframe_invalid_njobs_negative(self):
580+
records, _, _ = self._get_csv_test(delimiter=',', with_header=True)
581+
input_data_path = self._write_records_to_csv(records, self._get_temp_dir(),
582+
'input_data.csv')
583+
dataframe = pd.read_csv(input_data_path)
584+
with self.assertRaisesRegexp(
585+
ValueError, 'Invalid n_jobs parameter.*'):
586+
_ = stats_gen_lib.generate_statistics_from_dataframe(
587+
dataframe=dataframe,
588+
stats_options=self._default_stats_options, n_jobs=-2)
589+
568590

569591
if __name__ == '__main__':
570592
absltest.main()

0 commit comments

Comments
 (0)