Skip to content

Commit 1e60750

Browse files
committed
adds unit test
1 parent 82d0be2 commit 1e60750

File tree

1 file changed

+138
-0
lines changed

1 file changed

+138
-0
lines changed
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
#!/usr/bin/env python3
2+
3+
"""
4+
Unit tests for token federation functionality in the Databricks SQL connector.
5+
"""
6+
7+
import unittest
8+
from unittest.mock import patch, MagicMock
9+
import json
10+
from datetime import datetime, timezone, timedelta
11+
12+
from databricks.sql.auth.token_federation import (
13+
Token,
14+
DatabricksTokenFederationProvider,
15+
SimpleCredentialsProvider,
16+
create_token_federation_provider
17+
)
18+
19+
20+
class TestToken(unittest.TestCase):
21+
"""Tests for the Token class."""
22+
23+
def test_token_initialization(self):
24+
"""Test Token initialization."""
25+
token = Token("access_token_value", "Bearer", "refresh_token_value")
26+
self.assertEqual(token.access_token, "access_token_value")
27+
self.assertEqual(token.token_type, "Bearer")
28+
self.assertEqual(token.refresh_token, "refresh_token_value")
29+
30+
def test_token_is_expired(self):
31+
"""Test Token is_expired method."""
32+
# Token with expiry in the past
33+
past = datetime.now(tz=timezone.utc) - timedelta(hours=1)
34+
token = Token("access_token", "Bearer", expiry=past)
35+
self.assertTrue(token.is_expired())
36+
37+
# Token with expiry in the future
38+
future = datetime.now(tz=timezone.utc) + timedelta(hours=1)
39+
token = Token("access_token", "Bearer", expiry=future)
40+
self.assertFalse(token.is_expired())
41+
42+
def test_token_needs_refresh(self):
43+
"""Test Token needs_refresh method."""
44+
# Token with expiry in the past
45+
past = datetime.now(tz=timezone.utc) - timedelta(hours=1)
46+
token = Token("access_token", "Bearer", expiry=past)
47+
self.assertTrue(token.needs_refresh())
48+
49+
# Token with expiry in the near future (within refresh buffer)
50+
near_future = datetime.now(tz=timezone.utc) + timedelta(minutes=4)
51+
token = Token("access_token", "Bearer", expiry=near_future)
52+
self.assertTrue(token.needs_refresh())
53+
54+
# Token with expiry far in the future
55+
far_future = datetime.now(tz=timezone.utc) + timedelta(hours=1)
56+
token = Token("access_token", "Bearer", expiry=far_future)
57+
self.assertFalse(token.needs_refresh())
58+
59+
60+
class TestSimpleCredentialsProvider(unittest.TestCase):
61+
"""Tests for the SimpleCredentialsProvider class."""
62+
63+
def test_simple_credentials_provider(self):
64+
"""Test SimpleCredentialsProvider."""
65+
provider = SimpleCredentialsProvider("token_value", "Bearer", "custom_auth_type")
66+
self.assertEqual(provider.auth_type(), "custom_auth_type")
67+
68+
header_factory = provider()
69+
headers = header_factory()
70+
self.assertEqual(headers, {"Authorization": "Bearer token_value"})
71+
72+
73+
class TestTokenFederationProvider(unittest.TestCase):
74+
"""Tests for the DatabricksTokenFederationProvider class."""
75+
76+
def test_host_property(self):
77+
"""Test the host property of DatabricksTokenFederationProvider."""
78+
creds_provider = SimpleCredentialsProvider("token")
79+
federation_provider = DatabricksTokenFederationProvider(
80+
creds_provider, "example.com", "client_id"
81+
)
82+
self.assertEqual(federation_provider.host, "example.com")
83+
self.assertEqual(federation_provider.hostname, "example.com")
84+
85+
@patch('databricks.sql.auth.token_federation.requests.get')
86+
@patch('databricks.sql.auth.token_federation.get_oauth_endpoints')
87+
def test_init_oidc_discovery(self, mock_get_endpoints, mock_requests_get):
88+
"""Test _init_oidc_discovery method."""
89+
# Mock the get_oauth_endpoints function
90+
mock_endpoints = MagicMock()
91+
mock_endpoints.get_openid_config_url.return_value = "https://example.com/openid-config"
92+
mock_get_endpoints.return_value = mock_endpoints
93+
94+
# Mock the requests.get response
95+
mock_response = MagicMock()
96+
mock_response.status_code = 200
97+
mock_response.json.return_value = {"token_endpoint": "https://example.com/token"}
98+
mock_requests_get.return_value = mock_response
99+
100+
# Create the provider
101+
creds_provider = SimpleCredentialsProvider("token")
102+
federation_provider = DatabricksTokenFederationProvider(
103+
creds_provider, "example.com", "client_id"
104+
)
105+
106+
# Call the method
107+
federation_provider._init_oidc_discovery()
108+
109+
# Check if the token endpoint was set correctly
110+
self.assertEqual(federation_provider.token_endpoint, "https://example.com/token")
111+
112+
# Test fallback when discovery fails
113+
mock_requests_get.side_effect = Exception("Connection error")
114+
federation_provider.token_endpoint = None
115+
federation_provider._init_oidc_discovery()
116+
self.assertEqual(federation_provider.token_endpoint, "https://example.com/oidc/v1/token")
117+
118+
119+
class TestTokenFederationFactory(unittest.TestCase):
120+
"""Tests for the token federation factory function."""
121+
122+
def test_create_token_federation_provider(self):
123+
"""Test create_token_federation_provider function."""
124+
provider = create_token_federation_provider(
125+
"token_value", "example.com", "client_id", "Bearer"
126+
)
127+
128+
self.assertIsInstance(provider, DatabricksTokenFederationProvider)
129+
self.assertEqual(provider.hostname, "example.com")
130+
self.assertEqual(provider.identity_federation_client_id, "client_id")
131+
132+
# Test that the underlying credentials provider was set up correctly
133+
self.assertEqual(provider.credentials_provider.token, "token_value")
134+
self.assertEqual(provider.credentials_provider.token_type, "Bearer")
135+
136+
137+
if __name__ == "__main__":
138+
unittest.main()

0 commit comments

Comments
 (0)