diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e13d5031..5bce02ba 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -36,6 +36,30 @@ jobs: run: | uv run pytest -v --cov=data_designer --cov-report=term-missing --cov-report=xml --cov-fail-under=90 + test-e2e: + name: End to end test (Python ${{ matrix.python-version }} on ${{ matrix.os }}) + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, macos-latest] + python-version: ["3.10", "3.11", "3.12", "3.13"] + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Install uv + uses: astral-sh/setup-uv@v5 + with: + version: "latest" + python-version: ${{ matrix.python-version }} + enable-cache: true + + - name: Run e2e tests + run: | + make test-e2e + lint: name: Lint and Format Check runs-on: ubuntu-latest diff --git a/.gitignore b/.gitignore index 6f89c39d..7712099a 100644 --- a/.gitignore +++ b/.gitignore @@ -91,3 +91,5 @@ docs/notebooks/ docs/notebook_source/*.ipynb docs/notebook_source/*.csv docs/**/artifacts/ + +e2e_tests/uv.lock diff --git a/Makefile b/Makefile index 6085da6f..1ca7727a 100644 --- a/Makefile +++ b/Makefile @@ -63,28 +63,34 @@ check-all-fix: format lint-fix format: @echo "๐Ÿ“ Formatting code with ruff..." - uv run ruff format src/ tests/ scripts/ --exclude '**/src/data_designer/_version.py' + uv run ruff format src/ tests/ scripts/ e2e_tests/ --exclude '**/src/data_designer/_version.py' @echo "โœ… Formatting complete!" format-check: @echo "๐Ÿ“ Checking code formatting with ruff..." - uv run ruff format --check src/ tests/ scripts/ --exclude '**/src/data_designer/_version.py' + uv run ruff format --check src/ tests/ scripts/ e2e_tests/ --exclude '**/src/data_designer/_version.py' @echo "โœ… Formatting check complete! Run 'make format' to auto-fix issues." lint: @echo "๐Ÿ” Linting code with ruff..." - uv run ruff check --output-format=full src/ tests/ scripts/ --exclude '**/src/data_designer/_version.py' + uv run ruff check --output-format=full src/ tests/ scripts/ e2e_tests/ --exclude '**/src/data_designer/_version.py' @echo "โœ… Linting complete! Run 'make lint-fix' to auto-fix issues." lint-fix: @echo "๐Ÿ” Fixing linting issues with ruff..." - uv run ruff check --fix src/ tests/ scripts/ --exclude '**/src/data_designer/_version.py' + uv run ruff check --fix src/ tests/ scripts/ e2e_tests/ --exclude '**/src/data_designer/_version.py' @echo "โœ… Linting with autofix complete!" test: @echo "๐Ÿงช Running unit tests..." uv run --group dev pytest +test-e2e: + @echo "๐Ÿงน Cleaning e2e test environment..." + rm -rf e2e_tests/uv.lock e2e_tests/.pycache e2e_tests/.venv + @echo "๐Ÿงช Running e2e tests..." + uv run --no-cache --refresh --directory e2e_tests pytest -s + convert-execute-notebooks: @echo "๐Ÿ““ Converting Python tutorials to notebooks and executing..." @mkdir -p docs/notebooks diff --git a/e2e_tests/pyproject.toml b/e2e_tests/pyproject.toml new file mode 100644 index 00000000..fb00ed82 --- /dev/null +++ b/e2e_tests/pyproject.toml @@ -0,0 +1,38 @@ +[project] +name = "data-designer-e2e-tests" +version = "0.0.1" +requires-python = ">=3.10" + +dependencies = [ + "data-designer", +] + +[tool.uv.sources] +data-designer = { path = "../" } + +[dependency-groups] +dev = [ + "pytest>=8.3.3,<9", +] + +[project.entry-points."data_designer.plugins"] +demo-column-generator = "data_designer_e2e_tests.plugins.column_generator.plugin:column_generator_plugin" +demo-seed-reader = "data_designer_e2e_tests.plugins.seed_reader.plugin:seed_reader_plugin" + +[tool.pytest.ini_options] +testpaths = ["tests"] +env = [ + # ensure plugins are enabled + "DISABLE_DATA_DESIGNER_PLUGINS=false", +] + +[tool.uv] +package = true +required-version = ">=0.7.10" + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["src/data_designer_e2e_tests"] diff --git a/e2e_tests/src/data_designer_e2e_tests/plugins/__init__.py b/e2e_tests/src/data_designer_e2e_tests/plugins/__init__.py new file mode 100644 index 00000000..a4d63622 --- /dev/null +++ b/e2e_tests/src/data_designer_e2e_tests/plugins/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + diff --git a/e2e_tests/src/data_designer_e2e_tests/plugins/column_generator/__init__.py b/e2e_tests/src/data_designer_e2e_tests/plugins/column_generator/__init__.py new file mode 100644 index 00000000..a4d63622 --- /dev/null +++ b/e2e_tests/src/data_designer_e2e_tests/plugins/column_generator/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + diff --git a/e2e_tests/src/data_designer_e2e_tests/plugins/column_generator/config.py b/e2e_tests/src/data_designer_e2e_tests/plugins/column_generator/config.py new file mode 100644 index 00000000..85817a9e --- /dev/null +++ b/e2e_tests/src/data_designer_e2e_tests/plugins/column_generator/config.py @@ -0,0 +1,12 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from typing import Literal + +from data_designer.config.column_configs import SingleColumnConfig + + +class DemoColumnGeneratorConfig(SingleColumnConfig): + column_type: Literal["demo-column-generator"] = "demo-column-generator" + + text: str diff --git a/e2e_tests/src/data_designer_e2e_tests/plugins/column_generator/impl.py b/e2e_tests/src/data_designer_e2e_tests/plugins/column_generator/impl.py new file mode 100644 index 00000000..7993d637 --- /dev/null +++ b/e2e_tests/src/data_designer_e2e_tests/plugins/column_generator/impl.py @@ -0,0 +1,26 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import pandas as pd + +from data_designer.engine.column_generators.generators.base import ( + ColumnGenerator, + GenerationStrategy, + GeneratorMetadata, +) +from data_designer_e2e_tests.plugins.column_generator.config import DemoColumnGeneratorConfig + + +class DemoColumnGeneratorImpl(ColumnGenerator[DemoColumnGeneratorConfig]): + @staticmethod + def metadata() -> GeneratorMetadata: + return GeneratorMetadata( + name="demo-column-generator", + description="Shouts at you", + generation_strategy=GenerationStrategy.FULL_COLUMN, + ) + + def generate(self, data: pd.DataFrame) -> pd.DataFrame: + data[self.config.name] = self.config.text.upper() + + return data diff --git a/e2e_tests/src/data_designer_e2e_tests/plugins/column_generator/plugin.py b/e2e_tests/src/data_designer_e2e_tests/plugins/column_generator/plugin.py new file mode 100644 index 00000000..f1eec06a --- /dev/null +++ b/e2e_tests/src/data_designer_e2e_tests/plugins/column_generator/plugin.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from data_designer.plugins.plugin import Plugin, PluginType + +column_generator_plugin = Plugin( + config_qualified_name="data_designer_e2e_tests.plugins.column_generator.config.DemoColumnGeneratorConfig", + impl_qualified_name="data_designer_e2e_tests.plugins.column_generator.impl.DemoColumnGeneratorImpl", + plugin_type=PluginType.COLUMN_GENERATOR, +) diff --git a/e2e_tests/src/data_designer_e2e_tests/plugins/seed_reader/__init__.py b/e2e_tests/src/data_designer_e2e_tests/plugins/seed_reader/__init__.py new file mode 100644 index 00000000..a4d63622 --- /dev/null +++ b/e2e_tests/src/data_designer_e2e_tests/plugins/seed_reader/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + diff --git a/e2e_tests/src/data_designer_e2e_tests/plugins/seed_reader/config.py b/e2e_tests/src/data_designer_e2e_tests/plugins/seed_reader/config.py new file mode 100644 index 00000000..42a220bf --- /dev/null +++ b/e2e_tests/src/data_designer_e2e_tests/plugins/seed_reader/config.py @@ -0,0 +1,13 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from typing import Literal + +from data_designer.config.seed_source import SeedSource + + +class DemoSeedSource(SeedSource): + seed_type: Literal["demo-seed-reader"] = "demo-seed-reader" + + directory: str + filename: str diff --git a/e2e_tests/src/data_designer_e2e_tests/plugins/seed_reader/impl.py b/e2e_tests/src/data_designer_e2e_tests/plugins/seed_reader/impl.py new file mode 100644 index 00000000..7006300c --- /dev/null +++ b/e2e_tests/src/data_designer_e2e_tests/plugins/seed_reader/impl.py @@ -0,0 +1,15 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import duckdb + +from data_designer.engine.resources.seed_reader import SeedReader +from data_designer_e2e_tests.plugins.seed_reader.config import DemoSeedSource + + +class DemoSeedReader(SeedReader[DemoSeedSource]): + def create_duckdb_connection(self) -> duckdb.DuckDBPyConnection: + return duckdb.connect() + + def get_dataset_uri(self) -> str: + return f"{self.source.directory}/{self.source.filename}" diff --git a/e2e_tests/src/data_designer_e2e_tests/plugins/seed_reader/plugin.py b/e2e_tests/src/data_designer_e2e_tests/plugins/seed_reader/plugin.py new file mode 100644 index 00000000..d3dcf16d --- /dev/null +++ b/e2e_tests/src/data_designer_e2e_tests/plugins/seed_reader/plugin.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from data_designer.plugins.plugin import Plugin, PluginType + +seed_reader_plugin = Plugin( + config_qualified_name="data_designer_e2e_tests.plugins.seed_reader.config.DemoSeedSource", + impl_qualified_name="data_designer_e2e_tests.plugins.seed_reader.impl.DemoSeedReader", + plugin_type=PluginType.SEED_READER, +) diff --git a/e2e_tests/tests/test_e2e.py b/e2e_tests/tests/test_e2e.py new file mode 100644 index 00000000..a969034a --- /dev/null +++ b/e2e_tests/tests/test_e2e.py @@ -0,0 +1,73 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from pathlib import Path + +from data_designer.essentials import ( + CategorySamplerParams, + DataDesigner, + DataDesignerConfigBuilder, + ExpressionColumnConfig, + SamplerColumnConfig, + SamplerType, +) +from data_designer_e2e_tests.plugins.column_generator.config import DemoColumnGeneratorConfig +from data_designer_e2e_tests.plugins.seed_reader.config import DemoSeedSource + + +def test_column_generator_plugin(): + data_designer = DataDesigner() + + config_builder = DataDesignerConfigBuilder() + # This sampler column is necessary as a temporary workaround to https://github.com/NVIDIA-NeMo/DataDesigner/issues/4 + config_builder.add_column( + SamplerColumnConfig( + name="irrelevant", + sampler_type=SamplerType.CATEGORY, + params=CategorySamplerParams(values=["irrelevant"]), + ) + ) + config_builder.add_column( + DemoColumnGeneratorConfig( + name="upper", + text="hello world", + ) + ) + + preview = data_designer.preview(config_builder) + capitalized = set(preview.dataset["upper"].values) + + assert capitalized == {"HELLO WORLD"} + + +def test_seed_reader_plugin(): + current_dir = Path(__file__).parent + + data_designer = DataDesigner() + + config_builder = DataDesignerConfigBuilder() + config_builder.with_seed_dataset( + DemoSeedSource( + directory=str(current_dir), + filename="test_seed.csv", + ) + ) + # This sampler column is necessary as a temporary workaround to https://github.com/NVIDIA-NeMo/DataDesigner/issues/4 + config_builder.add_column( + SamplerColumnConfig( + name="irrelevant", + sampler_type=SamplerType.CATEGORY, + params=CategorySamplerParams(values=["irrelevant"]), + ) + ) + config_builder.add_column( + ExpressionColumnConfig( + name="full_name", + expr="{{ first_name }} + {{ last_name }}", + ) + ) + + preview = data_designer.preview(config_builder) + full_names = set(preview.dataset["full_name"].values) + + assert full_names == {"John + Coltrane", "Miles + Davis", "Bill + Evans"} diff --git a/e2e_tests/tests/test_seed.csv b/e2e_tests/tests/test_seed.csv new file mode 100644 index 00000000..749941b8 --- /dev/null +++ b/e2e_tests/tests/test_seed.csv @@ -0,0 +1,4 @@ +first_name,last_name +John,Coltrane +Miles,Davis +Bill,Evans diff --git a/pyproject.toml b/pyproject.toml index 37a985f1..bb5c3506 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -128,7 +128,7 @@ ignore = [ ] [tool.ruff.lint.isort] -known-first-party = ["data_designer"] +known-first-party = ["data_designer", "data_designer_e2e_tests"] [tool.ruff.lint.flake8-tidy-imports] ban-relative-imports = "all" diff --git a/scripts/update_license_headers.py b/scripts/update_license_headers.py index e863da99..9613980a 100644 --- a/scripts/update_license_headers.py +++ b/scripts/update_license_headers.py @@ -168,7 +168,7 @@ def main(path: Path, check_only: bool = False) -> tuple[int, int, int, list[Path total_updated = 0 total_skipped = 0 - for folder in ["src", "tests", "scripts"]: + for folder in ["src", "tests", "scripts", "e2e_tests"]: folder_path = repo_path / folder if not folder_path.exists(): continue diff --git a/src/data_designer/config/config_builder.py b/src/data_designer/config/config_builder.py index 52bdb582..bf0d6f30 100644 --- a/src/data_designer/config/config_builder.py +++ b/src/data_designer/config/config_builder.py @@ -39,7 +39,8 @@ SamplingStrategy, SeedConfig, ) -from data_designer.config.seed_source import DataFrameSeedSource, SeedSource +from data_designer.config.seed_source import DataFrameSeedSource +from data_designer.config.seed_source_types import SeedSourceT from data_designer.config.utils.constants import DEFAULT_REPR_HTML_STYLE, REPR_HTML_TEMPLATE from data_designer.config.utils.info import ConfigBuilderInfo from data_designer.config.utils.io_helpers import serialize_data, smart_load_yaml @@ -474,7 +475,7 @@ def num_columns_of_type(self, column_type: DataDesignerColumnType) -> int: def with_seed_dataset( self, - seed_source: SeedSource, + seed_source: SeedSourceT, *, sampling_strategy: SamplingStrategy = SamplingStrategy.ORDERED, selection_strategy: IndexRange | PartitionBlock | None = None, diff --git a/src/data_designer/config/seed.py b/src/data_designer/config/seed.py index 86070ff7..73f6ca67 100644 --- a/src/data_designer/config/seed.py +++ b/src/data_designer/config/seed.py @@ -7,7 +7,7 @@ from typing_extensions import Self from data_designer.config.base import ConfigBase -from data_designer.config.seed_source import SeedSourceT +from data_designer.config.seed_source_types import SeedSourceT class SamplingStrategy(str, Enum): diff --git a/src/data_designer/config/seed_source.py b/src/data_designer/config/seed_source.py index d95f3425..68660241 100644 --- a/src/data_designer/config/seed_source.py +++ b/src/data_designer/config/seed_source.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 from abc import ABC -from typing import Annotated, Literal +from typing import Literal import pandas as pd from pydantic import BaseModel, ConfigDict, Field, field_validator @@ -76,9 +76,3 @@ class DataFrameSeedSource(SeedSource): "you must use `LocalFileSeedSource` instead, since DataFrame objects are not serializable." ), ) - - -SeedSourceT = Annotated[ - LocalFileSeedSource | HuggingFaceSeedSource | DataFrameSeedSource, - Field(discriminator="seed_type"), -] diff --git a/src/data_designer/config/seed_source_types.py b/src/data_designer/config/seed_source_types.py new file mode 100644 index 00000000..ac5382fd --- /dev/null +++ b/src/data_designer/config/seed_source_types.py @@ -0,0 +1,17 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from typing import Annotated + +from pydantic import Field +from typing_extensions import TypeAlias + +from data_designer.config.seed_source import DataFrameSeedSource, HuggingFaceSeedSource, LocalFileSeedSource +from data_designer.plugin_manager import PluginManager + +plugin_manager = PluginManager() + +_SeedSourceT: TypeAlias = LocalFileSeedSource | HuggingFaceSeedSource | DataFrameSeedSource +_SeedSourceT = plugin_manager.inject_into_seed_source_type_union(_SeedSourceT) + +SeedSourceT = Annotated[_SeedSourceT, Field(discriminator="seed_type")] diff --git a/src/data_designer/interface/data_designer.py b/src/data_designer/interface/data_designer.py index 5e9d14c4..6e8f0659 100644 --- a/src/data_designer/interface/data_designer.py +++ b/src/data_designer/interface/data_designer.py @@ -60,6 +60,8 @@ ) from data_designer.interface.results import DatasetCreationResults from data_designer.logging import RandomEmoji +from data_designer.plugins.plugin import PluginType +from data_designer.plugins.registry import PluginRegistry DEFAULT_BUFFER_SIZE = 1000 @@ -70,6 +72,8 @@ LocalFileSeedReader(), DataFrameSeedReader(), ] +for plugin in PluginRegistry().get_plugins(PluginType.SEED_READER): + DEFAULT_SEED_READERS.append(plugin.impl_cls()) logger = logging.getLogger(__name__) diff --git a/src/data_designer/plugin_manager.py b/src/data_designer/plugin_manager.py index 19bca2ea..39aa4208 100644 --- a/src/data_designer/plugin_manager.py +++ b/src/data_designer/plugin_manager.py @@ -64,3 +64,15 @@ def inject_into_column_config_type_union(self, column_config_type: type[TypeAlia column_config_type, PluginType.COLUMN_GENERATOR ) return column_config_type + + def inject_into_seed_source_type_union(self, seed_source_type: type[TypeAlias]) -> type[TypeAlias]: + """Inject plugins into the seed source type. + + Args: + seed_source_type: The seed source type to inject plugins into. + + Returns: + The seed source type with plugins injected. + """ + seed_source_type = self._plugin_registry.add_plugin_types_to_union(seed_source_type, PluginType.SEED_READER) + return seed_source_type diff --git a/src/data_designer/plugins/plugin.py b/src/data_designer/plugins/plugin.py index c918bdb2..9a69ffed 100644 --- a/src/data_designer/plugins/plugin.py +++ b/src/data_designer/plugins/plugin.py @@ -19,11 +19,14 @@ class PluginType(str, Enum): COLUMN_GENERATOR = "column-generator" + SEED_READER = "seed-reader" @property def discriminator_field(self) -> str: if self == PluginType.COLUMN_GENERATOR: return "column_type" + elif self == PluginType.SEED_READER: + return "seed_type" else: raise ValueError(f"Invalid plugin type: {self.value}") diff --git a/src/data_designer/plugins/testing/utils.py b/src/data_designer/plugins/testing/utils.py index 36ba5679..462bc8b7 100644 --- a/src/data_designer/plugins/testing/utils.py +++ b/src/data_designer/plugins/testing/utils.py @@ -3,9 +3,16 @@ from data_designer.config.base import ConfigBase from data_designer.engine.configurable_task import ConfigurableTask -from data_designer.plugins.plugin import Plugin +from data_designer.engine.resources.seed_reader import SeedReader +from data_designer.plugins.plugin import Plugin, PluginType def assert_valid_plugin(plugin: Plugin) -> None: assert issubclass(plugin.config_cls, ConfigBase), "Plugin config class is not a subclass of ConfigBase" - assert issubclass(plugin.impl_cls, ConfigurableTask), "Plugin impl class is not a subclass of ConfigurableTask" + + if plugin.plugin_type == PluginType.COLUMN_GENERATOR: + assert issubclass(plugin.impl_cls, ConfigurableTask), ( + "Column generator plugin impl class must be a subclass of ConfigurableTask" + ) + elif plugin.plugin_type == PluginType.SEED_READER: + assert issubclass(plugin.impl_cls, SeedReader), "Seed reader plugin impl class must be a subclass of SeedReader"