Skip to content

Commit cedd259

Browse files
committed
Add batchesPerIteration
1 parent f59afea commit cedd259

File tree

5 files changed

+71
-34
lines changed

5 files changed

+71
-34
lines changed

algo/src/main/java/org/neo4j/gds/embeddings/graphsage/GraphSageModelTrainer.java

Lines changed: 13 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -70,19 +70,12 @@
7070
public class GraphSageModelTrainer {
7171
private final long randomSeed;
7272
private final boolean useWeights;
73-
private final double learningRate;
74-
private final double tolerance;
75-
private final int negativeSampleWeight;
76-
private final int concurrency;
77-
private final int epochs;
78-
private final int maxIterations;
79-
private final int maxSearchDepth;
8073
private final Function<Graph, List<LayerConfig>> layerConfigsFunction;
8174
private final FeatureFunction featureFunction;
8275
private final Collection<Weights<Matrix>> labelProjectionWeights;
8376
private final ExecutorService executor;
8477
private final ProgressTracker progressTracker;
85-
private final int batchSize;
78+
private final GraphSageTrainConfig config;
8679

8780
public GraphSageModelTrainer(GraphSageTrainConfig config, ExecutorService executor, ProgressTracker progressTracker) {
8881
this(config, executor, progressTracker, new SingleLabelFeatureFunction(), Collections.emptyList());
@@ -96,14 +89,7 @@ public GraphSageModelTrainer(
9689
Collection<Weights<Matrix>> labelProjectionWeights
9790
) {
9891
this.layerConfigsFunction = graph -> config.layerConfigs(firstLayerColumns(config, graph));
99-
this.batchSize = config.batchSize();
100-
this.learningRate = config.learningRate();
101-
this.tolerance = config.tolerance();
102-
this.negativeSampleWeight = config.negativeSampleWeight();
103-
this.concurrency = config.concurrency();
104-
this.epochs = config.epochs();
105-
this.maxIterations = config.maxIterations();
106-
this.maxSearchDepth = config.searchDepth();
92+
this.config = config;
10793
this.featureFunction = featureFunction;
10894
this.labelProjectionWeights = labelProjectionWeights;
10995
this.executor = executor;
@@ -141,7 +127,7 @@ public ModelTrainResult train(Graph graph, HugeObjectArray<double[]> features) {
141127

142128
var batchTasks = PartitionUtils.rangePartitionWithBatchSize(
143129
graph.nodeCount(),
144-
batchSize,
130+
config.batchSize(),
145131
batch -> createBatchTask(graph, features, layers, weights, batch)
146132
);
147133

@@ -155,6 +141,7 @@ public ModelTrainResult train(Graph graph, HugeObjectArray<double[]> features) {
155141

156142
progressTracker.beginSubTask("Train model");
157143

144+
int epochs = config.epochs();
158145
for (int epoch = 1; epoch <= epochs && !converged; epoch++) {
159146
progressTracker.beginSubTask("Epoch");
160147
var epochResult = trainEpoch(() -> batchTasks.get(random.nextInt(batchTasks.size())), weights, prevEpochLoss);
@@ -195,36 +182,36 @@ private BatchTask createBatchTask(
195182
useWeights ? localGraph::relationshipProperty : UNWEIGHTED,
196183
embeddingVariable,
197184
totalBatch,
198-
negativeSampleWeight
185+
config.negativeSampleWeight()
199186
);
200187

201-
return new BatchTask(lossFunction, weights, tolerance, progressTracker);
188+
return new BatchTask(lossFunction, weights, progressTracker);
202189
}
203190

204191
private EpochResult trainEpoch(Supplier<BatchTask> batchTaskSupplier, List<Weights<? extends Tensor<?>>> weights, double prevEpochLoss) {
205-
var updater = new AdamOptimizer(weights, learningRate);
192+
var updater = new AdamOptimizer(weights, config.learningRate());
206193

207194
int iteration = 1;
208195
var iterationLosses = new ArrayList<Double>();
209196
double prevLoss = prevEpochLoss;
210197
var converged = false;
211198

212-
for (;iteration <= maxIterations; iteration++) {
199+
int maxIterations = config.maxIterations();
200+
for (; iteration <= maxIterations; iteration++) {
213201
progressTracker.beginSubTask("Iteration");
214202

215-
// TODO let the user configer the number of batches per iteration
216203
var batchTasks = IntStream
217-
.range(0, concurrency)
204+
.range(0, config.batchesPerIteration())
218205
.mapToObj(__ -> batchTaskSupplier.get())
219206
.collect(Collectors.toList());
220207

221208
// run forward + maybe backward for each Batch
222-
ParallelUtil.runWithConcurrency(concurrency, batchTasks, executor);
209+
ParallelUtil.runWithConcurrency(config.concurrency(), batchTasks, executor);
223210
var avgLoss = batchTasks.stream().mapToDouble(BatchTask::loss).average().orElseThrow();
224211
iterationLosses.add(avgLoss);
225212
progressTracker.logMessage(formatWithLocale("LOSS: %.10f", avgLoss));
226213

227-
if (Math.abs(prevLoss - avgLoss) < tolerance) {
214+
if (Math.abs(prevLoss - avgLoss) < config.tolerance()) {
228215
converged = true;
229216
progressTracker.endSubTask("Iteration");
230217
break;
@@ -258,19 +245,16 @@ static class BatchTask implements Runnable {
258245
private final Variable<Scalar> lossFunction;
259246
private final List<Weights<? extends Tensor<?>>> weightVariables;
260247
private List<? extends Tensor<?>> weightGradients;
261-
private final double tolerance;
262248
private final ProgressTracker progressTracker;
263249
private double prevLoss;
264250

265251
BatchTask(
266252
Variable<Scalar> lossFunction,
267253
List<Weights<? extends Tensor<?>>> weightVariables,
268-
double tolerance,
269254
ProgressTracker progressTracker
270255
) {
271256
this.lossFunction = lossFunction;
272257
this.weightVariables = weightVariables;
273-
this.tolerance = tolerance;
274258
this.progressTracker = progressTracker;
275259
}
276260

@@ -321,7 +305,7 @@ LongStream neighborBatch(Graph graph, Partition batch, long batchLocalSeed) {
321305
// sample a neighbor for each batchNode
322306
batch.consume(nodeId -> {
323307
// randomWalk with at most maxSearchDepth steps and only save last node
324-
int searchDepth = localRandom.nextInt(maxSearchDepth) + 1;
308+
int searchDepth = localRandom.nextInt(config.searchDepth()) + 1;
325309
AtomicLong currentNode = new AtomicLong(nodeId);
326310
while (searchDepth > 0) {
327311
NeighborhoodSampler neighborhoodSampler = new NeighborhoodSampler(currentNode.get() + searchDepth);

algo/src/main/java/org/neo4j/gds/embeddings/graphsage/algo/GraphSageTrainConfig.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,15 @@ default int maxIterations() {
120120
return 10;
121121
}
122122

123+
@Configuration.Key("batchesPerIteration")
124+
Optional<Integer> maybeBatchesPerIteration();
125+
126+
@Configuration.Ignore
127+
@Value.Derived
128+
default int batchesPerIteration() {
129+
return maybeBatchesPerIteration().orElse(concurrency());
130+
}
131+
123132
@Value.Default
124133
default int searchDepth() {
125134
return 5;

algo/src/test/java/org/neo4j/gds/embeddings/graphsage/GraphSageModelTrainerTest.java

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import org.junit.jupiter.api.BeforeEach;
2727
import org.junit.jupiter.api.Test;
2828
import org.junit.jupiter.params.ParameterizedTest;
29+
import org.junit.jupiter.params.provider.CsvSource;
2930
import org.junit.jupiter.params.provider.ValueSource;
3031
import org.neo4j.gds.Orientation;
3132
import org.neo4j.gds.api.Graph;
@@ -34,7 +35,7 @@
3435
import org.neo4j.gds.core.utils.partition.PartitionUtils;
3536
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
3637
import org.neo4j.gds.embeddings.graphsage.algo.GraphSageTrainConfig;
37-
import org.neo4j.gds.embeddings.graphsage.algo.ImmutableGraphSageTrainConfig;
38+
import org.neo4j.gds.embeddings.graphsage.algo.GraphSageTrainConfigImpl;
3839
import org.neo4j.gds.extension.GdlExtension;
3940
import org.neo4j.gds.extension.GdlGraph;
4041
import org.neo4j.gds.extension.Inject;
@@ -77,7 +78,7 @@ class GraphSageModelTrainerTest {
7778
@Inject
7879
private Graph arrayGraph;
7980
private HugeObjectArray<double[]> features;
80-
private ImmutableGraphSageTrainConfig.Builder configBuilder;
81+
private GraphSageTrainConfigImpl.Builder configBuilder;
8182

8283

8384
@BeforeEach
@@ -87,7 +88,8 @@ void setUp() {
8788

8889
Random random = new Random(19L);
8990
LongStream.range(0, nodeCount).forEach(n -> features.set(n, random.doubles(FEATURES_COUNT).toArray()));
90-
configBuilder = ImmutableGraphSageTrainConfig.builder()
91+
configBuilder = GraphSageTrainConfigImpl.builder()
92+
.username("DUMMY")
9193
.featureProperties(Collections.nCopies(FEATURES_COUNT, "dummyProp"))
9294
.embeddingDimension(EMBEDDING_DIMENSION);
9395
}
@@ -202,7 +204,7 @@ void testLosses() {
202204
.embeddingDimension(12)
203205
.epochs(10)
204206
.tolerance(1e-10)
205-
.addSampleSizes(5, 3)
207+
.sampleSizes(List.of(5, 3))
206208
.batchSize(5)
207209
.maxIterations(100)
208210
.randomSeed(42L)
@@ -250,7 +252,7 @@ void testLossesWithPoolAggregator() {
250252
.aggregator(AggregatorType.POOL)
251253
.epochs(10)
252254
.tolerance(1e-10)
253-
.addSampleSizes(5, 3)
255+
.sampleSizes(List.of(5, 3))
254256
.batchSize(5)
255257
.maxIterations(100)
256258
.randomSeed(42L)
@@ -306,6 +308,35 @@ void testConvergence() {
306308
assertThat(trainMetrics.ranIterationsPerEpoch()).containsExactly(2);
307309
}
308310

311+
@ParameterizedTest
312+
@CsvSource({
313+
"1, true, 8",
314+
"5, false, 10"
315+
})
316+
void batchesPerIteration(int batchesPerIteration, boolean expectedConvergence, int expectedRanEpochs) {
317+
var trainer = new GraphSageModelTrainer(
318+
configBuilder.modelName("convergingModel:)")
319+
.maybeBatchesPerIteration(batchesPerIteration)
320+
.embeddingDimension(12)
321+
.aggregator(AggregatorType.POOL)
322+
.epochs(10)
323+
.tolerance(1e-10)
324+
.sampleSizes(List.of(5, 3))
325+
.batchSize(5)
326+
.maxIterations(100)
327+
.randomSeed(42L)
328+
.build(),
329+
Pools.DEFAULT,
330+
ProgressTracker.NULL_TRACKER
331+
);
332+
333+
var trainResult = trainer.train(graph, features);
334+
335+
var trainMetrics = trainResult.metrics();
336+
assertThat(trainMetrics.didConverge()).isEqualTo(expectedConvergence);
337+
assertThat(trainMetrics.ranEpochs()).isEqualTo(expectedRanEpochs);
338+
}
339+
309340
@ParameterizedTest
310341
@ValueSource(longs = {20L, -100L, 30L})
311342
void seededSingleBatch(long seed) {

algo/src/test/java/org/neo4j/gds/embeddings/graphsage/algo/GraphSageConfigTest.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import java.util.Map;
3131
import java.util.stream.Stream;
3232

33+
import static org.assertj.core.api.Assertions.assertThat;
3334
import static org.assertj.core.api.Assertions.assertThatThrownBy;
3435
import static org.junit.jupiter.api.Assertions.assertFalse;
3536
import static org.junit.jupiter.api.Assertions.assertTrue;
@@ -50,6 +51,17 @@ private static Stream<Arguments> invalidAggregator() {
5051
);
5152
}
5253

54+
@Test
55+
void specifyBatchesPerIteration() {
56+
var mapWrapper = CypherMapWrapper.create(Map.of(
57+
"modelName", "foo",
58+
"featureProperties", List.of("a"),
59+
"batchesPerIteration", 42
60+
));
61+
62+
assertThat(GraphSageTrainConfig.of("user", mapWrapper).batchesPerIteration()).isEqualTo(42);
63+
}
64+
5365
@Test
5466
void shouldThrowIfNoPropertiesProvided() {
5567
var mapWrapper = CypherMapWrapper.create(Map.of("modelName", "foo"));

doc/asciidoc/machine-learning/node-embeddings/graph-sage/specific-train-configuration.adoc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
| learningRate | Float | 0.1 | yes | The learning rate determines the step size at each iteration while moving toward a minimum of a loss function.
1515
| epochs | Integer | 1 | yes | Number of times to traverse the graph.
1616
| <<common-configuration-max-iterations,maxIterations>> | Integer | 10 | yes | Maximum number of iterations per epoch. Each iteration the weights are updated.
17+
| <<common-configuration-max-iterations,batchesPerIteration>> | Integer | `concurrency` | yes | Number of batches to consider per weight updates.
1718
| searchDepth | Integer | 5 | yes | Maximum depth of the RandomWalks to sample nearby nodes for the training.
1819
| negativeSampleWeight | Integer | 20 | yes | The weight of the negative samples. Higher values increase the impact of negative samples in the loss.
1920
| <<common-configuration-relationship-weight-property,relationshipWeightProperty>> | String | null | yes | Name of the relationship property to use as weights. If unspecified, the algorithm runs unweighted.

0 commit comments

Comments
 (0)