diff --git a/test/asynchronous/test_auth_oidc.py b/test/asynchronous/test_auth_oidc.py index 639c155e73..ff604f55ae 100644 --- a/test/asynchronous/test_auth_oidc.py +++ b/test/asynchronous/test_auth_oidc.py @@ -30,7 +30,7 @@ sys.path[0:0] = [""] -from test.asynchronous.unified_format import generate_test_classes +from test.asynchronous.unified_format import generate_test_classes, get_test_path from test.utils_shared import EventListener, OvertCommandListener from bson import SON @@ -54,14 +54,13 @@ _IS_SYNC = False ROOT = Path(__file__).parent.parent.resolve() -TEST_PATH = ROOT / "auth" / "unified" ENVIRON = os.environ.get("OIDC_ENV", "test") DOMAIN = os.environ.get("OIDC_DOMAIN", "") TOKEN_DIR = os.environ.get("OIDC_TOKEN_DIR", "") TOKEN_FILE = os.environ.get("OIDC_TOKEN_FILE", "") # Generate unified tests. -globals().update(generate_test_classes(str(TEST_PATH), module=__name__)) +globals().update(generate_test_classes(get_test_path("auth", "unified"), module=__name__)) pytestmark = pytest.mark.auth_oidc diff --git a/test/asynchronous/test_auth_spec.py b/test/asynchronous/test_auth_spec.py index 7c659c6d93..a40687348c 100644 --- a/test/asynchronous/test_auth_spec.py +++ b/test/asynchronous/test_auth_spec.py @@ -27,7 +27,7 @@ sys.path[0:0] = [""] from test import unittest -from test.asynchronous.unified_format import generate_test_classes +from test.asynchronous.unified_format import generate_test_classes, get_test_path from pymongo import AsyncMongoClient from pymongo.auth_oidc_shared import OIDCCallback @@ -35,8 +35,7 @@ pytestmark = pytest.mark.auth _IS_SYNC = False - -_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "auth") +_TEST_PATH = get_test_path("auth") class TestAuthSpec(AsyncPyMongoTestCase): diff --git a/test/asynchronous/test_change_stream.py b/test/asynchronous/test_change_stream.py index 3fb8b517f3..3c687dcb90 100644 --- a/test/asynchronous/test_change_stream.py +++ b/test/asynchronous/test_change_stream.py @@ -35,7 +35,7 @@ async_client_context, unittest, ) -from test.asynchronous.unified_format import generate_test_classes +from test.asynchronous.unified_format import generate_test_classes, get_test_path from test.utils_shared import ( AllowListEventListener, EventListener, @@ -1143,12 +1143,9 @@ def asyncTearDown(self): self.listener.reset() -_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "change_streams") - - globals().update( generate_test_classes( - os.path.join(_TEST_PATH, "unified"), + get_test_path("change_streams", "unified"), module=__name__, ) ) diff --git a/test/asynchronous/test_client_metadata.py b/test/asynchronous/test_client_metadata.py index 2f175cceed..45c1bd1b3b 100644 --- a/test/asynchronous/test_client_metadata.py +++ b/test/asynchronous/test_client_metadata.py @@ -19,7 +19,7 @@ import time import unittest from test.asynchronous import AsyncIntegrationTest -from test.asynchronous.unified_format import generate_test_classes +from test.asynchronous.unified_format import generate_test_classes, get_test_path from test.utils_shared import CMAPListener from typing import Any, Optional @@ -40,16 +40,8 @@ _IS_SYNC = False -# Location of JSON test specifications. -if _IS_SYNC: - _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "handshake", "unified") -else: - _TEST_PATH = os.path.join( - pathlib.Path(__file__).resolve().parent.parent, "handshake", "unified" - ) - # Generate unified tests. -globals().update(generate_test_classes(_TEST_PATH, module=__name__)) +globals().update(generate_test_classes(get_test_path("handshake", "unified"), module=__name__)) def _get_handshake_driver_info(request): diff --git a/test/asynchronous/test_collection_management.py b/test/asynchronous/test_collection_management.py index c0edf91581..7a142dc65b 100644 --- a/test/asynchronous/test_collection_management.py +++ b/test/asynchronous/test_collection_management.py @@ -22,20 +22,12 @@ sys.path[0:0] = [""] from test import unittest -from test.asynchronous.unified_format import generate_test_classes +from test.asynchronous.unified_format import generate_test_classes, get_test_path _IS_SYNC = False -# Location of JSON test specifications. -if _IS_SYNC: - _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "collection_management") -else: - _TEST_PATH = os.path.join( - pathlib.Path(__file__).resolve().parent.parent, "collection_management" - ) - # Generate unified tests. -globals().update(generate_test_classes(_TEST_PATH, module=__name__)) +globals().update(generate_test_classes(get_test_path("collection_management"), module=__name__)) if __name__ == "__main__": unittest.main() diff --git a/test/asynchronous/test_command_logging.py b/test/asynchronous/test_command_logging.py index f9b459c152..831dd0c109 100644 --- a/test/asynchronous/test_command_logging.py +++ b/test/asynchronous/test_command_logging.py @@ -22,20 +22,13 @@ sys.path[0:0] = [""] from test import unittest -from test.asynchronous.unified_format import generate_test_classes +from test.asynchronous.unified_format import generate_test_classes, get_test_path _IS_SYNC = False -# Location of JSON test specifications. -if _IS_SYNC: - _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "command_logging") -else: - _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "command_logging") - - globals().update( generate_test_classes( - _TEST_PATH, + get_test_path("command_logging"), module=__name__, ) ) diff --git a/test/asynchronous/test_command_monitoring.py b/test/asynchronous/test_command_monitoring.py index 311fd1fdc1..a04ba449ae 100644 --- a/test/asynchronous/test_command_monitoring.py +++ b/test/asynchronous/test_command_monitoring.py @@ -22,20 +22,13 @@ sys.path[0:0] = [""] from test import unittest -from test.asynchronous.unified_format import generate_test_classes +from test.asynchronous.unified_format import generate_test_classes, get_test_path _IS_SYNC = False -# Location of JSON test specifications. -if _IS_SYNC: - _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "command_monitoring") -else: - _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "command_monitoring") - - globals().update( generate_test_classes( - _TEST_PATH, + get_test_path("command_monitoring"), module=__name__, ) ) diff --git a/test/asynchronous/test_connection_logging.py b/test/asynchronous/test_connection_logging.py index 945c6c59b5..4d03391dd2 100644 --- a/test/asynchronous/test_connection_logging.py +++ b/test/asynchronous/test_connection_logging.py @@ -22,20 +22,13 @@ sys.path[0:0] = [""] from test import unittest -from test.asynchronous.unified_format import generate_test_classes +from test.asynchronous.unified_format import generate_test_classes, get_test_path _IS_SYNC = False -# Location of JSON test specifications. -if _IS_SYNC: - _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "connection_logging") -else: - _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "connection_logging") - - globals().update( generate_test_classes( - _TEST_PATH, + get_test_path("connection_logging"), module=__name__, ) ) diff --git a/test/asynchronous/test_crud_unified.py b/test/asynchronous/test_crud_unified.py index 8b1f9b8e38..94e47a26e7 100644 --- a/test/asynchronous/test_crud_unified.py +++ b/test/asynchronous/test_crud_unified.py @@ -22,18 +22,12 @@ sys.path[0:0] = [""] from test import unittest -from test.asynchronous.unified_format import generate_test_classes +from test.asynchronous.unified_format import generate_test_classes, get_test_path _IS_SYNC = False -# Location of JSON test specifications. -if _IS_SYNC: - _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "crud", "unified") -else: - _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "crud", "unified") - # Generate unified tests. -globals().update(generate_test_classes(_TEST_PATH, module=__name__)) +globals().update(generate_test_classes(get_test_path("crud", "unified"), module=__name__)) if __name__ == "__main__": unittest.main() diff --git a/test/asynchronous/test_csot.py b/test/asynchronous/test_csot.py index a978d1ccc0..547ee20a54 100644 --- a/test/asynchronous/test_csot.py +++ b/test/asynchronous/test_csot.py @@ -22,7 +22,7 @@ sys.path[0:0] = [""] from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest -from test.asynchronous.unified_format import generate_test_classes +from test.asynchronous.unified_format import generate_test_classes, get_test_path from test.asynchronous.utils import flaky import pymongo @@ -31,14 +31,8 @@ _IS_SYNC = False -# Location of JSON test specifications. -if _IS_SYNC: - TEST_PATH = os.path.join(Path(__file__).resolve().parent, "csot") -else: - TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "csot") - # Generate unified tests. -globals().update(generate_test_classes(TEST_PATH, module=__name__)) +globals().update(generate_test_classes(get_test_path("csot"), module=__name__)) class TestCSOT(AsyncIntegrationTest): diff --git a/test/asynchronous/test_discovery_and_monitoring.py b/test/asynchronous/test_discovery_and_monitoring.py index 5820d00c48..0bbf471d87 100644 --- a/test/asynchronous/test_discovery_and_monitoring.py +++ b/test/asynchronous/test_discovery_and_monitoring.py @@ -40,7 +40,7 @@ unittest, ) from test.asynchronous.pymongo_mocks import DummyMonitor -from test.asynchronous.unified_format import generate_test_classes +from test.asynchronous.unified_format import generate_test_classes, get_test_path from test.asynchronous.utils import ( async_get_pool, ) @@ -76,14 +76,7 @@ _IS_SYNC = False -# Location of JSON test specifications. -if _IS_SYNC: - SDAM_PATH = os.path.join(Path(__file__).resolve().parent, "discovery_and_monitoring") -else: - SDAM_PATH = os.path.join( - Path(__file__).resolve().parent.parent, - "discovery_and_monitoring", - ) +SDAM_PATH = get_test_path("discovery_and_monitoring") async def create_mock_topology(uri, monitor_class=DummyMonitor): diff --git a/test/asynchronous/test_encryption.py b/test/asynchronous/test_encryption.py index 74c0136ad0..5a404196e9 100644 --- a/test/asynchronous/test_encryption.py +++ b/test/asynchronous/test_encryption.py @@ -54,7 +54,7 @@ unittest, ) from test.asynchronous.test_bulk import AsyncBulkTestBase -from test.asynchronous.unified_format import generate_test_classes +from test.asynchronous.unified_format import generate_test_classes, get_test_path from test.asynchronous.utils_spec_runner import AsyncSpecRunner from test.helpers_shared import ( ALL_KMS_PROVIDERS, @@ -275,11 +275,7 @@ def unmanaged_create_client_encryption( # Location of JSON test files. -if _IS_SYNC: - BASE = os.path.join(pathlib.Path(__file__).resolve().parent, "client-side-encryption") -else: - BASE = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "client-side-encryption") - +BASE = get_test_path("client-side-encryption") SPEC_PATH = os.path.join(BASE, "spec") OPTS = CodecOptions() diff --git a/test/asynchronous/test_gridfs_spec.py b/test/asynchronous/test_gridfs_spec.py index f3dc14fbdc..ab1c8a0ebb 100644 --- a/test/asynchronous/test_gridfs_spec.py +++ b/test/asynchronous/test_gridfs_spec.py @@ -22,18 +22,12 @@ sys.path[0:0] = [""] from test import unittest -from test.asynchronous.unified_format import generate_test_classes +from test.asynchronous.unified_format import generate_test_classes, get_test_path _IS_SYNC = False -# Location of JSON test specifications. -if _IS_SYNC: - TEST_PATH = os.path.join(Path(__file__).resolve().parent, "gridfs") -else: - TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "gridfs") - # Generate unified tests. -globals().update(generate_test_classes(TEST_PATH, module=__name__)) +globals().update(generate_test_classes(get_test_path("gridfs"), module=__name__)) if __name__ == "__main__": unittest.main() diff --git a/test/asynchronous/test_index_management.py b/test/asynchronous/test_index_management.py index 890788fc56..ac096ec099 100644 --- a/test/asynchronous/test_index_management.py +++ b/test/asynchronous/test_index_management.py @@ -28,7 +28,7 @@ sys.path[0:0] = [""] from test.asynchronous import AsyncIntegrationTest, AsyncPyMongoTestCase, unittest -from test.asynchronous.unified_format import generate_test_classes +from test.asynchronous.unified_format import generate_test_classes, get_test_path from test.utils_shared import AllowListEventListener, OvertCommandListener from pymongo.errors import OperationFailure @@ -40,12 +40,6 @@ pytestmark = pytest.mark.search_index -# Location of JSON test specifications. -if _IS_SYNC: - _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "index_management") -else: - _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "index_management") - _NAME = "test-search-index" @@ -370,7 +364,7 @@ async def test_case_7(self): globals().update( generate_test_classes( - _TEST_PATH, + get_test_path("index_management"), module=__name__, ) ) diff --git a/test/asynchronous/test_load_balancer.py b/test/asynchronous/test_load_balancer.py index 17d85841f9..8e1ee3e797 100644 --- a/test/asynchronous/test_load_balancer.py +++ b/test/asynchronous/test_load_balancer.py @@ -30,7 +30,7 @@ sys.path[0:0] = [""] from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest -from test.asynchronous.unified_format import generate_test_classes +from test.asynchronous.unified_format import generate_test_classes, get_test_path from test.utils_shared import ( async_wait_until, create_async_event, @@ -40,14 +40,8 @@ pytestmark = pytest.mark.load_balancer -# Location of JSON test specifications. -if _IS_SYNC: - _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "load_balancer") -else: - _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "load_balancer") - # Generate unified tests. -globals().update(generate_test_classes(_TEST_PATH, module=__name__)) +globals().update(generate_test_classes(get_test_path("load_balancer"), module=__name__)) class TestLB(AsyncIntegrationTest): diff --git a/test/asynchronous/test_read_write_concern_spec.py b/test/asynchronous/test_read_write_concern_spec.py index b5cb32932f..2d08de7804 100644 --- a/test/asynchronous/test_read_write_concern_spec.py +++ b/test/asynchronous/test_read_write_concern_spec.py @@ -24,7 +24,7 @@ sys.path[0:0] = [""] from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest -from test.asynchronous.unified_format import generate_test_classes +from test.asynchronous.unified_format import generate_test_classes, get_test_path from test.utils_shared import OvertCommandListener from pymongo import DESCENDING @@ -42,11 +42,7 @@ _IS_SYNC = False -# Location of JSON test specifications. -if _IS_SYNC: - TEST_PATH = os.path.join(Path(__file__).resolve().parent, "read_write_concern") -else: - TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "read_write_concern") +TEST_PATH = get_test_path("read_write_concern") class TestReadWriteConcernSpec(AsyncIntegrationTest): diff --git a/test/asynchronous/test_retryable_reads_unified.py b/test/asynchronous/test_retryable_reads_unified.py index e62d606810..3de8aa96a3 100644 --- a/test/asynchronous/test_retryable_reads_unified.py +++ b/test/asynchronous/test_retryable_reads_unified.py @@ -22,21 +22,15 @@ sys.path[0:0] = [""] from test import unittest -from test.asynchronous.unified_format import generate_test_classes +from test.asynchronous.unified_format import generate_test_classes, get_test_path _IS_SYNC = False -# Location of JSON test specifications. -if _IS_SYNC: - TEST_PATH = os.path.join(Path(__file__).resolve().parent, "retryable_reads/unified") -else: - TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "retryable_reads/unified") - # Generate unified tests. # PyMongo does not support MapReduce, ListDatabaseObjects or ListCollectionObjects. globals().update( generate_test_classes( - TEST_PATH, + get_test_path("retryable_reads", "unified"), module=__name__, expected_failures=["ListDatabaseObjects .*", "ListCollectionObjects .*", "MapReduce .*"], ) diff --git a/test/asynchronous/test_retryable_writes_unified.py b/test/asynchronous/test_retryable_writes_unified.py index bb493e6010..7d33c5252a 100644 --- a/test/asynchronous/test_retryable_writes_unified.py +++ b/test/asynchronous/test_retryable_writes_unified.py @@ -22,18 +22,14 @@ sys.path[0:0] = [""] from test import unittest -from test.asynchronous.unified_format import generate_test_classes +from test.asynchronous.unified_format import generate_test_classes, get_test_path _IS_SYNC = False -# Location of JSON test specifications. -if _IS_SYNC: - TEST_PATH = os.path.join(Path(__file__).resolve().parent, "retryable_writes/unified") -else: - TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "retryable_writes/unified") - # Generate unified tests. -globals().update(generate_test_classes(TEST_PATH, module=__name__)) +globals().update( + generate_test_classes(get_test_path("retryable_writes", "unified"), module=__name__) +) if __name__ == "__main__": unittest.main() diff --git a/test/asynchronous/test_run_command.py b/test/asynchronous/test_run_command.py index 3ac8c32706..cfd1adfab7 100644 --- a/test/asynchronous/test_run_command.py +++ b/test/asynchronous/test_run_command.py @@ -18,20 +18,13 @@ import os import unittest from pathlib import Path -from test.asynchronous.unified_format import generate_test_classes +from test.asynchronous.unified_format import generate_test_classes, get_test_path _IS_SYNC = False -# Location of JSON test specifications. -if _IS_SYNC: - TEST_PATH = os.path.join(Path(__file__).resolve().parent, "run_command") -else: - TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "run_command") - - globals().update( generate_test_classes( - os.path.join(TEST_PATH, "unified"), + get_test_path("run_command", "unified"), module=__name__, ) ) diff --git a/test/asynchronous/test_server_selection_logging.py b/test/asynchronous/test_server_selection_logging.py index 6b0975318a..6f3ea207f4 100644 --- a/test/asynchronous/test_server_selection_logging.py +++ b/test/asynchronous/test_server_selection_logging.py @@ -22,20 +22,14 @@ sys.path[0:0] = [""] from test import unittest -from test.asynchronous.unified_format import generate_test_classes +from test.asynchronous.unified_format import generate_test_classes, get_test_path _IS_SYNC = False -# Location of JSON test specifications. -if _IS_SYNC: - TEST_PATH = os.path.join(Path(__file__).resolve().parent, "server_selection_logging") -else: - TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "server_selection_logging") - globals().update( generate_test_classes( - TEST_PATH, + get_test_path("server_selection_logging"), module=__name__, ) ) diff --git a/test/asynchronous/test_sessions_unified.py b/test/asynchronous/test_sessions_unified.py index b4cbac5704..ee2b4d418a 100644 --- a/test/asynchronous/test_sessions_unified.py +++ b/test/asynchronous/test_sessions_unified.py @@ -22,19 +22,12 @@ sys.path[0:0] = [""] from test import unittest -from test.asynchronous.unified_format import generate_test_classes +from test.asynchronous.unified_format import generate_test_classes, get_test_path _IS_SYNC = False -# Location of JSON test specifications. -if _IS_SYNC: - TEST_PATH = os.path.join(Path(__file__).resolve().parent, "sessions") -else: - TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "sessions") - - # Generate unified tests. -globals().update(generate_test_classes(TEST_PATH, module=__name__)) +globals().update(generate_test_classes(get_test_path("sessions"), module=__name__)) if __name__ == "__main__": unittest.main() diff --git a/test/asynchronous/test_transactions_unified.py b/test/asynchronous/test_transactions_unified.py index 8e5b1ae181..5f9b5d0225 100644 --- a/test/asynchronous/test_transactions_unified.py +++ b/test/asynchronous/test_transactions_unified.py @@ -22,7 +22,7 @@ sys.path[0:0] = [""] from test import client_context, unittest -from test.asynchronous.unified_format import generate_test_classes +from test.asynchronous.unified_format import generate_test_classes, get_test_path _IS_SYNC = False @@ -31,25 +31,13 @@ def setUpModule(): pass -# Location of JSON test specifications. -if _IS_SYNC: - TEST_PATH = os.path.join(Path(__file__).resolve().parent, "transactions/unified") -else: - TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "transactions/unified") - # Generate unified tests. -globals().update(generate_test_classes(TEST_PATH, module=__name__)) - -# Location of JSON test specifications for transactions-convenient-api. -if _IS_SYNC: - TEST_PATH = os.path.join(Path(__file__).resolve().parent, "transactions-convenient-api/unified") -else: - TEST_PATH = os.path.join( - Path(__file__).resolve().parent.parent, "transactions-convenient-api/unified" - ) +globals().update(generate_test_classes(get_test_path("transactions/unified"), module=__name__)) # Generate unified tests. -globals().update(generate_test_classes(TEST_PATH, module=__name__)) +globals().update( + generate_test_classes(get_test_path("transactions-convenient-api/unified"), module=__name__) +) if __name__ == "__main__": unittest.main() diff --git a/test/asynchronous/test_unified_format.py b/test/asynchronous/test_unified_format.py index 58a1ea3326..8136641236 100644 --- a/test/asynchronous/test_unified_format.py +++ b/test/asynchronous/test_unified_format.py @@ -21,18 +21,18 @@ sys.path[0:0] = [""] from test import UnitTest, unittest -from test.asynchronous.unified_format import MatchEvaluatorUtil, generate_test_classes +from test.asynchronous.unified_format import ( + MatchEvaluatorUtil, + generate_test_classes, + get_test_path, +) from bson import ObjectId _IS_SYNC = False # Location of JSON test specifications. -if _IS_SYNC: - TEST_PATH = os.path.join(Path(__file__).resolve().parent, "unified-test-format") -else: - TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "unified-test-format") - +TEST_PATH = get_test_path("unified-test-format") globals().update( generate_test_classes( diff --git a/test/asynchronous/test_versioned_api_integration.py b/test/asynchronous/test_versioned_api_integration.py index 0f6b544465..7228b945ab 100644 --- a/test/asynchronous/test_versioned_api_integration.py +++ b/test/asynchronous/test_versioned_api_integration.py @@ -16,7 +16,7 @@ import os import sys from pathlib import Path -from test.asynchronous.unified_format import generate_test_classes +from test.asynchronous.unified_format import generate_test_classes, get_test_path sys.path[0:0] = [""] @@ -27,15 +27,8 @@ _IS_SYNC = False -# Location of JSON test specifications. -if _IS_SYNC: - TEST_PATH = os.path.join(Path(__file__).resolve().parent, "versioned-api") -else: - TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "versioned-api") - - # Generate unified tests. -globals().update(generate_test_classes(TEST_PATH, module=__name__)) +globals().update(generate_test_classes(get_test_path("versioned-api"), module=__name__)) class TestServerApiIntegration(AsyncIntegrationTest): diff --git a/test/asynchronous/unified_format.py b/test/asynchronous/unified_format.py index 0c9e8c10c8..e308833367 100644 --- a/test/asynchronous/unified_format.py +++ b/test/asynchronous/unified_format.py @@ -29,6 +29,7 @@ import traceback from collections import defaultdict from inspect import iscoroutinefunction +from pathlib import Path from test.asynchronous import ( AsyncIntegrationTest, async_client_context, @@ -1564,6 +1565,14 @@ async def test_case(self): } +def get_test_path(*args): + if _IS_SYNC: + root_dir = Path(__file__).resolve().parent + else: + root_dir = Path(__file__).resolve().parent.parent + return os.path.join(root_dir, *args) + + def generate_test_classes( test_path, module=__name__, @@ -1596,10 +1605,12 @@ class SpecTestBase(with_metaclass(UnifiedSpecTestMeta)): # type: ignore return base + found_any = False for dirpath, _, filenames in os.walk(test_path): dirname = os.path.split(dirpath)[-1] for filename in filenames: + found_any = True fpath = os.path.join(dirpath, filename) with open(fpath) as scenario_stream: # Use tz_aware=False to match how CodecOptions decodes @@ -1637,4 +1648,7 @@ class SpecTestBase(with_metaclass(UnifiedSpecTestMeta)): # type: ignore continue raise + if not found_any: + raise ValueError(f"No test files found in {test_path}") + return test_klasses diff --git a/test/test_auth_oidc.py b/test/test_auth_oidc.py index 877a5ca981..1defe82006 100644 --- a/test/test_auth_oidc.py +++ b/test/test_auth_oidc.py @@ -30,7 +30,7 @@ sys.path[0:0] = [""] -from test.unified_format import generate_test_classes +from test.unified_format import generate_test_classes, get_test_path from test.utils_shared import EventListener, OvertCommandListener from bson import SON @@ -54,14 +54,13 @@ _IS_SYNC = True ROOT = Path(__file__).parent.parent.resolve() -TEST_PATH = ROOT / "auth" / "unified" ENVIRON = os.environ.get("OIDC_ENV", "test") DOMAIN = os.environ.get("OIDC_DOMAIN", "") TOKEN_DIR = os.environ.get("OIDC_TOKEN_DIR", "") TOKEN_FILE = os.environ.get("OIDC_TOKEN_FILE", "") # Generate unified tests. -globals().update(generate_test_classes(str(TEST_PATH), module=__name__)) +globals().update(generate_test_classes(get_test_path("auth", "unified"), module=__name__)) pytestmark = pytest.mark.auth_oidc diff --git a/test/test_auth_spec.py b/test/test_auth_spec.py index ac6411cd89..93c5e7666d 100644 --- a/test/test_auth_spec.py +++ b/test/test_auth_spec.py @@ -27,7 +27,7 @@ sys.path[0:0] = [""] from test import unittest -from test.unified_format import generate_test_classes +from test.unified_format import generate_test_classes, get_test_path from pymongo import MongoClient from pymongo.auth_oidc_shared import OIDCCallback @@ -35,8 +35,7 @@ pytestmark = pytest.mark.auth _IS_SYNC = True - -_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "auth") +_TEST_PATH = get_test_path("auth") class TestAuthSpec(PyMongoTestCase): diff --git a/test/test_change_stream.py b/test/test_change_stream.py index ad51f91873..ae00b90a85 100644 --- a/test/test_change_stream.py +++ b/test/test_change_stream.py @@ -35,7 +35,7 @@ client_context, unittest, ) -from test.unified_format import generate_test_classes +from test.unified_format import generate_test_classes, get_test_path from test.utils_shared import ( AllowListEventListener, EventListener, @@ -1123,12 +1123,9 @@ def tearDown(self): self.listener.reset() -_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "change_streams") - - globals().update( generate_test_classes( - os.path.join(_TEST_PATH, "unified"), + get_test_path("change_streams", "unified"), module=__name__, ) ) diff --git a/test/test_client_metadata.py b/test/test_client_metadata.py index a94c5aa25e..5f103f739a 100644 --- a/test/test_client_metadata.py +++ b/test/test_client_metadata.py @@ -19,7 +19,7 @@ import time import unittest from test import IntegrationTest -from test.unified_format import generate_test_classes +from test.unified_format import generate_test_classes, get_test_path from test.utils_shared import CMAPListener from typing import Any, Optional @@ -40,16 +40,8 @@ _IS_SYNC = True -# Location of JSON test specifications. -if _IS_SYNC: - _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "handshake", "unified") -else: - _TEST_PATH = os.path.join( - pathlib.Path(__file__).resolve().parent.parent, "handshake", "unified" - ) - # Generate unified tests. -globals().update(generate_test_classes(_TEST_PATH, module=__name__)) +globals().update(generate_test_classes(get_test_path("handshake", "unified"), module=__name__)) def _get_handshake_driver_info(request): diff --git a/test/test_collection_management.py b/test/test_collection_management.py index 063c20df8f..deb43677a6 100644 --- a/test/test_collection_management.py +++ b/test/test_collection_management.py @@ -22,20 +22,12 @@ sys.path[0:0] = [""] from test import unittest -from test.unified_format import generate_test_classes +from test.unified_format import generate_test_classes, get_test_path _IS_SYNC = True -# Location of JSON test specifications. -if _IS_SYNC: - _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "collection_management") -else: - _TEST_PATH = os.path.join( - pathlib.Path(__file__).resolve().parent.parent, "collection_management" - ) - # Generate unified tests. -globals().update(generate_test_classes(_TEST_PATH, module=__name__)) +globals().update(generate_test_classes(get_test_path("collection_management"), module=__name__)) if __name__ == "__main__": unittest.main() diff --git a/test/test_command_logging.py b/test/test_command_logging.py index cf865920ca..17bc319d9c 100644 --- a/test/test_command_logging.py +++ b/test/test_command_logging.py @@ -22,20 +22,13 @@ sys.path[0:0] = [""] from test import unittest -from test.unified_format import generate_test_classes +from test.unified_format import generate_test_classes, get_test_path _IS_SYNC = True -# Location of JSON test specifications. -if _IS_SYNC: - _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "command_logging") -else: - _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "command_logging") - - globals().update( generate_test_classes( - _TEST_PATH, + get_test_path("command_logging"), module=__name__, ) ) diff --git a/test/test_command_monitoring.py b/test/test_command_monitoring.py index 4f5ef06f28..eaa2af5ee8 100644 --- a/test/test_command_monitoring.py +++ b/test/test_command_monitoring.py @@ -22,20 +22,13 @@ sys.path[0:0] = [""] from test import unittest -from test.unified_format import generate_test_classes +from test.unified_format import generate_test_classes, get_test_path _IS_SYNC = True -# Location of JSON test specifications. -if _IS_SYNC: - _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "command_monitoring") -else: - _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "command_monitoring") - - globals().update( generate_test_classes( - _TEST_PATH, + get_test_path("command_monitoring"), module=__name__, ) ) diff --git a/test/test_connection_logging.py b/test/test_connection_logging.py index 253193cc43..9f5da0c436 100644 --- a/test/test_connection_logging.py +++ b/test/test_connection_logging.py @@ -22,20 +22,13 @@ sys.path[0:0] = [""] from test import unittest -from test.unified_format import generate_test_classes +from test.unified_format import generate_test_classes, get_test_path _IS_SYNC = True -# Location of JSON test specifications. -if _IS_SYNC: - _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "connection_logging") -else: - _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "connection_logging") - - globals().update( generate_test_classes( - _TEST_PATH, + get_test_path("connection_logging"), module=__name__, ) ) diff --git a/test/test_crud_unified.py b/test/test_crud_unified.py index 1b1abf3600..45af155bd7 100644 --- a/test/test_crud_unified.py +++ b/test/test_crud_unified.py @@ -22,18 +22,12 @@ sys.path[0:0] = [""] from test import unittest -from test.unified_format import generate_test_classes +from test.unified_format import generate_test_classes, get_test_path _IS_SYNC = True -# Location of JSON test specifications. -if _IS_SYNC: - _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "crud", "unified") -else: - _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "crud", "unified") - # Generate unified tests. -globals().update(generate_test_classes(_TEST_PATH, module=__name__)) +globals().update(generate_test_classes(get_test_path("crud", "unified"), module=__name__)) if __name__ == "__main__": unittest.main() diff --git a/test/test_csot.py b/test/test_csot.py index 981af1ed03..d6dec51d37 100644 --- a/test/test_csot.py +++ b/test/test_csot.py @@ -22,7 +22,7 @@ sys.path[0:0] = [""] from test import IntegrationTest, client_context, unittest -from test.unified_format import generate_test_classes +from test.unified_format import generate_test_classes, get_test_path from test.utils import flaky import pymongo @@ -31,14 +31,8 @@ _IS_SYNC = True -# Location of JSON test specifications. -if _IS_SYNC: - TEST_PATH = os.path.join(Path(__file__).resolve().parent, "csot") -else: - TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "csot") - # Generate unified tests. -globals().update(generate_test_classes(TEST_PATH, module=__name__)) +globals().update(generate_test_classes(get_test_path("csot"), module=__name__)) class TestCSOT(IntegrationTest): diff --git a/test/test_discovery_and_monitoring.py b/test/test_discovery_and_monitoring.py index 67a82996bd..8375d63e97 100644 --- a/test/test_discovery_and_monitoring.py +++ b/test/test_discovery_and_monitoring.py @@ -40,7 +40,7 @@ unittest, ) from test.pymongo_mocks import DummyMonitor -from test.unified_format import generate_test_classes +from test.unified_format import generate_test_classes, get_test_path from test.utils import ( get_pool, ) @@ -76,14 +76,7 @@ _IS_SYNC = True -# Location of JSON test specifications. -if _IS_SYNC: - SDAM_PATH = os.path.join(Path(__file__).resolve().parent, "discovery_and_monitoring") -else: - SDAM_PATH = os.path.join( - Path(__file__).resolve().parent.parent, - "discovery_and_monitoring", - ) +SDAM_PATH = get_test_path("discovery_and_monitoring") def create_mock_topology(uri, monitor_class=DummyMonitor): diff --git a/test/test_encryption.py b/test/test_encryption.py index 04e61b7bad..90773972df 100644 --- a/test/test_encryption.py +++ b/test/test_encryption.py @@ -66,7 +66,7 @@ LOCAL_MASTER_KEY, ) from test.test_bulk import BulkTestBase -from test.unified_format import generate_test_classes +from test.unified_format import generate_test_classes, get_test_path from test.utils_shared import ( AllowListEventListener, OvertCommandListener, @@ -275,11 +275,7 @@ def unmanaged_create_client_encryption( # Location of JSON test files. -if _IS_SYNC: - BASE = os.path.join(pathlib.Path(__file__).resolve().parent, "client-side-encryption") -else: - BASE = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "client-side-encryption") - +BASE = get_test_path("client-side-encryption") SPEC_PATH = os.path.join(BASE, "spec") OPTS = CodecOptions() diff --git a/test/test_gridfs_spec.py b/test/test_gridfs_spec.py index e84e19725e..8e1a37364e 100644 --- a/test/test_gridfs_spec.py +++ b/test/test_gridfs_spec.py @@ -22,18 +22,12 @@ sys.path[0:0] = [""] from test import unittest -from test.unified_format import generate_test_classes +from test.unified_format import generate_test_classes, get_test_path _IS_SYNC = True -# Location of JSON test specifications. -if _IS_SYNC: - TEST_PATH = os.path.join(Path(__file__).resolve().parent, "gridfs") -else: - TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "gridfs") - # Generate unified tests. -globals().update(generate_test_classes(TEST_PATH, module=__name__)) +globals().update(generate_test_classes(get_test_path("gridfs"), module=__name__)) if __name__ == "__main__": unittest.main() diff --git a/test/test_index_management.py b/test/test_index_management.py index dea8c0e2be..2d723bb4a3 100644 --- a/test/test_index_management.py +++ b/test/test_index_management.py @@ -28,7 +28,7 @@ sys.path[0:0] = [""] from test import IntegrationTest, PyMongoTestCase, unittest -from test.unified_format import generate_test_classes +from test.unified_format import generate_test_classes, get_test_path from test.utils_shared import AllowListEventListener, OvertCommandListener from pymongo.errors import OperationFailure @@ -40,12 +40,6 @@ pytestmark = pytest.mark.search_index -# Location of JSON test specifications. -if _IS_SYNC: - _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "index_management") -else: - _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "index_management") - _NAME = "test-search-index" @@ -370,7 +364,7 @@ def test_case_7(self): globals().update( generate_test_classes( - _TEST_PATH, + get_test_path("index_management"), module=__name__, ) ) diff --git a/test/test_load_balancer.py b/test/test_load_balancer.py index 472ef51da3..41663d988b 100644 --- a/test/test_load_balancer.py +++ b/test/test_load_balancer.py @@ -30,7 +30,7 @@ sys.path[0:0] = [""] from test import IntegrationTest, client_context, unittest -from test.unified_format import generate_test_classes +from test.unified_format import generate_test_classes, get_test_path from test.utils_shared import ( create_event, wait_until, @@ -40,14 +40,8 @@ pytestmark = pytest.mark.load_balancer -# Location of JSON test specifications. -if _IS_SYNC: - _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "load_balancer") -else: - _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "load_balancer") - # Generate unified tests. -globals().update(generate_test_classes(_TEST_PATH, module=__name__)) +globals().update(generate_test_classes(get_test_path("load_balancer"), module=__name__)) class TestLB(IntegrationTest): diff --git a/test/test_read_write_concern_spec.py b/test/test_read_write_concern_spec.py index 4b816b7af9..54946c3ea0 100644 --- a/test/test_read_write_concern_spec.py +++ b/test/test_read_write_concern_spec.py @@ -24,7 +24,7 @@ sys.path[0:0] = [""] from test import IntegrationTest, client_context, unittest -from test.unified_format import generate_test_classes +from test.unified_format import generate_test_classes, get_test_path from test.utils_shared import OvertCommandListener from pymongo import DESCENDING @@ -42,11 +42,7 @@ _IS_SYNC = True -# Location of JSON test specifications. -if _IS_SYNC: - TEST_PATH = os.path.join(Path(__file__).resolve().parent, "read_write_concern") -else: - TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "read_write_concern") +TEST_PATH = get_test_path("read_write_concern") class TestReadWriteConcernSpec(IntegrationTest): diff --git a/test/test_retryable_reads_unified.py b/test/test_retryable_reads_unified.py index b1c6435c9a..c47d89d045 100644 --- a/test/test_retryable_reads_unified.py +++ b/test/test_retryable_reads_unified.py @@ -22,21 +22,15 @@ sys.path[0:0] = [""] from test import unittest -from test.unified_format import generate_test_classes +from test.unified_format import generate_test_classes, get_test_path _IS_SYNC = True -# Location of JSON test specifications. -if _IS_SYNC: - TEST_PATH = os.path.join(Path(__file__).resolve().parent, "retryable_reads/unified") -else: - TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "retryable_reads/unified") - # Generate unified tests. # PyMongo does not support MapReduce, ListDatabaseObjects or ListCollectionObjects. globals().update( generate_test_classes( - TEST_PATH, + get_test_path("retryable_reads", "unified"), module=__name__, expected_failures=["ListDatabaseObjects .*", "ListCollectionObjects .*", "MapReduce .*"], ) diff --git a/test/test_retryable_writes_unified.py b/test/test_retryable_writes_unified.py index 036c410e24..d06ee206fc 100644 --- a/test/test_retryable_writes_unified.py +++ b/test/test_retryable_writes_unified.py @@ -22,18 +22,14 @@ sys.path[0:0] = [""] from test import unittest -from test.unified_format import generate_test_classes +from test.unified_format import generate_test_classes, get_test_path _IS_SYNC = True -# Location of JSON test specifications. -if _IS_SYNC: - TEST_PATH = os.path.join(Path(__file__).resolve().parent, "retryable_writes/unified") -else: - TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "retryable_writes/unified") - # Generate unified tests. -globals().update(generate_test_classes(TEST_PATH, module=__name__)) +globals().update( + generate_test_classes(get_test_path("retryable_writes", "unified"), module=__name__) +) if __name__ == "__main__": unittest.main() diff --git a/test/test_run_command.py b/test/test_run_command.py index d2ef43b97e..df835fb6d7 100644 --- a/test/test_run_command.py +++ b/test/test_run_command.py @@ -18,20 +18,13 @@ import os import unittest from pathlib import Path -from test.unified_format import generate_test_classes +from test.unified_format import generate_test_classes, get_test_path _IS_SYNC = True -# Location of JSON test specifications. -if _IS_SYNC: - TEST_PATH = os.path.join(Path(__file__).resolve().parent, "run_command") -else: - TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "run_command") - - globals().update( generate_test_classes( - os.path.join(TEST_PATH, "unified"), + get_test_path("run_command", "unified"), module=__name__, ) ) diff --git a/test/test_server_selection_logging.py b/test/test_server_selection_logging.py index d53d8dc84f..c48e166a19 100644 --- a/test/test_server_selection_logging.py +++ b/test/test_server_selection_logging.py @@ -22,20 +22,14 @@ sys.path[0:0] = [""] from test import unittest -from test.unified_format import generate_test_classes +from test.unified_format import generate_test_classes, get_test_path _IS_SYNC = True -# Location of JSON test specifications. -if _IS_SYNC: - TEST_PATH = os.path.join(Path(__file__).resolve().parent, "server_selection_logging") -else: - TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "server_selection_logging") - globals().update( generate_test_classes( - TEST_PATH, + get_test_path("server_selection_logging"), module=__name__, ) ) diff --git a/test/test_sessions_unified.py b/test/test_sessions_unified.py index 3c80c70d38..3d15fac85f 100644 --- a/test/test_sessions_unified.py +++ b/test/test_sessions_unified.py @@ -22,19 +22,12 @@ sys.path[0:0] = [""] from test import unittest -from test.unified_format import generate_test_classes +from test.unified_format import generate_test_classes, get_test_path _IS_SYNC = True -# Location of JSON test specifications. -if _IS_SYNC: - TEST_PATH = os.path.join(Path(__file__).resolve().parent, "sessions") -else: - TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "sessions") - - # Generate unified tests. -globals().update(generate_test_classes(TEST_PATH, module=__name__)) +globals().update(generate_test_classes(get_test_path("sessions"), module=__name__)) if __name__ == "__main__": unittest.main() diff --git a/test/test_transactions_unified.py b/test/test_transactions_unified.py index 4ab4885e2a..05e4a1e5c3 100644 --- a/test/test_transactions_unified.py +++ b/test/test_transactions_unified.py @@ -22,7 +22,7 @@ sys.path[0:0] = [""] from test import client_context, unittest -from test.unified_format import generate_test_classes +from test.unified_format import generate_test_classes, get_test_path _IS_SYNC = True @@ -31,25 +31,13 @@ def setUpModule(): pass -# Location of JSON test specifications. -if _IS_SYNC: - TEST_PATH = os.path.join(Path(__file__).resolve().parent, "transactions/unified") -else: - TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "transactions/unified") - # Generate unified tests. -globals().update(generate_test_classes(TEST_PATH, module=__name__)) - -# Location of JSON test specifications for transactions-convenient-api. -if _IS_SYNC: - TEST_PATH = os.path.join(Path(__file__).resolve().parent, "transactions-convenient-api/unified") -else: - TEST_PATH = os.path.join( - Path(__file__).resolve().parent.parent, "transactions-convenient-api/unified" - ) +globals().update(generate_test_classes(get_test_path("transactions/unified"), module=__name__)) # Generate unified tests. -globals().update(generate_test_classes(TEST_PATH, module=__name__)) +globals().update( + generate_test_classes(get_test_path("transactions-convenient-api/unified"), module=__name__) +) if __name__ == "__main__": unittest.main() diff --git a/test/test_unified_format.py b/test/test_unified_format.py index f1cfd0139b..a55f810473 100644 --- a/test/test_unified_format.py +++ b/test/test_unified_format.py @@ -21,18 +21,18 @@ sys.path[0:0] = [""] from test import UnitTest, unittest -from test.unified_format import MatchEvaluatorUtil, generate_test_classes +from test.unified_format import ( + MatchEvaluatorUtil, + generate_test_classes, + get_test_path, +) from bson import ObjectId _IS_SYNC = True # Location of JSON test specifications. -if _IS_SYNC: - TEST_PATH = os.path.join(Path(__file__).resolve().parent, "unified-test-format") -else: - TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "unified-test-format") - +TEST_PATH = get_test_path("unified-test-format") globals().update( generate_test_classes( diff --git a/test/test_versioned_api_integration.py b/test/test_versioned_api_integration.py index 066a1935ca..c4ee7856f3 100644 --- a/test/test_versioned_api_integration.py +++ b/test/test_versioned_api_integration.py @@ -16,7 +16,7 @@ import os import sys from pathlib import Path -from test.unified_format import generate_test_classes +from test.unified_format import generate_test_classes, get_test_path sys.path[0:0] = [""] @@ -27,15 +27,8 @@ _IS_SYNC = True -# Location of JSON test specifications. -if _IS_SYNC: - TEST_PATH = os.path.join(Path(__file__).resolve().parent, "versioned-api") -else: - TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "versioned-api") - - # Generate unified tests. -globals().update(generate_test_classes(TEST_PATH, module=__name__)) +globals().update(generate_test_classes(get_test_path("versioned-api"), module=__name__)) class TestServerApiIntegration(IntegrationTest): diff --git a/test/unified_format.py b/test/unified_format.py index 0c5f68edd3..277783a9a3 100644 --- a/test/unified_format.py +++ b/test/unified_format.py @@ -29,6 +29,7 @@ import traceback from collections import defaultdict from inspect import iscoroutinefunction +from pathlib import Path from test import ( IntegrationTest, client_context, @@ -1549,6 +1550,14 @@ def test_case(self): } +def get_test_path(*args): + if _IS_SYNC: + root_dir = Path(__file__).resolve().parent + else: + root_dir = Path(__file__).resolve().parent.parent + return os.path.join(root_dir, *args) + + def generate_test_classes( test_path, module=__name__, @@ -1581,10 +1590,12 @@ class SpecTestBase(with_metaclass(UnifiedSpecTestMeta)): # type: ignore return base + found_any = False for dirpath, _, filenames in os.walk(test_path): dirname = os.path.split(dirpath)[-1] for filename in filenames: + found_any = True fpath = os.path.join(dirpath, filename) with open(fpath) as scenario_stream: # Use tz_aware=False to match how CodecOptions decodes @@ -1622,4 +1633,7 @@ class SpecTestBase(with_metaclass(UnifiedSpecTestMeta)): # type: ignore continue raise + if not found_any: + raise ValueError(f"No test files found in {test_path}") + return test_klasses