|
| 1 | +# Licensed to the Apache Software Foundation (ASF) under one |
| 2 | +# or more contributor license agreements. See the NOTICE file |
| 3 | +# distributed with this work for additional information |
| 4 | +# regarding copyright ownership. The ASF licenses this file |
| 5 | +# to you under the Apache License, Version 2.0 (the |
| 6 | +# "License"); you may not use this file except in compliance |
| 7 | +# with the License. You may obtain a copy of the License at |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Unless required by applicable law or agreed to in writing, |
| 12 | +# software distributed under the License is distributed on an |
| 13 | +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 14 | +# KIND, either express or implied. See the License for the |
| 15 | +# specific language governing permissions and limitations |
| 16 | +# under the License. |
| 17 | +from __future__ import annotations |
| 18 | + |
| 19 | +import os |
| 20 | +from pathlib import Path |
| 21 | + |
| 22 | +import structlog |
| 23 | +from google.api_core.exceptions import NotFound |
| 24 | + |
| 25 | +from airflow.dag_processing.bundles.base import BaseDagBundle |
| 26 | +from airflow.exceptions import AirflowException |
| 27 | +from airflow.providers.google.cloud.hooks.gcs import GCSHook |
| 28 | +from airflow.providers.google.common.hooks.base_google import GoogleBaseHook |
| 29 | + |
| 30 | + |
| 31 | +class GCSDagBundle(BaseDagBundle): |
| 32 | + """ |
| 33 | + GCS Dag bundle - exposes a directory in GCS as a Dag bundle. |
| 34 | +
|
| 35 | + This allows Airflow to load Dags directly from a GCS bucket. |
| 36 | +
|
| 37 | + :param gcp_conn_id: Airflow connection ID for GCS. Defaults to GoogleBaseHook.default_conn_name. |
| 38 | + :param bucket_name: The name of the GCS bucket containing the Dag files. |
| 39 | + :param prefix: Optional subdirectory within the GCS bucket where the Dags are stored. |
| 40 | + If None, Dags are assumed to be at the root of the bucket (Optional). |
| 41 | + """ |
| 42 | + |
| 43 | + supports_versioning = False |
| 44 | + |
| 45 | + def __init__( |
| 46 | + self, |
| 47 | + *, |
| 48 | + gcp_conn_id: str = GoogleBaseHook.default_conn_name, |
| 49 | + bucket_name: str, |
| 50 | + prefix: str = "", |
| 51 | + **kwargs, |
| 52 | + ) -> None: |
| 53 | + super().__init__(**kwargs) |
| 54 | + self.gcp_conn_id = gcp_conn_id |
| 55 | + self.bucket_name = bucket_name |
| 56 | + self.prefix = prefix |
| 57 | + # Local path where GCS Dags are downloaded |
| 58 | + self.gcs_dags_dir: Path = self.base_dir |
| 59 | + |
| 60 | + log = structlog.get_logger(__name__) |
| 61 | + self._log = log.bind( |
| 62 | + bundle_name=self.name, |
| 63 | + version=self.version, |
| 64 | + bucket_name=self.bucket_name, |
| 65 | + prefix=self.prefix, |
| 66 | + gcp_conn_id=self.gcp_conn_id, |
| 67 | + ) |
| 68 | + self._gcs_hook: GCSHook | None = None |
| 69 | + |
| 70 | + def _initialize(self): |
| 71 | + with self.lock(): |
| 72 | + if not self.gcs_dags_dir.exists(): |
| 73 | + self._log.info("Creating local Dags directory: %s", self.gcs_dags_dir) |
| 74 | + os.makedirs(self.gcs_dags_dir) |
| 75 | + |
| 76 | + if not self.gcs_dags_dir.is_dir(): |
| 77 | + raise NotADirectoryError(f"Local Dags path: {self.gcs_dags_dir} is not a directory.") |
| 78 | + |
| 79 | + try: |
| 80 | + self.gcs_hook.get_bucket(bucket_name=self.bucket_name) |
| 81 | + except NotFound: |
| 82 | + raise ValueError(f"GCS bucket '{self.bucket_name}' does not exist.") |
| 83 | + |
| 84 | + if self.prefix: |
| 85 | + # don't check when prefix is "" |
| 86 | + if not self.gcs_hook.list(bucket_name=self.bucket_name, prefix=self.prefix): |
| 87 | + raise ValueError(f"GCS prefix 'gs://{self.bucket_name}/{self.prefix}' does not exist.") |
| 88 | + self.refresh() |
| 89 | + |
| 90 | + def initialize(self) -> None: |
| 91 | + self._initialize() |
| 92 | + super().initialize() |
| 93 | + |
| 94 | + @property |
| 95 | + def gcs_hook(self): |
| 96 | + if self._gcs_hook is None: |
| 97 | + try: |
| 98 | + self._gcs_hook: GCSHook = GCSHook(gcp_conn_id=self.gcp_conn_id) # Initialize GCS hook. |
| 99 | + except AirflowException as e: |
| 100 | + self._log.warning("Could not create GCSHook for connection %s: %s", self.gcp_conn_id, e) |
| 101 | + return self._gcs_hook |
| 102 | + |
| 103 | + def __repr__(self): |
| 104 | + return ( |
| 105 | + f"<GCSDagBundle(" |
| 106 | + f"name={self.name!r}, " |
| 107 | + f"bucket_name={self.bucket_name!r}, " |
| 108 | + f"prefix={self.prefix!r}, " |
| 109 | + f"version={self.version!r}" |
| 110 | + f")>" |
| 111 | + ) |
| 112 | + |
| 113 | + def get_current_version(self) -> str | None: |
| 114 | + """Return the current version of the Dag bundle. Currently not supported.""" |
| 115 | + return None |
| 116 | + |
| 117 | + @property |
| 118 | + def path(self) -> Path: |
| 119 | + """Return the local path to the Dag files.""" |
| 120 | + return self.gcs_dags_dir # Path where Dags are downloaded. |
| 121 | + |
| 122 | + def refresh(self) -> None: |
| 123 | + """Refresh the Dag bundle by re-downloading the Dags from GCS.""" |
| 124 | + if self.version: |
| 125 | + raise ValueError("Refreshing a specific version is not supported") |
| 126 | + |
| 127 | + with self.lock(): |
| 128 | + self._log.debug( |
| 129 | + "Downloading Dags from gs://%s/%s to %s", self.bucket_name, self.prefix, self.gcs_dags_dir |
| 130 | + ) |
| 131 | + self.gcs_hook.sync_to_local_dir( |
| 132 | + bucket_name=self.bucket_name, |
| 133 | + prefix=self.prefix, |
| 134 | + local_dir=self.gcs_dags_dir, |
| 135 | + delete_stale=True, |
| 136 | + ) |
| 137 | + |
| 138 | + def view_url(self, version: str | None = None) -> str | None: |
| 139 | + """ |
| 140 | + Return a URL for viewing the Dags in GCS. Currently, versioning is not supported. |
| 141 | +
|
| 142 | + This method is deprecated and will be removed when the minimum supported Airflow version is 3.1. |
| 143 | + Use `view_url_template` instead. |
| 144 | + """ |
| 145 | + return self.view_url_template() |
| 146 | + |
| 147 | + def view_url_template(self) -> str | None: |
| 148 | + """Return a URL for viewing the Dags in GCS. Currently, versioning is not supported.""" |
| 149 | + if self.version: |
| 150 | + raise ValueError("GCS url with version is not supported") |
| 151 | + if hasattr(self, "_view_url_template") and self._view_url_template: |
| 152 | + # Because we use this method in the view_url method, we need to handle |
| 153 | + # backward compatibility for Airflow versions that doesn't have the |
| 154 | + # _view_url_template attribute. Should be removed when we drop support for Airflow 3.0 |
| 155 | + return self._view_url_template |
| 156 | + # https://console.cloud.google.com/storage/browser/<bucket-name>/<prefix> |
| 157 | + url = f"https://console.cloud.google.com/storage/browser/{self.bucket_name}" |
| 158 | + if self.prefix: |
| 159 | + url += f"/{self.prefix}" |
| 160 | + |
| 161 | + return url |
0 commit comments