Skip to content

Commit 250bae6

Browse files
committed
Enable retries for dynamic pipeline function execution
1 parent d2071ec commit 250bae6

File tree

7 files changed

+367
-62
lines changed

7 files changed

+367
-62
lines changed

src/zenml/enums.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ def is_finished(self) -> bool:
104104
ExecutionStatus.COMPLETED,
105105
ExecutionStatus.CACHED,
106106
ExecutionStatus.RETRIED,
107+
ExecutionStatus.RETRYING,
107108
ExecutionStatus.STOPPED,
108109
}
109110

@@ -125,6 +126,20 @@ def is_failed(self) -> bool:
125126
"""
126127
return self in {ExecutionStatus.FAILED}
127128

129+
@property
130+
def is_in_progress(self) -> bool:
131+
"""Whether the execution status refers to an in progress execution.
132+
133+
Returns:
134+
Whether the execution status refers to an in progress execution.
135+
"""
136+
return self in {
137+
ExecutionStatus.INITIALIZING,
138+
ExecutionStatus.PROVISIONING,
139+
ExecutionStatus.RUNNING,
140+
ExecutionStatus.STOPPING,
141+
}
142+
128143

129144
class LoggingLevels(Enum):
130145
"""Enum for logging levels."""

src/zenml/execution/pipeline/dynamic/outputs.py

Lines changed: 56 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# permissions and limitations under the License.
1414
"""Dynamic pipeline execution outputs."""
1515

16+
from abc import abstractmethod
1617
from concurrent.futures import Future
1718
from typing import Any, Iterator, List, Optional, Tuple, Union, overload
1819

@@ -34,7 +35,27 @@ class OutputArtifact(ArtifactVersionResponse):
3435
StepRunOutputs = Union[None, OutputArtifact, Tuple[OutputArtifact, ...]]
3536

3637

37-
class _BaseStepRunFuture:
38+
class BaseFuture:
39+
"""Base future."""
40+
41+
@abstractmethod
42+
def running(self) -> bool:
43+
"""Check if the future is running.
44+
45+
Returns:
46+
True if the future is running, False otherwise.
47+
"""
48+
49+
@abstractmethod
50+
def result(self) -> Any:
51+
"""Get the result of the future.
52+
53+
Returns:
54+
The result of the future.
55+
"""
56+
57+
58+
class BaseStepRunFuture(BaseFuture):
3859
"""Base step run future."""
3960

4061
def __init__(
@@ -62,12 +83,16 @@ def invocation_id(self) -> str:
6283
"""
6384
return self._invocation_id
6485

65-
def _wait(self) -> None:
66-
"""Wait for the step run future to complete."""
67-
self._wrapped.result()
86+
def running(self) -> bool:
87+
"""Check if the step run future is running.
88+
89+
Returns:
90+
True if the step run future is running, False otherwise.
91+
"""
92+
return self._wrapped.running()
6893

6994

70-
class ArtifactFuture(_BaseStepRunFuture):
95+
class ArtifactFuture(BaseStepRunFuture):
7196
"""Future for a step run output artifact."""
7297

7398
def __init__(
@@ -115,7 +140,7 @@ def load(self, disable_cache: bool = False) -> Any:
115140
return self.result().load(disable_cache=disable_cache)
116141

117142

118-
class StepRunOutputsFuture(_BaseStepRunFuture):
143+
class StepRunOutputsFuture(BaseStepRunFuture):
119144
"""Future for a step run output."""
120145

121146
def __init__(
@@ -270,7 +295,7 @@ def __len__(self) -> int:
270295
return len(self._output_keys)
271296

272297

273-
class MapResultsFuture:
298+
class MapResultsFuture(BaseFuture):
274299
"""Future that represents the results of a `step.map/product(...)` call."""
275300

276301
def __init__(self, futures: List[StepRunOutputsFuture]) -> None:
@@ -281,6 +306,14 @@ def __init__(self, futures: List[StepRunOutputsFuture]) -> None:
281306
"""
282307
self.futures = futures
283308

309+
def running(self) -> bool:
310+
"""Check if the map results future is running.
311+
312+
Returns:
313+
True if the map results future is running, False otherwise.
314+
"""
315+
return any(future.running() for future in self.futures)
316+
284317
def result(self) -> List[StepRunOutputs]:
285318
"""Get the step run outputs this future represents.
286319
@@ -289,6 +322,19 @@ def result(self) -> List[StepRunOutputs]:
289322
"""
290323
return [future.result() for future in self.futures]
291324

325+
def load(self, disable_cache: bool = False) -> List[Any]:
326+
"""Load the step run output artifacts.
327+
328+
Args:
329+
disable_cache: Whether to disable the artifact cache.
330+
331+
Returns:
332+
The step run output artifacts.
333+
"""
334+
return [
335+
future.load(disable_cache=disable_cache) for future in self.futures
336+
]
337+
292338
def unpack(self) -> Tuple[List[ArtifactFuture], ...]:
293339
"""Unpack the map results future.
294340
@@ -358,4 +404,6 @@ def __len__(self) -> int:
358404
return len(self.futures)
359405

360406

361-
StepRunFuture = Union[ArtifactFuture, StepRunOutputsFuture, MapResultsFuture]
407+
AnyStepRunFuture = Union[
408+
ArtifactFuture, StepRunOutputsFuture, MapResultsFuture
409+
]

0 commit comments

Comments
 (0)