Skip to content

Commit 69dfea4

Browse files
authored
Add GCS Dag bundle (unversioned) (#55919)
1 parent 1cbed3d commit 69dfea4

File tree

7 files changed

+625
-3
lines changed

7 files changed

+625
-3
lines changed

airflow-core/docs/administration-and-deployment/dag-bundles.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,9 @@ Airflow supports multiple types of Dag Bundles, each catering to specific use ca
5353
**airflow.providers.amazon.aws.bundles.s3.S3DagBundle**
5454
These bundles reference an S3 bucket containing Dag files. They do not support versioning of the bundle, meaning tasks always run using the latest code.
5555

56+
**airflow.providers.google.cloud.bundles.gcs.GCSDagBundle**
57+
These bundles reference a GCS bucket containing Dag files. They do not support versioning of the bundle, meaning tasks always run using the latest code.
58+
5659
Configuring Dag bundles
5760
-----------------------
5861

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
from __future__ import annotations
18+
19+
import os
20+
from pathlib import Path
21+
22+
import structlog
23+
from google.api_core.exceptions import NotFound
24+
25+
from airflow.dag_processing.bundles.base import BaseDagBundle
26+
from airflow.exceptions import AirflowException
27+
from airflow.providers.google.cloud.hooks.gcs import GCSHook
28+
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
29+
30+
31+
class GCSDagBundle(BaseDagBundle):
32+
"""
33+
GCS Dag bundle - exposes a directory in GCS as a Dag bundle.
34+
35+
This allows Airflow to load Dags directly from a GCS bucket.
36+
37+
:param gcp_conn_id: Airflow connection ID for GCS. Defaults to GoogleBaseHook.default_conn_name.
38+
:param bucket_name: The name of the GCS bucket containing the Dag files.
39+
:param prefix: Optional subdirectory within the GCS bucket where the Dags are stored.
40+
If None, Dags are assumed to be at the root of the bucket (Optional).
41+
"""
42+
43+
supports_versioning = False
44+
45+
def __init__(
46+
self,
47+
*,
48+
gcp_conn_id: str = GoogleBaseHook.default_conn_name,
49+
bucket_name: str,
50+
prefix: str = "",
51+
**kwargs,
52+
) -> None:
53+
super().__init__(**kwargs)
54+
self.gcp_conn_id = gcp_conn_id
55+
self.bucket_name = bucket_name
56+
self.prefix = prefix
57+
# Local path where GCS Dags are downloaded
58+
self.gcs_dags_dir: Path = self.base_dir
59+
60+
log = structlog.get_logger(__name__)
61+
self._log = log.bind(
62+
bundle_name=self.name,
63+
version=self.version,
64+
bucket_name=self.bucket_name,
65+
prefix=self.prefix,
66+
gcp_conn_id=self.gcp_conn_id,
67+
)
68+
self._gcs_hook: GCSHook | None = None
69+
70+
def _initialize(self):
71+
with self.lock():
72+
if not self.gcs_dags_dir.exists():
73+
self._log.info("Creating local Dags directory: %s", self.gcs_dags_dir)
74+
os.makedirs(self.gcs_dags_dir)
75+
76+
if not self.gcs_dags_dir.is_dir():
77+
raise NotADirectoryError(f"Local Dags path: {self.gcs_dags_dir} is not a directory.")
78+
79+
try:
80+
self.gcs_hook.get_bucket(bucket_name=self.bucket_name)
81+
except NotFound:
82+
raise ValueError(f"GCS bucket '{self.bucket_name}' does not exist.")
83+
84+
if self.prefix:
85+
# don't check when prefix is ""
86+
if not self.gcs_hook.list(bucket_name=self.bucket_name, prefix=self.prefix):
87+
raise ValueError(f"GCS prefix 'gs://{self.bucket_name}/{self.prefix}' does not exist.")
88+
self.refresh()
89+
90+
def initialize(self) -> None:
91+
self._initialize()
92+
super().initialize()
93+
94+
@property
95+
def gcs_hook(self):
96+
if self._gcs_hook is None:
97+
try:
98+
self._gcs_hook: GCSHook = GCSHook(gcp_conn_id=self.gcp_conn_id) # Initialize GCS hook.
99+
except AirflowException as e:
100+
self._log.warning("Could not create GCSHook for connection %s: %s", self.gcp_conn_id, e)
101+
return self._gcs_hook
102+
103+
def __repr__(self):
104+
return (
105+
f"<GCSDagBundle("
106+
f"name={self.name!r}, "
107+
f"bucket_name={self.bucket_name!r}, "
108+
f"prefix={self.prefix!r}, "
109+
f"version={self.version!r}"
110+
f")>"
111+
)
112+
113+
def get_current_version(self) -> str | None:
114+
"""Return the current version of the Dag bundle. Currently not supported."""
115+
return None
116+
117+
@property
118+
def path(self) -> Path:
119+
"""Return the local path to the Dag files."""
120+
return self.gcs_dags_dir # Path where Dags are downloaded.
121+
122+
def refresh(self) -> None:
123+
"""Refresh the Dag bundle by re-downloading the Dags from GCS."""
124+
if self.version:
125+
raise ValueError("Refreshing a specific version is not supported")
126+
127+
with self.lock():
128+
self._log.debug(
129+
"Downloading Dags from gs://%s/%s to %s", self.bucket_name, self.prefix, self.gcs_dags_dir
130+
)
131+
self.gcs_hook.sync_to_local_dir(
132+
bucket_name=self.bucket_name,
133+
prefix=self.prefix,
134+
local_dir=self.gcs_dags_dir,
135+
delete_stale=True,
136+
)
137+
138+
def view_url(self, version: str | None = None) -> str | None:
139+
"""
140+
Return a URL for viewing the Dags in GCS. Currently, versioning is not supported.
141+
142+
This method is deprecated and will be removed when the minimum supported Airflow version is 3.1.
143+
Use `view_url_template` instead.
144+
"""
145+
return self.view_url_template()
146+
147+
def view_url_template(self) -> str | None:
148+
"""Return a URL for viewing the Dags in GCS. Currently, versioning is not supported."""
149+
if self.version:
150+
raise ValueError("GCS url with version is not supported")
151+
if hasattr(self, "_view_url_template") and self._view_url_template:
152+
# Because we use this method in the view_url method, we need to handle
153+
# backward compatibility for Airflow versions that doesn't have the
154+
# _view_url_template attribute. Should be removed when we drop support for Airflow 3.0
155+
return self._view_url_template
156+
# https://console.cloud.google.com/storage/browser/<bucket-name>/<prefix>
157+
url = f"https://console.cloud.google.com/storage/browser/{self.bucket_name}"
158+
if self.prefix:
159+
url += f"/{self.prefix}"
160+
161+
return url

providers/google/src/airflow/providers/google/cloud/hooks/gcs.py

Lines changed: 107 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,10 @@
2828
import warnings
2929
from collections.abc import Callable, Generator, Sequence
3030
from contextlib import contextmanager
31+
from datetime import datetime
3132
from functools import partial
3233
from io import BytesIO
34+
from pathlib import Path
3335
from tempfile import NamedTemporaryFile
3436
from typing import IO, TYPE_CHECKING, Any, ParamSpec, TypeVar, cast, overload
3537
from urllib.parse import urlsplit
@@ -50,12 +52,14 @@
5052
GoogleBaseAsyncHook,
5153
GoogleBaseHook,
5254
)
53-
from airflow.utils import timezone
55+
56+
try:
57+
from airflow.sdk import timezone
58+
except ImportError:
59+
from airflow.utils import timezone # type: ignore[attr-defined,no-redef]
5460
from airflow.version import version
5561

5662
if TYPE_CHECKING:
57-
from datetime import datetime
58-
5963
from aiohttp import ClientSession
6064
from google.api_core.retry import Retry
6165
from google.cloud.storage.blob import Blob
@@ -1249,6 +1253,106 @@ def compose(self, bucket_name: str, source_objects: List[str], destination_objec
12491253

12501254
self.log.info("Completed successfully.")
12511255

1256+
def _sync_to_local_dir_delete_stale_local_files(self, current_gcs_objects: List[Path], local_dir: Path):
1257+
current_gcs_keys = {key.resolve() for key in current_gcs_objects}
1258+
1259+
for item in local_dir.rglob("*"):
1260+
if item.is_file():
1261+
if item.resolve() not in current_gcs_keys:
1262+
self.log.debug("Deleting stale local file: %s", item)
1263+
item.unlink()
1264+
# Clean up empty directories
1265+
for root, dirs, _ in os.walk(local_dir, topdown=False):
1266+
for d in dirs:
1267+
dir_path = os.path.join(root, d)
1268+
if not os.listdir(dir_path):
1269+
self.log.debug("Deleting stale empty directory: %s", dir_path)
1270+
os.rmdir(dir_path)
1271+
1272+
def _sync_to_local_dir_if_changed(self, blob: Blob, local_target_path: Path):
1273+
should_download = False
1274+
download_msg = ""
1275+
if not local_target_path.exists():
1276+
should_download = True
1277+
download_msg = f"Local file {local_target_path} does not exist."
1278+
else:
1279+
local_stats = local_target_path.stat()
1280+
# Reload blob to get fresh metadata, including size and updated time
1281+
blob.reload()
1282+
1283+
if blob.size != local_stats.st_size:
1284+
should_download = True
1285+
download_msg = (
1286+
f"GCS object size ({blob.size}) and local file size ({local_stats.st_size}) differ."
1287+
)
1288+
1289+
gcs_last_modified = blob.updated
1290+
if (
1291+
not should_download
1292+
and gcs_last_modified
1293+
and local_stats.st_mtime < gcs_last_modified.timestamp()
1294+
):
1295+
should_download = True
1296+
download_msg = f"GCS object last modified ({gcs_last_modified}) is newer than local file last modified ({datetime.fromtimestamp(local_stats.st_mtime, tz=timezone.utc)})."
1297+
1298+
if should_download:
1299+
self.log.debug("%s Downloading %s to %s", download_msg, blob.name, local_target_path.as_posix())
1300+
self.download(
1301+
bucket_name=blob.bucket.name, object_name=blob.name, filename=str(local_target_path)
1302+
)
1303+
else:
1304+
self.log.debug(
1305+
"Local file %s is up-to-date with GCS object %s. Skipping download.",
1306+
local_target_path.as_posix(),
1307+
blob.name,
1308+
)
1309+
1310+
def sync_to_local_dir(
1311+
self,
1312+
bucket_name: str,
1313+
local_dir: str | Path,
1314+
prefix: str | None = None,
1315+
delete_stale: bool = False,
1316+
) -> None:
1317+
"""
1318+
Download files from a GCS bucket to a local directory.
1319+
1320+
It will download all files from the given ``prefix`` and create the corresponding
1321+
directory structure in the ``local_dir``.
1322+
1323+
If ``delete_stale`` is ``True``, it will delete all local files that do not exist in the GCS bucket.
1324+
1325+
:param bucket_name: The name of the GCS bucket.
1326+
:param local_dir: The local directory to which the files will be downloaded.
1327+
:param prefix: The prefix of the files to be downloaded.
1328+
:param delete_stale: If ``True``, deletes local files that don't exist in the bucket.
1329+
"""
1330+
prefix = prefix or ""
1331+
local_dir_path = Path(local_dir)
1332+
self.log.debug("Downloading data from gs://%s/%s to %s", bucket_name, prefix, local_dir_path)
1333+
1334+
gcs_bucket = self.get_bucket(bucket_name)
1335+
local_gcs_objects = []
1336+
1337+
for blob in gcs_bucket.list_blobs(prefix=prefix):
1338+
# GCS lists "directories" as objects ending with a slash. We should skip them.
1339+
if blob.name.endswith("/"):
1340+
continue
1341+
1342+
blob_path = Path(blob.name)
1343+
local_target_path = local_dir_path.joinpath(blob_path.relative_to(prefix))
1344+
if not local_target_path.parent.exists():
1345+
local_target_path.parent.mkdir(parents=True, exist_ok=True)
1346+
self.log.debug("Created local directory: %s", local_target_path.parent)
1347+
1348+
self._sync_to_local_dir_if_changed(blob=blob, local_target_path=local_target_path)
1349+
local_gcs_objects.append(local_target_path)
1350+
1351+
if delete_stale:
1352+
self._sync_to_local_dir_delete_stale_local_files(
1353+
current_gcs_objects=local_gcs_objects, local_dir=local_dir_path
1354+
)
1355+
12521356
def sync(
12531357
self,
12541358
source_bucket: str,
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.

0 commit comments

Comments
 (0)