@@ -956,7 +956,7 @@ static void rmsnorm(FloatTensor out, FloatTensor x, FloatBuffer weight, int size
956956 out .mapWithIndexInPlace (0 , size , (value , index ) -> weight .get (index ) * (finalss * x .getFloat (index )));
957957 }
958958
959- static FloatTensor forward (Llama model , State state , int [] tokens , int position ) {
959+ static FloatTensor forward (Llama model , State state , int [] tokens , int position , boolean computeLogits ) {
960960 // a few convenience variables
961961 Configuration config = model .configuration ();
962962 Weights weights = model .weights ();
@@ -1010,6 +1010,12 @@ static FloatTensor forward(Llama model, State state, int[] tokens, int position)
10101010 state .v [t ].copyTo (0 , state .valueCache [curLayer ], (position + t ) * kvDim , kvDim );
10111011 });
10121012
1013+ // If the logits are not required, the attention and FFN of the last layer can be skipped entirely.
1014+ if (!computeLogits && curLayer == config .numberOfLayers - 1 ) {
1015+ state .idxPrevBlock = nTokens - 1 ;
1016+ return null ;
1017+ }
1018+
10131019 // multihead attention. iterate over all heads
10141020 Parallel .parallelForLong (0 , (long ) nTokens * (long ) config .numberOfHeads , ht -> {
10151021 int token = (int ) (ht / config .numberOfHeads );
@@ -1136,7 +1142,7 @@ public static List<Integer> generateTokens(Llama model, State state, int startPo
11361142 int promptIndex = 0 ;
11371143 for (int position = startPosition ; position < maxTokens ; ++position ) {
11381144 if (promptIndex < promptTokens .size ()) {
1139- final int nTokens = Math .min (promptTokens .size () - promptIndex , state .batchsize );
1145+ final int nTokens = Math .min (maxTokens - position , Math . min ( promptTokens .size () - promptIndex , state .batchsize ) );
11401146 final int [] tokens = new int [nTokens ];
11411147 for (int i = 0 ; i < nTokens ; i ++) {
11421148 tokens [i ] = promptTokens .get (promptIndex + i );
@@ -1148,15 +1154,17 @@ public static List<Integer> generateTokens(Llama model, State state, int startPo
11481154 if (echo ) {
11491155 System .out .format ("position=%d, promptIdx=%d, promptSize=%d, tokens=%s%n" , position , promptIndex , promptTokens .size (), Arrays .toString (tokens ));
11501156 }
1151- forward (model , state , tokens , position );
1157+ // Only compute logits on the very last batch.
1158+ boolean computeLogits = promptIndex + nTokens >= promptTokens .size ();
1159+ forward (model , state , tokens , position , computeLogits );
11521160 position += nTokens - 1 ; // -1 -> incremented later in the for loop
11531161 promptIndex += nTokens ;
11541162 if (promptIndex < promptTokens .size ()) {
11551163 continue ;
11561164 }
11571165 startGen = System .nanoTime ();
11581166 } else {
1159- forward (model , state , new int []{token }, position );
1167+ forward (model , state , new int []{token }, position , true );
11601168 }
11611169 nextToken = sampler .sampleToken (state .logits );
11621170 if (echo ) {
0 commit comments