diff --git a/mypy-baseline.txt b/mypy-baseline.txt index 232ce8be..ccd2bf0b 100644 --- a/mypy-baseline.txt +++ b/mypy-baseline.txt @@ -23,13 +23,10 @@ posthog/client.py:0: note: Hint: "python3 -m pip install types-six" posthog/client.py:0: error: Name "queue" already defined (by an import) [no-redef] posthog/client.py:0: error: Need type annotation for "queue" [var-annotated] posthog/client.py:0: error: Incompatible types in assignment (expression has type "Any | list[Any]", variable has type "None") [assignment] -posthog/client.py:0: error: Incompatible types in assignment (expression has type "dict[Any, Any]", variable has type "None") [assignment] posthog/client.py:0: error: "None" has no attribute "__iter__" (not iterable) [attr-defined] posthog/client.py:0: error: Statement is unreachable [unreachable] posthog/client.py:0: error: Right operand of "and" is never evaluated [unreachable] posthog/client.py:0: error: Incompatible types in assignment (expression has type "Poller", variable has type "None") [assignment] posthog/client.py:0: error: "None" has no attribute "start" [attr-defined] -posthog/client.py:0: error: Statement is unreachable [unreachable] -posthog/client.py:0: error: Statement is unreachable [unreachable] posthog/client.py:0: error: Name "urlparse" already defined (possibly by an import) [no-redef] posthog/client.py:0: error: Name "parse_qs" already defined (possibly by an import) [no-redef] diff --git a/posthog/client.py b/posthog/client.py index 426bb24c..51d5fbfa 100644 --- a/posthog/client.py +++ b/posthog/client.py @@ -2,6 +2,7 @@ import logging import os import sys +import threading from datetime import datetime, timedelta from typing import Any, Dict, Optional, Union from typing_extensions import Unpack @@ -192,6 +193,8 @@ def __init__( capture_exception_code_variables=False, code_variables_mask_patterns=None, code_variables_ignore_patterns=None, + realtime_flags=False, + on_feature_flags_update=None, ): """ Initialize a new PostHog client instance. @@ -226,7 +229,7 @@ def __init__( self.gzip = gzip self.timeout = timeout self._feature_flags = None # private variable to store flags - self.feature_flags_by_key = None + self.feature_flags_by_key: Optional[dict[str, dict]] = None self.group_type_mapping: Optional[dict[str, str]] = None self.cohorts: Optional[dict[str, Any]] = None self.poll_interval = poll_interval @@ -248,6 +251,15 @@ def __init__( self.exception_capture = None self.privacy_mode = privacy_mode self.enable_local_evaluation = enable_local_evaluation + self.realtime_flags = realtime_flags + self.on_feature_flags_update = on_feature_flags_update + self.sse_connection = None # type: Optional[Any] + self.sse_response = None # type: Optional[Any] + self.sse_connected = False + self._sse_lock = threading.Lock() + self._flags_lock = ( + threading.Lock() + ) # Protects feature_flags and feature_flags_by_key self.capture_exception_code_variables = capture_exception_code_variables self.code_variables_mask_patterns = ( @@ -1190,6 +1202,10 @@ def join(self): except Exception as e: self.log.error(f"[FEATURE FLAGS] Cache provider shutdown error: {e}") + # Close SSE connection + if self.sse_connection: + self._close_sse_connection() + def shutdown(self): """ Flush all messages and cleanly shutdown the client. Call this before the process ends in serverless environments to avoid data loss. @@ -1209,19 +1225,20 @@ def _update_flag_state( self, data: FlagDefinitionCacheData, old_flags_by_key: Optional[dict] = None ) -> None: """Update internal flag state from cache data and invalidate evaluation cache if changed.""" - self.feature_flags = data["flags"] - self.group_type_mapping = data["group_type_mapping"] - self.cohorts = data["cohorts"] - - # Invalidate evaluation cache if flag definitions changed - if ( - self.flag_cache - and old_flags_by_key is not None - and old_flags_by_key != (self.feature_flags_by_key or {}) - ): - old_version = self.flag_definition_version - self.flag_definition_version += 1 - self.flag_cache.invalidate_version(old_version) + with self._flags_lock: + self.feature_flags = data["flags"] + self.group_type_mapping = data["group_type_mapping"] + self.cohorts = data["cohorts"] + + # Invalidate evaluation cache if flag definitions changed + if ( + self.flag_cache + and old_flags_by_key is not None + and old_flags_by_key != (self.feature_flags_by_key or {}) + ): + old_version = self.flag_definition_version + self.flag_definition_version += 1 + self.flag_cache.invalidate_version(old_version) def _load_feature_flags(self): should_fetch = True @@ -1247,8 +1264,12 @@ def _load_feature_flags(self): self.log.debug( "[FEATURE FLAGS] Using cached flag definitions from external cache" ) + with self._flags_lock: + old_flags_copy: dict[str, dict] = ( + self.feature_flags_by_key or {} + ) self._update_flag_state( - cached_data, old_flags_by_key=self.feature_flags_by_key or {} + cached_data, old_flags_by_key=old_flags_copy ) self._last_feature_flag_poll = datetime.now(tz=tzutc()) return @@ -1272,7 +1293,8 @@ def _fetch_feature_flags_from_api(self): """Fetch feature flags from the PostHog API.""" try: # Store old flags to detect changes - old_flags_by_key: dict[str, dict] = self.feature_flags_by_key or {} + with self._flags_lock: + old_flags_by_key: dict[str, dict] = self.feature_flags_by_key or {} response = get( self.personal_api_key, @@ -1315,6 +1337,10 @@ def _fetch_feature_flags_from_api(self): self.log.error(f"[FEATURE FLAGS] Cache provider store error: {e}") # Flags are already in memory, so continue normally + # Setup SSE connection if realtime_flags is enabled + if self.realtime_flags and not self.sse_connected: + self._setup_sse_connection() + except APIError as e: if e.status == 401: self.log.error( @@ -1727,30 +1753,33 @@ def _locally_evaluate_flag( self.load_feature_flags() response = None - if self.feature_flags: - assert self.feature_flags_by_key is not None, ( - "feature_flags_by_key should be initialized when feature_flags is set" - ) - # Local evaluation - flag = self.feature_flags_by_key.get(key) - if flag: - try: - response = self._compute_flag_locally( - flag, - distinct_id, - groups=groups, - person_properties=person_properties, - group_properties=group_properties, - ) - self.log.debug( - f"Successfully computed flag locally: {key} -> {response}" - ) - except (RequiresServerEvaluation, InconclusiveMatchError) as e: - self.log.debug(f"Failed to compute flag {key} locally: {e}") - except Exception as e: - self.log.exception( - f"[FEATURE FLAGS] Error while computing variant locally: {e}" - ) + flag = None + with self._flags_lock: + if self.feature_flags: + assert self.feature_flags_by_key is not None, ( + "feature_flags_by_key should be initialized when feature_flags is set" + ) + # Local evaluation - copy flag to avoid holding lock during computation + flag = self.feature_flags_by_key.get(key) + + if flag: + try: + response = self._compute_flag_locally( + flag, + distinct_id, + groups=groups, + person_properties=person_properties, + group_properties=group_properties, + ) + self.log.debug( + f"Successfully computed flag locally: {key} -> {response}" + ) + except (RequiresServerEvaluation, InconclusiveMatchError) as e: + self.log.debug(f"Failed to compute flag {key} locally: {e}") + except Exception as e: + self.log.exception( + f"[FEATURE FLAGS] Error while computing variant locally: {e}" + ) return response def get_feature_flag_payload( @@ -1916,13 +1945,14 @@ def get_remote_config_payload(self, key: str): def _compute_payload_locally( self, key: str, match_value: FlagValue ) -> Optional[str]: - payload = None + with self._flags_lock: + if self.feature_flags_by_key is None: + return None - if self.feature_flags_by_key is None: - return payload + flag_definition = self.feature_flags_by_key.get(key) + if not flag_definition: + return None - flag_definition = self.feature_flags_by_key.get(key) - if flag_definition: flag_filters = flag_definition.get("filters") or {} flag_payloads = flag_filters.get("payloads") or {} # For boolean flags, convert True to "true" @@ -1932,8 +1962,7 @@ def _compute_payload_locally( if isinstance(match_value, bool) and match_value else str(match_value) ) - payload = flag_payloads.get(lookup_value, None) - return payload + return flag_payloads.get(lookup_value, None) def get_all_flags( self, @@ -2220,6 +2249,212 @@ def _add_local_person_and_group_properties( return all_person_properties, all_group_properties + def _setup_sse_connection(self): + """ + Establish a real-time connection using Server-Sent Events to receive feature flag updates. + """ + if not self.personal_api_key: + self.log.warning( + "[FEATURE FLAGS] Cannot establish real-time connection without personal_api_key" + ) + return + + with self._sse_lock: + if self.sse_connected: + self.log.debug("[FEATURE FLAGS] SSE connection already established") + return + self.sse_connected = True + + try: + # Use requests with stream=True for SSE + import requests # type: ignore[import-untyped] + + url = f"{self.host}/flags/definitions/stream" + headers = { + "Authorization": f"Bearer {self.api_key}", + "Accept": "text/event-stream", + } + + def sse_listener(): + """Background thread to listen for SSE messages""" + import json + + try: + response = requests.get( + url, headers=headers, stream=True, timeout=None + ) + self.sse_response = response + + if response.status_code != 200: + self.log.warning( + f"[FEATURE FLAGS] SSE connection failed with status {response.status_code}" + ) + with self._sse_lock: + self.sse_connected = False + return + + self.log.debug("[FEATURE FLAGS] SSE connection established") + + # Process the stream line by line, checking sse_connected periodically + for line in response.iter_lines(): + with self._sse_lock: + if not self.sse_connected: + break + + if not line: + continue + + line = line.decode("utf-8") + + # SSE format: "data: {...}" + if line.startswith("data: "): + data_str = line[6:] # Remove "data: " prefix + try: + flag_data = json.loads(data_str) + self._process_flag_update(flag_data) + except json.JSONDecodeError as e: + self.log.warning( + f"[FEATURE FLAGS] Failed to parse SSE message: {e}" + ) + except Exception as e: + self.log.warning( + f"[FEATURE FLAGS] SSE connection error: {e}. Reconnecting in 5 seconds..." + ) + with self._sse_lock: + self.sse_connected = False + + # Attempt to reconnect after 5 seconds if realtime_flags is still enabled + if self.realtime_flags: + import time + + time.sleep(5) + self._setup_sse_connection() + finally: + if self.sse_response: + try: + self.sse_response.close() + except Exception: + pass + self.sse_response = None + + # Start the SSE listener in a daemon thread + sse_thread = threading.Thread(target=sse_listener, daemon=True) + sse_thread.start() + self.sse_connection = sse_thread + + except ImportError: + self.log.warning( + "[FEATURE FLAGS] requests library required for real-time flags" + ) + except Exception as e: + self.log.exception(f"[FEATURE FLAGS] Failed to setup SSE connection: {e}") + + def _close_sse_connection(self): + """ + Close the active SSE connection and prevent reconnection. + """ + if self.sse_connection: + self.log.debug("[FEATURE FLAGS] Closing SSE connection") + # Disable realtime flags to prevent reconnection + self.realtime_flags = False + + with self._sse_lock: + self.sse_connected = False + + # Close the response to interrupt iter_lines() + if self.sse_response: + try: + self.sse_response.close() + except Exception: + pass + self.sse_response = None + + self.sse_connection = None + + def _process_flag_update(self, flag_data): + """ + Process incoming flag updates from SSE messages. + + Args: + flag_data: The flag data from the SSE message + """ + try: + flag_key = flag_data.get("key") + if not flag_key: + self.log.warning("[FEATURE FLAGS] Received flag update without key") + return + + is_deleted = flag_data.get("deleted", False) + + with self._flags_lock: + # Handle flag deletion + if is_deleted: + self.log.debug(f"[FEATURE FLAGS] Deleting flag: {flag_key}") + if ( + self.feature_flags_by_key + and flag_key in self.feature_flags_by_key + ): + del self.feature_flags_by_key[flag_key] + + # Also remove from the array + if self.feature_flags: + self.feature_flags = [ + f for f in self.feature_flags if f.get("key") != flag_key + ] + + # Invalidate cache for this flag + if self.flag_cache: + old_version = self.flag_definition_version + self.flag_definition_version += 1 + self.flag_cache.invalidate_version(old_version) + + else: + # Update or add flag + self.log.debug(f"[FEATURE FLAGS] Updating flag: {flag_key}") + + if self.feature_flags_by_key is None: + self.feature_flags_by_key = {} + + if self.feature_flags is None: + self.feature_flags = [] + + # Update the lookup table + # mypy doesn't track that the setter ensures feature_flags_by_key is a dict + assert self.feature_flags_by_key is not None + self.feature_flags_by_key[flag_key] = flag_data + + # Update or add to the array + flag_exists = False + for i, f in enumerate(self.feature_flags): + if f.get("key") == flag_key: + self.feature_flags[i] = flag_data + flag_exists = True + break + + if not flag_exists: + self.feature_flags.append(flag_data) + + # Invalidate cache when flag definitions change + if self.flag_cache: + old_version = self.flag_definition_version + self.flag_definition_version += 1 + self.flag_cache.invalidate_version(old_version) + + # Call the user's callback if provided + if self.on_feature_flags_update: + try: + self.on_feature_flags_update( + flag_key=flag_key, + flag_data=flag_data, + ) + except Exception as e: + self.log.exception( + f"[FEATURE FLAGS] Error in on_feature_flags_update callback: {e}" + ) + + except Exception as e: + self.log.exception(f"[FEATURE FLAGS] Error processing flag update: {e}") + def stringify_id(val): if val is None: diff --git a/posthog/test/test_realtime_feature_flags.py b/posthog/test/test_realtime_feature_flags.py new file mode 100644 index 00000000..ad4bea1a --- /dev/null +++ b/posthog/test/test_realtime_feature_flags.py @@ -0,0 +1,415 @@ +import time +import unittest +from unittest import mock + +from posthog.client import Client +from posthog.request import GetResponse +from posthog.test.test_utils import FAKE_TEST_API_KEY + + +class TestRealtimeFeatureFlags(unittest.TestCase): + @classmethod + def setUpClass(cls): + # This ensures no real HTTP POST requests are made + cls.capture_patch = mock.patch.object(Client, "capture") + cls.capture_patch.start() + + @classmethod + def tearDownClass(cls): + cls.capture_patch.stop() + + def setUp(self): + self.failed = False + + def set_fail(self, e, batch): + """Mark the failure handler""" + print("FAIL", e, batch) + self.failed = True + + @mock.patch("posthog.client.get") + @mock.patch("requests.get") + def test_sse_connection_setup(self, mock_requests_get, mock_get): + """Test that SSE connection is established when realtime_flags is enabled""" + # Setup mock for initial flag loading + mock_get.return_value = GetResponse( + data={ + "flags": [ + { + "id": 1, + "name": "Test Flag", + "key": "test-flag", + "active": True, + "filters": {"groups": [{"rollout_percentage": 100}]}, + } + ], + "group_type_mapping": {}, + "cohorts": {}, + } + ) + + # Setup mock for SSE connection + mock_response = mock.Mock() + mock_response.status_code = 200 + mock_response.iter_lines = mock.Mock(return_value=iter([])) + mock_response.__enter__ = mock.Mock(return_value=mock_response) + mock_response.__exit__ = mock.Mock(return_value=False) + mock_requests_get.return_value = mock_response + + # Create client with realtime_flags enabled + client = Client( + FAKE_TEST_API_KEY, + personal_api_key="test_personal_key", + on_error=self.set_fail, + realtime_flags=True, + ) + + # Load feature flags (which should trigger SSE connection) + client.load_feature_flags() + + # Give the SSE thread a moment to start + time.sleep(0.1) + + # Verify SSE connection was attempted + mock_requests_get.assert_called_once() + call_args = mock_requests_get.call_args + + # Check URL contains the stream endpoint + self.assertIn("stream", call_args[0][0]) + + # Check headers include authorization + headers = call_args[1]["headers"] + self.assertIn("Authorization", headers) + self.assertEqual(headers["Accept"], "text/event-stream") + + # Cleanup + client.shutdown() + + @mock.patch("posthog.client.get") + def test_process_flag_update(self, mock_get): + """Test that flag updates are processed correctly""" + # Setup mock for initial flag loading + mock_get.return_value = GetResponse( + data={ + "flags": [ + { + "id": 1, + "name": "Test Flag", + "key": "test-flag", + "active": True, + "filters": {"groups": [{"rollout_percentage": 100}]}, + } + ], + "group_type_mapping": {}, + "cohorts": {}, + } + ) + + client = Client( + FAKE_TEST_API_KEY, + personal_api_key="test_personal_key", + on_error=self.set_fail, + realtime_flags=True, + ) + + # Load initial flags + client.load_feature_flags() + + # Verify initial flag exists + self.assertIn("test-flag", client.feature_flags_by_key) + self.assertEqual(len(client.feature_flags), 1) + + # Simulate a flag update + updated_flag = { + "id": 1, + "name": "Updated Test Flag", + "key": "test-flag", + "active": False, + "filters": {"groups": [{"rollout_percentage": 50}]}, + } + client._process_flag_update(updated_flag) + + # Verify flag was updated + self.assertIn("test-flag", client.feature_flags_by_key) + self.assertEqual( + client.feature_flags_by_key["test-flag"]["name"], "Updated Test Flag" + ) + self.assertFalse(client.feature_flags_by_key["test-flag"]["active"]) + + # Cleanup + client.shutdown() + + @mock.patch("posthog.client.get") + def test_process_flag_deletion(self, mock_get): + """Test that flag deletions are processed correctly""" + # Setup mock for initial flag loading + mock_get.return_value = GetResponse( + data={ + "flags": [ + { + "id": 1, + "name": "Test Flag", + "key": "test-flag", + "active": True, + "filters": {"groups": [{"rollout_percentage": 100}]}, + } + ], + "group_type_mapping": {}, + "cohorts": {}, + } + ) + + client = Client( + FAKE_TEST_API_KEY, + personal_api_key="test_personal_key", + on_error=self.set_fail, + realtime_flags=True, + ) + + # Load initial flags + client.load_feature_flags() + + # Verify initial flag exists + self.assertIn("test-flag", client.feature_flags_by_key) + self.assertEqual(len(client.feature_flags), 1) + + # Simulate a flag deletion + deleted_flag = { + "key": "test-flag", + "deleted": True, + } + client._process_flag_update(deleted_flag) + + # Verify flag was deleted + self.assertNotIn("test-flag", client.feature_flags_by_key) + self.assertEqual(len(client.feature_flags), 0) + + # Cleanup + client.shutdown() + + @mock.patch("posthog.client.get") + def test_process_new_flag_addition(self, mock_get): + """Test that new flags are added correctly""" + # Setup mock for initial flag loading + mock_get.return_value = GetResponse( + data={ + "flags": [], + "group_type_mapping": {}, + "cohorts": {}, + } + ) + + client = Client( + FAKE_TEST_API_KEY, + personal_api_key="test_personal_key", + on_error=self.set_fail, + realtime_flags=True, + ) + + # Load initial flags (empty) + client.load_feature_flags() + + # Verify no flags initially + self.assertEqual(len(client.feature_flags), 0) + + # Simulate a new flag addition + new_flag = { + "id": 1, + "name": "New Test Flag", + "key": "new-flag", + "active": True, + "filters": {"groups": [{"rollout_percentage": 100}]}, + } + client._process_flag_update(new_flag) + + # Verify flag was added + self.assertIn("new-flag", client.feature_flags_by_key) + self.assertEqual(len(client.feature_flags), 1) + self.assertEqual(client.feature_flags[0]["name"], "New Test Flag") + + # Cleanup + client.shutdown() + + @mock.patch("posthog.client.get") + def test_sse_disabled_by_default(self, mock_get): + """Test that SSE connection is NOT established when realtime_flags is False""" + # Setup mock for initial flag loading + mock_get.return_value = GetResponse( + data={ + "flags": [], + "group_type_mapping": {}, + "cohorts": {}, + } + ) + + # Create client with realtime_flags disabled (default) + client = Client( + FAKE_TEST_API_KEY, + personal_api_key="test_personal_key", + on_error=self.set_fail, + realtime_flags=False, + ) + + # Load feature flags + client.load_feature_flags() + + # Verify SSE connection was NOT established + self.assertFalse(client.sse_connected) + self.assertIsNone(client.sse_connection) + + # Cleanup + client.shutdown() + + @mock.patch("posthog.client.get") + def test_sse_cleanup_on_shutdown(self, mock_get): + """Test that SSE connection is properly cleaned up on shutdown""" + # Setup mock for initial flag loading + mock_get.return_value = GetResponse( + data={ + "flags": [], + "group_type_mapping": {}, + "cohorts": {}, + } + ) + + client = Client( + FAKE_TEST_API_KEY, + personal_api_key="test_personal_key", + on_error=self.set_fail, + realtime_flags=True, + ) + + # Manually set up a mock SSE connection + client.sse_connection = mock.Mock() + client.sse_connected = True + + # Shutdown the client + client.shutdown() + + # Verify SSE connection was cleaned up + self.assertFalse(client.sse_connected) + self.assertIsNone(client.sse_connection) + + @mock.patch("posthog.client.get") + def test_on_feature_flags_update_callback(self, mock_get): + """Test that the callback is called when flags are updated""" + # Setup mock for initial flag loading + mock_get.return_value = GetResponse( + data={ + "flags": [], + "group_type_mapping": {}, + "cohorts": {}, + } + ) + + # Track callback invocations + callback_calls = [] + + def flag_update_callback(flag_key, flag_data): + callback_calls.append({"flag_key": flag_key, "flag_data": flag_data}) + + client = Client( + FAKE_TEST_API_KEY, + personal_api_key="test_personal_key", + on_error=self.set_fail, + realtime_flags=True, + on_feature_flags_update=flag_update_callback, + ) + + # Load initial flags + client.load_feature_flags() + + # Simulate a new flag addition + new_flag = { + "id": 1, + "name": "New Test Flag", + "key": "new-flag", + "active": True, + "filters": {"groups": [{"rollout_percentage": 100}]}, + } + client._process_flag_update(new_flag) + + # Verify callback was called for addition + self.assertEqual(len(callback_calls), 1) + self.assertEqual(callback_calls[0]["flag_key"], "new-flag") + self.assertEqual(callback_calls[0]["flag_data"]["name"], "New Test Flag") + self.assertIsNotNone(callback_calls[0]["flag_data"]) + + # Simulate a flag update + updated_flag = { + "id": 1, + "name": "Updated Test Flag", + "key": "new-flag", + "active": False, + "filters": {"groups": [{"rollout_percentage": 50}]}, + } + client._process_flag_update(updated_flag) + + # Verify callback was called for update + self.assertEqual(len(callback_calls), 2) + self.assertEqual(callback_calls[1]["flag_key"], "new-flag") + self.assertEqual(callback_calls[1]["flag_data"]["name"], "Updated Test Flag") + self.assertIsNotNone(callback_calls[1]["flag_data"]) + + # Simulate a flag deletion + deleted_flag = { + "key": "new-flag", + "deleted": True, + } + client._process_flag_update(deleted_flag) + + # Verify callback was called for deletion (flag_data contains deleted=True) + self.assertEqual(len(callback_calls), 3) + self.assertEqual(callback_calls[2]["flag_key"], "new-flag") + self.assertTrue(callback_calls[2]["flag_data"]["deleted"]) + + # Cleanup + client.shutdown() + + @mock.patch("posthog.client.get") + def test_callback_exception_doesnt_break_flag_processing(self, mock_get): + """Test that exceptions in the callback don't break flag processing""" + # Setup mock for initial flag loading + mock_get.return_value = GetResponse( + data={ + "flags": [], + "group_type_mapping": {}, + "cohorts": {}, + } + ) + + def bad_callback(flag_key, flag_data): + raise Exception("Callback error!") + + client = Client( + FAKE_TEST_API_KEY, + personal_api_key="test_personal_key", + on_error=self.set_fail, + realtime_flags=True, + on_feature_flags_update=bad_callback, + ) + + # Load initial flags + client.load_feature_flags() + + # Simulate a new flag addition + new_flag = { + "id": 1, + "name": "New Test Flag", + "key": "new-flag", + "active": True, + "filters": {"groups": [{"rollout_percentage": 100}]}, + } + + # This should not raise an exception even though the callback does + client._process_flag_update(new_flag) + + # Verify flag was still added despite callback error + self.assertIn("new-flag", client.feature_flags_by_key) + self.assertEqual(len(client.feature_flags), 1) + + # Cleanup + client.shutdown() + + +if __name__ == "__main__": + unittest.main() diff --git a/realtime_flags_example.py b/realtime_flags_example.py new file mode 100644 index 00000000..32857759 --- /dev/null +++ b/realtime_flags_example.py @@ -0,0 +1,59 @@ +""" +Example demonstrating real-time feature flags with callbacks. + +This example shows how to: +1. Enable real-time feature flags +2. Listen to flag updates with a callback +3. React to flag changes in your application +""" + +from posthog import Posthog +import time + + +def on_flag_update(flag_key, flag_data): + """ + Callback function that gets called whenever a feature flag is updated. + + Args: + flag_key: The key of the flag that was updated + flag_data: The full flag data (includes 'deleted' field if deleted) + """ + if flag_data.get("deleted"): + print(f"šŸ—‘ļø Flag '{flag_key}' was deleted") + else: + is_active = flag_data.get("active", False) + status = "āœ… active" if is_active else "āŒ inactive" + print(f"šŸ”„ Flag '{flag_key}' was updated - {status}") + print(f" Name: {flag_data.get('name')}") + print(f" ID: {flag_data.get('id')}") + + +# Initialize PostHog with real-time flags enabled +posthog = Posthog( + project_api_key="", + personal_api_key="", # Required for real-time flags + host="https://us.i.posthog.com", # Or your self-hosted instance + realtime_flags=True, + on_feature_flags_update=on_flag_update, + debug=True, # Enable debug logging to see connection status +) + +# Load feature flags (this will also establish the SSE connection) +posthog.load_feature_flags() + +print("šŸš€ Real-time feature flags enabled!") +print("šŸ“” Listening for flag updates...") +print("šŸ’” Try updating a flag in the PostHog UI and watch it update here in real-time!") +print("\nPress Ctrl+C to stop\n") + +try: + # Keep the script running to receive updates + while True: + # Your application logic here + # You can check flags as normal, they will be updated automatically + time.sleep(1) +except KeyboardInterrupt: + print("\n\nšŸ‘‹ Shutting down...") + posthog.shutdown() + print("āœ… Cleanup complete")