|
10 | 10 | TelemetryClientFactory, |
11 | 11 | TelemetryHelper, |
12 | 12 | ) |
| 13 | +from databricks.sql.common.feature_flag import ( |
| 14 | + FeatureFlagsContextFactory, |
| 15 | + FeatureFlagsContext, |
| 16 | +) |
13 | 17 | from databricks.sql.telemetry.models.enums import AuthMech, AuthFlow, DatabricksClientType |
14 | 18 | from databricks.sql.telemetry.models.event import ( |
15 | 19 | TelemetryEvent, |
@@ -805,7 +809,67 @@ def test_cf_proxy_fields_default_to_false_none(self, mock_setup_pools, mock_sess |
805 | 809 |
|
806 | 810 | mock_export.assert_called_once() |
807 | 811 | driver_params = mock_export.call_args.kwargs.get("driver_connection_params") |
808 | | - |
| 812 | + |
809 | 813 | # CF proxy not yet supported - should be False/None |
810 | 814 | assert driver_params.use_cf_proxy is False |
811 | 815 | assert driver_params.cf_proxy_host_info is None |
| 816 | + |
| 817 | + |
| 818 | +class TestFeatureFlagsContextFactory: |
| 819 | + """Tests for FeatureFlagsContextFactory host-level caching.""" |
| 820 | + |
| 821 | + @pytest.fixture(autouse=True) |
| 822 | + def reset_factory(self): |
| 823 | + """Reset factory state before/after each test.""" |
| 824 | + FeatureFlagsContextFactory._context_map.clear() |
| 825 | + if FeatureFlagsContextFactory._executor: |
| 826 | + FeatureFlagsContextFactory._executor.shutdown(wait=False) |
| 827 | + FeatureFlagsContextFactory._executor = None |
| 828 | + yield |
| 829 | + FeatureFlagsContextFactory._context_map.clear() |
| 830 | + if FeatureFlagsContextFactory._executor: |
| 831 | + FeatureFlagsContextFactory._executor.shutdown(wait=False) |
| 832 | + FeatureFlagsContextFactory._executor = None |
| 833 | + |
| 834 | + @pytest.mark.parametrize( |
| 835 | + "hosts,expected_contexts", |
| 836 | + [ |
| 837 | + (["host1.com", "host1.com"], 1), # Same host shares context |
| 838 | + (["host1.com", "host2.com"], 2), # Different hosts get separate contexts |
| 839 | + (["host1.com", "host1.com", "host2.com"], 2), # Mixed scenario |
| 840 | + ], |
| 841 | + ) |
| 842 | + def test_host_level_caching(self, hosts, expected_contexts): |
| 843 | + """Test that contexts are cached by host correctly.""" |
| 844 | + contexts = [] |
| 845 | + for host in hosts: |
| 846 | + conn = MagicMock() |
| 847 | + conn.session.host = host |
| 848 | + conn.session.http_client = MagicMock() |
| 849 | + contexts.append(FeatureFlagsContextFactory.get_instance(conn)) |
| 850 | + |
| 851 | + assert len(FeatureFlagsContextFactory._context_map) == expected_contexts |
| 852 | + if expected_contexts == 1: |
| 853 | + assert all(ctx is contexts[0] for ctx in contexts) |
| 854 | + |
| 855 | + def test_remove_instance_and_executor_cleanup(self): |
| 856 | + """Test removal uses host key and cleans up executor when empty.""" |
| 857 | + conn1 = MagicMock() |
| 858 | + conn1.session.host = "host1.com" |
| 859 | + conn1.session.http_client = MagicMock() |
| 860 | + |
| 861 | + conn2 = MagicMock() |
| 862 | + conn2.session.host = "host2.com" |
| 863 | + conn2.session.http_client = MagicMock() |
| 864 | + |
| 865 | + FeatureFlagsContextFactory.get_instance(conn1) |
| 866 | + FeatureFlagsContextFactory.get_instance(conn2) |
| 867 | + assert FeatureFlagsContextFactory._executor is not None |
| 868 | + |
| 869 | + FeatureFlagsContextFactory.remove_instance(conn1) |
| 870 | + assert len(FeatureFlagsContextFactory._context_map) == 1 |
| 871 | + assert FeatureFlagsContextFactory._executor is not None |
| 872 | + |
| 873 | + FeatureFlagsContextFactory.remove_instance(conn2) |
| 874 | + assert len(FeatureFlagsContextFactory._context_map) == 0 |
| 875 | + assert FeatureFlagsContextFactory._executor is None |
0 commit comments