Skip to content

Commit ff807b9

Browse files
authored
Merge pull request #116 from phcerdan/add_multiply_methods_to_riesz_rotation_matrix
ENH: Add multiplyWith methods to RieszRotationMatrix
2 parents 06ae631 + 889c914 commit ff807b9

File tree

3 files changed

+76
-0
lines changed

3 files changed

+76
-0
lines changed

include/itkRieszRotationMatrix.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,12 @@ class RieszRotationMatrix:
6464
template <typename TImage>
6565
std::vector< typename TImage::Pointer > MultiplyWithVectorOfImages(const std::vector< typename TImage::Pointer > & vect) const;
6666

67+
template <typename TInputValue>
68+
std::vector< TInputValue > MultiplyWithVector(const std::vector< TInputValue > & vect) const;
69+
70+
template <typename TInputValue>
71+
VariableSizeMatrix< TInputValue > MultiplyWithColumnMatrix(const VariableSizeMatrix< TInputValue > & vect) const;
72+
6773
/**
6874
* Multi-index notation
6975
* S[n = (n1,...,nd)][m = (m1,...,md)]

include/itkRieszRotationMatrix.hxx

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,50 @@ RieszRotationMatrix< T, VImageDimension >
6464
this->ComputeSteerableMatrix();
6565
}
6666

67+
template< typename T, unsigned int VImageDimension >
68+
template <typename TInputValue>
69+
std::vector<TInputValue>
70+
RieszRotationMatrix< T, VImageDimension >
71+
::MultiplyWithVector(const std::vector<TInputValue> & vect) const
72+
{
73+
unsigned int rows = this->Rows();
74+
unsigned int cols = this->Cols();
75+
auto resultVector = std::vector<TInputValue>(rows, NumericTraits<TInputValue>::ZeroValue());
76+
77+
for ( unsigned int r = 0; r < rows; r++ )
78+
{
79+
for ( unsigned int c = 0; c < cols; c++ )
80+
{
81+
resultVector[r] += this->GetVnlMatrix()(r, c) * vect[c];
82+
}
83+
}
84+
return resultVector;
85+
}
86+
87+
template< typename T, unsigned int VImageDimension >
88+
template <typename TInputValue>
89+
itk::VariableSizeMatrix<TInputValue>
90+
RieszRotationMatrix< T, VImageDimension >
91+
::MultiplyWithColumnMatrix(const itk::VariableSizeMatrix<TInputValue> & inputColumn) const
92+
{
93+
unsigned int rows = this->Rows();
94+
unsigned int cols = this->Cols();
95+
using ColumnMatrix = VariableSizeMatrix<TInputValue>;
96+
ColumnMatrix columnMatrix(rows, 1);
97+
columnMatrix.Fill(NumericTraits<TInputValue>::ZeroValue());
98+
99+
for ( unsigned int r = 0; r < rows; r++ )
100+
{
101+
TInputValue sum = NumericTraits<TInputValue>::ZeroValue();
102+
for ( unsigned int c = 0; c < cols; c++ )
103+
{
104+
sum += this->GetVnlMatrix()(r, c) * inputColumn.GetVnlMatrix()(c, 0);
105+
}
106+
columnMatrix.GetVnlMatrix()(r, 0) = sum;
107+
}
108+
return columnMatrix;
109+
}
110+
67111
template< typename T, unsigned int VImageDimension >
68112
template <typename TImage>
69113
std::vector< typename TImage::Pointer >

test/itkRieszRotationMatrixTest.cxx

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,28 @@ runRieszRotationMatrixInterfaceWithRieszFrequencyFilterBankGeneratorTest()
129129
std::cout << images[i]->GetLargestPossibleRegion() << std::endl;
130130
}
131131

132+
std::vector<PixelType> inputVector(M, 0);
133+
inputVector[0] = 1;
134+
inputVector[1] = 2;
135+
inputVector[2] = 3;
136+
auto vectorRotated = S.MultiplyWithVector(inputVector);
137+
138+
itk::VariableSizeMatrix<PixelType> inputColumnMatrix(M, 1);
139+
inputColumnMatrix.GetVnlMatrix()(0,0) = 1;
140+
inputColumnMatrix.GetVnlMatrix()(1,0) = 2;
141+
inputColumnMatrix.GetVnlMatrix()(2,0) = 3;
142+
auto columnMatrixRotated = S.MultiplyWithColumnMatrix(inputColumnMatrix);
143+
bool multiplyWithSomethingPassed = true;
144+
for (unsigned int i = 0; i < M; ++i)
145+
{
146+
if(!itk::Math::FloatAlmostEqual(inputVector[i], columnMatrixRotated.GetVnlMatrix()(i, 0)))
147+
{
148+
std::cout << "Not Equal!: ";
149+
std::cout << inputVector[i] << " != " << columnMatrixRotated.GetVnlMatrix()(i,0) << std::endl;
150+
multiplyWithSomethingPassed = false;
151+
}
152+
}
153+
132154
auto imagesMultipliedByRieszRotationMatrix =
133155
S.MultiplyWithVectorOfImages<ImageType>(images);
134156
std::cout << "Size: ";
@@ -155,6 +177,10 @@ runRieszRotationMatrixInterfaceWithRieszFrequencyFilterBankGeneratorTest()
155177
int thirdComponentStatus = compareImagesAndReport<ImageType>(
156178
images[0], imagesMultipliedByRieszRotationMatrix[2]);
157179

180+
if(!multiplyWithSomethingPassed)
181+
{
182+
return EXIT_FAILURE;
183+
}
158184
return firstComponentStatus && secondComponentStatus && thirdComponentStatus;
159185

160186
}

0 commit comments

Comments
 (0)