diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 8899a05e5a..e87804a3e3 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,6 +1,6 @@ - [Introduction](#introduction) - [The contribution process](#the-contribution-process) - * [Preparing pull requests](#preparing-pull-requests) + - [Preparing pull requests](#preparing-pull-requests) 1. [Checking the coding style](#checking-the-coding-style) 1. [Unit testing](#unit-testing) 1. [Building the documentation](#building-the-documentation) @@ -9,15 +9,14 @@ 1. [Signing your work](#signing-your-work) 1. [Utility functions](#utility-functions) 1. [Backwards compatibility](#backwards-compatibility) - * [Submitting pull requests](#submitting-pull-requests) + - [Submitting pull requests](#submitting-pull-requests) - [The code reviewing process (for the maintainers)](#the-code-reviewing-process) - * [Reviewing pull requests](#reviewing-pull-requests) + - [Reviewing pull requests](#reviewing-pull-requests) - [Admin tasks (for the maintainers)](#admin-tasks) - * [Releasing a new version](#release-a-new-version) + - [Releasing a new version](#release-a-new-version) ## Introduction - Welcome to Project MONAI! We're excited you're here and want to contribute. This documentation is intended for individuals and institutions interested in contributing to MONAI. MONAI is an open-source project and, as such, its success relies on its community of contributors willing to keep improving it. Your contribution will be a valued addition to the code base; we simply ask that you read this page and understand our contribution process, whether you are a seasoned open-source contributor or whether you are a first-time contributor. ### Communicate with us @@ -30,26 +29,28 @@ MONAI is part of [PyTorch Ecosystem](https://pytorch.org/ecosystem/), and mainly ## The contribution process -_Pull request early_ +*Pull request early* We encourage you to create pull requests early. It helps us track the contributions under development, whether they are ready to be merged or not. [Create a draft pull request](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/changing-the-stage-of-a-pull-request) until it is ready for formal review. Please note that, as per PyTorch, MONAI uses American English spelling. This means classes and variables should be: normali**z**e, visuali**z**e, colo~~u~~r, etc. ### Preparing pull requests + To ensure the code quality, MONAI relies on several linting tools ([flake8 and its plugins](https://gitlab.com/pycqa/flake8), [black](https://github.com/psf/black), [isort](https://github.com/timothycrosley/isort), [ruff](https://github.com/astral-sh/ruff)), static type analysis tools ([mypy](https://github.com/python/mypy), [pytype](https://github.com/google/pytype)), as well as a set of unit/integration tests. This section highlights all the necessary preparation steps required before sending a pull request. To collaborate efficiently, please read through this section and follow them. -* [Checking the coding style](#checking-the-coding-style) -* [Licensing information](#licensing-information) -* [Unit testing](#unit-testing) -* [Building documentation](#building-the-documentation) -* [Signing your work](#signing-your-work) +- [Checking the coding style](#checking-the-coding-style) +- [Licensing information](#licensing-information) +- [Unit testing](#unit-testing) +- [Building documentation](#building-the-documentation) +- [Signing your work](#signing-your-work) #### Checking the coding style + Coding style is checked and enforced by flake8, black, isort, and ruff, using [a flake8 configuration](./setup.cfg) similar to [PyTorch's](https://github.com/pytorch/pytorch/blob/master/.flake8). Before submitting a pull request, we recommend that all linting should pass, by running the following command locally: @@ -66,12 +67,14 @@ python -m pip install -U -r requirements-dev.txt ``` Full linting and type checking may take some time. If you need a quick check, run + ```bash # run ruff only ./runtests.sh --ruff ``` #### Licensing information + All source code files should start with this paragraph: ``` @@ -96,6 +99,7 @@ If you intend for any variables/functions/classes to be available outside of the - Add to the `__init__.py` file. #### Unit testing + MONAI tests are located under `tests/`. - The unit test's file name currently follows `test_[module_name].py` or `test_[module_name]_dist.py`. @@ -106,6 +110,7 @@ A bash script (`runtests.sh`) is provided to run all tests locally. Please run ``./runtests.sh -h`` to see all options. To run a particular test, for example `tests/losses/test_dice_loss.py`: + ``` python -m tests.losses.test_dice_loss ``` @@ -116,6 +121,7 @@ should pass, by running the following command locally: ```bash ./runtests.sh -f -u --net --coverage ``` + or (for new features that would not break existing functionality): ```bash @@ -125,17 +131,19 @@ or (for new features that would not break existing functionality): It is recommended that the new test `test_[module_name].py` is constructed by using only python 3.9+ build-in functions, `torch`, `numpy`, `coverage` (for reporting code coverages) and `parameterized` (for organising test cases) packages. If it requires any other external packages, please make sure: + - the packages are listed in [`requirements-dev.txt`](requirements-dev.txt) - the new test `test_[module_name].py` is added to the `exclude_cases` in [`./tests/min_tests.py`](./tests/min_tests.py) so that the minimal CI runner will not execute it. ##### Testing data + Testing data such as images and binary files should not be placed in the source code repository. Please deploy them to a reliable file sharing location (the current preferred one is [https://github.com/Project-MONAI/MONAI-extra-test-data/releases](https://github.com/Project-MONAI/MONAI-extra-test-data/releases)). At test time, the URLs within `tests/testing_data/data_config.json` are accessible via the APIs provided in `tests.utils`: `tests.utils.testing_data_config` and `tests.utils.download_url_or_skip_test`. -_If it's not tested, it's broken_ +*If it's not tested, it's broken* All new functionality should be accompanied by an appropriate set of tests. MONAI functionality has plenty of unit tests from which you can draw inspiration, @@ -144,6 +152,7 @@ and you can reach out to us if you are unsure of how to proceed with testing. MONAI's code coverage report is available at [CodeCov](https://codecov.io/gh/Project-MONAI/MONAI). #### Building the documentation + MONAI's documentation is located at `docs/`. ```bash @@ -155,6 +164,7 @@ pip install -r docs/requirements.txt cd docs/ make html ``` + The above commands build html documentation, they are used to automatically generate [https://docs.monai.io](https://docs.monai.io). The Python code docstring are written in @@ -162,6 +172,7 @@ The Python code docstring are written in the documentation pages can be in either [reStructuredText](https://www.sphinx-doc.org/en/master/usage/restructuredtext/basics.html) or [Markdown](https://en.wikipedia.org/wiki/Markdown). In general the Python docstrings follow the [Google style](https://google.github.io/styleguide/pyguide.html#38-comments-and-docstrings). Before submitting a pull request, it is recommended to: + - edit the relevant `.rst` files in [`docs/source`](./docs/source) accordingly. - build html documentation locally - check the auto-generated documentation (by browsing `./docs/build/html/index.html` with a web browser) @@ -170,11 +181,13 @@ Before submitting a pull request, it is recommended to: Please type `make help` in `docs/` folder for all supported format options. #### Automatic code formatting + MONAI provides support of automatic Python code formatting via [a customised GitHub action](https://github.com/Project-MONAI/monai-code-formatter). This makes the project's Python coding style consistent and reduces maintenance burdens. Commenting a pull request with `/black` triggers the formatting action based on [`psf/Black`](https://github.com/psf/black) (this is implemented with [`slash command dispatch`](https://github.com/marketplace/actions/slash-command-dispatch)). Steps for the formatting process: + - After submitting a pull request or push to an existing pull request, make a comment to the pull request to trigger the formatting action. The first line of the comment must be `/black` so that it will be interpreted by [the comment parser](https://github.com/marketplace/actions/slash-command-dispatch#how-are-comments-parsed-for-slash-commands). @@ -183,11 +196,13 @@ The first line of the comment must be `/black` so that it will be interpreted by - Repeat the above steps if necessary. #### Adding new optional dependencies + In addition to the minimal requirements of PyTorch and Numpy, MONAI's core modules are built optionally based on 3rd-party packages. The current set of dependencies is listed in [installing dependencies](https://docs.monai.io/en/stable/installation.html#installing-the-recommended-dependencies). To allow for flexible integration of MONAI with other systems and environments, the optional dependency APIs are always invoked lazily. For example, + ```py from monai.utils import optional_import itk, _ = optional_import("itk", ...) @@ -197,11 +212,13 @@ class ITKReader(ImageReader): def read(self, ...): return itk.imread(...) ``` + The availability of the external `itk.imread` API is not required unless `monai.data.ITKReader.read` is called by the user. Integration tests with minimal requirements are deployed to ensure this strategy. To add new optional dependencies, please communicate with the core team during pull request reviews, and add the necessary information (at least) to the following files: + - [setup.cfg](https://github.com/Project-MONAI/MONAI/blob/dev/setup.cfg) (for package's `[options.extras_require]` config) - [requirements-dev.txt](https://github.com/Project-MONAI/MONAI/blob/dev/requirements-dev.txt) (pip requirements file) - [docs/requirements.txt](https://github.com/Project-MONAI/MONAI/blob/dev/docs/requirements.txt) (docs pip requirements file) @@ -211,6 +228,7 @@ and add the necessary information (at least) to the following files: When writing unit tests that use 3rd-party packages, it is a good practice to always consider an appropriate fallback default behaviour when the packages are not installed in the testing environment. For example: + ```py from monai.utils import optional_import plt, has_matplotlib = optional_import("matplotlib.pyplot") @@ -218,22 +236,25 @@ plt, has_matplotlib = optional_import("matplotlib.pyplot") @skipUnless(has_matplotlib, "Matplotlib required") class TestBlendImages(unittest.TestCase): ``` + It skips the test cases when `matplotlib.pyplot` APIs are not available. Alternatively, add the test file name to the ``exclude_cases`` in `tests/min_tests.py` to completely skip the test cases when running in a minimal setup. - - #### Signing your work + MONAI enforces the [Developer Certificate of Origin](https://developercertificate.org/) (DCO) on all pull requests. All commit messages should contain the `Signed-off-by` line with an email address. The [GitHub DCO app](https://github.com/apps/dco) is deployed on MONAI. The pull request's status will be `failed` if commits do not contain a valid `Signed-off-by` line. Git has a `-s` (or `--signoff`) command-line option to append this automatically to your commit message: + ```bash git commit -s -m 'a new commit' ``` + The commit message will be: + ``` a new commit @@ -241,6 +262,7 @@ The commit message will be: ``` Full text of the DCO: + ``` Developer Certificate of Origin Version 1.1 @@ -282,11 +304,13 @@ By making a contribution to this project, I certify that: ``` #### Utility functions + MONAI provides a set of generic utility functions and frequently used routines. These are located in [``monai/utils``](./monai/utils/) and in the module folders such as [``networks/utils.py``](./monai/networks/). Users are encouraged to use these common routines to improve code readability and reduce the code maintenance burdens. Notably, + - ``monai.module.export`` decorator can make the module name shorter when importing, for example, ``import monai.transforms.Spacing`` is the equivalent of ``monai.transforms.spatial.array.Spacing`` if ``class Spacing`` defined in file `monai/transforms/spatial/array.py` is decorated with ``@export("monai.transforms")``. @@ -294,12 +318,14 @@ for example, ``import monai.transforms.Spacing`` is the equivalent of ``monai.tr For string definition, [f-string](https://www.python.org/dev/peps/pep-0498/) is recommended to use over `%-print` and `format-print`. So please try to use `f-string` if you need to define any string object. #### Backwards compatibility + MONAI in general follows [PyTorch's policy for backward compatibility](https://github.com/pytorch/pytorch/wiki/PyTorch's-Python-Frontend-Backward-and-Forward-Compatibility-Policy). Utility functions are provided in `monai.utils.deprecated` to help migrate from the deprecated to new APIs. The use of these utilities is encouraged. The pull request [template contains checkboxes](https://github.com/Project-MONAI/MONAI/blame/dev/.github/pull_request_template.md#L11-L12) that the contributor should use accordingly to clearly indicate breaking changes. The process of releasing backwards incompatible API changes is as follows: + 1. discuss the breaking changes during pull requests or in dev meetings with a feature proposal if needed. 1. add a warning message in the upcoming release (version `X.Y`), the warning message should include a forecast of removing the deprecated API in: 1. `X+1.0` -- major version `X+1` and minor version `0` the next major version if it's a significant change, @@ -313,10 +339,10 @@ The process of releasing backwards incompatible API changes is as follows: 1. collect feedback from the users during the subsequent few releases, and reconsider step 1 if needed. 1. before each release, review the deprecating APIs and relevant tests, and clean up the removed APIs described in step 2. - - ### Submitting pull requests + All code changes to the dev branch must be done via [pull requests](https://help.github.com/en/github/collaborating-with-issues-and-pull-requests/proposing-changes-to-your-work-with-pull-requests). + 1. Create a new ticket or take a known ticket from [the issue list][monai issue list]. 1. Check if there's already a branch dedicated to the task. 1. If the task has not been taken, [create a new branch in your fork](https://help.github.com/en/github/collaborating-with-issues-and-pull-requests/creating-a-pull-request-from-a-fork) @@ -334,9 +360,10 @@ Ideally, the new branch should be based on the latest `dev` branch. ## The code reviewing process - ### Reviewing pull requests + All code review comments should be specific, constructive, and actionable. + 1. Check [the CI/CD status of the pull request][github ci], make sure all CI/CD tests passed before reviewing (contact the branch owner if needed). 1. Read carefully the descriptions of the pull request and the files changed, write comments if needed. 1. Make in-line comments to specific code segments, [request for changes](https://help.github.com/en/github/collaborating-with-issues-and-pull-requests/about-pull-request-reviews) if needed. @@ -349,14 +376,15 @@ All code review comments should be specific, constructive, and actionable. [github ci]: https://github.com/Project-MONAI/MONAI/actions [monai issue list]: https://github.com/Project-MONAI/MONAI/issues - ## Admin tasks ### Release a new version + The `dev` branch's `HEAD` always corresponds to MONAI docker image's latest tag: `projectmonai/monai:latest`. The `main` branch's `HEAD` always corresponds to the latest MONAI milestone release. When major features are ready for a milestone, to prepare for a new release: + - Prepare [a release note](https://github.com/Project-MONAI/MONAI/releases) and release checklist. - Check out or cherry-pick a new branch `releasing/[version number]` locally from the `dev` branch and push to the codebase. - Create a release candidate tag, for example, `git tag -a 0.1.0rc1 -m "release candidate 1 of version 0.1.0"`. @@ -364,7 +392,7 @@ When major features are ready for a milestone, to prepare for a new release: This step will trigger package building and testing. The resultant packages are automatically uploaded to [TestPyPI](https://test.pypi.org/project/monai/). The packages are also available for downloading as - repository's artifacts (e.g. the file at https://github.com/Project-MONAI/MONAI/actions/runs/66570977). + repository's artifacts (e.g. the file at ). - Check the release test at [TestPyPI](https://test.pypi.org/project/monai/), download the artifacts when the CI finishes. - Optionally run [the cron testing jobs](https://github.com/Project-MONAI/MONAI/blob/dev/.github/workflows/cron.yml) on `releasing/[version number]`. - Rebase `releasing/[version number]` to `main`, make sure all the test pipelines succeed. @@ -384,7 +412,6 @@ If any error occurs after the release process, first check out a new hotfix bran make a patch version release following the semantic versioning, for example, `releasing/0.1.1`. Make sure the `releasing/0.1.1` is merged back into both `dev` and `main` and all the test pipelines succeed. -

⬆️ Back to Top

diff --git a/tests/apps/detection/networks/test_retinanet.py b/tests/apps/detection/networks/test_retinanet.py index 240fd3a9e2..9def63c1d3 100644 --- a/tests/apps/detection/networks/test_retinanet.py +++ b/tests/apps/detection/networks/test_retinanet.py @@ -20,7 +20,7 @@ from monai.networks import eval_mode from monai.networks.nets import resnet10, resnet18, resnet34, resnet50, resnet101, resnet152, resnet200 from monai.utils import ensure_tuple, optional_import -from tests.test_utils import SkipIfBeforePyTorchVersion, skip_if_quick, test_onnx_save, test_script_save +from tests.test_utils import SkipIfBeforePyTorchVersion, dict_product, skip_if_quick, test_onnx_save, test_script_save _, has_torchvision = optional_import("torchvision") @@ -86,15 +86,12 @@ (2, 1, 32, 64), ] -TEST_CASES = [] -for case in [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_2_A, TEST_CASE_3_A]: - for model in [resnet10, resnet18, resnet34, resnet50, resnet101, resnet152, resnet200]: - TEST_CASES.append([model, *case]) +# Create all test case combinations using dict_product +CASE_LIST = [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_2_A, TEST_CASE_3_A] +MODEL_LIST = [resnet10, resnet18, resnet34, resnet50, resnet101, resnet152, resnet200] -TEST_CASES_TS = [] -for case in [TEST_CASE_1]: - for model in [resnet10, resnet18, resnet34, resnet50, resnet101, resnet152, resnet200]: - TEST_CASES_TS.append([model, *case]) +TEST_CASES = [[params["model"], *params["case"]] for params in dict_product(model=MODEL_LIST, case=CASE_LIST)] +TEST_CASES_TS = [[params["model"], *params["case"]] for params in dict_product(model=MODEL_LIST, case=[TEST_CASE_1])] @SkipIfBeforePyTorchVersion((1, 12)) diff --git a/tests/data/meta_tensor/test_meta_tensor.py b/tests/data/meta_tensor/test_meta_tensor.py index f52d70e7b6..1ee5531b76 100644 --- a/tests/data/meta_tensor/test_meta_tensor.py +++ b/tests/data/meta_tensor/test_meta_tensor.py @@ -32,13 +32,13 @@ from monai.data.utils import decollate_batch, list_data_collate from monai.transforms import BorderPadd, Compose, DivisiblePadd, FromMetaTensord, ToMetaTensord from monai.utils.enums import PostFix -from tests.test_utils import TEST_DEVICES, SkipIfBeforePyTorchVersion, assert_allclose, skip_if_no_cuda +from tests.test_utils import TEST_DEVICES, SkipIfBeforePyTorchVersion, assert_allclose, dict_product, skip_if_no_cuda DTYPES = [[torch.float32], [torch.float64], [torch.float16], [torch.int64], [torch.int32], [None]] -TESTS = [] -for _device in TEST_DEVICES: - for _dtype in DTYPES: - TESTS.append((*_device, *_dtype)) # type: ignore + +# Replace nested loops with dict_product + +TESTS = [(*params["device"], *params["dtype"]) for params in dict_product(device=TEST_DEVICES, dtype=DTYPES)] def rand_string(min_len=5, max_len=10): diff --git a/tests/networks/blocks/test_CABlock.py b/tests/networks/blocks/test_CABlock.py index 42531131c5..132910504b 100644 --- a/tests/networks/blocks/test_CABlock.py +++ b/tests/networks/blocks/test_CABlock.py @@ -20,28 +20,18 @@ from monai.networks import eval_mode from monai.networks.blocks.cablock import CABlock, FeedForward from monai.utils import optional_import -from tests.test_utils import SkipIfBeforePyTorchVersion, assert_allclose +from tests.test_utils import SkipIfBeforePyTorchVersion, assert_allclose, dict_product einops, has_einops = optional_import("einops") - -TEST_CASES_CAB = [] -for spatial_dims in [2, 3]: - for dim in [32, 64, 128]: - for num_heads in [2, 4, 8]: - for bias in [True, False]: - test_case = [ - { - "spatial_dims": spatial_dims, - "dim": dim, - "num_heads": num_heads, - "bias": bias, - "flash_attention": False, - }, - (2, dim, *([16] * spatial_dims)), - (2, dim, *([16] * spatial_dims)), - ] - TEST_CASES_CAB.append(test_case) +TEST_CASES_CAB = [ + [ + {**params, "flash_attention": False}, + (2, params["dim"], *([16] * params["spatial_dims"])), + (2, params["dim"], *([16] * params["spatial_dims"])), + ] + for params in dict_product(spatial_dims=[2, 3], dim=[32, 64, 128], num_heads=[2, 4, 8], bias=[True, False]) +] TEST_CASES_FEEDFORWARD = [ @@ -53,7 +43,6 @@ class TestFeedForward(unittest.TestCase): - @parameterized.expand(TEST_CASES_FEEDFORWARD) def test_shape(self, input_param, input_shape): net = FeedForward(**input_param) @@ -69,7 +58,6 @@ def test_gating_mechanism(self): class TestCABlock(unittest.TestCase): - @parameterized.expand(TEST_CASES_CAB) @skipUnless(has_einops, "Requires einops") def test_shape(self, input_param, input_shape, expected_shape): diff --git a/tests/networks/blocks/test_crossattention.py b/tests/networks/blocks/test_crossattention.py index 32cd655d4c..741d5e3b53 100644 --- a/tests/networks/blocks/test_crossattention.py +++ b/tests/networks/blocks/test_crossattention.py @@ -22,30 +22,28 @@ from monai.networks.blocks.crossattention import CrossAttentionBlock from monai.networks.layers.factories import RelPosEmbedding from monai.utils import optional_import -from tests.test_utils import SkipIfBeforePyTorchVersion, assert_allclose +from tests.test_utils import SkipIfBeforePyTorchVersion, assert_allclose, dict_product einops, has_einops = optional_import("einops") -TEST_CASE_CABLOCK = [] -for dropout_rate in np.linspace(0, 1, 4): - for hidden_size in [360, 480, 600, 768]: - for num_heads in [4, 6, 8, 12]: - for rel_pos_embedding in [None, RelPosEmbedding.DECOMPOSED]: - for input_size in [(16, 32), (8, 8, 8)]: - for flash_attn in [True, False]: - test_case = [ - { - "hidden_size": hidden_size, - "num_heads": num_heads, - "dropout_rate": dropout_rate, - "rel_pos_embedding": rel_pos_embedding if not flash_attn else None, - "input_size": input_size, - "use_flash_attention": flash_attn, - }, - (2, 512, hidden_size), - (2, 512, hidden_size), - ] - TEST_CASE_CABLOCK.append(test_case) +TEST_CASE_CABLOCK = [ + [ + { + **{k: v for k, v in params.items() if k not in ["rel_pos_embedding_val"]}, + "rel_pos_embedding": params["rel_pos_embedding_val"] if not params["use_flash_attention"] else None, + }, + (2, 512, params["hidden_size"]), + (2, 512, params["hidden_size"]), + ] + for params in dict_product( + dropout_rate=np.linspace(0, 1, 4), + hidden_size=[360, 480, 600, 768], + num_heads=[4, 6, 8, 12], + rel_pos_embedding_val=[None, RelPosEmbedding.DECOMPOSED], + input_size=[(16, 32), (8, 8, 8)], + use_flash_attention=[True, False], + ) +] class TestResBlock(unittest.TestCase): diff --git a/tests/networks/blocks/test_dynunet_block.py b/tests/networks/blocks/test_dynunet_block.py index d469c6f3e9..4cf68912be 100644 --- a/tests/networks/blocks/test_dynunet_block.py +++ b/tests/networks/blocks/test_dynunet_block.py @@ -18,58 +18,55 @@ from monai.networks import eval_mode from monai.networks.blocks.dynunet_block import UnetBasicBlock, UnetResBlock, UnetUpBlock, get_padding -from tests.test_utils import test_script_save +from tests.test_utils import dict_product, test_script_save TEST_CASE_RES_BASIC_BLOCK = [] -for spatial_dims in range(2, 4): - for kernel_size in [1, 3]: - for stride in [1, 2]: - for norm_name in [("GROUP", {"num_groups": 16}), ("batch", {"track_running_stats": False}), "instance"]: - for in_size in [15, 16]: - padding = get_padding(kernel_size, stride) - if not isinstance(padding, int): - padding = padding[0] - out_size = int((in_size + 2 * padding - kernel_size) / stride) + 1 - test_case = [ - { - "spatial_dims": spatial_dims, - "in_channels": 16, - "out_channels": 16, - "kernel_size": kernel_size, - "norm_name": norm_name, - "act_name": ("leakyrelu", {"inplace": True, "negative_slope": 0.1}), - "stride": stride, - }, - (1, 16, *([in_size] * spatial_dims)), - (1, 16, *([out_size] * spatial_dims)), - ] - TEST_CASE_RES_BASIC_BLOCK.append(test_case) +for params in dict_product( + spatial_dims=range(2, 4), + kernel_size=[1, 3], + stride=[1, 2], + norm_name=[("GROUP", {"num_groups": 16}), ("batch", {"track_running_stats": False}), "instance"], + in_size=[15, 16], +): + padding = get_padding(params["kernel_size"], params["stride"]) + if not isinstance(padding, int): + padding = padding[0] + out_size = int((params["in_size"] + 2 * padding - params["kernel_size"]) / params["stride"]) + 1 + test_case = [ + { + **{k: v for k, v in params.items() if k != "in_size"}, + "in_channels": 16, + "out_channels": 16, + "act_name": ("leakyrelu", {"inplace": True, "negative_slope": 0.1}), + }, + (1, 16, *([params["in_size"]] * params["spatial_dims"])), + (1, 16, *([out_size] * params["spatial_dims"])), + ] + TEST_CASE_RES_BASIC_BLOCK.append(test_case) TEST_UP_BLOCK = [] in_channels, out_channels = 4, 2 -for spatial_dims in range(2, 4): - for kernel_size in [1, 3]: - for stride in [1, 2]: - for norm_name in ["batch", "instance"]: - for in_size in [15, 16]: - for trans_bias in [True, False]: - out_size = in_size * stride - test_case = [ - { - "spatial_dims": spatial_dims, - "in_channels": in_channels, - "out_channels": out_channels, - "kernel_size": kernel_size, - "norm_name": norm_name, - "stride": stride, - "upsample_kernel_size": stride, - "trans_bias": trans_bias, - }, - (1, in_channels, *([in_size] * spatial_dims)), - (1, out_channels, *([out_size] * spatial_dims)), - (1, out_channels, *([in_size * stride] * spatial_dims)), - ] - TEST_UP_BLOCK.append(test_case) +for params in dict_product( + spatial_dims=range(2, 4), + kernel_size=[1, 3], + stride=[1, 2], + norm_name=["batch", "instance"], + in_size=[15, 16], + trans_bias=[True, False], +): + out_size = params["in_size"] * params["stride"] + test_case = [ + { + **{k: v for k, v in params.items() if k != "in_size"}, + "in_channels": in_channels, + "out_channels": out_channels, + "upsample_kernel_size": params["stride"], + }, + (1, in_channels, *([params["in_size"]] * params["spatial_dims"])), + (1, out_channels, *([out_size] * params["spatial_dims"])), + (1, out_channels, *([params["in_size"] * params["stride"]] * params["spatial_dims"])), + ] + TEST_UP_BLOCK.append(test_case) class TestResBasicBlock(unittest.TestCase): diff --git a/tests/networks/blocks/test_patchembedding.py b/tests/networks/blocks/test_patchembedding.py index 95ca95b36a..de6b486279 100644 --- a/tests/networks/blocks/test_patchembedding.py +++ b/tests/networks/blocks/test_patchembedding.py @@ -21,58 +21,41 @@ from monai.networks import eval_mode from monai.networks.blocks.patchembedding import PatchEmbed, PatchEmbeddingBlock from monai.utils import optional_import -from tests.test_utils import SkipIfBeforePyTorchVersion +from tests.test_utils import SkipIfBeforePyTorchVersion, dict_product einops, has_einops = optional_import("einops") -TEST_CASE_PATCHEMBEDDINGBLOCK = [] -for dropout_rate in (0.5,): - for in_channels in [1, 4]: - for hidden_size in [96, 288]: - for img_size in [32, 64]: - for patch_size in [8, 16]: - for num_heads in [8, 12]: - for proj_type in ["conv", "perceptron"]: - for pos_embed_type in ["none", "learnable", "sincos"]: - # for classification in (False, True): # TODO: add classification tests - for nd in (2, 3): - test_case = [ - { - "in_channels": in_channels, - "img_size": (img_size,) * nd, - "patch_size": (patch_size,) * nd, - "hidden_size": hidden_size, - "num_heads": num_heads, - "proj_type": proj_type, - "pos_embed_type": pos_embed_type, - "dropout_rate": dropout_rate, - }, - (2, in_channels, *([img_size] * nd)), - (2, (img_size // patch_size) ** nd, hidden_size), - ] - if nd == 2: - test_case[0]["spatial_dims"] = 2 # type: ignore - TEST_CASE_PATCHEMBEDDINGBLOCK.append(test_case) - -TEST_CASE_PATCHEMBED = [] -for patch_size in [2]: - for in_chans in [1, 4]: - for img_size in [96]: - for embed_dim in [6, 12]: - for norm_layer in [nn.LayerNorm]: - for nd in [2, 3]: - test_case = [ - { - "patch_size": (patch_size,) * nd, - "in_chans": in_chans, - "embed_dim": embed_dim, - "norm_layer": norm_layer, - "spatial_dims": nd, - }, - (2, in_chans, *([img_size] * nd)), - (2, embed_dim, *([img_size // patch_size] * nd)), - ] - TEST_CASE_PATCHEMBED.append(test_case) + +TEST_CASE_PATCHEMBEDDINGBLOCK = [ + [ + params, + (2, params["in_channels"], *([params["img_size"]] * params["spatial_dims"])), + (2, (params["img_size"] // params["patch_size"]) ** params["spatial_dims"], params["hidden_size"]), + ] + for params in dict_product( + dropout_rate=[0.5], + in_channels=[1, 4], + hidden_size=[96, 288], + img_size=[32, 64], + patch_size=[8, 16], + num_heads=[8, 12], + proj_type=["conv", "perceptron"], + pos_embed_type=["none", "learnable", "sincos"], + spatial_dims=[2, 3], + ) +] + +img_size = 96 +TEST_CASE_PATCHEMBED = [ + [ + params, + (2, params["in_chans"], *([img_size] * params["spatial_dims"])), + (2, params["embed_dim"], *([img_size // params["patch_size"]]) * params["spatial_dims"]), + ] + for params in dict_product( + patch_size=[2], in_chans=[1, 4], embed_dim=[6, 12], norm_layer=[nn.LayerNorm], spatial_dims=[2, 3] + ) +] @SkipIfBeforePyTorchVersion((1, 11, 1)) diff --git a/tests/networks/blocks/test_segresnet_block.py b/tests/networks/blocks/test_segresnet_block.py index 633507a06a..444a171345 100644 --- a/tests/networks/blocks/test_segresnet_block.py +++ b/tests/networks/blocks/test_segresnet_block.py @@ -18,27 +18,24 @@ from monai.networks import eval_mode from monai.networks.blocks.segresnet_block import ResBlock - -TEST_CASE_RESBLOCK = [] -for spatial_dims in range(2, 4): - for in_channels in range(1, 4): - for kernel_size in [1, 3]: - for norm in [("group", {"num_groups": 1}), "batch", "instance"]: - test_case = [ - { - "spatial_dims": spatial_dims, - "in_channels": in_channels, - "kernel_size": kernel_size, - "norm": norm, - }, - (2, in_channels, *([16] * spatial_dims)), - (2, in_channels, *([16] * spatial_dims)), - ] - TEST_CASE_RESBLOCK.append(test_case) +from tests.test_utils import dict_product + +TEST_CASE_RESBLOCK = [ + [ + params, + (2, params["in_channels"], *([16] * params["spatial_dims"])), + (2, params["in_channels"], *([16] * params["spatial_dims"])), + ] + for params in dict_product( + spatial_dims=range(2, 4), + in_channels=range(1, 4), + kernel_size=[1, 3], + norm=[("group", {"num_groups": 1}), "batch", "instance"], + ) +] class TestResBlock(unittest.TestCase): - @parameterized.expand(TEST_CASE_RESBLOCK) def test_shape(self, input_param, input_shape, expected_shape): net = ResBlock(**input_param) diff --git a/tests/networks/blocks/test_transformerblock.py b/tests/networks/blocks/test_transformerblock.py index a850cc6f74..b977a38e73 100644 --- a/tests/networks/blocks/test_transformerblock.py +++ b/tests/networks/blocks/test_transformerblock.py @@ -21,26 +21,19 @@ from monai.networks import eval_mode from monai.networks.blocks.transformerblock import TransformerBlock from monai.utils import optional_import +from tests.test_utils import dict_product einops, has_einops = optional_import("einops") -TEST_CASE_TRANSFORMERBLOCK = [] -for dropout_rate in np.linspace(0, 1, 4): - for hidden_size in [360, 480, 600, 768]: - for num_heads in [4, 8, 12]: - for mlp_dim in [1024, 3072]: - for cross_attention in [False, True]: - test_case = [ - { - "hidden_size": hidden_size, - "num_heads": num_heads, - "mlp_dim": mlp_dim, - "dropout_rate": dropout_rate, - "with_cross_attention": cross_attention, - }, - (2, 512, hidden_size), - (2, 512, hidden_size), - ] - TEST_CASE_TRANSFORMERBLOCK.append(test_case) +TEST_CASE_TRANSFORMERBLOCK = [ + [params, (2, 512, params["hidden_size"]), (2, 512, params["hidden_size"])] + for params in dict_product( + dropout_rate=np.linspace(0, 1, 4), + hidden_size=[360, 480, 600, 768], + num_heads=[4, 8, 12], + mlp_dim=[1024, 3072], + with_cross_attention=[False, True], + ) +] class TestTransformerBlock(unittest.TestCase): diff --git a/tests/networks/blocks/test_unetr_block.py b/tests/networks/blocks/test_unetr_block.py index 1396a08193..0073efc609 100644 --- a/tests/networks/blocks/test_unetr_block.py +++ b/tests/networks/blocks/test_unetr_block.py @@ -19,86 +19,82 @@ from monai.networks import eval_mode from monai.networks.blocks.dynunet_block import get_padding from monai.networks.blocks.unetr_block import UnetrBasicBlock, UnetrPrUpBlock, UnetrUpBlock -from tests.test_utils import test_script_save +from tests.test_utils import dict_product, test_script_save + +def _get_out_size(params): + in_size = params["in_size"] + kernel_size = params["kernel_size"] + stride = params["stride"] + padding = get_padding(kernel_size, stride) + if not isinstance(padding, int): + padding = padding[0] + return int((in_size + 2 * padding - kernel_size) / stride) + 1 + + +norm_names = [("GROUP", {"num_groups": 16}), ("batch", {"track_running_stats": False}), "instance"] +param_dicts = dict_product( + spatial_dims=range(1, 4), kernel_size=[1, 3], stride=[2], norm_name=norm_names, in_size=[15, 16] +) TEST_CASE_UNETR_BASIC_BLOCK = [] -for spatial_dims in range(1, 4): - for kernel_size in [1, 3]: - for stride in [2]: - for norm_name in [("GROUP", {"num_groups": 16}), ("batch", {"track_running_stats": False}), "instance"]: - for in_size in [15, 16]: - padding = get_padding(kernel_size, stride) - if not isinstance(padding, int): - padding = padding[0] - out_size = int((in_size + 2 * padding - kernel_size) / stride) + 1 - test_case = [ - { - "spatial_dims": spatial_dims, - "in_channels": 16, - "out_channels": 16, - "kernel_size": kernel_size, - "norm_name": norm_name, - "stride": stride, - }, - (1, 16, *([in_size] * spatial_dims)), - (1, 16, *([out_size] * spatial_dims)), - ] - TEST_CASE_UNETR_BASIC_BLOCK.append(test_case) - -TEST_UP_BLOCK = [] -in_channels, out_channels = 4, 2 -for spatial_dims in range(1, 4): - for kernel_size in [1, 3]: - for res_block in [False, True]: - for norm_name in ["instance"]: - for in_size in [15, 16]: - out_size = in_size * stride - test_case = [ - { - "spatial_dims": spatial_dims, - "in_channels": in_channels, - "out_channels": out_channels, - "kernel_size": kernel_size, - "norm_name": norm_name, - "res_block": res_block, - "upsample_kernel_size": stride, - }, - (1, in_channels, *([in_size] * spatial_dims)), - (1, out_channels, *([out_size] * spatial_dims)), - (1, out_channels, *([in_size * stride] * spatial_dims)), - ] - TEST_UP_BLOCK.append(test_case) +for params in param_dicts: + input_param = {**{k: v for k, v in params.items() if k != "in_size"}, "in_channels": 16, "out_channels": 16} + input_shape = (1, 16, *([params["in_size"]] * params["spatial_dims"])) + expected_shape = (1, 16, *([_get_out_size(params)] * params["spatial_dims"])) + TEST_CASE_UNETR_BASIC_BLOCK.append([input_param, input_shape, expected_shape]) + + +TEST_UP_BLOCK = [ + [ + { + **{k: v for k, v in params.items() if k not in ["in_size", "stride", "upsample_kernel_size"]}, + "upsample_kernel_size": params["stride"], + }, + (1, params["in_channels"], *([params["in_size"]] * params["spatial_dims"])), + (1, params["out_channels"], *([params["in_size"] * params["stride"]] * params["spatial_dims"])), + (1, params["out_channels"], *([params["in_size"] * params["stride"]] * params["spatial_dims"])), + ] + for params in dict_product( + spatial_dims=range(1, 4), + in_channels=[4], + out_channels=[2], + kernel_size=[1, 3], + norm_name=["instance"], + res_block=[False, True], + upsample_kernel_size=[2, 3], + stride=[1, 2], + in_size=[15, 16], + ) +] TEST_PRUP_BLOCK = [] in_channels, out_channels = 4, 2 -for spatial_dims in range(1, 4): - for kernel_size in [1, 3]: - for upsample_kernel_size in [2, 3]: - for stride in [1, 2]: - for res_block in [False, True]: - for norm_name in ["instance"]: - for in_size in [15, 16]: - for num_layer in [0, 2]: - in_size_tmp = in_size - for _ in range(num_layer + 1): - out_size = in_size_tmp * upsample_kernel_size - in_size_tmp = out_size - test_case = [ - { - "spatial_dims": spatial_dims, - "in_channels": in_channels, - "out_channels": out_channels, - "num_layer": num_layer, - "kernel_size": kernel_size, - "norm_name": norm_name, - "stride": stride, - "res_block": res_block, - "upsample_kernel_size": upsample_kernel_size, - }, - (1, in_channels, *([in_size] * spatial_dims)), - (1, out_channels, *([out_size] * spatial_dims)), - ] - TEST_PRUP_BLOCK.append(test_case) +for params in dict_product( + spatial_dims=range(1, 4), + kernel_size=[1, 3], + upsample_kernel_size=[2, 3], + stride=[1, 2], + res_block=[False, True], + norm_name=["instance"], + in_size_scalar=[15, 16], + num_layer=[0, 2], +): + in_size_tmp = params["in_size_scalar"] + out_size = 0 # Initialize out_size + for _ in range(params["num_layer"] + 1): + out_size = in_size_tmp * params["upsample_kernel_size"] + in_size_tmp = out_size + + test_case = [ + { + **{k: v for k, v in params.items() if k != "in_size_scalar"}, + "in_channels": in_channels, + "out_channels": out_channels, + }, + (1, in_channels, *([params["in_size_scalar"]] * params["spatial_dims"])), + (1, out_channels, *([out_size] * params["spatial_dims"])), + ] + TEST_PRUP_BLOCK.append(test_case) class TestResBasicBlock(unittest.TestCase): diff --git a/tests/networks/nets/test_dynunet.py b/tests/networks/nets/test_dynunet.py index a9dd588e13..c2c9369923 100644 --- a/tests/networks/nets/test_dynunet.py +++ b/tests/networks/nets/test_dynunet.py @@ -13,8 +13,6 @@ import platform import unittest -from collections.abc import Sequence -from typing import Any import torch from parameterized import parameterized @@ -22,7 +20,7 @@ from monai.networks import eval_mode from monai.networks.nets import DynUNet from monai.utils import optional_import -from tests.test_utils import assert_allclose, skip_if_no_cuda, skip_if_windows, test_script_save +from tests.test_utils import assert_allclose, dict_product, skip_if_no_cuda, skip_if_windows, test_script_save InstanceNorm3dNVFuser, _ = optional_import("apex.normalization", name="InstanceNorm3dNVFuser") @@ -34,86 +32,96 @@ device = "cuda" if torch.cuda.is_available() else "cpu" -strides: Sequence[Sequence[int] | int] -kernel_size: Sequence[Any] -expected_shape: Sequence[Any] - TEST_CASE_DYNUNET_2D = [] -out_channels = 2 -in_size = 64 -spatial_dims = 2 -for kernel_size in [(3, 3, 3, 1), ((3, 1), 1, (3, 3), (1, 1))]: - for strides in [(1, 1, 1, 1), (2, 2, 2, 1)]: - expected_shape = (1, out_channels, *[in_size // strides[0]] * spatial_dims) - for in_channels in [2, 3]: - for res_block in [True, False]: - test_case = [ - { - "spatial_dims": spatial_dims, - "in_channels": in_channels, - "out_channels": out_channels, - "kernel_size": kernel_size, - "strides": strides, - "upsample_kernel_size": strides[1:], - "norm_name": "batch", - "act_name": ("leakyrelu", {"inplace": True, "negative_slope": 0.2}), - "deep_supervision": False, - "res_block": res_block, - "dropout": None, - }, - (1, in_channels, in_size, in_size), - expected_shape, - ] - TEST_CASE_DYNUNET_2D.append(test_case) +out_channels_2d = 2 +in_size_2d = 64 +spatial_dims_2d = 2 +for params in dict_product( + kernel_size=[(3, 3, 3, 1), ((3, 1), 1, (3, 3), (1, 1))], + strides=[(1, 1, 1, 1), (2, 2, 2, 1)], + in_channels=[2, 3], + res_block=[True, False], +): + kernel_size = params["kernel_size"] + strides = params["strides"] + in_channels = params["in_channels"] + res_block = params["res_block"] + expected_shape = (1, out_channels_2d, *[in_size_2d // strides[0]] * spatial_dims_2d) + test_case = [ + { + "spatial_dims": spatial_dims_2d, + "in_channels": in_channels, + "out_channels": out_channels_2d, + "kernel_size": kernel_size, + "strides": strides, + "upsample_kernel_size": strides[1:], + "norm_name": "batch", + "act_name": ("leakyrelu", {"inplace": True, "negative_slope": 0.2}), + "deep_supervision": False, + "res_block": res_block, + "dropout": None, + }, + (1, in_channels, in_size_2d, in_size_2d), + expected_shape, + ] + TEST_CASE_DYNUNET_2D.append(test_case) TEST_CASE_DYNUNET_3D = [] # in 3d cases, also test anisotropic kernel/strides -in_channels = 1 -in_size = 64 -for out_channels in [2, 3]: +in_channels_3d = 1 +in_size_3d = 64 +for params in dict_product(out_channels=[2, 3], res_block=[True, False]): + out_channels = params["out_channels"] + res_block = params["res_block"] expected_shape = (1, out_channels, 64, 32, 64) - for res_block in [True, False]: - test_case = [ - { - "spatial_dims": 3, - "in_channels": in_channels, - "out_channels": out_channels, - "kernel_size": (3, (1, 1, 3), 3, 3), - "strides": ((1, 2, 1), 2, 2, 1), - "upsample_kernel_size": (2, 2, 1), - "filters": (64, 96, 128, 192), - "norm_name": ("INSTANCE", {"affine": True}), - "deep_supervision": True, - "res_block": res_block, - "dropout": ("alphadropout", {"p": 0.25}), - }, - (1, in_channels, in_size, in_size, in_size), - expected_shape, - ] - TEST_CASE_DYNUNET_3D.append(test_case) + test_case = [ + { + "spatial_dims": 3, + "in_channels": in_channels_3d, + "out_channels": out_channels, + "kernel_size": (3, (1, 1, 3), 3, 3), + "strides": ((1, 2, 1), 2, 2, 1), + "upsample_kernel_size": (2, 2, 1), + "filters": (64, 96, 128, 192), + "norm_name": ("INSTANCE", {"affine": True}), + "deep_supervision": True, + "res_block": res_block, + "dropout": ("alphadropout", {"p": 0.25}), + }, + (1, in_channels_3d, in_size_3d, in_size_3d, in_size_3d), + expected_shape, + ] + TEST_CASE_DYNUNET_3D.append(test_case) TEST_CASE_DEEP_SUPERVISION = [] -for spatial_dims in [2, 3]: - for res_block in [True, False]: - for deep_supr_num in [1, 2]: - for strides in [(1, 2, 1, 2, 1), (2, 2, 2, 1), (2, 1, 1, 2, 2)]: - scale = strides[0] - test_case = [ - { - "spatial_dims": spatial_dims, - "in_channels": 1, - "out_channels": 2, - "kernel_size": [3] * len(strides), - "strides": strides, - "upsample_kernel_size": strides[1:], - "norm_name": ("group", {"num_groups": 16}), - "deep_supervision": True, - "deep_supr_num": deep_supr_num, - "res_block": res_block, - }, - (1, 1, *[in_size] * spatial_dims), - (1, 1 + deep_supr_num, 2, *[in_size // scale] * spatial_dims), - ] - TEST_CASE_DEEP_SUPERVISION.append(test_case) +in_size_ds = 64 +for params in dict_product( + spatial_dims=[2, 3], + res_block=[True, False], + deep_supr_num=[1, 2], + strides=[(1, 2, 1, 2, 1), (2, 2, 2, 1), (2, 1, 1, 2, 2)], +): + spatial_dims = params["spatial_dims"] + res_block = params["res_block"] + deep_supr_num = params["deep_supr_num"] + strides = params["strides"] + scale = strides[0] + test_case = [ + { + "spatial_dims": spatial_dims, + "in_channels": 1, + "out_channels": 2, + "kernel_size": [3] * len(strides), + "strides": strides, + "upsample_kernel_size": strides[1:], + "norm_name": ("group", {"num_groups": 16}), + "deep_supervision": True, + "deep_supr_num": deep_supr_num, + "res_block": res_block, + }, + (1, 1, *[in_size_ds] * spatial_dims), + (1, 1 + deep_supr_num, 2, *[in_size_ds // scale] * spatial_dims), + ] + TEST_CASE_DEEP_SUPERVISION.append(test_case) class TestDynUNet(unittest.TestCase): diff --git a/tests/networks/nets/test_mednext.py b/tests/networks/nets/test_mednext.py index b4ba4f9939..53fd77ad5f 100644 --- a/tests/networks/nets/test_mednext.py +++ b/tests/networks/nets/test_mednext.py @@ -18,57 +18,41 @@ from monai.networks import eval_mode from monai.networks.nets import MedNeXt, MedNeXtL, MedNeXtM, MedNeXtS +from tests.test_utils import dict_product # Import dict_product device = "cuda" if torch.cuda.is_available() else "cpu" -TEST_CASE_MEDNEXT = [] -for spatial_dims in range(2, 4): - for init_filters in [8, 16]: - for deep_supervision in [False, True]: - for do_res in [False, True]: - test_case = [ - { - "spatial_dims": spatial_dims, - "init_filters": init_filters, - "deep_supervision": deep_supervision, - "use_residual_connection": do_res, - }, - (2, 1, *([16] * spatial_dims)), - (2, 2, *([16] * spatial_dims)), - ] - TEST_CASE_MEDNEXT.append(test_case) - -TEST_CASE_MEDNEXT_2 = [] -for spatial_dims in range(2, 4): - for out_channels in [1, 2]: - for deep_supervision in [False, True]: - test_case = [ - { - "spatial_dims": spatial_dims, - "init_filters": 8, - "out_channels": out_channels, - "deep_supervision": deep_supervision, - }, - (2, 1, *([16] * spatial_dims)), - (2, out_channels, *([16] * spatial_dims)), - ] - TEST_CASE_MEDNEXT_2.append(test_case) - -TEST_CASE_MEDNEXT_VARIANTS = [] -for model in [MedNeXtS, MedNeXtM, MedNeXtL]: - for spatial_dims in range(2, 4): - for out_channels in [1, 2]: - test_case = [ - model, # type: ignore - {"spatial_dims": spatial_dims, "in_channels": 1, "out_channels": out_channels}, - (2, 1, *([16] * spatial_dims)), - (2, out_channels, *([16] * spatial_dims)), - ] - TEST_CASE_MEDNEXT_VARIANTS.append(test_case) +TEST_CASE_MEDNEXT = [ + [params, (2, 1, *([16] * params["spatial_dims"])), (2, 2, *([16] * params["spatial_dims"]))] + for params in dict_product( + spatial_dims=range(2, 4), + init_filters=[8, 16], + deep_supervision=[False, True], + use_residual_connection=[False, True], + ) +] +TEST_CASE_MEDNEXT_2 = [ + [params, (2, 1, *([16] * params["spatial_dims"])), (2, params["out_channels"], *([16] * params["spatial_dims"]))] + for params in dict_product( + spatial_dims=range(2, 4), out_channels=[1, 2], deep_supervision=[False, True], init_filters=[8] + ) +] + + +TEST_CASE_MEDNEXT_VARIANTS = [ + [ + params["model"], + {"spatial_dims": params["spatial_dims"], "in_channels": 1, "out_channels": params["out_channels"]}, + (2, 1, *([16] * params["spatial_dims"])), + (2, params["out_channels"], *([16] * params["spatial_dims"])), + ] + for params in dict_product( + model=[MedNeXtS, MedNeXtM, MedNeXtL], spatial_dims=range(2, 4), out_channels=[1, 2], in_channels=[1] + ) +] class TestMedNeXt(unittest.TestCase): - @parameterized.expand(TEST_CASE_MEDNEXT) def test_shape(self, input_param, input_shape, expected_shape): net = MedNeXt(**input_param).to(device) diff --git a/tests/networks/nets/test_segresnet.py b/tests/networks/nets/test_segresnet.py index b3b3d1051a..1536d33853 100644 --- a/tests/networks/nets/test_segresnet.py +++ b/tests/networks/nets/test_segresnet.py @@ -19,67 +19,50 @@ from monai.networks import eval_mode from monai.networks.nets import SegResNet, SegResNetVAE from monai.utils import UpsampleMode -from tests.test_utils import test_script_save +from tests.test_utils import dict_product, test_script_save device = "cuda" if torch.cuda.is_available() else "cpu" -TEST_CASE_SEGRESNET = [] -for spatial_dims in range(2, 4): - for init_filters in [8, 16]: - for dropout_prob in [None, 0.2]: - for norm in [("GROUP", {"num_groups": 8}), ("batch", {"track_running_stats": False}), "instance"]: - for upsample_mode in UpsampleMode: - test_case = [ - { - "spatial_dims": spatial_dims, - "init_filters": init_filters, - "dropout_prob": dropout_prob, - "norm": norm, - "upsample_mode": upsample_mode, - "use_conv_final": False, - }, - (2, 1, *([16] * spatial_dims)), - (2, init_filters, *([16] * spatial_dims)), - ] - TEST_CASE_SEGRESNET.append(test_case) - -TEST_CASE_SEGRESNET_2 = [] -for spatial_dims in range(2, 4): - for init_filters in [8, 16]: - for out_channels in range(1, 3): - for upsample_mode in UpsampleMode: - test_case = [ - { - "spatial_dims": spatial_dims, - "init_filters": init_filters, - "out_channels": out_channels, - "upsample_mode": upsample_mode, - }, - (2, 1, *([16] * spatial_dims)), - (2, out_channels, *([16] * spatial_dims)), - ] - TEST_CASE_SEGRESNET_2.append(test_case) - -TEST_CASE_SEGRESNET_VAE = [] -for spatial_dims in range(2, 4): - for init_filters in [8, 16]: - for out_channels in range(1, 3): - for upsample_mode in UpsampleMode: - for vae_estimate_std in [True, False]: - test_case = [ - { - "spatial_dims": spatial_dims, - "init_filters": init_filters, - "out_channels": out_channels, - "upsample_mode": upsample_mode, - "act": ("leakyrelu", {"inplace": True, "negative_slope": 0.01}), - "input_image_size": ([16] * spatial_dims), - "vae_estimate_std": vae_estimate_std, - }, - (2, 1, *([16] * spatial_dims)), - (2, out_channels, *([16] * spatial_dims)), - ] - TEST_CASE_SEGRESNET_VAE.append(test_case) +TEST_CASE_SEGRESNET = [ + [ + {**params, "use_conv_final": False}, + (2, 1, *([16] * params["spatial_dims"])), + (2, params["init_filters"], *([16] * params["spatial_dims"])), + ] + for params in dict_product( + spatial_dims=range(2, 4), + init_filters=[8, 16], + dropout_prob=[None, 0.2], + norm=[("GROUP", {"num_groups": 8}), ("batch", {"track_running_stats": False}), "instance"], + upsample_mode=list(UpsampleMode), + ) +] + +TEST_CASE_SEGRESNET_2 = [ + [params, (2, 1, *([16] * params["spatial_dims"])), (2, params["out_channels"], *([16] * params["spatial_dims"]))] + for params in dict_product( + spatial_dims=range(2, 4), init_filters=[8, 16], out_channels=range(1, 3), upsample_mode=list(UpsampleMode) + ) +] + +TEST_CASE_SEGRESNET_VAE = [ + [ + { + **params, + "act": ("leakyrelu", {"inplace": True, "negative_slope": 0.01}), + "input_image_size": ([16] * params["spatial_dims"]), + }, + (2, 1, *([16] * params["spatial_dims"])), + (2, params["out_channels"], *([16] * params["spatial_dims"])), + ] + for params in dict_product( + spatial_dims=range(2, 4), + init_filters=[8, 16], + out_channels=range(1, 3), + upsample_mode=list(UpsampleMode), + vae_estimate_std=[True, False], + ) +] class TestResNet(unittest.TestCase): diff --git a/tests/networks/nets/test_segresnet_ds.py b/tests/networks/nets/test_segresnet_ds.py index 064b2ba06c..ad97829f76 100644 --- a/tests/networks/nets/test_segresnet_ds.py +++ b/tests/networks/nets/test_segresnet_ds.py @@ -18,38 +18,29 @@ from monai.networks import eval_mode from monai.networks.nets import SegResNetDS, SegResNetDS2 -from tests.test_utils import SkipIfBeforePyTorchVersion, test_script_save +from tests.test_utils import SkipIfBeforePyTorchVersion, dict_product, test_script_save device = "cuda" if torch.cuda.is_available() else "cpu" -TEST_CASE_SEGRESNET_DS = [] -for spatial_dims in range(2, 4): - for init_filters in [8, 16]: - for act in ["relu", "leakyrelu"]: - for norm in ["BATCH", ("instance", {"affine": True})]: - for upsample_mode in ["deconv", "nontrainable"]: - test_case = [ - { - "spatial_dims": spatial_dims, - "init_filters": init_filters, - "act": act, - "norm": norm, - "upsample_mode": upsample_mode, - }, - (2, 1, *([16] * spatial_dims)), - (2, 2, *([16] * spatial_dims)), - ] - TEST_CASE_SEGRESNET_DS.append(test_case) - -TEST_CASE_SEGRESNET_DS2 = [] -for spatial_dims in range(2, 4): - for out_channels in [1, 2]: - for dsdepth in [1, 2, 3]: - test_case = [ - {"spatial_dims": spatial_dims, "init_filters": 8, "out_channels": out_channels, "dsdepth": dsdepth}, - (2, 1, *([16] * spatial_dims)), - (2, out_channels, *([16] * spatial_dims)), - ] - TEST_CASE_SEGRESNET_DS2.append(test_case) + +TEST_CASE_SEGRESNET_DS = [ + [params, (2, 1, *([16] * params["spatial_dims"])), (2, 2, *([16] * params["spatial_dims"]))] + for params in dict_product( + spatial_dims=range(2, 4), + init_filters=[8, 16], + act=["relu", "leakyrelu"], + norm=["BATCH", ("instance", {"affine": True})], + upsample_mode=["deconv", "nontrainable"], + ) +] + +TEST_CASE_SEGRESNET_DS2 = [ + [ + {**params, "init_filters": 8}, + (2, 1, *([16] * params["spatial_dims"])), + (2, params["out_channels"], *([16] * params["spatial_dims"])), + ] + for params in dict_product(spatial_dims=range(2, 4), out_channels=[1, 2], dsdepth=[1, 2, 3]) +] TEST_CASE_SEGRESNET_DS3 = [ ({"init_filters": 8, "dsdepth": 2, "resolution": None}, (2, 1, 16, 16, 16), ((2, 2, 16, 16, 16), (2, 2, 8, 8, 8))), diff --git a/tests/networks/nets/test_swin_unetr.py b/tests/networks/nets/test_swin_unetr.py index d321c777ba..cc15158b43 100644 --- a/tests/networks/nets/test_swin_unetr.py +++ b/tests/networks/nets/test_swin_unetr.py @@ -26,6 +26,7 @@ from monai.utils import optional_import from tests.test_utils import ( assert_allclose, + dict_product, skip_if_downloading_fails, skip_if_no_cuda, skip_if_quick, @@ -34,35 +35,32 @@ einops, has_einops = optional_import("einops") -TEST_CASE_SWIN_UNETR = [] -case_idx = 0 test_merging_mode = ["mergingv2", "merging", PatchMerging, PatchMergingV2] checkpoint_vals = [True, False] -for attn_drop_rate in [0.4]: - for in_channels in [1]: - for depth in [[2, 1, 1, 1], [1, 2, 1, 1]]: - for out_channels in [2]: - for img_size in ((64, 32, 192), (96, 32)): - for feature_size in [12]: - for norm_name in ["instance"]: - for use_checkpoint in checkpoint_vals: - test_case = [ - { - "spatial_dims": len(img_size), - "in_channels": in_channels, - "out_channels": out_channels, - "feature_size": feature_size, - "depths": depth, - "norm_name": norm_name, - "attn_drop_rate": attn_drop_rate, - "downsample": test_merging_mode[case_idx % 4], - "use_checkpoint": use_checkpoint, - }, - (2, in_channels, *img_size), - (2, out_channels, *img_size), - ] - case_idx += 1 - TEST_CASE_SWIN_UNETR.append(test_case) + +TEST_CASE_SWIN_UNETR = [ + [ + { + **{k: v for k, v in params.items() if k != "img_size"}, + "spatial_dims": len(params["img_size"]), + "downsample": test_merging_mode[i % len(test_merging_mode)], + }, + (2, params["in_channels"], *params["img_size"]), + (2, params["out_channels"], *params["img_size"]), + ] + for i, params in enumerate( + dict_product( + attn_drop_rate=[0.4], + depths=[[2, 1, 1, 1], [1, 2, 1, 1]], + feature_size=[12], + img_size=((64, 32, 192), (96, 32)), + in_channels=[1], + norm_name=["instance"], + out_channels=[2], + use_checkpoint=checkpoint_vals, + ) + ) +] TEST_CASE_FILTER = [ [ diff --git a/tests/networks/nets/test_transchex.py b/tests/networks/nets/test_transchex.py index 1816bc2dd8..b20e604bbf 100644 --- a/tests/networks/nets/test_transchex.py +++ b/tests/networks/nets/test_transchex.py @@ -18,31 +18,28 @@ from monai.networks import eval_mode from monai.networks.nets.transchex import Transchex -from tests.test_utils import skip_if_downloading_fails, skip_if_quick +from tests.test_utils import dict_product, skip_if_downloading_fails, skip_if_quick -TEST_CASE_TRANSCHEX = [] -for drop_out in [0.4]: - for in_channels in [3]: - for img_size in [224]: - for patch_size in [16, 32]: - for num_language_layers in [2]: - for num_vision_layers in [4]: - for num_mixed_layers in [3]: - for num_classes in [8]: - test_case = [ - { - "in_channels": in_channels, - "img_size": (img_size,) * 2, - "patch_size": (patch_size,) * 2, - "num_vision_layers": num_vision_layers, - "num_mixed_layers": num_mixed_layers, - "num_language_layers": num_language_layers, - "num_classes": num_classes, - "drop_out": drop_out, - }, - (2, num_classes), - ] - TEST_CASE_TRANSCHEX.append(test_case) +TEST_CASE_TRANSCHEX = [ + [ + { + **{k: v for k, v in params.items() if k != "img_size"}, + "img_size": (params["img_size"],) * 2, + "patch_size": (params["patch_size"],) * 2, + }, + (2, params["num_classes"]), + ] + for params in dict_product( + drop_out=[0.4], + img_size=[224], + in_channels=[3], + num_classes=[8], + num_language_layers=[2], + num_mixed_layers=[3], + num_vision_layers=[4], + patch_size=[16, 32], + ) +] @skip_if_quick diff --git a/tests/networks/nets/test_unetr.py b/tests/networks/nets/test_unetr.py index 9e37750b4a..76ac3a7439 100644 --- a/tests/networks/nets/test_unetr.py +++ b/tests/networks/nets/test_unetr.py @@ -18,41 +18,34 @@ from monai.networks import eval_mode from monai.networks.nets.unetr import UNETR -from tests.test_utils import SkipIfBeforePyTorchVersion, skip_if_quick, test_script_save +from tests.test_utils import SkipIfBeforePyTorchVersion, dict_product, skip_if_quick, test_script_save -TEST_CASE_UNETR = [] -for dropout_rate in [0.4]: - for in_channels in [1]: - for out_channels in [2]: - for hidden_size in [768]: - for img_size in [96, 128]: - for feature_size in [16]: - for num_heads in [8]: - for mlp_dim in [3072]: - for norm_name in ["instance"]: - for proj_type in ["perceptron"]: - for nd in (2, 3): - test_case = [ - { - "in_channels": in_channels, - "out_channels": out_channels, - "img_size": (img_size,) * nd, - "hidden_size": hidden_size, - "feature_size": feature_size, - "norm_name": norm_name, - "mlp_dim": mlp_dim, - "num_heads": num_heads, - "proj_type": proj_type, - "dropout_rate": dropout_rate, - "conv_block": True, - "res_block": False, - }, - (2, in_channels, *([img_size] * nd)), - (2, out_channels, *([img_size] * nd)), - ] - if nd == 2: - test_case[0]["spatial_dims"] = 2 # type: ignore - TEST_CASE_UNETR.append(test_case) +TEST_CASE_UNETR = [ + [ + { + **{k: v for k, v in params.items() if k not in ["img_size", "nd"]}, + "conv_block": True, + "res_block": False, + "img_size": (params["img_size"],) * params["nd"], + **({"spatial_dims": 2} if params["nd"] == 2 else {}), + }, + (2, params["in_channels"], *([params["img_size"]] * params["nd"])), + (2, params["out_channels"], *([params["img_size"]] * params["nd"])), + ] + for params in dict_product( + dropout_rate=[0.4], + feature_size=[16], + hidden_size=[768], + img_size=[96, 128], + in_channels=[1], + mlp_dim=[3072], + nd=[2, 3], + norm_name=["instance"], + num_heads=[8], + out_channels=[2], + proj_type=["perceptron"], + ) +] @skip_if_quick diff --git a/tests/networks/nets/test_vit.py b/tests/networks/nets/test_vit.py index 56b7807449..e7290f89d4 100644 --- a/tests/networks/nets/test_vit.py +++ b/tests/networks/nets/test_vit.py @@ -18,45 +18,39 @@ from monai.networks import eval_mode from monai.networks.nets.vit import ViT -from tests.test_utils import SkipIfBeforePyTorchVersion, skip_if_quick, test_script_save +from tests.test_utils import SkipIfBeforePyTorchVersion, dict_product, skip_if_quick, test_script_save -TEST_CASE_Vit = [] -for dropout_rate in [0.6]: - for in_channels in [4]: - for hidden_size in [768]: - for img_size in [96, 128]: - for patch_size in [16]: - for num_heads in [12]: - for mlp_dim in [3072]: - for num_layers in [4]: - for num_classes in [8]: - for proj_type in ["conv", "perceptron"]: - for classification in [False, True]: - for nd in (2, 3): - test_case = [ - { - "in_channels": in_channels, - "img_size": (img_size,) * nd, - "patch_size": (patch_size,) * nd, - "hidden_size": hidden_size, - "mlp_dim": mlp_dim, - "num_layers": num_layers, - "num_heads": num_heads, - "proj_type": proj_type, - "classification": classification, - "num_classes": num_classes, - "dropout_rate": dropout_rate, - }, - (2, in_channels, *([img_size] * nd)), - (2, (img_size // patch_size) ** nd, hidden_size), - ] - if nd == 2: - test_case[0]["spatial_dims"] = 2 # type: ignore - if classification: - test_case[0]["post_activation"] = False # type: ignore - if test_case[0]["classification"]: # type: ignore - test_case[2] = (2, test_case[0]["num_classes"]) # type: ignore - TEST_CASE_Vit.append(test_case) +TEST_CASE_Vit = [ + ( + [ + { + **{k: v for k, v in params.items() if k not in ["nd"]}, + **({"spatial_dims": 2} if params["nd"] == 2 else {}), + **({"post_activation": False} if params["nd"] == 2 and params["classification"] else {}), + }, + (2, params["in_channels"], *([params["img_size"]] * params["nd"])), + ( + (2, params["num_classes"]) + if params["classification"] + else (2, (params["img_size"] // params["patch_size"]) ** params["nd"], params["hidden_size"]) + ), + ] + ) + for params in dict_product( + dropout_rate=[0.6], + in_channels=[4], + hidden_size=[768], + img_size=[96, 128], + patch_size=[16], + num_heads=[12], + mlp_dim=[3072], + num_layers=[4], + num_classes=[8], + proj_type=["conv", "perceptron"], + classification=[False, True], + nd=[2, 3], + ) +] @skip_if_quick diff --git a/tests/networks/nets/test_vitautoenc.py b/tests/networks/nets/test_vitautoenc.py index 97f144aa2d..3d6923910c 100644 --- a/tests/networks/nets/test_vitautoenc.py +++ b/tests/networks/nets/test_vitautoenc.py @@ -17,32 +17,29 @@ from monai.networks import eval_mode from monai.networks.nets.vitautoenc import ViTAutoEnc -from tests.test_utils import skip_if_quick, skip_if_windows +from tests.test_utils import dict_product, skip_if_quick, skip_if_windows -TEST_CASE_Vitautoenc = [] -for in_channels in [1, 4]: - for img_size in [64, 96, 128]: - for patch_size in [16]: - for proj_type in ["conv", "perceptron"]: - for nd in [2, 3]: - test_case = [ - { - "in_channels": in_channels, - "img_size": (img_size,) * nd, - "patch_size": (patch_size,) * nd, - "hidden_size": 768, - "mlp_dim": 3072, - "num_layers": 4, - "num_heads": 12, - "proj_type": proj_type, - "dropout_rate": 0.6, - "spatial_dims": nd, - }, - (2, in_channels, *([img_size] * nd)), - (2, 1, *([img_size] * nd)), - ] - - TEST_CASE_Vitautoenc.append(test_case) +TEST_CASE_Vitautoenc = [ + [ + { + "in_channels": params["in_channels"], + "img_size": (params["img_size"],) * params["nd"], + "patch_size": (params["patch_size"],) * params["nd"], + "hidden_size": 768, + "mlp_dim": 3072, + "num_layers": 4, + "num_heads": 12, + "proj_type": params["proj_type"], + "dropout_rate": 0.6, + "spatial_dims": params["nd"], + }, + (2, params["in_channels"], *([params["img_size"]] * params["nd"])), + (2, 1, *([params["img_size"]] * params["nd"])), + ] + for params in dict_product( + in_channels=[1, 4], img_size=[64, 96, 128], patch_size=[16], proj_type=["conv", "perceptron"], nd=[2, 3] + ) +] TEST_CASE_Vitautoenc.append( [ diff --git a/tests/test_masked_autoencoder_vit.py b/tests/test_masked_autoencoder_vit.py index b649c1266c..ca2275c81c 100644 --- a/tests/test_masked_autoencoder_vit.py +++ b/tests/test_masked_autoencoder_vit.py @@ -18,54 +18,52 @@ from monai.networks import eval_mode from monai.networks.nets.masked_autoencoder_vit import MaskedAutoEncoderViT -from tests.test_utils import skip_if_quick +from tests.test_utils import dict_product, skip_if_quick TEST_CASE_MaskedAutoEncoderViT = [] -for masking_ratio in [0.5]: - for dropout_rate in [0.6]: - for in_channels in [4]: - for hidden_size in [768]: - for img_size in [96, 128]: - for patch_size in [16]: - for num_heads in [12]: - for mlp_dim in [3072]: - for num_layers in [4]: - for decoder_hidden_size in [384]: - for decoder_mlp_dim in [512]: - for decoder_num_layers in [4]: - for decoder_num_heads in [16]: - for pos_embed_type in ["sincos", "learnable"]: - for proj_type in ["conv", "perceptron"]: - for nd in (2, 3): - test_case = [ - { - "in_channels": in_channels, - "img_size": (img_size,) * nd, - "patch_size": (patch_size,) * nd, - "hidden_size": hidden_size, - "mlp_dim": mlp_dim, - "num_layers": num_layers, - "decoder_hidden_size": decoder_hidden_size, - "decoder_mlp_dim": decoder_mlp_dim, - "decoder_num_layers": decoder_num_layers, - "decoder_num_heads": decoder_num_heads, - "pos_embed_type": pos_embed_type, - "masking_ratio": masking_ratio, - "decoder_pos_embed_type": pos_embed_type, - "num_heads": num_heads, - "proj_type": proj_type, - "dropout_rate": dropout_rate, - }, - (2, in_channels, *([img_size] * nd)), - ( - 2, - (img_size // patch_size) ** nd, - in_channels * (patch_size**nd), - ), - ] - if nd == 2: - test_case[0]["spatial_dims"] = 2 # type: ignore - TEST_CASE_MaskedAutoEncoderViT.append(test_case) + +for base_params in dict_product( + masking_ratio=[0.5], + dropout_rate=[0.6], + in_channels=[4], + hidden_size=[768], + img_size_scalar=[96, 128], + patch_size_scalar=[16], + num_heads=[12], + mlp_dim=[3072], + num_layers=[4], + decoder_hidden_size=[384], + decoder_mlp_dim=[512], + decoder_num_layers=[4], + decoder_num_heads=[16], + pos_embed_type=["sincos", "learnable"], + proj_type=["conv", "perceptron"], +): + img_size_scalar = base_params.pop("img_size_scalar") + patch_size_scalar = base_params.pop("patch_size_scalar") + for nd in (2, 3): + # Parameters for the MaskedAutoEncoderViT model + model_params = base_params.copy() + model_params["img_size"] = (img_size_scalar,) * nd + model_params["patch_size"] = (patch_size_scalar,) * nd + model_params["decoder_pos_embed_type"] = model_params["pos_embed_type"] + + # Expected input and output shapes + input_shape = (2, model_params["in_channels"], *([img_size_scalar] * nd)) + # N, num_patches, patch_dim_product + # num_patches = (img_size // patch_size) ** nd + # patch_dim_product = in_channels * (patch_size**nd) + expected_shape = ( + 2, + (img_size_scalar // patch_size_scalar) ** nd, + model_params["in_channels"] * (patch_size_scalar**nd), + ) + + if nd == 2: + model_params["spatial_dims"] = 2 + + test_case = [model_params, input_shape, expected_shape] + TEST_CASE_MaskedAutoEncoderViT.append(test_case) TEST_CASE_ill_args = [ [{"in_channels": 1, "img_size": (128, 128, 128), "patch_size": (16, 16, 16), "dropout_rate": 5.0}], diff --git a/tests/test_utils.py b/tests/test_utils.py index 97a3181c44..784b25f663 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -28,12 +28,13 @@ import traceback import unittest import warnings +from collections.abc import Iterable from contextlib import contextmanager from functools import partial, reduce from itertools import product from pathlib import Path from subprocess import PIPE, Popen -from typing import Callable, Literal +from typing import Any, Callable from urllib.error import ContentTooShortError, HTTPError import numpy as np @@ -864,18 +865,24 @@ def equal_state_dict(st_1, st_2): TEST_DEVICES.append([torch.device("cuda")]) -def dict_product(trailing=False, format: Literal["list", "dict"] = "dict", **items): +def dict_product(**items: Iterable[Any]) -> list[dict]: + """Create cartesian product, equivalent to a nested for-loop, combinations of the items dict. + + Args: + items: dict of items to be combined. + + Returns: + list: list of dictionaries with the combinations of the input items. + + Example: + >>> dict_product(x=[1, 2], y=[3, 4]) + [{'x': 1, 'y': 3}, {'x': 1, 'y': 4}, {'x': 2, 'y': 3}, {'x': 2, 'y': 4}] + """ keys = items.keys() values = items.values() - for pvalues in product(*values): - dict_comb = dict(zip(keys, pvalues)) - if format == "dict": - if trailing: - yield [dict_comb] + list(pvalues) - else: - yield dict_comb - else: - yield pvalues + prod_values = product(*values) + prod_dict = [dict(zip(keys, v)) for v in prod_values] + return prod_dict if __name__ == "__main__": diff --git a/tests/transforms/spatial/test_spatial_resampled.py b/tests/transforms/spatial/test_spatial_resampled.py index 12d54cabfc..a03986159a 100644 --- a/tests/transforms/spatial/test_spatial_resampled.py +++ b/tests/transforms/spatial/test_spatial_resampled.py @@ -22,7 +22,7 @@ from monai.data.utils import to_affine_nd from monai.transforms.spatial.dictionary import SpatialResampled from tests.lazy_transforms_utils import test_resampler_lazy -from tests.test_utils import TEST_DEVICES, assert_allclose +from tests.test_utils import TEST_DEVICES, assert_allclose, dict_product ON_AARCH64 = platform.machine() == "aarch64" if ON_AARCH64: @@ -42,27 +42,28 @@ ] for dst, expct in zip(destinations_3d, expected_3d): - for device in TEST_DEVICES: - for align in (True, False): - for dtype in (torch.float32, torch.float64): - interp = ("nearest", "bilinear") - for interp_mode in interp: - for padding_mode in ("zeros", "border", "reflection"): - TESTS.append( - [ - np.arange(12).reshape((1, 2, 2, 3)) + 1.0, # data - *device, - dst, - { - "dst_keys": "dst_affine", - "dtype": dtype, - "align_corners": align, - "mode": interp_mode, - "padding_mode": padding_mode, - }, - expct, - ] - ) + TESTS.extend( + [ + [ + np.arange(12).reshape((1, 2, 2, 3)) + 1.0, # data + *params["device"], + dst, + { + **{k: v for k, v in params.items() if k not in ["device", "interp_mode"]}, + "dst_keys": "dst_affine", + "padding_mode": "zeros", + }, + expct, + ] + for params in dict_product( + device=TEST_DEVICES, + align_corners=[True, False], + dtype=[torch.float32, torch.float64], + interp_mode=["nearest", "bilinear"], + padding_mode=["zeros", "border", "reflection"], + ) + ] + ) destinations_2d = [ torch.tensor([[1.0, 0.0, 0.0], [0.0, -1.0, 1.0], [0.0, 0.0, 1.0]]), # flip the second @@ -72,25 +73,25 @@ expected_2d = [torch.tensor([[[2.0, 1.0], [4.0, 3.0]]]), torch.tensor([[[3.0, 4.0], [1.0, 2.0]]])] for dst, expct in zip(destinations_2d, expected_2d): - for device in TEST_DEVICES: - for align in (False, True): - for dtype in (torch.float32, torch.float64): - for interp_mode in ("nearest", "bilinear"): - TESTS.append( - [ - np.arange(4).reshape((1, 2, 2)) + 1.0, # data - *device, - dst, - { - "dst_keys": "dst_affine", - "dtype": dtype, - "align_corners": align, - "mode": interp_mode, - "padding_mode": "zeros", - }, - expct, - ] - ) + TESTS += [ + [ + np.arange(4).reshape((1, 2, 2)) + 1.0, # data + *params.pop("device"), + dst, + { + **{k: v for k, v in params.items() if k not in ["align", "interp_mode"]}, + "dst_keys": "dst_affine", + "padding_mode": "zeros", + }, + expct, + ] + for params in dict_product( + device=TEST_DEVICES, + align=[False, True], + dtype=[torch.float32, torch.float64], + interp_mode=["nearest", "bilinear"], + ) + ] class TestSpatialResample(unittest.TestCase): diff --git a/tests/transforms/test_gibbs_noise.py b/tests/transforms/test_gibbs_noise.py index 1f96595a26..889222fdb3 100644 --- a/tests/transforms/test_gibbs_noise.py +++ b/tests/transforms/test_gibbs_noise.py @@ -25,8 +25,9 @@ _, has_torch_fft = optional_import("torch.fft", name="fftshift") -params = {"shape": ((128, 64), (64, 48, 80)), "input_type": TEST_NDARRAYS if has_torch_fft else [np.array]} -TEST_CASES = list(dict_product(format="list", **params)) +shapes = ((128, 64), (64, 48, 80)) +input_types = TEST_NDARRAYS if has_torch_fft else [np.array] +TEST_CASES = [[p_dict["shape"], p_dict["input_type"]] for p_dict in dict_product(shape=shapes, input_type=input_types)] class TestGibbsNoise(unittest.TestCase): diff --git a/tests/transforms/test_spacing.py b/tests/transforms/test_spacing.py index f3ef25d1d2..3862472753 100644 --- a/tests/transforms/test_spacing.py +++ b/tests/transforms/test_spacing.py @@ -24,249 +24,200 @@ from monai.transforms import Spacing from monai.utils import fall_back_tuple from tests.lazy_transforms_utils import test_resampler_lazy -from tests.test_utils import TEST_DEVICES, TEST_NDARRAYS_ALL, assert_allclose, skip_if_quick +from tests.test_utils import TEST_DEVICES, TEST_NDARRAYS_ALL, assert_allclose, dict_product, skip_if_quick -TESTS: list[list] = [] -for device in TEST_DEVICES: - TESTS.append( - [ - {"pixdim": (1.0, 1.5), "padding_mode": "zeros", "dtype": float}, - torch.arange(4).reshape((1, 2, 2)) + 1.0, # data - torch.eye(4), - {}, - torch.tensor([[[1.0, 1.0], [3.0, 2.0]]]), - *device, - ] - ) - TESTS.append( - [ - {"pixdim": 1.0, "padding_mode": "zeros", "dtype": float}, - torch.ones((1, 2, 1, 2)), # data - torch.eye(4), - {}, - torch.tensor([[[[1.0, 1.0]], [[1.0, 1.0]]]]), - *device, - ] - ) - TESTS.append( - [ - {"pixdim": 2.0, "padding_mode": "zeros", "dtype": float}, - torch.arange(4).reshape((1, 2, 2)) + 1.0, # data - torch.eye(4), - {}, - torch.tensor([[[1.0, 0.0], [0.0, 0.0]]]), - *device, - ] - ) - TESTS.append( - [ - {"pixdim": (1.0, 1.0, 1.0), "padding_mode": "zeros", "dtype": float}, - torch.ones((1, 2, 1, 2)), # data - torch.eye(4), - {}, - torch.tensor([[[[1.0, 1.0]], [[1.0, 1.0]]]]), - *device, - ] - ) - TESTS.append( - [ - {"pixdim": (1.0, 0.2, 1.5), "diagonal": False, "padding_mode": "zeros", "align_corners": True}, - torch.ones((1, 2, 1, 2)), # data - torch.tensor([[2, 1, 0, 4], [-1, -3, 0, 5], [0, 0, 2.0, 5], [0, 0, 0, 1]]), - {}, - ( - torch.tensor([[[[0.75, 0.75]], [[0.75, 0.75]], [[0.75, 0.75]]]]) - if USE_COMPILED - else torch.tensor([[[[0.95527864, 0.95527864]], [[1.0, 1.0]], [[1.0, 1.0]]]]) - ), - *device, - ] - ) - TESTS.append( - [ - {"pixdim": (3.0, 1.0), "padding_mode": "zeros"}, - torch.arange(24).reshape((2, 3, 4)), # data - torch.as_tensor(np.diag([-3.0, 0.2, 1.5, 1])), - {}, - torch.tensor([[[0, 0], [4, 0], [8, 0]], [[12, 0], [16, 0], [20, 0]]]), - *device, - ] - ) - TESTS.append( - [ - {"pixdim": (3.0, 1.0), "padding_mode": "zeros"}, - torch.arange(24).reshape((2, 3, 4)), # data - torch.eye(4), - {}, - torch.tensor([[[0, 1, 2, 3], [0, 0, 0, 0]], [[12, 13, 14, 15], [0, 0, 0, 0]]]), - *device, - ] - ) - TESTS.append( - [ - {"pixdim": (1.0, 1.0), "align_corners": True}, - torch.arange(24).reshape((2, 3, 4)), # data - torch.eye(4), - {}, - torch.tensor( - [[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]] - ), - *device, - ] - ) - TESTS.append( - [ - {"pixdim": (4.0, 5.0, 6.0)}, - torch.arange(24).reshape((1, 2, 3, 4)), # data - torch.tensor([[-4, 0, 0, 4], [0, 5, 0, -5], [0, 0, 6, -6], [0, 0, 0, 1]]), - {}, - torch.arange(24).reshape((1, 2, 3, 4)), # data - *device, - ] - ) - TESTS.append( - [ - {"pixdim": (4.0, 5.0, 6.0), "diagonal": True}, - torch.arange(24).reshape((1, 2, 3, 4)), # data - torch.tensor([[-4, 0, 0, 4], [0, 5, 0, 0], [0, 0, 6, 0], [0, 0, 0, 1]]), - {}, - torch.tensor( - [[[[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]], [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]]] - ), - *device, - ] - ) - TESTS.append( - [ - {"pixdim": (4.0, 5.0, 6.0), "padding_mode": "border", "diagonal": True}, - torch.arange(24).reshape((1, 2, 3, 4)), # data - torch.tensor([[-4, 0, 0, -4], [0, 5, 0, 0], [0, 0, 6, 0], [0, 0, 0, 1]]), - {}, - torch.tensor( - [[[[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]], [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]]] - ), - *device, - ] - ) - TESTS.append( - [ - {"pixdim": (4.0, 5.0, 6.0), "padding_mode": "border", "diagonal": True}, - torch.arange(24).reshape((1, 2, 3, 4)), # data - torch.tensor([[-4, 0, 0, -4], [0, 5, 0, 0], [0, 0, 6, 0], [0, 0, 0, 1]]), - {"mode": "nearest"}, - torch.tensor( - [[[[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]], [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]]] - ), - *device, - ] - ) - TESTS.append( - [ - {"pixdim": (1.9, 4.0), "padding_mode": "zeros", "diagonal": True}, - torch.arange(24).reshape((1, 4, 6)), # data - torch.tensor([[-4, 0, 0, -4], [0, 5, 0, 0], [0, 0, 6, 0], [0, 0, 0, 1]]), - {"mode": "nearest"}, - torch.tensor( - [ - [ - [18.0, 19.0, 20.0, 20.0, 21.0, 22.0, 23.0], - [18.0, 19.0, 20.0, 20.0, 21.0, 22.0, 23.0], - [12.0, 13.0, 14.0, 14.0, 15.0, 16.0, 17.0], - [12.0, 13.0, 14.0, 14.0, 15.0, 16.0, 17.0], - [6.0, 7.0, 8.0, 8.0, 9.0, 10.0, 11.0], - [6.0, 7.0, 8.0, 8.0, 9.0, 10.0, 11.0], - [0.0, 1.0, 2.0, 2.0, 3.0, 4.0, 5.0], - ] - ] - ), - *device, - ] - ) - TESTS.append( - [ - {"pixdim": (5.0, 3.0), "padding_mode": "border", "diagonal": True, "dtype": torch.float32}, - torch.arange(24).reshape((1, 4, 6)), # data - torch.tensor([[-4, 0, 0, 0], [0, 5, 0, 0], [0, 0, 6, 0], [0, 0, 0, 1]]), - {"mode": "bilinear"}, - torch.tensor( +# Define the static parts of each test case +_template_5_expected_output = ( + torch.tensor([[[[0.75, 0.75]], [[0.75, 0.75]], [[0.75, 0.75]]]]) + if USE_COMPILED + else torch.tensor([[[[0.95527864, 0.95527864]], [[1.0, 1.0]], [[1.0, 1.0]]]]) +) + +all_template_parts = [ + [ + {"pixdim": (1.0, 1.5), "padding_mode": "zeros", "dtype": float}, + torch.arange(4).reshape((1, 2, 2)) + 1.0, + torch.eye(4), + {}, + torch.tensor([[[1.0, 1.0], [3.0, 2.0]]]), + ], + [ + {"pixdim": 1.0, "padding_mode": "zeros", "dtype": float}, + torch.ones((1, 2, 1, 2)), + torch.eye(4), + {}, + torch.tensor([[[[1.0, 1.0]], [[1.0, 1.0]]]]), + ], + [ + {"pixdim": 2.0, "padding_mode": "zeros", "dtype": float}, + torch.arange(4).reshape((1, 2, 2)) + 1.0, + torch.eye(4), + {}, + torch.tensor([[[1.0, 0.0], [0.0, 0.0]]]), + ], + [ + {"pixdim": (1.0, 1.0, 1.0), "padding_mode": "zeros", "dtype": float}, + torch.ones((1, 2, 1, 2)), + torch.eye(4), + {}, + torch.tensor([[[[1.0, 1.0]], [[1.0, 1.0]]]]), + ], + [ + {"pixdim": (1.0, 0.2, 1.5), "diagonal": False, "padding_mode": "zeros", "align_corners": True}, + torch.ones((1, 2, 1, 2)), + torch.tensor([[2, 1, 0, 4], [-1, -3, 0, 5], [0, 0, 2.0, 5], [0, 0, 0, 1]]), + {}, + _template_5_expected_output, + ], + [ + {"pixdim": (3.0, 1.0), "padding_mode": "zeros"}, + torch.arange(24).reshape((2, 3, 4)), + torch.as_tensor(np.diag([-3.0, 0.2, 1.5, 1])), + {}, + torch.tensor([[[0, 0], [4, 0], [8, 0]], [[12, 0], [16, 0], [20, 0]]]), + ], + [ + {"pixdim": (3.0, 1.0), "padding_mode": "zeros"}, + torch.arange(24).reshape((2, 3, 4)), + torch.eye(4), + {}, + torch.tensor([[[0, 1, 2, 3], [0, 0, 0, 0]], [[12, 13, 14, 15], [0, 0, 0, 0]]]), + ], + [ + {"pixdim": (1.0, 1.0), "align_corners": True}, + torch.arange(24).reshape((2, 3, 4)), + torch.eye(4), + {}, + torch.tensor( + [[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]] + ), + ], + [ + {"pixdim": (4.0, 5.0, 6.0)}, + torch.arange(24).reshape((1, 2, 3, 4)), + torch.tensor([[-4, 0, 0, 4], [0, 5, 0, -5], [0, 0, 6, -6], [0, 0, 0, 1]]), + {}, + torch.arange(24).reshape((1, 2, 3, 4)), + ], + [ + {"pixdim": (4.0, 5.0, 6.0), "diagonal": True}, + torch.arange(24).reshape((1, 2, 3, 4)), + torch.tensor([[-4, 0, 0, 4], [0, 5, 0, 0], [0, 0, 6, 0], [0, 0, 0, 1]]), + {}, + torch.tensor( + [[[[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]], [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]]] + ), + ], + [ + {"pixdim": (4.0, 5.0, 6.0), "padding_mode": "border", "diagonal": True}, + torch.arange(24).reshape((1, 2, 3, 4)), + torch.tensor([[-4, 0, 0, -4], [0, 5, 0, 0], [0, 0, 6, 0], [0, 0, 0, 1]]), + {}, + torch.tensor( + [[[[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]], [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]]] + ), + ], + [ + {"pixdim": (4.0, 5.0, 6.0), "padding_mode": "border", "diagonal": True}, + torch.arange(24).reshape((1, 2, 3, 4)), + torch.tensor([[-4, 0, 0, -4], [0, 5, 0, 0], [0, 0, 6, 0], [0, 0, 0, 1]]), + {"mode": "nearest"}, + torch.tensor( + [[[[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]], [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]]] + ), + ], + [ + {"pixdim": (1.9, 4.0), "padding_mode": "zeros", "diagonal": True}, + torch.arange(24).reshape((1, 4, 6)), + torch.tensor([[-4, 0, 0, -4], [0, 5, 0, 0], [0, 0, 6, 0], [0, 0, 0, 1]]), + {"mode": "nearest"}, + torch.tensor( + [ [ - [ - [18.0, 18.6, 19.2, 19.8, 20.400002, 21.0, 21.6, 22.2, 22.8], - [10.5, 11.1, 11.700001, 12.299999, 12.900001, 13.5, 14.1, 14.700001, 15.3], - [3.0, 3.6000001, 4.2000003, 4.8, 5.4000006, 6.0, 6.6000004, 7.200001, 7.8], - ] + [18.0, 19.0, 20.0, 20.0, 21.0, 22.0, 23.0], + [18.0, 19.0, 20.0, 20.0, 21.0, 22.0, 23.0], + [12.0, 13.0, 14.0, 14.0, 15.0, 16.0, 17.0], + [12.0, 13.0, 14.0, 14.0, 15.0, 16.0, 17.0], + [6.0, 7.0, 8.0, 8.0, 9.0, 10.0, 11.0], + [6.0, 7.0, 8.0, 8.0, 9.0, 10.0, 11.0], + [0.0, 1.0, 2.0, 2.0, 3.0, 4.0, 5.0], ] - ), - *device, - ] - ) - TESTS.append( - [ - {"pixdim": (5.0, 3.0), "padding_mode": "zeros", "diagonal": True, "dtype": torch.float32}, - torch.arange(24).reshape((1, 4, 6)), # data - torch.tensor([[-4, 0, 0, 0], [0, 5, 0, 0], [0, 0, 6, 0], [0, 0, 0, 1]]), - {"mode": "bilinear"}, - torch.tensor( + ] + ), + ], + [ + {"pixdim": (5.0, 3.0), "padding_mode": "border", "diagonal": True, "dtype": torch.float32}, + torch.arange(24).reshape((1, 4, 6)), + torch.tensor([[-4, 0, 0, 0], [0, 5, 0, 0], [0, 0, 6, 0], [0, 0, 0, 1]]), + {"mode": "bilinear"}, + torch.tensor( + [ [ - [ - [18.0000, 18.6000, 19.2000, 19.8000, 20.4000, 21.0000, 21.6000, 22.2000, 22.8000], - [10.5000, 11.1000, 11.7000, 12.3000, 12.9000, 13.5000, 14.1000, 14.7000, 15.3000], - [3.0000, 3.6000, 4.2000, 4.8000, 5.4000, 6.0000, 6.6000, 7.2000, 7.8000], - ] + [18.0, 18.6, 19.2, 19.8, 20.400002, 21.0, 21.6, 22.2, 22.8], + [10.5, 11.1, 11.700001, 12.299999, 12.900001, 13.5, 14.1, 14.700001, 15.3], + [3.0, 3.6000001, 4.2000003, 4.8, 5.4000006, 6.0, 6.6000004, 7.200001, 7.8], ] - ), - *device, - ] - ) - TESTS.append( - [ - {"pixdim": [-1, -1, 0.5], "padding_mode": "zeros", "dtype": float}, - torch.ones((1, 2, 1, 2)), # data - torch.eye(4), - {}, - torch.tensor([[[[1.0, 1.0, 1.0]], [[1.0, 1.0, 1.0]]]]), - *device, - ] - ) - TESTS.append( # 5D input - [ - {"pixdim": [-1, -1, 0.5], "padding_mode": "zeros", "dtype": float, "align_corners": True}, - torch.ones((1, 2, 2, 2, 1)), # data - torch.eye(4), - {}, - torch.ones((1, 2, 2, 3, 1)), - *device, - ] - ) - TESTS.append( # 5D input - [ - {"pixdim": 0.5, "padding_mode": "constant", "mode": "nearest", "scale_extent": True}, - torch.ones((1, 368, 336, 368)), # data - torch.tensor( + ] + ), + ], + [ + {"pixdim": (5.0, 3.0), "padding_mode": "zeros", "diagonal": True, "dtype": torch.float32}, + torch.arange(24).reshape((1, 4, 6)), + torch.tensor([[-4, 0, 0, 0], [0, 5, 0, 0], [0, 0, 6, 0], [0, 0, 0, 1]]), + {"mode": "bilinear"}, + torch.tensor( + [ [ - [0.41, 0.005, 0.008, -79.7], - [-0.0049, 0.592, 0.0664, -57.4], - [-0.0073, -0.0972, 0.404, -32.1], - [0.0, 0.0, 0.0, 1.0], + [18.0000, 18.6000, 19.2000, 19.8000, 20.4000, 21.0000, 21.6000, 22.2000, 22.8000], + [10.5000, 11.1000, 11.7000, 12.3000, 12.9000, 13.5000, 14.1000, 14.7000, 15.3000], + [3.0000, 3.6000, 4.2000, 4.8000, 5.4000, 6.0000, 6.6000, 7.2000, 7.8000], ] - ), - {}, - torch.ones((1, 302, 403, 301)), - *device, - ] - ) + ] + ), + ], + [ + {"pixdim": [-1, -1, 0.5], "padding_mode": "zeros", "dtype": float}, + torch.ones((1, 2, 1, 2)), + torch.eye(4), + {}, + torch.tensor([[[[1.0, 1.0, 1.0]], [[1.0, 1.0, 1.0]]]]), + ], + [ # 5D input + {"pixdim": [-1, -1, 0.5], "padding_mode": "zeros", "dtype": float, "align_corners": True}, + torch.ones((1, 2, 2, 2, 1)), + torch.eye(4), + {}, + torch.ones((1, 2, 2, 3, 1)), + ], + [ # 5D input + {"pixdim": 0.5, "padding_mode": "constant", "mode": "nearest", "scale_extent": True}, + torch.ones((1, 368, 336, 368)), + torch.tensor( + [ + [0.41, 0.005, 0.008, -79.7], + [-0.0049, 0.592, 0.0664, -57.4], + [-0.0073, -0.0972, 0.404, -32.1], + [0.0, 0.0, 0.0, 1.0], + ] + ), + {}, + torch.ones((1, 302, 403, 301)), + ], +] +TESTS: list[list] = [ + params["template"] + [*params["device_val"]] + for params in dict_product(template=all_template_parts, device_val=TEST_DEVICES) +] -TESTS_TORCH = [] -for track_meta in (False, True): - for p in TEST_NDARRAYS_ALL: - TESTS_TORCH.append([[1.2, 1.3, 0.9], p(torch.zeros((1, 3, 4, 5))), track_meta]) +TESTS_TORCH = [ + [[1.2, 1.3, 0.9], params["p"](torch.zeros((1, 3, 4, 5))), params["track_meta"]] + for params in dict_product(track_meta=[False, True], p=TEST_NDARRAYS_ALL) +] -TEST_INVERSE = [] -for d in TEST_DEVICES: - for recompute in (False, True): - for align in (False, True): - for scale_extent in (False, True): - TEST_INVERSE.append([*d, recompute, align, scale_extent]) +TEST_INVERSE = [ + [*params["d"], params["recompute"], params["align"], params["scale_extent"]] + for params in dict_product(d=TEST_DEVICES, recompute=[False, True], align=[False, True], scale_extent=[False, True]) +] @skip_if_quick diff --git a/tests/transforms/test_spatial_resample.py b/tests/transforms/test_spatial_resample.py index 7962c77f1c..becd909048 100644 --- a/tests/transforms/test_spatial_resample.py +++ b/tests/transforms/test_spatial_resample.py @@ -24,7 +24,7 @@ from monai.transforms import SpatialResample from monai.utils import optional_import from tests.lazy_transforms_utils import test_resampler_lazy -from tests.test_utils import TEST_DEVICES, TEST_NDARRAYS_ALL, assert_allclose +from tests.test_utils import TEST_DEVICES, TEST_NDARRAYS_ALL, assert_allclose, dict_product TESTS = [] @@ -68,23 +68,23 @@ expected_2d = [torch.tensor([[[2.0, 1.0], [4.0, 3.0]]]), torch.tensor([[[3.0, 4.0], [1.0, 2.0]]])] for dst, expct in zip(destinations_2d, expected_2d): - for device in TEST_DEVICES: - for align in (False, True): - for interp_mode in ("nearest", "bilinear"): - TESTS.append( - [ - torch.arange(4).reshape((1, 2, 2)) + 1.0, - *device, - { - "dst_affine": dst, - "dtype": torch.float32, - "align_corners": align, - "mode": interp_mode, - "padding_mode": "zeros", - }, - expct, - ] - ) + TESTS.extend( + [ + [ + torch.arange(4).reshape((1, 2, 2)) + 1.0, + *params["device"], + { + "dst_affine": dst, + "dtype": torch.float32, + "align_corners": params["align"], + "mode": params["interp_mode"], + "padding_mode": "zeros", + }, + expct, + ] + for params in dict_product(device=TEST_DEVICES, align=[False, True], interp_mode=["nearest", "bilinear"]) + ] + ) TEST_4_5_D = [] for device in TEST_DEVICES: diff --git a/tests/transforms/utility/test_splitdimd.py b/tests/transforms/utility/test_splitdimd.py index 6e221d3d52..9f0858f7e8 100644 --- a/tests/transforms/utility/test_splitdimd.py +++ b/tests/transforms/utility/test_splitdimd.py @@ -21,14 +21,9 @@ from monai.data.meta_tensor import MetaTensor from monai.transforms import LoadImaged from monai.transforms.utility.dictionary import SplitDimd -from tests.test_utils import TEST_NDARRAYS, assert_allclose, make_nifti_image, make_rand_affine +from tests.test_utils import TEST_NDARRAYS, assert_allclose, dict_product, make_nifti_image, make_rand_affine -TESTS = [] -for p in TEST_NDARRAYS: - for keepdim in (True, False): - for update_meta in (True, False): - for list_output in (True, False): - TESTS.append((keepdim, p, update_meta, list_output)) +TESTS = list(dict_product(keepdim=[True, False], p=TEST_NDARRAYS, update_meta=[True, False], list_output=[True, False])) class TestSplitDimd(unittest.TestCase): @@ -44,9 +39,8 @@ def setUpClass(cls) -> None: cls.data = loader(data) @parameterized.expand(TESTS) - def test_correct(self, keepdim, im_type, update_meta, list_output): + def test_correct(self, keepdim, _, update_meta, list_output): data = deepcopy(self.data) - data["i"] = im_type(data["i"]) arr = data["i"] for dim in range(arr.ndim): out = SplitDimd("i", dim=dim, keepdim=keepdim, update_meta=update_meta, list_output=list_output)(data) diff --git a/tests/utils/test_pad_mode.py b/tests/utils/test_pad_mode.py index a4a4012fc5..e13ec8aacd 100644 --- a/tests/utils/test_pad_mode.py +++ b/tests/utils/test_pad_mode.py @@ -18,21 +18,27 @@ from monai.transforms import CastToType, Pad from monai.utils import NumpyPadMode, PytorchPadMode -from tests.test_utils import SkipIfBeforePyTorchVersion +from tests.test_utils import SkipIfBeforePyTorchVersion, dict_product @SkipIfBeforePyTorchVersion((1, 10, 1)) class TestPadMode(unittest.TestCase): def test_pad(self): expected_shapes = {3: (1, 15, 10), 4: (1, 10, 6, 7)} - for t in (float, int, np.uint8, np.int16, np.float32, bool): - for d in ("cuda:0", "cpu") if torch.cuda.is_available() else ("cpu",): - for s in ((1, 10, 10), (1, 5, 6, 7)): - for m in list(PytorchPadMode) + list(NumpyPadMode): - a = torch.rand(s) - to_pad = [(0, 0), (2, 3)] if len(s) == 3 else [(0, 0), (2, 3), (0, 0), (0, 0)] - out = Pad(to_pad=to_pad, mode=m)(CastToType(dtype=t)(a).to(d)) - self.assertEqual(out.shape, expected_shapes[len(s)]) + devices = ("cuda:0", "cpu") if torch.cuda.is_available() else ("cpu",) + shapes = ((1, 10, 10), (1, 5, 6, 7)) + types = (float, int, np.uint8, np.int16, np.float32, bool) + modes = list(PytorchPadMode) + list(NumpyPadMode) + + for params in dict_product(t=types, d=devices, s=shapes, m=modes): + t = params["t"] + d = params["d"] + s = params["s"] + m = params["m"] + a = torch.rand(s) + to_pad = [(0, 0), (2, 3)] if len(s) == 3 else [(0, 0), (2, 3), (0, 0), (0, 0)] + out = Pad(to_pad=to_pad, mode=m)(CastToType(dtype=t)(a).to(d)) + self.assertEqual(out.shape, expected_shapes[len(s)]) if __name__ == "__main__":