Skip to content

Commit 7d7fbfe

Browse files
committed
Updates from_functions and from_molues to use the same base "compile"
function This makes it so that we reduce code paths -- we'll eventually use compile for everything. This also allows us to push functions up to the user level, enabling `with_functions`
1 parent a288e84 commit 7d7fbfe

File tree

4 files changed

+69
-21
lines changed

4 files changed

+69
-21
lines changed

hamilton/async_driver.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import time
66
import typing
77
import uuid
8-
from types import ModuleType
8+
from types import FunctionType, ModuleType
99
from typing import Any, Dict, Optional, Tuple
1010

1111
import hamilton.lifecycle.base as lifecycle_base
@@ -199,6 +199,7 @@ def __init__(
199199
result_builder: Optional[base.ResultMixin] = None,
200200
adapters: typing.List[lifecycle.LifecycleAdapter] = None,
201201
allow_module_overrides: bool = False,
202+
functions: typing.List[FunctionType],
202203
):
203204
"""Instantiates an asynchronous driver.
204205
@@ -249,6 +250,7 @@ def __init__(
249250
*async_adapters, # note async adapters will not be called during synchronous execution -- this is for access later
250251
],
251252
allow_module_overrides=allow_module_overrides,
253+
functions=functions,
252254
)
253255
self.initialized = False
254256

hamilton/driver.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import typing
1414
import uuid
1515
from datetime import datetime
16-
from types import ModuleType
16+
from types import FunctionType, ModuleType
1717
from typing import (
1818
Any,
1919
Callable,
@@ -402,6 +402,7 @@ def __init__(
402402
self,
403403
config: Dict[str, Any],
404404
*modules: ModuleType,
405+
functions: List[FunctionType] = None,
405406
adapter: Optional[
406407
Union[lifecycle_base.LifecycleAdapter, List[lifecycle_base.LifecycleAdapter]]
407408
] = None,
@@ -435,13 +436,15 @@ def __init__(
435436
if adapter.does_hook("pre_do_anything", is_async=False):
436437
adapter.call_all_lifecycle_hooks_sync("pre_do_anything")
437438
error = None
439+
self.graph_functions = functions if functions is not None else []
438440
self.graph_modules = modules
439441
try:
440-
self.graph = graph.FunctionGraph.from_modules(
441-
*modules,
442+
self.graph = graph.FunctionGraph.compile(
443+
modules=list(modules),
444+
functions=functions,
442445
config=config,
443446
adapter=adapter,
444-
allow_module_overrides=allow_module_overrides,
447+
allow_node_overrides=allow_module_overrides,
445448
)
446449
if _materializers:
447450
materializer_factories, extractor_factories = self._process_materializers(
@@ -1866,6 +1869,7 @@ def __init__(self):
18661869
# common fields
18671870
self.config = {}
18681871
self.modules = []
1872+
self.functions = []
18691873
self.materializers = []
18701874

18711875
# Allow later modules to override nodes of the same name
@@ -1927,6 +1931,17 @@ def with_modules(self, *modules: ModuleType) -> "Builder":
19271931
self.modules.extend(modules)
19281932
return self
19291933

1934+
def with_functions(self, *functions: FunctionType) -> "Builder":
1935+
"""Adds the specified functions to the list.
1936+
This can be called multiple times. If you have allow_module_overrides
1937+
set this will enabl overwriting modules or previously added functions.
1938+
1939+
:param functions:
1940+
:return: self
1941+
"""
1942+
self.functions.extend(functions)
1943+
return self
1944+
19301945
def with_adapter(self, adapter: base.HamiltonGraphAdapter) -> "Builder":
19311946
"""Sets the adapter to use.
19321947
@@ -2168,6 +2183,7 @@ def build(self) -> Driver:
21682183
_graph_executor=graph_executor,
21692184
_use_legacy_adapter=False,
21702185
allow_module_overrides=self._allow_module_overrides,
2186+
functions=self.functions,
21712187
)
21722188

21732189
def copy(self) -> "Builder":

hamilton/graph.py

Lines changed: 45 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import pathlib
1414
import uuid
1515
from enum import Enum
16-
from types import ModuleType
16+
from types import FunctionType, ModuleType
1717
from typing import Any, Callable, Collection, Dict, FrozenSet, List, Optional, Set, Tuple, Type
1818

1919
import hamilton.lifecycle.base as lifecycle_base
@@ -142,17 +142,18 @@ def update_dependencies(
142142
return nodes
143143

144144

145-
def create_function_graph(
145+
def compile_to_nodes(
146146
*functions: List[Tuple[str, Callable]],
147147
config: Dict[str, Any],
148148
adapter: lifecycle_base.LifecycleAdapterSet = None,
149149
fg: Optional["FunctionGraph"] = None,
150-
allow_module_overrides: bool = False,
150+
allow_node_leveL_overrides: bool = False,
151151
) -> Dict[str, node.Node]:
152152
"""Creates a graph of all available functions & their dependencies.
153153
:param modules: A set of modules over which one wants to compute the function graph
154154
:param config: Dictionary that we will inspect to get values from in building the function graph.
155155
:param adapter: The adapter that adapts our node type checking based on the context.
156+
:param allow_node_leveL_overrides: Whether or not to allow node names to override each other
156157
:return: list of nodes in the graph.
157158
If it needs to be more complicated, we'll return an actual networkx graph and get all the rest of the logic for free
158159
"""
@@ -170,7 +171,7 @@ def create_function_graph(
170171
for n in fm_base.resolve_nodes(f, config):
171172
if n.name in config:
172173
continue # This makes sure we overwrite things if they're in the config...
173-
if n.name in nodes and not allow_module_overrides:
174+
if n.name in nodes and not allow_node_leveL_overrides:
174175
raise ValueError(
175176
f"Cannot define function {n.name} more than once."
176177
f" Already defined by function {f}"
@@ -713,13 +714,43 @@ def __init__(
713714
self.nodes = nodes
714715
self.adapter = adapter
715716

717+
@staticmethod
718+
def compile(
719+
modules: List[ModuleType],
720+
functions: List[FunctionType],
721+
config: Dict[str, Any],
722+
adapter: lifecycle_base.LifecycleAdapterSet = None,
723+
allow_node_overrides: bool = False,
724+
) -> "FunctionGraph":
725+
"""Base level static function for compiling a function graph. Note
726+
that this can both use functions (E.G. passing them directly) and modules
727+
(passing them in and crawling.
728+
729+
:param modules: Modules to use
730+
:param functions: Functions to use
731+
:param config: Config to use for setting up the DAG
732+
:param adapter: Adapter to use for node resolution
733+
:param allow_node_overrides: Whether or not to allow node level overrides.
734+
:return: The compiled function graph
735+
"""
736+
module_functions = sum([find_functions(module) for module in modules], [])
737+
nodes = compile_to_nodes(
738+
*module_functions,
739+
*functions,
740+
config=config,
741+
adapter=adapter,
742+
allow_node_leveL_overrides=allow_node_overrides,
743+
)
744+
return FunctionGraph(nodes, config, adapter)
745+
716746
@staticmethod
717747
def from_modules(
718748
*modules: ModuleType,
719749
config: Dict[str, Any],
720750
adapter: lifecycle_base.LifecycleAdapterSet = None,
721751
allow_module_overrides: bool = False,
722-
):
752+
additional_functions: List[FunctionType],
753+
) -> "FunctionGraph":
723754
"""Initializes a function graph from the specified modules. Note that this was the old
724755
way we constructed FunctionGraph -- this is not a public-facing API, so we replaced it
725756
with a constructor that takes in nodes directly. If you hacked in something using
@@ -732,28 +763,28 @@ def from_modules(
732763
:return: a function graph.
733764
"""
734765

735-
functions = sum([find_functions(module) for module in modules], [])
736-
return FunctionGraph.from_functions(
737-
*functions,
766+
return FunctionGraph.compile(
767+
modules=modules,
768+
functions=[],
738769
config=config,
739770
adapter=adapter,
740-
allow_module_overrides=allow_module_overrides,
771+
allow_node_overrides=allow_module_overrides,
741772
)
742773

743774
@staticmethod
744775
def from_functions(
745-
*functions,
776+
*functions: FunctionType,
746777
config: Dict[str, Any],
747778
adapter: lifecycle_base.LifecycleAdapterSet = None,
748779
allow_module_overrides: bool = False,
749780
) -> "FunctionGraph":
750-
nodes = create_function_graph(
751-
*functions,
781+
return FunctionGraph.compile(
782+
modules=[],
783+
functions=functions,
752784
config=config,
753785
adapter=adapter,
754-
allow_module_overrides=allow_module_overrides,
786+
allow_node_overrides=allow_module_overrides,
755787
)
756-
return FunctionGraph(nodes, config, adapter)
757788

758789
def with_nodes(self, nodes: Dict[str, Node]) -> "FunctionGraph":
759790
"""Creates a new function graph with the additional specified nodes.

pyproject.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,7 @@ docs = [
5757
"diskcache",
5858
# required for all the plugins
5959
"dlt",
60-
# furo -- install from main for now until the next release is out:
61-
"furo @ git+https://github.com/pradyunsg/furo@main",
60+
"furo",
6261
"gitpython", # Required for parsing git info for generation of data-adapter docs
6362
"grpcio-status",
6463
"lightgbm",

0 commit comments

Comments
 (0)