diff --git a/src/datasets/packaged_modules/json/json.py b/src/datasets/packaged_modules/json/json.py index c5d8bcd03fc..7f612e02bc0 100644 --- a/src/datasets/packaged_modules/json/json.py +++ b/src/datasets/packaged_modules/json/json.py @@ -1,7 +1,7 @@ import io import itertools from dataclasses import dataclass -from typing import Optional +from typing import Optional, List import pandas as pd import pyarrow as pa @@ -50,6 +50,7 @@ class JsonConfig(datasets.BuilderConfig): block_size: Optional[int] = None # deprecated chunksize: int = 10 << 20 # 10MB newlines_in_values: Optional[bool] = None + columns: Optional[List[str]] = None def __post_init__(self): super().__post_init__() @@ -121,6 +122,12 @@ def _generate_tables(self, files): if df.columns.tolist() == [0]: df.columns = list(self.config.features) if self.config.features else ["text"] pa_table = pa.Table.from_pandas(df, preserve_index=False) + if self.config.columns is not None: + missing_cols = [col for col in self.config.columns if col not in pa_table.column_names] + if missing_cols: + for col in missing_cols: + pa_table = pa_table.append_column(col, pa.array([None] * len(pa_table))) + pa_table = pa_table.select(self.config.columns) yield file_idx, self._cast_table(pa_table) # If the file has one json object per line @@ -186,7 +193,21 @@ def _generate_tables(self, files): raise ValueError( f"Failed to convert pandas DataFrame to Arrow Table from file {file}." ) from None + # Column filtering (pandas fallback) + if self.config.columns is not None: + missing_cols = [col for col in self.config.columns if col not in pa_table.column_names] + if missing_cols: + for col in missing_cols: + pa_table = pa_table.append_column(col, pa.array([None] * len(pa_table))) + pa_table = pa_table.select(self.config.columns) yield file_idx, self._cast_table(pa_table) break + # Column filtering (Arrow JSON path) + if self.config.columns is not None: + missing_cols = [col for col in self.config.columns if col not in pa_table.column_names] + if missing_cols: + for col in missing_cols: + pa_table = pa_table.append_column(col, pa.array([None] * len(pa_table))) + pa_table = pa_table.select(self.config.columns) yield (file_idx, batch_idx), self._cast_table(pa_table) batch_idx += 1 diff --git a/tests/packaged_modules/test_json.py b/tests/packaged_modules/test_json.py index 18f066b5e68..15a1c8df90a 100644 --- a/tests/packaged_modules/test_json.py +++ b/tests/packaged_modules/test_json.py @@ -265,3 +265,30 @@ def test_json_generate_tables_with_sorted_columns(file_fixture, config_kwargs, r generator = builder._generate_tables([[request.getfixturevalue(file_fixture)]]) pa_table = pa.concat_tables([table for _, table in generator]) assert pa_table.column_names == ["ID", "Language", "Topic"] + + +def test_json_generate_tables_with_columns_subset(jsonl_file): + # Keep only col_1 + builder = Json(columns=["col_1"]) + generator = builder._generate_tables([[jsonl_file]]) + pa_table = pa.concat_tables([table for _, table in generator]) + assert pa_table.column_names == ["col_1"] + assert pa_table.to_pydict() == {"col_1": [-1, 1, 10]} + + +def test_json_generate_tables_with_columns_and_missing(jsonl_file): + # Ask for col_1 and a non-existent column -> should be filled with None + builder = Json(columns=["col_1", "missing_col"]) + generator = builder._generate_tables([[jsonl_file]]) + pa_table = pa.concat_tables([table for _, table in generator]) + assert pa_table.column_names == ["col_1", "missing_col"] + assert pa_table.to_pydict() == {"col_1": [-1, 1, 10], "missing_col": [None, None, None]} + + +def test_json_generate_tables_with_columns_on_list_of_strings(json_file_with_list_of_strings): + # list-of-strings becomes a single "text" column; ensure selection works + builder = Json(columns=["text"]) + generator = builder._generate_tables([[json_file_with_list_of_strings]]) + pa_table = pa.concat_tables([table for _, table in generator]) + assert pa_table.column_names == ["text"] + assert pa_table.to_pydict() == {"text": ["First text.", "Second text.", "Third text."]}