diff --git a/python/ray/data/datasource/parquet_meta_provider.py b/python/ray/data/datasource/parquet_meta_provider.py index c8484574da18..73a3c41ef6e2 100644 --- a/python/ray/data/datasource/parquet_meta_provider.py +++ b/python/ray/data/datasource/parquet_meta_provider.py @@ -41,10 +41,10 @@ def __init__(self, fragment_metadata: "pyarrow.parquet.FileMetaData"): self.num_row_groups = fragment_metadata.num_row_groups self.num_rows = fragment_metadata.num_rows self.serialized_size = fragment_metadata.serialized_size - # This is a pickled schema object, to be set later with - # `self.set_schema_pickled()`. To get the underlying schema, use - # `cloudpickle.loads(self.schema_pickled)`. - self.schema_pickled = None + + # Serialize the schema directly in the constructor + schema_ser = cloudpickle.dumps(fragment_metadata.schema.to_arrow_schema()) + self.schema_pickled = schema_ser # Calculate the total byte size of the file fragment using the original # object, as it is not possible to access row groups from this class. @@ -150,10 +150,16 @@ def fetch_func(fragments): **ray_remote_args, ) ) + + return _dedupe_schemas(raw_metadata) + else: + # We don't deduplicate schemas in this branch because they're already + # deduplicated in `_fetch_metadata`. See + # https://github.com/ray-project/ray/pull/54821/files#r2265140929 for + # related discussion. raw_metadata = _fetch_metadata(fragments) - - return _dedupe_metadata(raw_metadata) + return raw_metadata def _fetch_metadata_serialization_wrapper( @@ -161,7 +167,7 @@ def _fetch_metadata_serialization_wrapper( retry_match: Optional[List[str]], retry_max_attempts: int, retry_max_interval: int, -) -> List["pyarrow.parquet.FileMetaData"]: +) -> List["_ParquetFileFragmentMetaData"]: from ray.data._internal.datasource.parquet_datasource import ( _deserialize_fragments_with_retry, ) @@ -209,39 +215,53 @@ def _fetch_metadata_serialization_wrapper( def _fetch_metadata( fragments: List["pyarrow.dataset.ParquetFileFragment"], -) -> List["pyarrow.parquet.FileMetaData"]: - fragment_metadata = [] +) -> List[_ParquetFileFragmentMetaData]: + fragment_metadatas = [] for f in fragments: try: - fragment_metadata.append(f.metadata) + # Convert directly to _ParquetFileFragmentMetaData + fragment_metadatas.append(_ParquetFileFragmentMetaData(f.metadata)) except AttributeError: break - return fragment_metadata + # Deduplicate schemas to reduce memory usage + return _dedupe_schemas(fragment_metadatas) -def _dedupe_metadata( - raw_metadatas: List["pyarrow.parquet.FileMetaData"], +def _dedupe_schemas( + metadatas: List[_ParquetFileFragmentMetaData], ) -> List[_ParquetFileFragmentMetaData]: - """For datasets with a large number of columns, the FileMetaData - (in particular the schema) can be very large. We can reduce the - memory usage by only keeping unique schema objects across all - file fragments. This method deduplicates the schemas and returns - a list of `_ParquetFileFragmentMetaData` objects.""" - schema_to_id = {} # schema_id -> serialized_schema - id_to_schema = {} # serialized_schema -> schema_id - stripped_metadatas = [] - for fragment_metadata in raw_metadatas: - stripped_md = _ParquetFileFragmentMetaData(fragment_metadata) + """Deduplicates schema objects across existing _ParquetFileFragmentMetaData objects. + + For datasets with a large number of columns, the pickled schema can be very large. + This function reduces memory usage by ensuring that identical schemas across multiple + fragment metadata objects reference the same underlying pickled schema object, + rather than each fragment maintaining its own copy. + + Args: + metadatas: List of _ParquetFileFragmentMetaData objects that already have + pickled schemas set. + + Returns: + The same list of _ParquetFileFragmentMetaData objects, but with duplicate + schemas deduplicated to reference the same object in memory. + """ + schema_to_id = {} # schema_ser -> schema_id + id_to_schema = {} # schema_id -> schema_ser + + for metadata in metadatas: + # Get the current schema serialization + schema_ser = metadata.schema_pickled - schema_ser = cloudpickle.dumps(fragment_metadata.schema.to_arrow_schema()) if schema_ser not in schema_to_id: + # This is a new unique schema schema_id = len(schema_to_id) schema_to_id[schema_ser] = schema_id id_to_schema[schema_id] = schema_ser - stripped_md.set_schema_pickled(schema_ser) + # No need to set schema_pickled - it already has the correct value else: - schema_id = schema_to_id.get(schema_ser) + # This schema already exists, reuse the existing one + schema_id = schema_to_id[schema_ser] existing_schema_ser = id_to_schema[schema_id] - stripped_md.set_schema_pickled(existing_schema_ser) - stripped_metadatas.append(stripped_md) - return stripped_metadatas + metadata.set_schema_pickled(existing_schema_ser) + + return metadatas