Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,30 @@ jobs:
run: |
uv run pytest -v --cov=data_designer --cov-report=term-missing --cov-report=xml --cov-fail-under=90

test-e2e:
name: End to end test (Python ${{ matrix.python-version }} on ${{ matrix.os }})
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-latest]
python-version: ["3.10", "3.11", "3.12", "3.13"]

steps:
- name: Checkout code
uses: actions/checkout@v4

- name: Install uv
uses: astral-sh/setup-uv@v5
with:
version: "latest"
python-version: ${{ matrix.python-version }}
enable-cache: true

- name: Run e2e tests
run: |
make test-e2e

lint:
name: Lint and Format Check
runs-on: ubuntu-latest
Expand Down
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,5 @@ docs/notebooks/
docs/notebook_source/*.ipynb
docs/notebook_source/*.csv
docs/**/artifacts/

e2e_tests/uv.lock
14 changes: 10 additions & 4 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -63,28 +63,34 @@ check-all-fix: format lint-fix

format:
@echo "📐 Formatting code with ruff..."
uv run ruff format src/ tests/ scripts/ --exclude '**/src/data_designer/_version.py'
uv run ruff format src/ tests/ scripts/ e2e_tests/ --exclude '**/src/data_designer/_version.py'
@echo "✅ Formatting complete!"

format-check:
@echo "📐 Checking code formatting with ruff..."
uv run ruff format --check src/ tests/ scripts/ --exclude '**/src/data_designer/_version.py'
uv run ruff format --check src/ tests/ scripts/ e2e_tests/ --exclude '**/src/data_designer/_version.py'
@echo "✅ Formatting check complete! Run 'make format' to auto-fix issues."

lint:
@echo "🔍 Linting code with ruff..."
uv run ruff check --output-format=full src/ tests/ scripts/ --exclude '**/src/data_designer/_version.py'
uv run ruff check --output-format=full src/ tests/ scripts/ e2e_tests/ --exclude '**/src/data_designer/_version.py'
@echo "✅ Linting complete! Run 'make lint-fix' to auto-fix issues."

lint-fix:
@echo "🔍 Fixing linting issues with ruff..."
uv run ruff check --fix src/ tests/ scripts/ --exclude '**/src/data_designer/_version.py'
uv run ruff check --fix src/ tests/ scripts/ e2e_tests/ --exclude '**/src/data_designer/_version.py'
@echo "✅ Linting with autofix complete!"

test:
@echo "🧪 Running unit tests..."
uv run --group dev pytest

test-e2e:
@echo "🧹 Cleaning e2e test environment..."
rm -rf e2e_tests/uv.lock e2e_tests/.pycache e2e_tests/.venv
@echo "🧪 Running e2e tests..."
uv run --no-cache --refresh --directory e2e_tests pytest -s
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this need --group dev

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe uv run automatically uses the dev group? Seems to be working fine as-is, but I'm not opposed to adding it explicitly


convert-execute-notebooks:
@echo "📓 Converting Python tutorials to notebooks and executing..."
@mkdir -p docs/notebooks
Expand Down
38 changes: 38 additions & 0 deletions e2e_tests/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
[project]
name = "data-designer-e2e-tests"
version = "0.0.1"
requires-python = ">=3.10"

dependencies = [
"data-designer",
]

[tool.uv.sources]
data-designer = { path = "../" }

[dependency-groups]
dev = [
"pytest>=8.3.3,<9",
]

[project.entry-points."data_designer.plugins"]
demo-column-generator = "data_designer_e2e_tests.plugins.column_generator.plugin:column_generator_plugin"
demo-seed-reader = "data_designer_e2e_tests.plugins.seed_reader.plugin:seed_reader_plugin"

[tool.pytest.ini_options]
testpaths = ["tests"]
env = [
# ensure plugins are enabled
"DISABLE_DATA_DESIGNER_PLUGINS=false",
]

[tool.uv]
package = true
required-version = ">=0.7.10"

[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"

[tool.hatch.build.targets.wheel]
packages = ["src/data_designer_e2e_tests"]
3 changes: 3 additions & 0 deletions e2e_tests/src/data_designer_e2e_tests/plugins/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

from typing import Literal

from data_designer.config.column_configs import SingleColumnConfig


class DemoColumnGeneratorConfig(SingleColumnConfig):
column_type: Literal["demo-column-generator"] = "demo-column-generator"

text: str
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import pandas as pd

from data_designer.engine.column_generators.generators.base import (
ColumnGenerator,
GenerationStrategy,
GeneratorMetadata,
)
from data_designer_e2e_tests.plugins.column_generator.config import DemoColumnGeneratorConfig


class DemoColumnGeneratorImpl(ColumnGenerator[DemoColumnGeneratorConfig]):
@staticmethod
def metadata() -> GeneratorMetadata:
return GeneratorMetadata(
name="demo-column-generator",
description="Shouts at you",
generation_strategy=GenerationStrategy.FULL_COLUMN,
)

def generate(self, data: pd.DataFrame) -> pd.DataFrame:
data[self.config.name] = self.config.text.upper()

return data
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

from data_designer.plugins.plugin import Plugin, PluginType

column_generator_plugin = Plugin(
config_qualified_name="data_designer_e2e_tests.plugins.column_generator.config.DemoColumnGeneratorConfig",
impl_qualified_name="data_designer_e2e_tests.plugins.column_generator.impl.DemoColumnGeneratorImpl",
plugin_type=PluginType.COLUMN_GENERATOR,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

from typing import Literal

from data_designer.config.seed_source import SeedSource


class DemoSeedSource(SeedSource):
seed_type: Literal["demo-seed-reader"] = "demo-seed-reader"

directory: str
filename: str
15 changes: 15 additions & 0 deletions e2e_tests/src/data_designer_e2e_tests/plugins/seed_reader/impl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import duckdb

from data_designer.engine.resources.seed_reader import SeedReader
from data_designer_e2e_tests.plugins.seed_reader.config import DemoSeedSource


class DemoSeedReader(SeedReader[DemoSeedSource]):
def create_duckdb_connection(self) -> duckdb.DuckDBPyConnection:
return duckdb.connect()

def get_dataset_uri(self) -> str:
return f"{self.source.directory}/{self.source.filename}"
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

from data_designer.plugins.plugin import Plugin, PluginType

seed_reader_plugin = Plugin(
config_qualified_name="data_designer_e2e_tests.plugins.seed_reader.config.DemoSeedSource",
impl_qualified_name="data_designer_e2e_tests.plugins.seed_reader.impl.DemoSeedReader",
plugin_type=PluginType.SEED_READER,
)
73 changes: 73 additions & 0 deletions e2e_tests/tests/test_e2e.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

from pathlib import Path

from data_designer.essentials import (
CategorySamplerParams,
DataDesigner,
DataDesignerConfigBuilder,
ExpressionColumnConfig,
SamplerColumnConfig,
SamplerType,
)
from data_designer_e2e_tests.plugins.column_generator.config import DemoColumnGeneratorConfig
from data_designer_e2e_tests.plugins.seed_reader.config import DemoSeedSource


def test_column_generator_plugin():
data_designer = DataDesigner()

config_builder = DataDesignerConfigBuilder()
# This sampler column is necessary as a temporary workaround to https://github.com/NVIDIA-NeMo/DataDesigner/issues/4
config_builder.add_column(
SamplerColumnConfig(
name="irrelevant",
sampler_type=SamplerType.CATEGORY,
params=CategorySamplerParams(values=["irrelevant"]),
)
)
config_builder.add_column(
DemoColumnGeneratorConfig(
name="upper",
text="hello world",
)
)

preview = data_designer.preview(config_builder)
capitalized = set(preview.dataset["upper"].values)

assert capitalized == {"HELLO WORLD"}


def test_seed_reader_plugin():
current_dir = Path(__file__).parent

data_designer = DataDesigner()

config_builder = DataDesignerConfigBuilder()
config_builder.with_seed_dataset(
DemoSeedSource(
directory=str(current_dir),
filename="test_seed.csv",
)
)
# This sampler column is necessary as a temporary workaround to https://github.com/NVIDIA-NeMo/DataDesigner/issues/4
config_builder.add_column(
SamplerColumnConfig(
name="irrelevant",
sampler_type=SamplerType.CATEGORY,
params=CategorySamplerParams(values=["irrelevant"]),
)
)
config_builder.add_column(
ExpressionColumnConfig(
name="full_name",
expr="{{ first_name }} + {{ last_name }}",
)
)

preview = data_designer.preview(config_builder)
full_names = set(preview.dataset["full_name"].values)

assert full_names == {"John + Coltrane", "Miles + Davis", "Bill + Evans"}
4 changes: 4 additions & 0 deletions e2e_tests/tests/test_seed.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
first_name,last_name
John,Coltrane
Miles,Davis
Bill,Evans
Comment on lines +2 to +4
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dream trio?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing bass and drums to fill out the quintet 😂 In all seriousness, I'm happy to change this to anything else if we want to have a "house style" for dummy data in tests like this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Astronomy and jazz legends" sounds like a sweet house style to me!

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ ignore = [
]

[tool.ruff.lint.isort]
known-first-party = ["data_designer"]
known-first-party = ["data_designer", "data_designer_e2e_tests"]

[tool.ruff.lint.flake8-tidy-imports]
ban-relative-imports = "all"
Expand Down
2 changes: 1 addition & 1 deletion scripts/update_license_headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def main(path: Path, check_only: bool = False) -> tuple[int, int, int, list[Path
total_updated = 0
total_skipped = 0

for folder in ["src", "tests", "scripts"]:
for folder in ["src", "tests", "scripts", "e2e_tests"]:
folder_path = repo_path / folder
if not folder_path.exists():
continue
Expand Down
5 changes: 3 additions & 2 deletions src/data_designer/config/config_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@
SamplingStrategy,
SeedConfig,
)
from data_designer.config.seed_source import DataFrameSeedSource, SeedSource
from data_designer.config.seed_source import DataFrameSeedSource
from data_designer.config.seed_source_types import SeedSourceT
from data_designer.config.utils.constants import DEFAULT_REPR_HTML_STYLE, REPR_HTML_TEMPLATE
from data_designer.config.utils.info import ConfigBuilderInfo
from data_designer.config.utils.io_helpers import serialize_data, smart_load_yaml
Expand Down Expand Up @@ -474,7 +475,7 @@ def num_columns_of_type(self, column_type: DataDesignerColumnType) -> int:

def with_seed_dataset(
self,
seed_source: SeedSource,
seed_source: SeedSourceT,
*,
sampling_strategy: SamplingStrategy = SamplingStrategy.ORDERED,
selection_strategy: IndexRange | PartitionBlock | None = None,
Expand Down
2 changes: 1 addition & 1 deletion src/data_designer/config/seed.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing_extensions import Self

from data_designer.config.base import ConfigBase
from data_designer.config.seed_source import SeedSourceT
from data_designer.config.seed_source_types import SeedSourceT


class SamplingStrategy(str, Enum):
Expand Down
8 changes: 1 addition & 7 deletions src/data_designer/config/seed_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

from abc import ABC
from typing import Annotated, Literal
from typing import Literal

import pandas as pd
from pydantic import BaseModel, ConfigDict, Field, field_validator
Expand Down Expand Up @@ -76,9 +76,3 @@ class DataFrameSeedSource(SeedSource):
"you must use `LocalFileSeedSource` instead, since DataFrame objects are not serializable."
),
)


SeedSourceT = Annotated[
LocalFileSeedSource | HuggingFaceSeedSource | DataFrameSeedSource,
Field(discriminator="seed_type"),
]
17 changes: 17 additions & 0 deletions src/data_designer/config/seed_source_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

from typing import Annotated

from pydantic import Field
from typing_extensions import TypeAlias

from data_designer.config.seed_source import DataFrameSeedSource, HuggingFaceSeedSource, LocalFileSeedSource
from data_designer.plugin_manager import PluginManager

plugin_manager = PluginManager()

_SeedSourceT: TypeAlias = LocalFileSeedSource | HuggingFaceSeedSource | DataFrameSeedSource
_SeedSourceT = plugin_manager.inject_into_seed_source_type_union(_SeedSourceT)

SeedSourceT = Annotated[_SeedSourceT, Field(discriminator="seed_type")]
Comment on lines +12 to +17
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This type alias needs to live in a separate module instead of config/seed_source.py because:

  • The config classes for seed dataset plugins are expected to inherit from SeedSource
  • Since config code is always save to import/load for plugins, we do so as part of pydantic validation of each Plugin instance to validate the config discriminator is set properly
  • If the code highlighted here is defined in the same module as SeedSource, we run into a deadlock: the plugin registry is actively discovering/loading plugins and so has grabbed a lock, and to instantiate and load a seed dataset plugin we need to load the module containing the base SeedSource class, and when doing so we hit the PluginManager() call, which initializes a (singleton) registry, which tries to grab the lock that is already taken.

We didn't see this before* for column generator plugins because the base classes to inherit from are already defined in a different module than the ColumnConfigT type alias union (config.column_configs and config.column_types respectively)

*or maybe we did see this before and these two modules were set up the way they are now for this very reason; I don't know the history

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

*or maybe we did see this before and these two modules were set up the way they are now for this very reason; I don't know the history

haha yes indeed that's why we split them up

4 changes: 4 additions & 0 deletions src/data_designer/interface/data_designer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@
)
from data_designer.interface.results import DatasetCreationResults
from data_designer.logging import RandomEmoji
from data_designer.plugins.plugin import PluginType
from data_designer.plugins.registry import PluginRegistry

DEFAULT_BUFFER_SIZE = 1000

Expand All @@ -70,6 +72,8 @@
LocalFileSeedReader(),
DataFrameSeedReader(),
]
for plugin in PluginRegistry().get_plugins(PluginType.SEED_READER):
DEFAULT_SEED_READERS.append(plugin.impl_cls())

logger = logging.getLogger(__name__)

Expand Down
Loading