Skip to content

Commit 5f00ce8

Browse files
Configure batchesPerIteration relative to the nodeCount
Using a ratio instead of an absolute number Co-authored-by: Jacob Sznajdman <breakanalysis@gmail.com>
1 parent 33220ee commit 5f00ce8

File tree

5 files changed

+43
-32
lines changed

5 files changed

+43
-32
lines changed

algo/src/main/java/org/neo4j/gds/embeddings/graphsage/GraphSageModelTrainer.java

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -130,21 +130,23 @@ public ModelTrainResult train(Graph graph, HugeObjectArray<double[]> features) {
130130
config.batchSize(),
131131
batch -> createBatchTask(graph, features, layers, weights, batch)
132132
);
133+
var random = new Random(randomSeed);
134+
Supplier<List<BatchTask>> batchTaskSampler = () -> IntStream.range(0, config.batchesPerIteration(graph.nodeCount()))
135+
.mapToObj(__ -> batchTasks.get(random.nextInt(batchTasks.size())))
136+
.collect(Collectors.toList());
133137

134138
progressTracker.endSubTask("Prepare batches");
135139

140+
progressTracker.beginSubTask("Train model");
141+
136142
boolean converged = false;
137143
var iterationLossesPerEpoch = new ArrayList<List<Double>>();
138-
139144
var prevEpochLoss = Double.NaN;
140-
var random = new Random(randomSeed);
141-
142-
progressTracker.beginSubTask("Train model");
143-
144145
int epochs = config.epochs();
146+
145147
for (int epoch = 1; epoch <= epochs && !converged; epoch++) {
146148
progressTracker.beginSubTask("Epoch");
147-
var epochResult = trainEpoch(() -> batchTasks.get(random.nextInt(batchTasks.size())), weights, prevEpochLoss);
149+
var epochResult = trainEpoch(batchTaskSampler, weights, prevEpochLoss);
148150
List<Double> epochLosses = epochResult.losses();
149151
iterationLossesPerEpoch.add(epochLosses);
150152
prevEpochLoss = epochLosses.get(epochLosses.size() - 1);
@@ -188,7 +190,11 @@ private BatchTask createBatchTask(
188190
return new BatchTask(lossFunction, weights, progressTracker);
189191
}
190192

191-
private EpochResult trainEpoch(Supplier<BatchTask> batchTaskSupplier, List<Weights<? extends Tensor<?>>> weights, double prevEpochLoss) {
193+
private EpochResult trainEpoch(
194+
Supplier<List<BatchTask>> sampledBatchTaskSupplier,
195+
List<Weights<? extends Tensor<?>>> weights,
196+
double prevEpochLoss
197+
) {
192198
var updater = new AdamOptimizer(weights, config.learningRate());
193199

194200
int iteration = 1;
@@ -200,14 +206,11 @@ private EpochResult trainEpoch(Supplier<BatchTask> batchTaskSupplier, List<Weigh
200206
for (; iteration <= maxIterations; iteration++) {
201207
progressTracker.beginSubTask("Iteration");
202208

203-
var batchTasks = IntStream
204-
.range(0, config.batchesPerIteration())
205-
.mapToObj(__ -> batchTaskSupplier.get())
206-
.collect(Collectors.toList());
209+
var sampledBatchTasks = sampledBatchTaskSupplier.get();
207210

208211
// run forward + maybe backward for each Batch
209-
ParallelUtil.runWithConcurrency(config.concurrency(), batchTasks, executor);
210-
var avgLoss = batchTasks.stream().mapToDouble(BatchTask::loss).average().orElseThrow();
212+
ParallelUtil.runWithConcurrency(config.concurrency(), sampledBatchTasks, executor);
213+
var avgLoss = sampledBatchTasks.stream().mapToDouble(BatchTask::loss).average().orElseThrow();
211214
iterationLosses.add(avgLoss);
212215
progressTracker.logMessage(formatWithLocale("LOSS: %.10f", avgLoss));
213216

@@ -219,7 +222,7 @@ private EpochResult trainEpoch(Supplier<BatchTask> batchTaskSupplier, List<Weigh
219222

220223
prevLoss = avgLoss;
221224

222-
var batchedGradients = batchTasks
225+
var batchedGradients = sampledBatchTasks
223226
.stream()
224227
.map(BatchTask::weightGradients)
225228
.collect(Collectors.toList());
@@ -246,7 +249,7 @@ static class BatchTask implements Runnable {
246249
private final List<Weights<? extends Tensor<?>>> weightVariables;
247250
private List<? extends Tensor<?>> weightGradients;
248251
private final ProgressTracker progressTracker;
249-
private double prevLoss;
252+
private double loss;
250253

251254
BatchTask(
252255
Variable<Scalar> lossFunction,
@@ -261,9 +264,7 @@ static class BatchTask implements Runnable {
261264
@Override
262265
public void run() {
263266
var localCtx = new ComputationContext();
264-
var loss = localCtx.forward(lossFunction).value();
265-
266-
prevLoss = loss;
267+
loss = localCtx.forward(lossFunction).value();
267268

268269
localCtx.backward(lossFunction);
269270
weightGradients = weightVariables.stream().map(localCtx::gradient).collect(Collectors.toList());
@@ -272,7 +273,7 @@ public void run() {
272273
}
273274

274275
public double loss() {
275-
return prevLoss;
276+
return loss;
276277
}
277278

278279
List<? extends Tensor<?>> weightGradients() {

algo/src/main/java/org/neo4j/gds/embeddings/graphsage/algo/GraphSageTrainConfig.java

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -120,13 +120,16 @@ default int maxIterations() {
120120
return 10;
121121
}
122122

123-
@Configuration.Key("batchesPerIteration")
124-
Optional<Integer> maybeBatchesPerIteration();
123+
@Configuration.Key("batchSamplingRatio")
124+
@Configuration.DoubleRange(min = 0, max = 1, minInclusive = false)
125+
Optional<Double> maybeBatchSamplingRatio();
125126

126127
@Configuration.Ignore
127128
@Value.Derived
128-
default int batchesPerIteration() {
129-
return maybeBatchesPerIteration().orElse(concurrency());
129+
default int batchesPerIteration(long nodeCount) {
130+
var samplingRatio = maybeBatchSamplingRatio().orElse(Math.min(1.0, batchSize() * concurrency() / (double) nodeCount));
131+
var totalNumberOfBatches = Math.ceil(nodeCount / (double) batchSize());
132+
return (int) Math.ceil(samplingRatio * totalNumberOfBatches);
130133
}
131134

132135
@Value.Default

algo/src/test/java/org/neo4j/gds/embeddings/graphsage/GraphSageModelTrainerTest.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -310,13 +310,13 @@ void testConvergence() {
310310

311311
@ParameterizedTest
312312
@CsvSource({
313-
"1, true, 8",
314-
"5, false, 10"
313+
"0.01, true, 8",
314+
"1.0, false, 10"
315315
})
316-
void batchesPerIteration(int batchesPerIteration, boolean expectedConvergence, int expectedRanEpochs) {
316+
void batchesPerIteration(double batchSamplingRatio, boolean expectedConvergence, int expectedRanEpochs) {
317317
var trainer = new GraphSageModelTrainer(
318318
configBuilder.modelName("convergingModel:)")
319-
.maybeBatchesPerIteration(batchesPerIteration)
319+
.maybeBatchSamplingRatio(batchSamplingRatio)
320320
.embeddingDimension(12)
321321
.aggregator(AggregatorType.POOL)
322322
.epochs(10)

algo/src/test/java/org/neo4j/gds/embeddings/graphsage/algo/GraphSageConfigTest.java

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import org.junit.jupiter.api.Test;
2323
import org.junit.jupiter.params.ParameterizedTest;
2424
import org.junit.jupiter.params.provider.Arguments;
25+
import org.junit.jupiter.params.provider.CsvSource;
2526
import org.junit.jupiter.params.provider.MethodSource;
2627
import org.junit.jupiter.params.provider.ValueSource;
2728
import org.neo4j.gds.core.CypherMapWrapper;
@@ -51,15 +52,21 @@ private static Stream<Arguments> invalidAggregator() {
5152
);
5253
}
5354

54-
@Test
55-
void specifyBatchesPerIteration() {
55+
@ParameterizedTest
56+
@CsvSource({
57+
"0.5, 100, 1",
58+
"0.2, 1000, 2",
59+
"0.99, 1000, 10",
60+
})
61+
void specifyBatchesPerIteration(double samplingRatio, long nodeCount, int expectedSampledBatches) {
5662
var mapWrapper = CypherMapWrapper.create(Map.of(
5763
"modelName", "foo",
5864
"featureProperties", List.of("a"),
59-
"batchesPerIteration", 42
65+
"batchSamplingRatio", samplingRatio,
66+
"batchSize", 100
6067
));
6168

62-
assertThat(GraphSageTrainConfig.of("user", mapWrapper).batchesPerIteration()).isEqualTo(42);
69+
assertThat(GraphSageTrainConfig.of("user", mapWrapper).batchesPerIteration(nodeCount)).isEqualTo(expectedSampledBatches);
6370
}
6471

6572
@Test

doc/asciidoc/machine-learning/node-embeddings/graph-sage/specific-train-configuration.adoc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
| learningRate | Float | 0.1 | yes | The learning rate determines the step size at each iteration while moving toward a minimum of a loss function.
1515
| epochs | Integer | 1 | yes | Number of times to traverse the graph.
1616
| <<common-configuration-max-iterations,maxIterations>> | Integer | 10 | yes | Maximum number of iterations per epoch. Each iteration the weights are updated.
17-
| <<common-configuration-max-iterations,batchesPerIteration>> | Integer | `concurrency` | yes | Number of batches to consider per weight updates.
17+
| batchSamplingRatio | Float | `concurrency * batchSize / nodeCount` | yes | Sampling ratio of batches to consider per weight updates. By default, each thread evaluates a single batch. The gradients per batch are averaged to update the weights.
1818
| searchDepth | Integer | 5 | yes | Maximum depth of the RandomWalks to sample nearby nodes for the training.
1919
| negativeSampleWeight | Integer | 20 | yes | The weight of the negative samples. Higher values increase the impact of negative samples in the loss.
2020
| <<common-configuration-relationship-weight-property,relationshipWeightProperty>> | String | null | yes | Name of the relationship property to use as weights. If unspecified, the algorithm runs unweighted.

0 commit comments

Comments
 (0)