Skip to content

Commit 7c0f397

Browse files
Alfonso CastañoThe TensorFlow Datasets Authors
authored andcommitted
Base HF dataset name on url of the croissant file
PiperOrigin-RevId: 640134142
1 parent 27e89c3 commit 7c0f397

File tree

4 files changed

+48
-6
lines changed

4 files changed

+48
-6
lines changed

tensorflow_datasets/core/dataset_builders/croissant_builder.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,9 @@ def __init__(
173173
if mapping is None:
174174
mapping = {}
175175
self.dataset = mlc.Dataset(jsonld, mapping=mapping)
176-
self.name = huggingface_utils.convert_hf_name(self.dataset.metadata.name)
176+
self.name = huggingface_utils.get_tfds_name_from_croissant_dataset(
177+
self.dataset
178+
)
177179
self.metadata = self.dataset.metadata
178180

179181
# In TFDS, version is a mandatory attribute, while in Croissant it is only a

tensorflow_datasets/core/dataset_builders/huggingface_dataset_builder.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,9 @@ def __init__(
203203
self._hf_repo_id = hf_repo_id
204204
self._hf_config = hf_config
205205
self.config_kwargs = config_kwargs
206-
tfds_config = huggingface_utils.convert_hf_name(hf_config)
206+
tfds_config = (
207+
huggingface_utils.convert_hf_name(hf_config) if hf_config else None
208+
)
207209
try:
208210
self._hf_builder = hf_datasets.load_dataset_builder(
209211
self._hf_repo_id, self._hf_config, **self.config_kwargs

tensorflow_datasets/core/utils/huggingface_utils.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,11 @@
1515

1616
"""Utility functions for huggingface_dataset_builder."""
1717

18+
from __future__ import annotations
19+
1820
from collections.abc import Mapping, Sequence
1921
import datetime
22+
import typing
2023
from typing import Any, Type, TypeVar
2124

2225
from etils import epath
@@ -30,6 +33,10 @@
3033
from tensorflow_datasets.core.utils.lazy_imports_utils import datasets as hf_datasets
3134
from tensorflow_datasets.core.utils.lazy_imports_utils import tensorflow as tf
3235

36+
if typing.TYPE_CHECKING:
37+
# pylint: disable=g-bad-import-order
38+
import mlcroissant as mlc
39+
3340

3441
_HF_DTYPE_TO_NP_DTYPE = immutabledict.immutabledict({
3542
'bool': np.bool_,
@@ -242,7 +249,17 @@ def convert_hf_value(
242249
)
243250

244251

245-
def convert_hf_name(hf_name: _StrOrNone) -> _StrOrNone:
252+
def get_tfds_name_from_croissant_dataset(dataset: mlc.Dataset) -> str:
253+
"""Returns TFDS compatible dataset name of the given MLcroissant dataset."""
254+
if (url := dataset.metadata.url) and url.startswith(
255+
'https://huggingface.co/datasets/'
256+
):
257+
url_suffix = url.removeprefix('https://huggingface.co/datasets/')
258+
return convert_hf_name(url_suffix)
259+
return convert_hf_name(dataset.metadata.name)
260+
261+
262+
def convert_hf_name(hf_name: str) -> str:
246263
"""Converts Huggingface name to a TFDS compatible dataset name.
247264
248265
Huggingface names can contain characters that are not supported in
@@ -259,8 +276,6 @@ def convert_hf_name(hf_name: _StrOrNone) -> _StrOrNone:
259276
The TFDS compatible dataset name (dataset names, config names and split
260277
names).
261278
"""
262-
if hf_name is None:
263-
return hf_name
264279
hf_name = hf_name.lower().replace('/', '__')
265280
return py_utils.make_valid_name(hf_name)
266281

tensorflow_datasets/core/utils/huggingface_utils_test.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from tensorflow_datasets.core import lazy_imports_lib
2323
from tensorflow_datasets.core import registered
2424
from tensorflow_datasets.core.utils import huggingface_utils
25+
from tensorflow_datasets.core.utils.lazy_imports_utils import mlcroissant as mlc
2526
from tensorflow_datasets.core.utils.lazy_imports_utils import tensorflow as tf
2627

2728

@@ -216,6 +217,29 @@ def test_convert_value(hf_value, feature, expected_value):
216217
assert huggingface_utils.convert_hf_value(hf_value, feature) == expected_value
217218

218219

220+
@pytest.mark.parametrize(
221+
'croissant_name,croissant_url,tfds_name',
222+
[
223+
(
224+
'Name+1',
225+
'https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k',
226+
'huggingfaceh4__ultrachat_200k',
227+
),
228+
('Name+1', 'bad_url', 'name_1'),
229+
('Name+1', None, 'name_1'),
230+
],
231+
)
232+
def test_get_tfds_name_from_croissant_dataset(
233+
croissant_name, croissant_url, tfds_name
234+
):
235+
metadata = mlc.Metadata(name=croissant_name, url=croissant_url)
236+
dataset = mlc.Dataset.from_metadata(metadata)
237+
assert (
238+
huggingface_utils.get_tfds_name_from_croissant_dataset(dataset)
239+
== tfds_name
240+
)
241+
242+
219243
@pytest.mark.parametrize(
220244
'hf_name,tfds_name',
221245
[
@@ -228,7 +252,6 @@ def test_convert_value(hf_value, feature, expected_value):
228252
# Config and split names
229253
('x.y', 'x_y'),
230254
('x_v1.0', 'x_v1_0'),
231-
(None, None),
232255
],
233256
)
234257
def test_from_hf_to_tfds(hf_name, tfds_name):

0 commit comments

Comments
 (0)