Skip to content

Commit 926baee

Browse files
authored
Merge pull request #115 from phcerdan/set_riesz_rotation_matrix_to_complex
ENH: Set RieszRotationMatrix ValueType to std::complex
2 parents 2549b15 + 31c4d79 commit 926baee

File tree

7 files changed

+130
-61
lines changed

7 files changed

+130
-61
lines changed

include/itkRieszRotationMatrix.h

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,12 @@
1717
*=========================================================================*/
1818
#ifndef itkRieszRotationMatrix_h
1919
#define itkRieszRotationMatrix_h
20-
#include "itkVariableSizeMatrix.h"
2120
#include <vector>
21+
#include "itkImage.h"
2222
#include "itkMatrix.h"
2323
#include "itkRieszUtilities.h"
24+
#include "itkVariableSizeMatrix.h"
25+
#include "itkVectorContainer.h"
2426

2527
namespace itk
2628
{
@@ -35,26 +37,28 @@ namespace itk
3537
*
3638
* \f[ M := p(N,d) = \frac{(N+d-1)!}{(d-1)! N!} \f]
3739
*
38-
* The rotation matrix is a dxd matrix.
40+
* The rotation matrix is a dxd matrix of real type.
3941
*
4042
* \sa RieszFrequencyFunction
4143
* \sa RieszFrequencyFilterBankGenerator
4244
*
4345
* \ingroup IsotropicWavelets
4446
*/
4547

46-
template <typename T = double, unsigned int VImageDimension = 3>
47-
class RieszRotationMatrix : public itk::VariableSizeMatrix<T>
48+
template <unsigned int VImageDimension>
49+
class RieszRotationMatrix : public itk::VariableSizeMatrix<std::complex<double>>
4850
{
4951
public:
5052
/** Standard type alias */
5153
using Self = RieszRotationMatrix;
52-
using Superclass = itk::VariableSizeMatrix<T>;
54+
using ValueType = std::complex<double>;
55+
using Superclass = itk::VariableSizeMatrix<ValueType>;
5356

5457
/** Component value type */
55-
using ValueType = typename Superclass::ValueType;
58+
using RealType = typename ValueType::value_type;
5659
using InternalMatrixType = typename Superclass::InternalMatrixType;
57-
using SpatialRotationMatrixType = itk::Matrix<T, VImageDimension, VImageDimension>;
60+
using SpatialRotationMatrixType = itk::Matrix<RealType, VImageDimension, VImageDimension>;
61+
using ComplexImageType = itk::Image<ValueType, VImageDimension>;
5862

5963
/** Matrix by std::vector<TImage> multiplication.
6064
* To perform the rotation with the output of
@@ -194,7 +198,7 @@ class RieszRotationMatrix : public itk::VariableSizeMatrix<T>
194198
return this->m_MaxAbsoluteDifferenceCloseToZero;
195199
}
196200
inline void
197-
SetMaxAbsoluteDifferenceCloseToZero(const ValueType & maxAbsoluteDifference)
201+
SetMaxAbsoluteDifferenceCloseToZero(const RealType & maxAbsoluteDifference)
198202
{
199203
this->m_MaxAbsoluteDifferenceCloseToZero = maxAbsoluteDifference;
200204
}
@@ -228,10 +232,11 @@ class RieszRotationMatrix : public itk::VariableSizeMatrix<T>
228232
#endif
229233

230234
private:
235+
using ResultValueType = std::complex<long double>;
231236
SpatialRotationMatrixType m_SpatialRotationMatrix;
232237
unsigned int m_Order{ 0 };
233238
unsigned int m_Components{ 0 };
234-
ValueType m_MaxAbsoluteDifferenceCloseToZero;
239+
RealType m_MaxAbsoluteDifferenceCloseToZero;
235240
bool m_Debug{ false };
236241

237242
}; // end of class

include/itkRieszRotationMatrix.hxx

Lines changed: 45 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -25,42 +25,42 @@
2525

2626
namespace itk
2727
{
28-
template <typename T, unsigned int VImageDimension>
29-
RieszRotationMatrix<T, VImageDimension>::RieszRotationMatrix()
28+
template <unsigned int VImageDimension>
29+
RieszRotationMatrix<VImageDimension>::RieszRotationMatrix()
3030
: Superclass()
3131
, m_SpatialRotationMatrix()
3232
,
3333

34-
m_MaxAbsoluteDifferenceCloseToZero(1 * itk::NumericTraits<ValueType>::epsilon())
34+
m_MaxAbsoluteDifferenceCloseToZero(1 * itk::NumericTraits<RealType>::epsilon())
3535

3636
{}
3737

38-
template <typename T, unsigned int VImageDimension>
39-
RieszRotationMatrix<T, VImageDimension>::RieszRotationMatrix(const Self & rieszMatrix)
38+
template <unsigned int VImageDimension>
39+
RieszRotationMatrix<VImageDimension>::RieszRotationMatrix(const Self & rieszMatrix)
4040
: Superclass(rieszMatrix)
4141
, m_SpatialRotationMatrix(rieszMatrix.GetSpatialRotationMatrix())
4242
, m_Order(rieszMatrix.GetOrder())
4343
, m_Components(rieszMatrix.GetComponents())
44-
, m_MaxAbsoluteDifferenceCloseToZero(1 * itk::NumericTraits<ValueType>::epsilon())
44+
, m_MaxAbsoluteDifferenceCloseToZero(1 * itk::NumericTraits<RealType>::epsilon())
4545

4646
{}
4747

48-
template <typename T, unsigned int VImageDimension>
49-
RieszRotationMatrix<T, VImageDimension>::RieszRotationMatrix(const SpatialRotationMatrixType & spatialRotationMatrix,
50-
const unsigned int & order)
48+
template <unsigned int VImageDimension>
49+
RieszRotationMatrix<VImageDimension>::RieszRotationMatrix(const SpatialRotationMatrixType & spatialRotationMatrix,
50+
const unsigned int & order)
5151
: Superclass()
5252
, m_SpatialRotationMatrix(spatialRotationMatrix)
53-
, m_MaxAbsoluteDifferenceCloseToZero(1 * itk::NumericTraits<ValueType>::epsilon())
53+
, m_MaxAbsoluteDifferenceCloseToZero(1 * itk::NumericTraits<RealType>::epsilon())
5454

5555
{
5656
this->SetOrder(order);
5757
this->ComputeSteerableMatrix();
5858
}
5959

60-
template <typename T, unsigned int VImageDimension>
60+
template <unsigned int VImageDimension>
6161
template <typename TInputValue>
6262
std::vector<TInputValue>
63-
RieszRotationMatrix<T, VImageDimension>::MultiplyWithVector(const std::vector<TInputValue> & vect) const
63+
RieszRotationMatrix<VImageDimension>::MultiplyWithVector(const std::vector<TInputValue> & vect) const
6464
{
6565
unsigned int rows = this->Rows();
6666
unsigned int cols = this->Cols();
@@ -76,10 +76,10 @@ RieszRotationMatrix<T, VImageDimension>::MultiplyWithVector(const std::vector<TI
7676
return resultVector;
7777
}
7878

79-
template <typename T, unsigned int VImageDimension>
79+
template <unsigned int VImageDimension>
8080
template <typename TInputValue>
8181
itk::VariableSizeMatrix<TInputValue>
82-
RieszRotationMatrix<T, VImageDimension>::MultiplyWithColumnMatrix(
82+
RieszRotationMatrix<VImageDimension>::MultiplyWithColumnMatrix(
8383
const itk::VariableSizeMatrix<TInputValue> & inputColumn) const
8484
{
8585
unsigned int rows = this->Rows();
@@ -100,10 +100,10 @@ RieszRotationMatrix<T, VImageDimension>::MultiplyWithColumnMatrix(
100100
return columnMatrix;
101101
}
102102

103-
template <typename T, unsigned int VImageDimension>
103+
template <unsigned int VImageDimension>
104104
template <typename TImage>
105105
std::vector<typename TImage::Pointer>
106-
RieszRotationMatrix<T, VImageDimension>::MultiplyWithVectorOfImages(
106+
RieszRotationMatrix<VImageDimension>::MultiplyWithVectorOfImages(
107107
const std::vector<typename TImage::Pointer> & vect) const
108108
{
109109
unsigned int rows = this->Rows();
@@ -148,9 +148,9 @@ RieszRotationMatrix<T, VImageDimension>::MultiplyWithVectorOfImages(
148148
return result;
149149
}
150150

151-
template <typename T, unsigned int VImageDimension>
152-
typename RieszRotationMatrix<T, VImageDimension>::IndicesMatrix
153-
RieszRotationMatrix<T, VImageDimension>::GenerateIndicesMatrix()
151+
template <unsigned int VImageDimension>
152+
typename RieszRotationMatrix<VImageDimension>::IndicesMatrix
153+
RieszRotationMatrix<VImageDimension>::GenerateIndicesMatrix()
154154
{
155155
using LocalIndicesArrayType = std::vector<unsigned int>;
156156
using LocalIndicesVector = std::vector<LocalIndicesArrayType>;
@@ -185,9 +185,9 @@ RieszRotationMatrix<T, VImageDimension>::GenerateIndicesMatrix()
185185
return allIndicesPairs;
186186
}
187187

188-
template <typename T, unsigned int VImageDimension>
189-
const typename RieszRotationMatrix<T, VImageDimension>::InternalMatrixType &
190-
RieszRotationMatrix<T, VImageDimension>::ComputeSteerableMatrix()
188+
template <unsigned int VImageDimension>
189+
const typename RieszRotationMatrix<VImageDimension>::InternalMatrixType &
190+
RieszRotationMatrix<VImageDimension>::ComputeSteerableMatrix()
191191
{
192192
// precondition
193193
if (this->m_Order == 0)
@@ -198,7 +198,13 @@ RieszRotationMatrix<T, VImageDimension>::ComputeSteerableMatrix()
198198
InternalMatrixType & S = this->GetVnlMatrix();
199199
if (this->m_Order == 1)
200200
{
201-
S = this->GetSpatialRotationMatrix().GetVnlMatrix().as_matrix();
201+
for (unsigned int i = 0; i < this->m_Components; ++i)
202+
{
203+
for (unsigned int j = 0; j < this->m_Components; ++j)
204+
{
205+
S.put(i, j, static_cast<ValueType>(m_SpatialRotationMatrix.GetVnlMatrix().get(i, j)));
206+
}
207+
}
202208
return this->GetVnlMatrix();
203209
}
204210

@@ -291,19 +297,19 @@ RieszRotationMatrix<T, VImageDimension>::ComputeSteerableMatrix()
291297
// matrix R = [r1,...,rd], where r_i are columns of the rotation matrix.
292298
// we sum and normalize them.
293299
S[i][j] = 0;
294-
long double result = 0;
295-
long nFactorial = 1;
296-
long mFactorial = 1;
300+
ResultValueType result = 0;
301+
long nFactorial = 1;
302+
long mFactorial = 1;
297303
for (unsigned int dim = 0; dim < VImageDimension; ++dim)
298304
{
299305
nFactorial *= itk::utils::Factorial(n[dim]);
300306
mFactorial *= itk::utils::Factorial(m[dim]);
301307
}
302-
auto nFactorialReal = static_cast<double>(nFactorial);
308+
auto nFactorialReal = static_cast<RealType>(nFactorial);
303309
for (auto & kValidIndex : kValidIndices)
304310
{
305-
long double rotationFactor = 1;
306-
long kFactorialMultiplication = 1;
311+
ValueType rotationFactor = 1;
312+
long kFactorialMultiplication = 1;
307313
// There are always VImageDimension indices. (k1,k2,...,kd)
308314
for (unsigned int kIndex = 0; kIndex < VImageDimension; ++kIndex)
309315
{
@@ -335,12 +341,19 @@ RieszRotationMatrix<T, VImageDimension>::ComputeSteerableMatrix()
335341
result *= sqrt(mFactorial / nFactorialReal);
336342
S[i][j] = static_cast<ValueType>(result);
337343
// Try to fix close to zero float errors
338-
if (itk::Math::FloatAlmostEqual(S[i][j],
339-
static_cast<ValueType>(0),
344+
if (itk::Math::FloatAlmostEqual(S[i][j].real(),
345+
static_cast<typename ValueType::value_type>(0),
346+
4, // default maxULPS from Math::AlmostFloatEqual
347+
this->m_MaxAbsoluteDifferenceCloseToZero))
348+
{
349+
S[i][j].real(0);
350+
}
351+
if (itk::Math::FloatAlmostEqual(S[i][j].imag(),
352+
static_cast<typename ValueType::value_type>(0),
340353
4, // default maxULPS from Math::AlmostFloatEqual
341354
this->m_MaxAbsoluteDifferenceCloseToZero))
342355
{
343-
S[i][j] = 0;
356+
S[i][j].imag(0);
344357
}
345358

346359
if (this->GetDebug())

test/itkRieszRotationMatrixTest.cxx

Lines changed: 49 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,26 +16,27 @@
1616
*
1717
*=========================================================================*/
1818

19-
#include "itkRieszRotationMatrix.h"
2019
#include <complex>
20+
#include "itkComplexToRealImageFilter.h"
2121
#include "itkImage.h"
2222
#include "itkImageDuplicator.h"
2323
#include "itkMath.h"
24+
#include "itkRieszRotationMatrix.h"
2425
#include "itkTestingComparisonImageFilter.h"
2526

2627
template <typename TImage>
2728
int
28-
compareImagesAndReport(typename TImage::Pointer validImage, typename TImage::Pointer testImage)
29+
compareRealImagesAndReport(typename TImage::Pointer validImage, typename TImage::Pointer testImage)
2930
{
3031
using ImageType = TImage;
3132
using ComparisonType = itk::Testing::ComparisonImageFilter<ImageType, ImageType>;
3233
auto diff = ComparisonType::New();
3334
diff->SetValidInput(validImage);
3435
diff->SetTestInput(testImage);
3536
diff->UpdateLargestPossibleRegion();
36-
bool differenceFailed = false;
37-
const double averageIntensityDifference = diff->GetTotalDifference();
38-
const unsigned long numberOfPixelsWithDifferences = diff->GetNumberOfPixelsWithDifferences();
37+
bool differenceFailed = false;
38+
const typename TImage::PixelType averageIntensityDifference = diff->GetTotalDifference();
39+
const unsigned long numberOfPixelsWithDifferences = diff->GetNumberOfPixelsWithDifferences();
3940
if (averageIntensityDifference > 0.0)
4041
{
4142
if (static_cast<int>(numberOfPixelsWithDifferences) > 0)
@@ -67,8 +68,8 @@ runRieszRotationMatrixInterfaceWithRieszFrequencyFilterBankGeneratorTest()
6768
{
6869
constexpr unsigned int Dimension = 2;
6970
// Create a rotation matrix
70-
using ValueType = double;
71-
using SteerableMatrix = itk::RieszRotationMatrix<ValueType, Dimension>;
71+
using ValueType = std::complex<double>;
72+
using SteerableMatrix = itk::RieszRotationMatrix<Dimension>;
7273
using SpatialRotationMatrix = SteerableMatrix::SpatialRotationMatrixType;
7374
SpatialRotationMatrix R;
7475
double angle = itk::Math::pi_over_2;
@@ -94,9 +95,12 @@ runRieszRotationMatrixInterfaceWithRieszFrequencyFilterBankGeneratorTest()
9495
// emulating the output of RieszFrequencyFilterBankGenerator
9596
// of Order = 2, Dimension = 2
9697

97-
using PixelType = double;
98+
using PixelType = std::complex<double>;
99+
using RealPixelType = double;
98100
using ImageType = itk::Image<PixelType, Dimension>;
99101
using ImagePointer = ImageType::Pointer;
102+
using RealImageType = itk::Image<RealPixelType, Dimension>;
103+
using RealImagePointer = RealImageType::Pointer;
100104
ImagePointer image = ImageType::New();
101105
ImageType::IndexType start;
102106
start.Fill(0);
@@ -137,7 +141,7 @@ runRieszRotationMatrixInterfaceWithRieszFrequencyFilterBankGeneratorTest()
137141
expectedMultiplyResult[2] = 1;
138142
for (unsigned int i = 0; i < M; ++i)
139143
{
140-
if (!itk::Math::FloatAlmostEqual(vectorRotated[i], expectedMultiplyResult[i]))
144+
if (!itk::Math::FloatAlmostEqual(vectorRotated[i].real(), expectedMultiplyResult[i].real()))
141145
{
142146
std::cout << "vectorRotated not equal!: ";
143147
std::cout << vectorRotated[i] << " != " << expectedMultiplyResult[i] << std::endl;
@@ -152,7 +156,7 @@ runRieszRotationMatrixInterfaceWithRieszFrequencyFilterBankGeneratorTest()
152156
auto columnMatrixRotated = S.MultiplyWithColumnMatrix(inputColumnMatrix);
153157
for (unsigned int i = 0; i < M; ++i)
154158
{
155-
if (!itk::Math::FloatAlmostEqual(columnMatrixRotated.GetVnlMatrix()(i, 0), expectedMultiplyResult[i]))
159+
if (!itk::Math::FloatAlmostEqual(columnMatrixRotated.GetVnlMatrix()(i, 0).real(), expectedMultiplyResult[i].real()))
156160
{
157161
std::cout << "columnMatrixRotated not Equal!: ";
158162
std::cout << columnMatrixRotated.GetVnlMatrix()(i, 0) << " != " << expectedMultiplyResult[i] << std::endl;
@@ -164,22 +168,50 @@ runRieszRotationMatrixInterfaceWithRieszFrequencyFilterBankGeneratorTest()
164168
std::cout << "Size: ";
165169
std::cout << imagesMultipliedByRieszRotationMatrix.size() << std::endl;
166170

171+
// Convert input images to real to perform comparison.
172+
std::vector<RealImagePointer> realImages(M);
173+
{
174+
for (unsigned int i = 0; i < M; ++i)
175+
{
176+
using ComplexToRealFilter = itk::ComplexToRealImageFilter<ImageType, RealImageType>;
177+
auto complexToRealFilter = ComplexToRealFilter::New();
178+
complexToRealFilter->SetInput(images[i]);
179+
complexToRealFilter->Update();
180+
realImages[i] = complexToRealFilter->GetOutput();
181+
}
182+
}
183+
std::vector<RealImagePointer> realImagesMultipliedByRieszRotationMatrix(M);
184+
{
185+
for (unsigned int i = 0; i < M; ++i)
186+
{
187+
using ComplexToRealFilter = itk::ComplexToRealImageFilter<ImageType, RealImageType>;
188+
auto complexToRealFilter = ComplexToRealFilter::New();
189+
complexToRealFilter->SetInput(imagesMultipliedByRieszRotationMatrix[i]);
190+
complexToRealFilter->Update();
191+
realImagesMultipliedByRieszRotationMatrix[i] = complexToRealFilter->GetOutput();
192+
}
193+
}
194+
167195
// First
168196
// g(0, 0) = 0; g(0, 1) = 0; g(0, 2) = 1;
169197
// result = 1.0 * images[2]
170-
int firstComponentStatus = compareImagesAndReport<ImageType>(images[2], imagesMultipliedByRieszRotationMatrix[0]);
198+
int firstComponentStatus =
199+
compareRealImagesAndReport<RealImageType>(realImages[2], realImagesMultipliedByRieszRotationMatrix[0]);
171200
// Second
172201
// g(1, 0) = 0; g(1, 1) = -1; g(1, 2) = 0;
173202
// expectedresult = -1.0 * images[1]
174-
typename MultiplyImageFilterType::Pointer multiplyImageFilterInvert = MultiplyImageFilterType::New();
175-
multiplyImageFilterInvert->SetInput(imagesMultipliedByRieszRotationMatrix[1]);
203+
using MultiplyRealImageFilterType = itk::MultiplyImageFilter<RealImageType>;
204+
typename MultiplyRealImageFilterType::Pointer multiplyImageFilterInvert = MultiplyRealImageFilterType::New();
205+
multiplyImageFilterInvert->SetInput(realImagesMultipliedByRieszRotationMatrix[1]);
176206
multiplyImageFilterInvert->SetConstant(-1.0);
177207
multiplyImageFilterInvert->Update();
178-
int secondComponentStatus = compareImagesAndReport<ImageType>(images[1], multiplyImageFilterInvert->GetOutput());
208+
int secondComponentStatus =
209+
compareRealImagesAndReport<RealImageType>(realImages[1], multiplyImageFilterInvert->GetOutput());
179210
// Third
180211
// g(2, 0) = 1; g(2, 1) = 0; g(2, 2) = 0;
181212
// result = 1.0 * images[0]
182-
int thirdComponentStatus = compareImagesAndReport<ImageType>(images[0], imagesMultipliedByRieszRotationMatrix[2]);
213+
int thirdComponentStatus =
214+
compareRealImagesAndReport<RealImageType>(realImages[0], realImagesMultipliedByRieszRotationMatrix[2]);
183215

184216
if (!multiplyWithSomethingPassed)
185217
{
@@ -195,8 +227,8 @@ runRieszRotationMatrixTest()
195227
bool testPassed = true;
196228
const unsigned int Dimension = VDimension;
197229

198-
using ValueType = double;
199-
using SteerableMatrix = itk::RieszRotationMatrix<ValueType, Dimension>;
230+
using ValueType = std::complex<double>;
231+
using SteerableMatrix = itk::RieszRotationMatrix<Dimension>;
200232
using SpatialRotationMatrix = typename SteerableMatrix::SpatialRotationMatrixType;
201233

202234
// Define a spatial rotation matrix.

wrapping/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
itk_wrap_module(IsotropicWavelets)
22
set(WRAPPER_SUBMODULE_ORDER
3+
itkMatrixComplex
4+
itkVariableSizeMatrixComplex
35
itkStructureTensor
46
itkMonogenicSignalFrequencyImageFilter
57
itkVectorInverseFFTImageFilter

0 commit comments

Comments
 (0)