Skip to content

Commit ec0c7f2

Browse files
committed
adding max iterations threshold
1 parent 44f5a4a commit ec0c7f2

File tree

2 files changed

+60
-7
lines changed

2 files changed

+60
-7
lines changed

src/Algorithm.php

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,21 +30,32 @@ public function registerIterationCallback(callable $callback): void
3030
$this->iterationCallbacks[] = $callback;
3131
}
3232

33-
public function clusterize(PointCollectionInterface $points, int $nbClusters): ClusterCollectionInterface
34-
{
33+
public function clusterize(
34+
PointCollectionInterface $points,
35+
int $nClusters,
36+
?int $maxIter = null
37+
): ClusterCollectionInterface {
38+
$maxIter ??= INF;
39+
40+
if ($maxIter < 1) {
41+
throw new \UnexpectedValueException(
42+
"Invalid maximum number of iterations: {$maxIter}"
43+
);
44+
}
45+
3546
// initialize clusters
36-
$clusters = $this->initScheme->initializeClusters($points, $nbClusters);
47+
$clusters = $this->initScheme->initializeClusters($points, $nClusters);
3748

3849
// iterate until convergence is reached
3950
do {
4051
$this->invokeIterationCallbacks($clusters);
41-
} while ($this->iterate($clusters));
52+
} while ($this->iterate($clusters) && --$maxIter);
4253

4354
// clustering is done.
4455
return $clusters;
4556
}
4657

47-
private function iterate(ClusterCollectionInterface $clusters): bool
58+
protected function iterate(ClusterCollectionInterface $clusters): bool
4859
{
4960
/** @var \SplObjectStorage<ClusterInterface, null> */
5061
$changed = new \SplObjectStorage();
@@ -78,13 +89,13 @@ private function iterate(ClusterCollectionInterface $clusters): bool
7889

7990
private function getClosestCluster(ClusterCollectionInterface $clusters, PointInterface $point): ClusterInterface
8091
{
81-
$min = null;
92+
$min = INF;
8293
$closest = null;
8394

8495
foreach ($clusters as $cluster) {
8596
$distance = $this->getDistanceBetween($point, $cluster->getCentroid());
8697

87-
if (is_null($min) || $distance < $min) {
98+
if ($distance < $min) {
8899
$min = $distance;
89100
$closest = $cluster;
90101
}

tests/Unit/Euclidean/AlgorithmTest.php

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@
66
use Kmeans\Euclidean\Point;
77
use Kmeans\Euclidean\Space;
88
use Kmeans\Interfaces\AlgorithmInterface;
9+
use Kmeans\Interfaces\ClusterCollectionInterface;
910
use Kmeans\Interfaces\InitializationSchemeInterface;
1011
use Kmeans\Interfaces\PointCollectionInterface;
1112
use Kmeans\Interfaces\SpaceInterface;
1213
use Kmeans\Math;
1314
use Kmeans\PointCollection;
15+
use Kmeans\RandomInitialization;
1416
use Tests\Unit\AlgorithmTest as BaseAlgorithmTest;
1517

1618
/**
@@ -24,6 +26,7 @@
2426
* @uses \Kmeans\Euclidean\Space
2527
* @uses \Kmeans\Math
2628
* @uses \Kmeans\PointCollection
29+
* @uses \Kmeans\RandomInitialization
2730
* @phpstan-import-type ClusterizeScenarioData from BaseAlgorithmTest
2831
*/
2932
class AlgorithmTest extends BaseAlgorithmTest
@@ -146,4 +149,43 @@ public function testFindCentroidException(): void
146149
new PointCollection(new \Kmeans\Gps\Space(), [])
147150
);
148151
}
152+
153+
public function testMaxIterations(): void
154+
{
155+
$algorithm = new class (new RandomInitialization()) extends Algorithm
156+
{
157+
protected function iterate(ClusterCollectionInterface $clusters): bool
158+
{
159+
// do nothing and iterate indefinitely
160+
return true;
161+
}
162+
};
163+
164+
$iterations = 0;
165+
$algorithm->registerIterationCallback(function () use (&$iterations) {
166+
$iterations++;
167+
});
168+
169+
$space = new Space(1);
170+
$points = new PointCollection(
171+
$space,
172+
array_map([$space, 'makePoint'], [[1],[2],[3]])
173+
);
174+
175+
$algorithm->clusterize($points, 3, 300);
176+
177+
$this->assertEquals(
178+
300,
179+
$iterations
180+
);
181+
}
182+
183+
public function testMaxIterationsException(): void
184+
{
185+
$this->expectException(\UnexpectedValueException::class);
186+
$this->expectExceptionMessageMatches('/^Invalid maximum number of iterations/');
187+
188+
$algorithm = new Algorithm(new RandomInitialization());
189+
$algorithm->clusterize(new PointCollection(new Space(1), []), 3, 0);
190+
}
149191
}

0 commit comments

Comments
 (0)