Skip to content

Commit a49683c

Browse files
committed
During prompt processing, logits are only needed for the last batch, so the matrix multiplication to compute the logits can be skiped and also the attention and FFN on the last layer.
1 parent 3c13567 commit a49683c

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

Llama3.java

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -956,7 +956,7 @@ static void rmsnorm(FloatTensor out, FloatTensor x, FloatBuffer weight, int size
956956
out.mapWithIndexInPlace(0, size, (value, index) -> weight.get(index) * (finalss * x.getFloat(index)));
957957
}
958958

959-
static FloatTensor forward(Llama model, State state, int[] tokens, int position) {
959+
static FloatTensor forward(Llama model, State state, int[] tokens, int position, boolean computeLogits) {
960960
// a few convenience variables
961961
Configuration config = model.configuration();
962962
Weights weights = model.weights();
@@ -1010,6 +1010,12 @@ static FloatTensor forward(Llama model, State state, int[] tokens, int position)
10101010
state.v[t].copyTo(0, state.valueCache[curLayer], (position + t) * kvDim, kvDim);
10111011
});
10121012

1013+
// If the logits are not required, the attention and FFN of the last layer can be skipped entirely.
1014+
if (!computeLogits && curLayer == config.numberOfLayers - 1) {
1015+
state.idxPrevBlock = nTokens - 1;
1016+
return null;
1017+
}
1018+
10131019
// multihead attention. iterate over all heads
10141020
Parallel.parallelForLong(0, (long) nTokens * (long) config.numberOfHeads, ht -> {
10151021
int token = (int) (ht / config.numberOfHeads);
@@ -1136,7 +1142,7 @@ public static List<Integer> generateTokens(Llama model, State state, int startPo
11361142
int promptIndex = 0;
11371143
for (int position = startPosition; position < maxTokens; ++position) {
11381144
if (promptIndex < promptTokens.size()) {
1139-
final int nTokens = Math.min(promptTokens.size() - promptIndex, state.batchsize);
1145+
final int nTokens = Math.min(maxTokens - position, Math.min(promptTokens.size() - promptIndex, state.batchsize));
11401146
final int[] tokens = new int[nTokens];
11411147
for (int i = 0; i < nTokens; i++) {
11421148
tokens[i] = promptTokens.get(promptIndex + i);
@@ -1148,15 +1154,17 @@ public static List<Integer> generateTokens(Llama model, State state, int startPo
11481154
if (echo) {
11491155
System.out.format("position=%d, promptIdx=%d, promptSize=%d, tokens=%s%n", position, promptIndex, promptTokens.size(), Arrays.toString(tokens));
11501156
}
1151-
forward(model, state, tokens, position);
1157+
// Only compute logits on the very last batch.
1158+
boolean computeLogits = promptIndex + nTokens >= promptTokens.size();
1159+
forward(model, state, tokens, position, computeLogits);
11521160
position += nTokens - 1; // -1 -> incremented later in the for loop
11531161
promptIndex += nTokens;
11541162
if (promptIndex < promptTokens.size()) {
11551163
continue;
11561164
}
11571165
startGen = System.nanoTime();
11581166
} else {
1159-
forward(model, state, new int[]{token}, position);
1167+
forward(model, state, new int[]{token}, position, true);
11601168
}
11611169
nextToken = sampler.sampleToken(state.logits);
11621170
if (echo) {

0 commit comments

Comments
 (0)