Skip to content

Commit f9cb824

Browse files
committed
feature_flag
Signed-off-by: Sai Shree Pradhan <saishree.pradhan@databricks.com>
1 parent e0ca049 commit f9cb824

File tree

3 files changed

+208
-7
lines changed

3 files changed

+208
-7
lines changed

src/databricks/sql/client.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -245,12 +245,6 @@ def read(self) -> Optional[OAuthToken]:
245245
self.use_cloud_fetch = kwargs.get("use_cloud_fetch", True)
246246
self._cursors = [] # type: List[Cursor]
247247

248-
self.server_telemetry_enabled = True
249-
self.client_telemetry_enabled = kwargs.get("enable_telemetry", False)
250-
self.telemetry_enabled = (
251-
self.client_telemetry_enabled and self.server_telemetry_enabled
252-
)
253-
254248
self.session = Session(
255249
server_hostname,
256250
http_path,
@@ -268,6 +262,17 @@ def read(self) -> Optional[OAuthToken]:
268262
)
269263
self.staging_allowed_local_path = kwargs.get("staging_allowed_local_path", None)
270264

265+
self.client_telemetry_enabled = kwargs.get("enable_telemetry", False)
266+
if self.client_telemetry_enabled:
267+
self.server_telemetry_enabled = TelemetryHelper.is_server_telemetry_enabled(
268+
self
269+
)
270+
self.telemetry_enabled = (
271+
self.client_telemetry_enabled and self.server_telemetry_enabled
272+
)
273+
else:
274+
self.telemetry_enabled = False
275+
271276
TelemetryClientFactory.initialize_telemetry_client(
272277
telemetry_enabled=self.telemetry_enabled,
273278
session_id_hex=self.get_session_id_hex(),
Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
# databricks/sql/common/feature_flags.py
2+
3+
import threading
4+
import time
5+
import requests
6+
from dataclasses import dataclass, field
7+
from concurrent.futures import ThreadPoolExecutor
8+
from typing import Dict, Optional, List, Any, TYPE_CHECKING
9+
10+
if TYPE_CHECKING:
11+
from databricks.sql.client import Connection
12+
13+
14+
@dataclass
15+
class FeatureFlagEntry:
16+
"""Represents a single feature flag from the server response."""
17+
18+
name: str
19+
value: str
20+
21+
22+
@dataclass
23+
class FeatureFlagsResponse:
24+
"""Represents the full JSON response from the feature flag endpoint."""
25+
26+
flags: List[FeatureFlagEntry] = field(default_factory=list)
27+
ttl_seconds: Optional[int] = None
28+
29+
@classmethod
30+
def from_dict(cls, data: Dict[str, Any]) -> "FeatureFlagsResponse":
31+
"""Factory method to create an instance from a dictionary (parsed JSON)."""
32+
flags_data = data.get("flags", [])
33+
flags_list = [FeatureFlagEntry(**flag) for flag in flags_data]
34+
return cls(flags=flags_list, ttl_seconds=data.get("ttl_seconds"))
35+
36+
37+
# --- Constants ---
38+
FEATURE_FLAGS_ENDPOINT_SUFFIX_FORMAT = (
39+
"/api/2.0/connector-service/feature-flags/PYTHON/{}"
40+
)
41+
DEFAULT_TTL_SECONDS = 900 # 15 minutes
42+
REFRESH_BEFORE_EXPIRY_SECONDS = 10 # Start proactive refresh 10s before expiry
43+
44+
45+
class FeatureFlagsContext:
46+
"""
47+
Manages fetching and caching of server-side feature flags for a connection.
48+
49+
1. The very first check for any flag is a synchronous, BLOCKING operation.
50+
2. Subsequent refreshes (triggered near TTL expiry) are done asynchronously
51+
in the background, returning stale data until the refresh completes.
52+
"""
53+
54+
def __init__(self, connection: "Connection", executor: ThreadPoolExecutor):
55+
from databricks.sql import __version__
56+
57+
self._connection = connection
58+
self._executor = executor # Used for ASYNCHRONOUS refreshes
59+
self._lock = threading.RLock()
60+
61+
# Cache state: `None` indicates the cache has never been loaded.
62+
self._flags: Optional[Dict[str, str]] = None
63+
self._ttl_seconds: int = DEFAULT_TTL_SECONDS
64+
self._last_refresh_time: float = 0
65+
66+
endpoint_suffix = FEATURE_FLAGS_ENDPOINT_SUFFIX_FORMAT.format(__version__)
67+
self._feature_flag_endpoint = (
68+
f"https://{self._connection.session.host}{endpoint_suffix}"
69+
)
70+
71+
def _is_refresh_needed(self) -> bool:
72+
"""Checks if the cache is due for a proactive background refresh."""
73+
if self._flags is None:
74+
return False # Not eligible for refresh until loaded once.
75+
76+
refresh_threshold = self._last_refresh_time + (
77+
self._ttl_seconds - REFRESH_BEFORE_EXPIRY_SECONDS
78+
)
79+
return time.monotonic() > refresh_threshold
80+
81+
def is_feature_enabled(self, name: str, default_value: bool) -> bool:
82+
"""
83+
Checks if a feature is enabled.
84+
- BLOCKS on the first call until flags are fetched.
85+
- Returns cached values on subsequent calls, triggering non-blocking refreshes if needed.
86+
"""
87+
with self._lock:
88+
# If cache has never been loaded, perform a synchronous, blocking fetch.
89+
if self._flags is None:
90+
self._refresh_flags()
91+
92+
# If a proactive background refresh is needed, start one. This is non-blocking.
93+
elif self._is_refresh_needed():
94+
# We don't check for an in-flight refresh; the executor queues the task, which is safe.
95+
self._executor.submit(self._refresh_flags)
96+
97+
# Now, return the value from the populated cache.
98+
flag_value = self._flags.get(name)
99+
if flag_value is None:
100+
return default_value
101+
return flag_value.lower() == "true"
102+
103+
def _refresh_flags(self):
104+
"""Performs a synchronous network request to fetch and update flags."""
105+
headers = {}
106+
try:
107+
# Authenticate the request
108+
self._connection.session.auth_provider.add_headers(headers)
109+
headers["User-Agent"] = self._connection.session.useragent_header
110+
111+
response = requests.get(
112+
self._feature_flag_endpoint, headers=headers, timeout=30
113+
)
114+
115+
if response.status_code == 200:
116+
ff_response = FeatureFlagsResponse.from_dict(response.json())
117+
self._update_cache_from_response(ff_response)
118+
else:
119+
# On failure, initialize with an empty dictionary to prevent re-blocking.
120+
if self._flags is None:
121+
self._flags = {}
122+
123+
except Exception as e:
124+
# On exception, initialize with an empty dictionary to prevent re-blocking.
125+
if self._flags is None:
126+
self._flags = {}
127+
128+
def _update_cache_from_response(self, ff_response: FeatureFlagsResponse):
129+
"""Atomically updates the internal cache state from a successful server response."""
130+
with self._lock:
131+
self._flags = {flag.name: flag.value for flag in ff_response.flags}
132+
if ff_response.ttl_seconds is not None and ff_response.ttl_seconds > 0:
133+
self._ttl_seconds = ff_response.ttl_seconds
134+
self._last_refresh_time = time.monotonic()
135+
136+
137+
class FeatureFlagsContextFactory:
138+
"""
139+
Manages a singleton instance of FeatureFlagsContext per connection session.
140+
Also manages a shared ThreadPoolExecutor for all background refresh operations.
141+
"""
142+
143+
_context_map: Dict[str, FeatureFlagsContext] = {}
144+
_executor: Optional[ThreadPoolExecutor] = None
145+
_lock = threading.Lock()
146+
147+
@classmethod
148+
def _initialize(cls):
149+
"""Initializes the shared executor for async refreshes if it doesn't exist."""
150+
if cls._executor is None:
151+
cls._executor = ThreadPoolExecutor(
152+
max_workers=3, thread_name_prefix="feature-flag-refresher"
153+
)
154+
155+
@classmethod
156+
def get_instance(cls, connection: "Connection") -> FeatureFlagsContext:
157+
"""Gets or creates a FeatureFlagsContext for the given connection."""
158+
with cls._lock:
159+
cls._initialize()
160+
# Use the unique session ID as the key
161+
key = connection.get_session_id_hex()
162+
if key not in cls._context_map:
163+
cls._context_map[key] = FeatureFlagsContext(connection, cls._executor)
164+
return cls._context_map[key]
165+
166+
@classmethod
167+
def remove_instance(cls, connection: "Connection"):
168+
"""Removes the context for a given connection and shuts down the executor if no clients remain."""
169+
with cls._lock:
170+
key = connection.get_session_id_hex()
171+
if key in cls._context_map:
172+
cls._context_map.pop(key, None)
173+
174+
# If this was the last active context, clean up the thread pool.
175+
if not cls._context_map and cls._executor is not None:
176+
cls._executor.shutdown(wait=False)
177+
cls._executor = None

src/databricks/sql/telemetry/telemetry_client.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import requests
44
import logging
55
from concurrent.futures import ThreadPoolExecutor
6-
from typing import Dict, Optional
6+
from typing import Dict, Optional, TYPE_CHECKING
77
from databricks.sql.telemetry.models.event import (
88
TelemetryEvent,
99
DriverSystemConfiguration,
@@ -30,6 +30,10 @@
3030
import uuid
3131
import locale
3232
from databricks.sql.telemetry.utils import BaseTelemetryClient
33+
from databricks.sql.common.feature_flag import FeatureFlagsContextFactory
34+
35+
if TYPE_CHECKING:
36+
from databricks.sql.client import Connection
3337

3438
logger = logging.getLogger(__name__)
3539

@@ -38,6 +42,9 @@ class TelemetryHelper:
3842
"""Helper class for getting telemetry related information."""
3943

4044
_DRIVER_SYSTEM_CONFIGURATION = None
45+
TELEMETRY_FEATURE_FLAG_NAME = (
46+
"databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetry"
47+
)
4148

4249
@classmethod
4350
def get_driver_system_configuration(cls) -> DriverSystemConfiguration:
@@ -92,6 +99,18 @@ def get_auth_flow(auth_provider):
9299
else:
93100
return None
94101

102+
@staticmethod
103+
def is_server_telemetry_enabled(connection: "Connection") -> bool:
104+
"""
105+
Checks if the server-side feature flag for telemetry is enabled.
106+
This is a BLOCKING call on the first check per connection.
107+
"""
108+
context = FeatureFlagsContextFactory.get_instance(connection)
109+
110+
return context.is_feature_enabled(
111+
TelemetryHelper.TELEMETRY_FEATURE_FLAG_NAME, default_value=False
112+
)
113+
95114

96115
class NoopTelemetryClient(BaseTelemetryClient):
97116
"""

0 commit comments

Comments
 (0)