From dbb13235ae3e4eeb6f3a1b40dc3e5760860127f8 Mon Sep 17 00:00:00 2001 From: Luis Carbonell Girones <3169189+lucargir@users.noreply.github.com> Date: Wed, 10 Sep 2025 17:12:01 +0200 Subject: [PATCH] fix: move get_asset_key_str to translator and define translator from config --- README.md | 32 ++++++- dagster_sqlmesh/asset.py | 9 +- dagster_sqlmesh/config.py | 57 ++++++++++++- dagster_sqlmesh/controller/dagster.py | 5 +- dagster_sqlmesh/resource.py | 20 +++-- dagster_sqlmesh/test_asset.py | 5 +- dagster_sqlmesh/test_config.py | 65 +++++++++++++++ dagster_sqlmesh/testing/context.py | 2 + dagster_sqlmesh/translator.py | 115 +++++++++++++++++++++++--- dagster_sqlmesh/utils.py | 10 --- sample/dagster_project/definitions.py | 7 +- 11 files changed, 278 insertions(+), 49 deletions(-) create mode 100644 dagster_sqlmesh/test_config.py diff --git a/README.md b/README.md index 62942f1..f1a4420 100644 --- a/README.md +++ b/README.md @@ -24,11 +24,11 @@ from dagster import ( AssetExecutionContext, Definitions, ) -from dagster_sqlmesh import sqlmesh_assets, SQLMeshContextConfig, SQLMeshResource, SQLMeshDagsterTranslator +from dagster_sqlmesh import sqlmesh_assets, SQLMeshContextConfig, SQLMeshResource sqlmesh_config = SQLMeshContextConfig(path="/home/foo/sqlmesh_project", gateway="name-of-your-gateway") -@sqlmesh_assets(environment="dev", config=sqlmesh_config, translator=SQLMeshDagsterTranslator()) +@sqlmesh_assets(environment="dev", config=sqlmesh_config) def sqlmesh_project(context: AssetExecutionContext, sqlmesh: SQLMeshResource): yield from sqlmesh.run(context) @@ -40,6 +40,34 @@ defs = Definitions( ) ``` +## Advanced Usage + +### Custom Translator + +The translator is centrally configured and ensures consistency across all components. You can customize the translator by specifying a custom class in the config: + +```python +from dagster_sqlmesh import SQLMeshDagsterTranslator + +class CustomSQLMeshTranslator(SQLMeshDagsterTranslator): + def get_asset_key_str(self, fqn: str) -> str: + # Custom asset key generation logic + return f"custom_prefix__{super().get_asset_key_str(fqn)}" + +# Configure with custom translator +sqlmesh_config = SQLMeshContextConfig( + path="/home/foo/sqlmesh_project", + gateway="name-of-your-gateway", + translator_class_name="your_module.CustomSQLMeshTranslator" +) + +@sqlmesh_assets(environment="dev", config=sqlmesh_config) +def sqlmesh_project(context: AssetExecutionContext, sqlmesh: SQLMeshResource): + yield from sqlmesh.run(context) +``` + +This approach ensures that both the `SQLMeshResource` and the `@sqlmesh_assets` decorator use the same translator instance, preventing inconsistencies. The translator is created using `config.get_translator()` and passed to all components that need it, including the `DagsterSQLMeshEventHandler`. + ## Contributing diff --git a/dagster_sqlmesh/asset.py b/dagster_sqlmesh/asset.py index cee12e9..da362ec 100644 --- a/dagster_sqlmesh/asset.py +++ b/dagster_sqlmesh/asset.py @@ -10,7 +10,6 @@ ContextFactory, DagsterSQLMeshController, ) -from dagster_sqlmesh.translator import SQLMeshDagsterTranslator from dagster_sqlmesh.types import SQLMeshMultiAssetOptions logger = logging.getLogger(__name__) @@ -20,7 +19,6 @@ def sqlmesh_to_multi_asset_options( environment: str, config: SQLMeshContextConfig, context_factory: ContextFactory[ContextCls] = lambda **kwargs: Context(**kwargs), - dagster_sqlmesh_translator: SQLMeshDagsterTranslator | None = None, ) -> SQLMeshMultiAssetOptions: """Converts sqlmesh project into a SQLMeshMultiAssetOptions object which is an intermediate representation of the SQLMesh project that can be used to @@ -28,12 +26,11 @@ def sqlmesh_to_multi_asset_options( controller = DagsterSQLMeshController.setup_with_config( config=config, context_factory=context_factory ) - if not dagster_sqlmesh_translator: - dagster_sqlmesh_translator = SQLMeshDagsterTranslator() + translator = config.get_translator() conversion = controller.to_asset_outs( environment, - translator=dagster_sqlmesh_translator, + translator=translator, ) return conversion @@ -74,7 +71,6 @@ def sqlmesh_assets( config: SQLMeshContextConfig, context_factory: ContextFactory[ContextCls] = lambda **kwargs: Context(**kwargs), name: str | None = None, - dagster_sqlmesh_translator: SQLMeshDagsterTranslator | None = None, compute_kind: str = "sqlmesh", op_tags: t.Mapping[str, t.Any] | None = None, required_resource_keys: set[str] | None = None, @@ -86,7 +82,6 @@ def sqlmesh_assets( environment=environment, config=config, context_factory=context_factory, - dagster_sqlmesh_translator=dagster_sqlmesh_translator, ) return sqlmesh_asset_from_multi_asset_options( diff --git a/dagster_sqlmesh/config.py b/dagster_sqlmesh/config.py index 7d60b2d..27b5620 100644 --- a/dagster_sqlmesh/config.py +++ b/dagster_sqlmesh/config.py @@ -1,18 +1,22 @@ +import inspect +import typing as t from dataclasses import dataclass from pathlib import Path -from typing import Any from dagster import Config from pydantic import Field from sqlmesh.core.config import Config as MeshConfig from sqlmesh.core.config.loader import load_configs +if t.TYPE_CHECKING: + from dagster_sqlmesh.translator import SQLMeshDagsterTranslator + @dataclass class ConfigOverride: - config_as_dict: dict[str, Any] + config_as_dict: dict[str, t.Any] - def dict(self) -> dict[str, Any]: + def dict(self) -> dict[str, t.Any]: return self.config_as_dict @@ -22,11 +26,56 @@ class SQLMeshContextConfig(Config): sqlmesh project define all the configuration in it's own directory which also ensures that configuration is consistent if running sqlmesh locally vs running via dagster. + + The config also manages the translator class used for converting SQLMesh + models to Dagster assets. You can specify a custom translator by setting + the translator_class_name field to the fully qualified class name. """ path: str gateway: str - config_override: dict[str, Any] | None = Field(default_factory=lambda: None) + config_override: dict[str, t.Any] | None = Field(default_factory=lambda: None) + translator_class_name: str = Field( + default="dagster_sqlmesh.translator.SQLMeshDagsterTranslator", + description="Fully qualified class name of the SQLMesh Dagster translator to use" + ) + + def get_translator(self) -> "SQLMeshDagsterTranslator": + """Get a translator instance using the configured class name. + + Imports and validates the translator class, then creates a new instance. + The class must inherit from SQLMeshDagsterTranslator. + + Returns: + SQLMeshDagsterTranslator: A new instance of the configured translator class + + Raises: + ValueError: If the imported object is not a class or does not inherit + from SQLMeshDagsterTranslator + """ + from importlib import import_module + + from dagster_sqlmesh.translator import SQLMeshDagsterTranslator + + module_name, class_name = self.translator_class_name.rsplit(".", 1) + module = import_module(module_name) + translator_class = getattr(module, class_name) + + # Validate that the imported class inherits from SQLMeshDagsterTranslator + if not inspect.isclass(translator_class): + raise ValueError( + f"'{self.translator_class_name}' is not a class. " + f"Expected a class that inherits from SQLMeshDagsterTranslator." + ) + + if not issubclass(translator_class, SQLMeshDagsterTranslator): + raise ValueError( + f"Translator class '{self.translator_class_name}' must inherit from " + f"SQLMeshDagsterTranslator. Found class that inherits from: " + f"{[base.__name__ for base in translator_class.__bases__]}" + ) + + return translator_class() @property def sqlmesh_config(self) -> MeshConfig: diff --git a/dagster_sqlmesh/controller/dagster.py b/dagster_sqlmesh/controller/dagster.py index dffff4d..c0e0343 100644 --- a/dagster_sqlmesh/controller/dagster.py +++ b/dagster_sqlmesh/controller/dagster.py @@ -12,7 +12,6 @@ SQLMeshModelDep, SQLMeshMultiAssetOptions, ) -from dagster_sqlmesh.utils import get_asset_key_str logger = logging.getLogger(__name__) @@ -53,7 +52,7 @@ def to_asset_outs( internal_asset_deps.add(dep_asset_key_str) else: - table = get_asset_key_str(dep.fqn) + table = translator.get_asset_key_str(dep.fqn) key = translator.get_asset_key( context, dep.fqn ).to_user_string() @@ -62,7 +61,7 @@ def to_asset_outs( # create an external dep deps_map[table] = translator.create_asset_dep(key=key) - model_key = get_asset_key_str(model.fqn) + model_key = translator.get_asset_key_str(model.fqn) asset_outs[model_key] = translator.create_asset_out( model_key=model_key, asset_key=asset_key_str, diff --git a/dagster_sqlmesh/resource.py b/dagster_sqlmesh/resource.py index bc5f037..4fe0297 100644 --- a/dagster_sqlmesh/resource.py +++ b/dagster_sqlmesh/resource.py @@ -25,7 +25,9 @@ ContextFactory, ) from dagster_sqlmesh.controller.dagster import DagsterSQLMeshController -from dagster_sqlmesh.utils import get_asset_key_str + +if t.TYPE_CHECKING: + from dagster_sqlmesh.translator import SQLMeshDagsterTranslator logger = logging.getLogger(__name__) @@ -329,6 +331,7 @@ def __init__( models_map: dict[str, Model], dag: DAG[t.Any], prefix: str, + translator: "SQLMeshDagsterTranslator", is_testing: bool = False, materializations_enabled: bool = True, ) -> None: @@ -341,6 +344,7 @@ def __init__( models_map: A mapping of model names to their SQLMesh model instances. dag: The directed acyclic graph representing the SQLMesh models. prefix: A prefix to use for all asset keys generated by this handler. + translator: The SQLMesh Dagster translator instance. is_testing: Whether the handler is being used in a testing context. materializations_enabled: Whether the handler is to generate materializations, this should be disabled if you with to run a @@ -351,6 +355,7 @@ def __init__( self._prefix = prefix self._context = context self._logger = context.log + self._translator = translator self._tracker = MaterializationTracker( sorted_dag=dag.sorted[:], logger=self._logger ) @@ -382,7 +387,7 @@ def notify_success( # If the model is not in models_map, we can skip any notification if model: # Passing model.fqn to get internal unique asset key - output_key = get_asset_key_str(model.fqn) + output_key = self._translator.get_asset_key_str(model.fqn) if self._is_testing: asset_key = dg.AssetKey(["testing", output_key]) self._logger.warning( @@ -491,7 +496,7 @@ def report_event(self, event: console.ConsoleEvent) -> None: log_context.info( "Snapshot progress complete", { - "asset_key": get_asset_key_str(snapshot.model.name), + "asset_key": self._translator.get_asset_key_str(snapshot.model.name), }, ) self._tracker.update_run(snapshot) @@ -499,7 +504,7 @@ def report_event(self, event: console.ConsoleEvent) -> None: log_context.info( "Snapshot progress update", { - "asset_key": get_asset_key_str(snapshot.model.name), + "asset_key": self._translator.get_asset_key_str(snapshot.model.name), "progress": f"{done}/{expected}", "duration_ms": duration_ms, }, @@ -687,11 +692,13 @@ def create_event_handler( is_testing: bool, materializations_enabled: bool, ) -> DagsterSQLMeshEventHandler: + translator = self.config.get_translator() return DagsterSQLMeshEventHandler( context=context, dag=dag, models_map=models_map, prefix=prefix, + translator=translator, is_testing=is_testing, materializations_enabled=materializations_enabled, ) @@ -701,7 +708,7 @@ def _get_selected_models_from_context( ) -> tuple[set[str], dict[str, Model], list[str] | None]: models_map = models.copy() try: - selected_output_names = set(context.selected_output_names) + selected_output_names = set(context.op_execution_context.selected_output_names) except (DagsterInvalidPropertyError, AttributeError) as e: # Special case for direct execution context when testing. This is related to: # https://github.com/dagster-io/dagster/issues/23633 @@ -711,10 +718,11 @@ def _get_selected_models_from_context( else: raise e + translator = self.config.get_translator() select_models: list[str] = [] models_map = {} for key, model in models.items(): - if get_asset_key_str(model.fqn) in selected_output_names: + if translator.get_asset_key_str(model.fqn) in selected_output_names: models_map[key] = model select_models.append(model.name) return ( diff --git a/dagster_sqlmesh/test_asset.py b/dagster_sqlmesh/test_asset.py index a41f657..b45eae8 100644 --- a/dagster_sqlmesh/test_asset.py +++ b/dagster_sqlmesh/test_asset.py @@ -1,10 +1,9 @@ -from dagster_sqlmesh.asset import SQLMeshDagsterTranslator from dagster_sqlmesh.conftest import SQLMeshTestContext def test_sqlmesh_context_to_asset_outs(sample_sqlmesh_test_context: SQLMeshTestContext): controller = sample_sqlmesh_test_context.create_controller() - translator = SQLMeshDagsterTranslator() - outs = controller.to_asset_outs("dev", translator) + translator = sample_sqlmesh_test_context.context_config.get_translator() + outs = controller.to_asset_outs("dev", translator=translator) assert len(list(outs.deps)) == 1 assert len(outs.outs) == 10 diff --git a/dagster_sqlmesh/test_config.py b/dagster_sqlmesh/test_config.py new file mode 100644 index 0000000..d031b8b --- /dev/null +++ b/dagster_sqlmesh/test_config.py @@ -0,0 +1,65 @@ +import pytest + +from dagster_sqlmesh.config import SQLMeshContextConfig +from dagster_sqlmesh.translator import SQLMeshDagsterTranslator + + +def test_get_translator_with_valid_class(): + """Test that get_translator works with the default translator class.""" + config = SQLMeshContextConfig(path="/tmp/test", gateway="local") + translator = config.get_translator() + assert isinstance(translator, SQLMeshDagsterTranslator) + + +def test_get_translator_with_non_class(): + """Test that get_translator raises ValueError when pointing to a non-class.""" + config = SQLMeshContextConfig( + path="/tmp/test", + gateway="local", + translator_class_name="sys.version" + ) + + with pytest.raises(ValueError, match="is not a class"): + config.get_translator() + + +def test_get_translator_with_invalid_inheritance(): + """Test that get_translator raises ValueError when class doesn't inherit from SQLMeshDagsterTranslator.""" + config = SQLMeshContextConfig( + path="/tmp/test", + gateway="local", + translator_class_name="builtins.dict" + ) + + with pytest.raises(ValueError, match="must inherit from SQLMeshDagsterTranslator"): + config.get_translator() + + +def test_get_translator_with_nonexistent_class(): + """Test that get_translator raises AttributeError when class doesn't exist.""" + config = SQLMeshContextConfig( + path="/tmp/test", + gateway="local", + translator_class_name="dagster_sqlmesh.translator.NonexistentClass" + ) + + with pytest.raises(AttributeError): + config.get_translator() + + +class MockValidTranslator(SQLMeshDagsterTranslator): + """A mock translator for testing custom inheritance.""" + pass + + +def test_get_translator_with_valid_custom_class(): + """Test that get_translator works with custom classes that inherit from SQLMeshDagsterTranslator.""" + config = SQLMeshContextConfig( + path="/tmp/test", + gateway="local", + translator_class_name=f"{__name__}.MockValidTranslator" + ) + + translator = config.get_translator() + assert isinstance(translator, SQLMeshDagsterTranslator) + assert isinstance(translator, MockValidTranslator) diff --git a/dagster_sqlmesh/testing/context.py b/dagster_sqlmesh/testing/context.py index dce47fa..2e07dcd 100644 --- a/dagster_sqlmesh/testing/context.py +++ b/dagster_sqlmesh/testing/context.py @@ -81,6 +81,8 @@ def create_event_handler(self, *args: t.Any, **kwargs: t.Any) -> DagsterSQLMeshE Returns: DagsterSQLMeshEventHandler: The created event handler. """ + # Ensure translator is passed to the event handler factory + kwargs['translator'] = self.config.get_translator() return self._event_handler_factory(*args, **kwargs) diff --git a/dagster_sqlmesh/translator.py b/dagster_sqlmesh/translator.py index 648dbbf..7b3d879 100644 --- a/dagster_sqlmesh/translator.py +++ b/dagster_sqlmesh/translator.py @@ -45,41 +45,97 @@ def to_asset_dep(self) -> AssetDep: class SQLMeshDagsterTranslator: - """Translates sqlmesh objects for dagster""" + """Translates SQLMesh objects for Dagster. + + This class provides methods to convert SQLMesh models and metadata into + Dagster-compatible formats. It can be subclassed to customize the translation + behavior, such as changing asset key generation or grouping logic. + + The translator is used throughout the dagster-sqlmesh integration, including + in the DagsterSQLMeshEventHandler and asset generation process. + """ def get_asset_key(self, context: Context, fqn: str) -> AssetKey: - """Given the sqlmesh context and a model return the asset key""" + """Get the Dagster AssetKey for a SQLMesh model. + + Args: + context: The SQLMesh context (unused in default implementation) + fqn: Fully qualified name of the SQLMesh model + + Returns: + AssetKey: The Dagster asset key for this model + """ path = self.get_asset_key_name(fqn) return AssetKey(path) def get_asset_key_name(self, fqn: str) -> Sequence[str]: + """Parse a fully qualified name into asset key components. + + Args: + fqn: Fully qualified name of the SQLMesh model (e.g., "catalog.schema.table") + + Returns: + Sequence[str]: Asset key components [catalog, schema, table] + """ table = exp.to_table(fqn) asset_key_name = [table.catalog, table.db, table.name] return asset_key_name def get_group_name(self, context: Context, model: Model) -> str: + """Get the Dagster asset group name for a SQLMesh model. + + Args: + context: The SQLMesh context (unused in default implementation) + model: The SQLMesh model + + Returns: + str: The asset group name (defaults to the schema/database name) + """ path = self.get_asset_key_name(model.fqn) return path[-2] def get_context_dialect(self, context: Context) -> str: + """Get the SQL dialect used by the SQLMesh context. + + Args: + context: The SQLMesh context + + Returns: + str: The SQL dialect name (e.g., "duckdb", "postgres", etc.) + """ return context.engine_adapter.dialect def create_asset_dep(self, *, key: str, **kwargs: t.Any) -> ConvertibleToAssetDep: - """Create an object that resolves to an AssetDep - - Most users of this library will not need to use this method, it is - primarily the way we enable cacheable assets from dagster-sqlmesh. + """Create an object that resolves to an AssetDep. + + This creates an intermediate representation that can be converted to a + Dagster AssetDep. Most users will not need to use this method directly. + + Args: + key: The asset key string for the dependency + **kwargs: Additional arguments to pass to the AssetDep + + Returns: + ConvertibleToAssetDep: An object that can be converted to an AssetDep """ return IntermediateAssetDep(key=key, kwargs=kwargs) def create_asset_out( self, *, model_key: str, asset_key: str, **kwargs: t.Any ) -> ConvertibleToAssetOut: - """Create an object that resolves to an AssetOut - - Most users of this library will not need to use this method, it is - primarily the way we enable cacheable assets from dagster-sqlmesh. + """Create an object that resolves to an AssetOut. + + This creates an intermediate representation that can be converted to a + Dagster AssetOut. Most users will not need to use this method directly. + + Args: + model_key: Internal key for the SQLMesh model + asset_key: The asset key string for the output + **kwargs: Additional arguments including tags, group_name, kinds, etc. + + Returns: + ConvertibleToAssetOut: An object that can be converted to an AssetOut """ return IntermediateAssetOut( model_key=model_key, @@ -91,6 +147,41 @@ def create_asset_out( kwargs=kwargs, ) + def get_asset_key_str(self, fqn: str) -> str: + """Get asset key string with sqlmesh prefix for internal mapping. + + This creates an internal identifier used to map outputs and dependencies + within the dagster-sqlmesh integration. It will not affect the actual + AssetKeys that users see. The result contains only alphanumeric characters + and underscores, making it safe for internal usage. + + Args: + fqn: Fully qualified name of the SQLMesh model + + Returns: + str: Internal asset key string with "sqlmesh__" prefix + """ + table = exp.to_table(fqn) + asset_key_name = [table.catalog, table.db, table.name] + + return "sqlmesh__" + "_".join(asset_key_name) + def get_tags(self, context: Context, model: Model) -> dict[str, str]: - """Given the sqlmesh context and a model return the tags for that model""" - return {k: "true" for k in model.tags} + """Get Dagster asset tags for a SQLMesh model. + + Args: + context: The SQLMesh context (unused in default implementation) + model: The SQLMesh model + + Returns: + dict[str, str]: Dictionary of tags to apply to the Dagster asset. + Default implementation converts SQLMesh model tags to + empty string values, which causes the Dagster UI to + render them as labels rather than key-value pairs. + + Note: + Tags must contain only strings as keys and values. The Dagster UI + will render tags with empty string values as "labels" rather than + key-value pairs. + """ + return {k: "" for k in model.tags} diff --git a/dagster_sqlmesh/utils.py b/dagster_sqlmesh/utils.py index 96b2dba..5d2cf8c 100644 --- a/dagster_sqlmesh/utils.py +++ b/dagster_sqlmesh/utils.py @@ -1,16 +1,6 @@ -from sqlglot import exp from sqlmesh.core.snapshot import SnapshotId -def get_asset_key_str(fqn: str) -> str: - # This is an internal identifier used to map outputs and dependencies - # it will not affect the existing AssetKeys - # Only alphanumeric characters and underscores - table = exp.to_table(fqn) - asset_key_name = [table.catalog, table.db, table.name] - - return "sqlmesh__" + "_".join(asset_key_name) - def snapshot_id_to_model_name(snapshot_id: SnapshotId) -> str: """Convert a SnapshotId to its model name. diff --git a/sample/dagster_project/definitions.py b/sample/dagster_project/definitions.py index 8510830..8980a78 100644 --- a/sample/dagster_project/definitions.py +++ b/sample/dagster_project/definitions.py @@ -27,7 +27,11 @@ SQLMESH_CACHE_PATH = os.path.join(SQLMESH_PROJECT_PATH, ".cache") DUCKDB_PATH = os.path.join(CURR_DIR, "../../db.db") -sqlmesh_config = SQLMeshContextConfig(path=SQLMESH_PROJECT_PATH, gateway="local") +sqlmesh_config = SQLMeshContextConfig( + path=SQLMESH_PROJECT_PATH, + gateway="local", + translator_class_name="definitions.RewrittenSQLMeshTranslator" +) class RewrittenSQLMeshTranslator(SQLMeshDagsterTranslator): @@ -101,7 +105,6 @@ def post_full_model() -> pl.DataFrame: environment="dev", config=sqlmesh_config, enabled_subsetting=True, - dagster_sqlmesh_translator=RewrittenSQLMeshTranslator(), ) def sqlmesh_project( context: AssetExecutionContext, sqlmesh: SQLMeshResource