Skip to content

Commit 9c7455d

Browse files
authored
Update AAD scope variable. (#42228)
1 parent 0497631 commit 9c7455d

File tree

5 files changed

+72
-2
lines changed

5 files changed

+72
-2
lines changed

sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ class _Constants:
5353
MAX_ITEM_BUFFER_VS_CONFIG_DEFAULT: int = 50000
5454
CIRCUIT_BREAKER_ENABLED_CONFIG: str = "AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER"
5555
CIRCUIT_BREAKER_ENABLED_CONFIG_DEFAULT: str = "False"
56+
AAD_SCOPE_OVERRIDE: str = "AZURE_COSMOS_AAD_SCOPE_OVERRIDE"
5657
# Only applicable when circuit breaker is enabled -------------------------
5758
CONSECUTIVE_ERROR_COUNT_TOLERATED_FOR_READ: str = "AZURE_COSMOS_CONSECUTIVE_ERROR_COUNT_TOLERATED_FOR_READ"
5859
CONSECUTIVE_ERROR_COUNT_TOLERATED_FOR_READ_DEFAULT: int = 10

sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,11 @@ def __init__( # pylint: disable=too-many-statements
201201

202202
credentials_policy = None
203203
if self.aad_credentials:
204-
scope = base.create_scope_from_url(self.url_connection)
204+
scope_override = os.environ.get(Constants.AAD_SCOPE_OVERRIDE, "")
205+
if scope_override:
206+
scope = scope_override
207+
else:
208+
scope = base.create_scope_from_url(self.url_connection)
205209
credentials_policy = CosmosBearerTokenCredentialPolicy(self.aad_credentials, scope)
206210

207211
policies = [

sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,11 @@ def __init__( # pylint: disable=too-many-statements
211211

212212
credentials_policy = None
213213
if self.aad_credentials:
214-
scope = base.create_scope_from_url(self.url_connection)
214+
scope_override = os.environ.get(Constants.AAD_SCOPE_OVERRIDE, "")
215+
if scope_override:
216+
scope = scope_override
217+
else:
218+
scope = base.create_scope_from_url(self.url_connection)
215219
credentials_policy = AsyncCosmosBearerTokenCredentialPolicy(self.aad_credentials, scope)
216220

217221
policies = [

sdk/cosmos/azure-cosmos/tests/test_aad.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import base64
55
import json
6+
import os
67
import time
78
import unittest
89
from io import StringIO
@@ -117,6 +118,33 @@ def test_aad_credentials(self):
117118
assert e.status_code == 403
118119
print("403 error assertion success")
119120

121+
def test_aad_scope_override(self):
122+
override_scope = "https://my.custom.scope/.default"
123+
os.environ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE"] = override_scope
124+
125+
scopes_captured = []
126+
original_get_token = CosmosEmulatorCredential.get_token
127+
128+
def capturing_get_token(self, *scopes, **kwargs):
129+
scopes_captured.extend(scopes)
130+
return original_get_token(self, *scopes, **kwargs)
131+
132+
CosmosEmulatorCredential.get_token = capturing_get_token
133+
134+
try:
135+
credential = CosmosEmulatorCredential()
136+
client = cosmos_client.CosmosClient(self.host, credential)
137+
db = client.get_database_client(self.configs.TEST_DATABASE_ID)
138+
container = db.get_container_client(self.configs.TEST_SINGLE_PARTITION_CONTAINER_ID)
139+
container.create_item(get_test_item(1))
140+
assert override_scope in scopes_captured
141+
finally:
142+
CosmosEmulatorCredential.get_token = original_get_token
143+
del os.environ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE"]
144+
try:
145+
container.delete_item(item='Item_1', partition_key='pk')
146+
except Exception:
147+
pass
120148

121149
if __name__ == "__main__":
122150
unittest.main()

sdk/cosmos/azure-cosmos/tests/test_aad_async.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import base64
55
import json
66
import time
7+
import os
78
import unittest
89
from io import StringIO
910

@@ -130,6 +131,38 @@ async def test_aad_credentials_async(self):
130131
assert e.status_code == 403
131132
print("403 error assertion success")
132133

134+
async def test_aad_scope_override_async(self):
135+
override_scope = "https://my.custom.scope/.default"
136+
os.environ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE"] = override_scope
137+
138+
scopes_captured = []
139+
original_get_token = CosmosEmulatorCredential.get_token
140+
141+
async def capturing_get_token(self, *scopes, **kwargs):
142+
scopes_captured.extend(scopes)
143+
# Await the original method!
144+
return await original_get_token(self, *scopes, **kwargs)
145+
146+
CosmosEmulatorCredential.get_token = capturing_get_token
147+
148+
try:
149+
credential = CosmosEmulatorCredential()
150+
client = CosmosClient(self.host, credential)
151+
database = client.get_database_client(self.configs.TEST_DATABASE_ID)
152+
container = database.get_container_client(self.configs.TEST_SINGLE_PARTITION_CONTAINER_ID)
153+
154+
await container.create_item(get_test_item(1))
155+
item = await container.read_item(item='Item_1', partition_key='pk')
156+
assert item["id"] == "Item_1"
157+
assert override_scope in scopes_captured
158+
finally:
159+
CosmosEmulatorCredential.get_token = original_get_token
160+
del os.environ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE"]
161+
try:
162+
await container.delete_item(item='Item_1', partition_key='pk')
163+
except Exception:
164+
pass
165+
await client.close()
133166

134167
if __name__ == "__main__":
135168
unittest.main()

0 commit comments

Comments
 (0)