-
Notifications
You must be signed in to change notification settings - Fork 51
feat: Seed dataset plugins #191
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
cc145c7
0fb362f
870e571
831415f
107b558
10e0ed5
a578283
c77d11c
2f30f4c
9ec794b
10e8055
6e1ec65
87f2480
c03d1e5
8ac66fd
64d0171
0e9b456
f81bcfd
f2913cc
f96f60e
1f4caf7
09bf334
7f73385
e1c842c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -91,3 +91,5 @@ docs/notebooks/ | |
| docs/notebook_source/*.ipynb | ||
| docs/notebook_source/*.csv | ||
| docs/**/artifacts/ | ||
|
|
||
| e2e_tests/uv.lock | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,38 @@ | ||
| [project] | ||
| name = "data-designer-e2e-tests" | ||
| version = "0.0.1" | ||
| requires-python = ">=3.10" | ||
|
|
||
| dependencies = [ | ||
| "data-designer", | ||
| ] | ||
|
|
||
| [tool.uv.sources] | ||
| data-designer = { path = "../" } | ||
|
|
||
| [dependency-groups] | ||
| dev = [ | ||
| "pytest>=8.3.3,<9", | ||
| ] | ||
|
|
||
| [project.entry-points."data_designer.plugins"] | ||
| demo-column-generator = "data_designer_e2e_tests.plugins.column_generator.plugin:column_generator_plugin" | ||
| demo-seed-reader = "data_designer_e2e_tests.plugins.seed_reader.plugin:seed_reader_plugin" | ||
|
|
||
| [tool.pytest.ini_options] | ||
| testpaths = ["tests"] | ||
| env = [ | ||
| # ensure plugins are enabled | ||
| "DISABLE_DATA_DESIGNER_PLUGINS=false", | ||
| ] | ||
|
|
||
| [tool.uv] | ||
| package = true | ||
| required-version = ">=0.7.10" | ||
|
|
||
| [build-system] | ||
| requires = ["hatchling"] | ||
| build-backend = "hatchling.build" | ||
|
|
||
| [tool.hatch.build.targets.wheel] | ||
| packages = ["src/data_designer_e2e_tests"] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,12 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| from typing import Literal | ||
|
|
||
| from data_designer.config.column_configs import SingleColumnConfig | ||
|
|
||
|
|
||
| class DemoColumnGeneratorConfig(SingleColumnConfig): | ||
| column_type: Literal["demo-column-generator"] = "demo-column-generator" | ||
|
|
||
| text: str |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,26 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| import pandas as pd | ||
|
|
||
| from data_designer.engine.column_generators.generators.base import ( | ||
| ColumnGenerator, | ||
| GenerationStrategy, | ||
| GeneratorMetadata, | ||
| ) | ||
| from data_designer_e2e_tests.plugins.column_generator.config import DemoColumnGeneratorConfig | ||
|
|
||
|
|
||
| class DemoColumnGeneratorImpl(ColumnGenerator[DemoColumnGeneratorConfig]): | ||
| @staticmethod | ||
| def metadata() -> GeneratorMetadata: | ||
| return GeneratorMetadata( | ||
| name="demo-column-generator", | ||
| description="Shouts at you", | ||
| generation_strategy=GenerationStrategy.FULL_COLUMN, | ||
| ) | ||
|
|
||
| def generate(self, data: pd.DataFrame) -> pd.DataFrame: | ||
| data[self.config.name] = self.config.text.upper() | ||
|
|
||
| return data |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,10 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| from data_designer.plugins.plugin import Plugin, PluginType | ||
|
|
||
| column_generator_plugin = Plugin( | ||
| config_qualified_name="data_designer_e2e_tests.plugins.column_generator.config.DemoColumnGeneratorConfig", | ||
| impl_qualified_name="data_designer_e2e_tests.plugins.column_generator.impl.DemoColumnGeneratorImpl", | ||
| plugin_type=PluginType.COLUMN_GENERATOR, | ||
| ) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,13 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| from typing import Literal | ||
|
|
||
| from data_designer.config.seed_source import SeedSource | ||
|
|
||
|
|
||
| class DemoSeedSource(SeedSource): | ||
| seed_type: Literal["demo-seed-reader"] = "demo-seed-reader" | ||
|
|
||
| directory: str | ||
| filename: str |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,15 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| import duckdb | ||
|
|
||
| from data_designer.engine.resources.seed_reader import SeedReader | ||
| from data_designer_e2e_tests.plugins.seed_reader.config import DemoSeedSource | ||
|
|
||
|
|
||
| class DemoSeedReader(SeedReader[DemoSeedSource]): | ||
| def create_duckdb_connection(self) -> duckdb.DuckDBPyConnection: | ||
| return duckdb.connect() | ||
|
|
||
| def get_dataset_uri(self) -> str: | ||
| return f"{self.source.directory}/{self.source.filename}" |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,10 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| from data_designer.plugins.plugin import Plugin, PluginType | ||
|
|
||
| seed_reader_plugin = Plugin( | ||
| config_qualified_name="data_designer_e2e_tests.plugins.seed_reader.config.DemoSeedSource", | ||
| impl_qualified_name="data_designer_e2e_tests.plugins.seed_reader.impl.DemoSeedReader", | ||
| plugin_type=PluginType.SEED_READER, | ||
| ) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,73 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| from pathlib import Path | ||
|
|
||
| from data_designer.essentials import ( | ||
| CategorySamplerParams, | ||
| DataDesigner, | ||
| DataDesignerConfigBuilder, | ||
| ExpressionColumnConfig, | ||
| SamplerColumnConfig, | ||
| SamplerType, | ||
| ) | ||
| from data_designer_e2e_tests.plugins.column_generator.config import DemoColumnGeneratorConfig | ||
| from data_designer_e2e_tests.plugins.seed_reader.config import DemoSeedSource | ||
|
|
||
|
|
||
| def test_column_generator_plugin(): | ||
| data_designer = DataDesigner() | ||
|
|
||
| config_builder = DataDesignerConfigBuilder() | ||
| # This sampler column is necessary as a temporary workaround to https://github.com/NVIDIA-NeMo/DataDesigner/issues/4 | ||
| config_builder.add_column( | ||
| SamplerColumnConfig( | ||
| name="irrelevant", | ||
| sampler_type=SamplerType.CATEGORY, | ||
| params=CategorySamplerParams(values=["irrelevant"]), | ||
| ) | ||
| ) | ||
| config_builder.add_column( | ||
| DemoColumnGeneratorConfig( | ||
| name="upper", | ||
| text="hello world", | ||
| ) | ||
| ) | ||
|
|
||
| preview = data_designer.preview(config_builder) | ||
| capitalized = set(preview.dataset["upper"].values) | ||
|
|
||
| assert capitalized == {"HELLO WORLD"} | ||
|
|
||
|
|
||
| def test_seed_reader_plugin(): | ||
| current_dir = Path(__file__).parent | ||
|
|
||
| data_designer = DataDesigner() | ||
|
|
||
| config_builder = DataDesignerConfigBuilder() | ||
| config_builder.with_seed_dataset( | ||
| DemoSeedSource( | ||
| directory=str(current_dir), | ||
| filename="test_seed.csv", | ||
| ) | ||
| ) | ||
| # This sampler column is necessary as a temporary workaround to https://github.com/NVIDIA-NeMo/DataDesigner/issues/4 | ||
| config_builder.add_column( | ||
| SamplerColumnConfig( | ||
| name="irrelevant", | ||
| sampler_type=SamplerType.CATEGORY, | ||
| params=CategorySamplerParams(values=["irrelevant"]), | ||
| ) | ||
| ) | ||
| config_builder.add_column( | ||
| ExpressionColumnConfig( | ||
| name="full_name", | ||
| expr="{{ first_name }} + {{ last_name }}", | ||
| ) | ||
| ) | ||
|
|
||
| preview = data_designer.preview(config_builder) | ||
| full_names = set(preview.dataset["full_name"].values) | ||
|
|
||
| assert full_names == {"John + Coltrane", "Miles + Davis", "Bill + Evans"} |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,4 @@ | ||
| first_name,last_name | ||
| John,Coltrane | ||
| Miles,Davis | ||
| Bill,Evans | ||
|
Comment on lines
+2
to
+4
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. dream trio?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing bass and drums to fill out the quintet 😂 In all seriousness, I'm happy to change this to anything else if we want to have a "house style" for dummy data in tests like this.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "Astronomy and jazz legends" sounds like a sweet house style to me! |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,17 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| from typing import Annotated | ||
|
|
||
| from pydantic import Field | ||
| from typing_extensions import TypeAlias | ||
|
|
||
| from data_designer.config.seed_source import DataFrameSeedSource, HuggingFaceSeedSource, LocalFileSeedSource | ||
| from data_designer.plugin_manager import PluginManager | ||
|
|
||
| plugin_manager = PluginManager() | ||
|
|
||
| _SeedSourceT: TypeAlias = LocalFileSeedSource | HuggingFaceSeedSource | DataFrameSeedSource | ||
| _SeedSourceT = plugin_manager.inject_into_seed_source_type_union(_SeedSourceT) | ||
|
|
||
| SeedSourceT = Annotated[_SeedSourceT, Field(discriminator="seed_type")] | ||
|
Comment on lines
+12
to
+17
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This type alias needs to live in a separate module instead of
We didn't see this before* for column generator plugins because the base classes to inherit from are already defined in a different module than the
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
haha yes indeed that's why we split them up |
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does this need
--group devThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe
uv runautomatically uses the dev group? Seems to be working fine as-is, but I'm not opposed to adding it explicitly