Skip to content

Commit e39bad9

Browse files
advcu987ericspod
andauthored
8267 fix normalize intensity (#8286)
Fixes #8267 . ### Description Fix channel-wise intensity normalization for integer type inputs. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [ ] Non-breaking change (fix or new feature that would not break existing functionality). - [x] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [x] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: advcu987 <adrianvoicu.tm@gmail.com> Signed-off-by: advcu <65158236+advcu987@users.noreply.github.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com>
1 parent 56d1f62 commit e39bad9

File tree

2 files changed

+25
-0
lines changed

2 files changed

+25
-0
lines changed

monai/transforms/intensity/array.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -821,6 +821,7 @@ class NormalizeIntensity(Transform):
821821
mean and std on each channel separately.
822822
When `channel_wise` is True, the first dimension of `subtrahend` and `divisor` should
823823
be the number of image channels if they are not None.
824+
If the input is not of floating point type, it will be converted to float32
824825
825826
Args:
826827
subtrahend: the amount to subtract by (usually the mean).
@@ -907,6 +908,9 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
907908
if self.divisor is not None and len(self.divisor) != len(img):
908909
raise ValueError(f"img has {len(img)} channels, but divisor has {len(self.divisor)} components.")
909910

911+
if not img.dtype.is_floating_point:
912+
img, *_ = convert_data_type(img, dtype=torch.float32)
913+
910914
for i, d in enumerate(img):
911915
img[i] = self._normalize( # type: ignore
912916
d,

tests/test_normalize_intensity.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,27 @@ def test_channel_wise(self, im_type):
108108
normalized = normalizer(input_data)
109109
assert_allclose(normalized, im_type(expected), type_test="tensor")
110110

111+
@parameterized.expand([[p] for p in TEST_NDARRAYS])
112+
def test_channel_wise_int(self, im_type):
113+
normalizer = NormalizeIntensity(nonzero=True, channel_wise=True)
114+
input_data = im_type(torch.arange(1, 25).reshape(2, 3, 4))
115+
expected = np.array(
116+
[
117+
[
118+
[-1.593255, -1.3035723, -1.0138896, -0.7242068],
119+
[-0.4345241, -0.1448414, 0.1448414, 0.4345241],
120+
[0.7242068, 1.0138896, 1.3035723, 1.593255],
121+
],
122+
[
123+
[-1.593255, -1.3035723, -1.0138896, -0.7242068],
124+
[-0.4345241, -0.1448414, 0.1448414, 0.4345241],
125+
[0.7242068, 1.0138896, 1.3035723, 1.593255],
126+
],
127+
]
128+
)
129+
normalized = normalizer(input_data)
130+
assert_allclose(normalized, im_type(expected), type_test="tensor", rtol=1e-7, atol=1e-7) # tolerance
131+
111132
@parameterized.expand([[p] for p in TEST_NDARRAYS])
112133
def test_value_errors(self, im_type):
113134
input_data = im_type(np.array([[0.0, 3.0, 0.0, 4.0], [0.0, 4.0, 0.0, 5.0]]))

0 commit comments

Comments
 (0)