1+ #!/usr/bin/env python3
2+
3+ """
4+ Unit tests for token federation functionality in the Databricks SQL connector.
5+ """
6+
7+ import unittest
8+ from unittest .mock import patch , MagicMock
9+ import json
10+ from datetime import datetime , timezone , timedelta
11+
12+ from databricks .sql .auth .token_federation import (
13+ Token ,
14+ DatabricksTokenFederationProvider ,
15+ SimpleCredentialsProvider ,
16+ create_token_federation_provider
17+ )
18+
19+
20+ class TestToken (unittest .TestCase ):
21+ """Tests for the Token class."""
22+
23+ def test_token_initialization (self ):
24+ """Test Token initialization."""
25+ token = Token ("access_token_value" , "Bearer" , "refresh_token_value" )
26+ self .assertEqual (token .access_token , "access_token_value" )
27+ self .assertEqual (token .token_type , "Bearer" )
28+ self .assertEqual (token .refresh_token , "refresh_token_value" )
29+
30+ def test_token_is_expired (self ):
31+ """Test Token is_expired method."""
32+ # Token with expiry in the past
33+ past = datetime .now (tz = timezone .utc ) - timedelta (hours = 1 )
34+ token = Token ("access_token" , "Bearer" , expiry = past )
35+ self .assertTrue (token .is_expired ())
36+
37+ # Token with expiry in the future
38+ future = datetime .now (tz = timezone .utc ) + timedelta (hours = 1 )
39+ token = Token ("access_token" , "Bearer" , expiry = future )
40+ self .assertFalse (token .is_expired ())
41+
42+ def test_token_needs_refresh (self ):
43+ """Test Token needs_refresh method."""
44+ # Token with expiry in the past
45+ past = datetime .now (tz = timezone .utc ) - timedelta (hours = 1 )
46+ token = Token ("access_token" , "Bearer" , expiry = past )
47+ self .assertTrue (token .needs_refresh ())
48+
49+ # Token with expiry in the near future (within refresh buffer)
50+ near_future = datetime .now (tz = timezone .utc ) + timedelta (minutes = 4 )
51+ token = Token ("access_token" , "Bearer" , expiry = near_future )
52+ self .assertTrue (token .needs_refresh ())
53+
54+ # Token with expiry far in the future
55+ far_future = datetime .now (tz = timezone .utc ) + timedelta (hours = 1 )
56+ token = Token ("access_token" , "Bearer" , expiry = far_future )
57+ self .assertFalse (token .needs_refresh ())
58+
59+
60+ class TestSimpleCredentialsProvider (unittest .TestCase ):
61+ """Tests for the SimpleCredentialsProvider class."""
62+
63+ def test_simple_credentials_provider (self ):
64+ """Test SimpleCredentialsProvider."""
65+ provider = SimpleCredentialsProvider ("token_value" , "Bearer" , "custom_auth_type" )
66+ self .assertEqual (provider .auth_type (), "custom_auth_type" )
67+
68+ header_factory = provider ()
69+ headers = header_factory ()
70+ self .assertEqual (headers , {"Authorization" : "Bearer token_value" })
71+
72+
73+ class TestTokenFederationProvider (unittest .TestCase ):
74+ """Tests for the DatabricksTokenFederationProvider class."""
75+
76+ def test_host_property (self ):
77+ """Test the host property of DatabricksTokenFederationProvider."""
78+ creds_provider = SimpleCredentialsProvider ("token" )
79+ federation_provider = DatabricksTokenFederationProvider (
80+ creds_provider , "example.com" , "client_id"
81+ )
82+ self .assertEqual (federation_provider .host , "example.com" )
83+ self .assertEqual (federation_provider .hostname , "example.com" )
84+
85+ @patch ('databricks.sql.auth.token_federation.requests.get' )
86+ @patch ('databricks.sql.auth.token_federation.get_oauth_endpoints' )
87+ def test_init_oidc_discovery (self , mock_get_endpoints , mock_requests_get ):
88+ """Test _init_oidc_discovery method."""
89+ # Mock the get_oauth_endpoints function
90+ mock_endpoints = MagicMock ()
91+ mock_endpoints .get_openid_config_url .return_value = "https://example.com/openid-config"
92+ mock_get_endpoints .return_value = mock_endpoints
93+
94+ # Mock the requests.get response
95+ mock_response = MagicMock ()
96+ mock_response .status_code = 200
97+ mock_response .json .return_value = {"token_endpoint" : "https://example.com/token" }
98+ mock_requests_get .return_value = mock_response
99+
100+ # Create the provider
101+ creds_provider = SimpleCredentialsProvider ("token" )
102+ federation_provider = DatabricksTokenFederationProvider (
103+ creds_provider , "example.com" , "client_id"
104+ )
105+
106+ # Call the method
107+ federation_provider ._init_oidc_discovery ()
108+
109+ # Check if the token endpoint was set correctly
110+ self .assertEqual (federation_provider .token_endpoint , "https://example.com/token" )
111+
112+ # Test fallback when discovery fails
113+ mock_requests_get .side_effect = Exception ("Connection error" )
114+ federation_provider .token_endpoint = None
115+ federation_provider ._init_oidc_discovery ()
116+ self .assertEqual (federation_provider .token_endpoint , "https://example.com/oidc/v1/token" )
117+
118+
119+ class TestTokenFederationFactory (unittest .TestCase ):
120+ """Tests for the token federation factory function."""
121+
122+ def test_create_token_federation_provider (self ):
123+ """Test create_token_federation_provider function."""
124+ provider = create_token_federation_provider (
125+ "token_value" , "example.com" , "client_id" , "Bearer"
126+ )
127+
128+ self .assertIsInstance (provider , DatabricksTokenFederationProvider )
129+ self .assertEqual (provider .hostname , "example.com" )
130+ self .assertEqual (provider .identity_federation_client_id , "client_id" )
131+
132+ # Test that the underlying credentials provider was set up correctly
133+ self .assertEqual (provider .credentials_provider .token , "token_value" )
134+ self .assertEqual (provider .credentials_provider .token_type , "Bearer" )
135+
136+
137+ if __name__ == "__main__" :
138+ unittest .main ()
0 commit comments