Skip to content
Open
20 changes: 12 additions & 8 deletions models/llama/llama/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,17 +193,21 @@ def generate(
for cur_pos in range(min_prompt_len, total_len):
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
if temperature > 0:
if self.use_triton:
probs = triton_softmax(logits[:,-1])
else:
probs = self.Math.softmax(logits[:, -1] / temperature, dim=-1)
# if self.use_triton:
# probs = triton_softmax(logits[:,-1])
# else:
# probs = self.Math.softmax(logits[:, -1] / temperature, dim=-1)
MathOps.softmax(logits[:, -1] / temperature, dim=-1)



next_token = sample_top_p(probs, top_p)
else:
if self.use_triton:
next_token = self.triton.language.argmax(logits[:, -1], axis=-1)
else:
next_token = self.Math.argmax(logits[:, -1], dim=-1)
# if self.use_triton:
# next_token = self.triton.language.argmax(logits[:, -1], axis=-1)
# else:
# next_token = self.Math.argmax(logits[:, -1], dim=-1)
MathOps.argmax(logits[:,-1], dim = -1)

next_token = next_token.reshape(-1)
# only replace token if prompt has already been generated
Expand Down
6 changes: 4 additions & 2 deletions models/llama/llama/math_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from kernels.cross_entropy import cross_entropy
from kernels.matmul import matmul
from kernels.flash_attention import attention
from kernels.fused_softmax import triton_softmax
from benchmarking import Profiler
import time

Expand Down Expand Up @@ -70,14 +71,15 @@ def attention(self, xq, keys, values, head_dim, mask):

@Profiler.profiling_decorator("softmax")
def softmax(self, x, dim):
if self.use_triton:
return F.softmax(x, dim=-1)
if self.use_triton and len(x) == 2:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like you're trying to check the number of dimensions here, right? len(x) gets the number of elements, equivalent to x.numel(). I think you want x.dim() or x.ndim.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

return triton_softmax(x, dim=-1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we passing dim=-1 to these calls, when we receive dim as an argument? Let's pass it through properly instead of overriding it. (Also, does the fused Triton kernel actually handle dim!=-1 correctly?)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently it does not handle dim != -1 . Looking into it (seeing how llama.cpp is doing this) if you have any pointers.

else:
return F.softmax(x, dim=-1)

@Profiler.profiling_decorator("argmax")
def argmax(self, x, dim):
if self.use_triton:
# TODO: change
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of adding a TODO to the code here, would you mind creating an issue to track it?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed

return torch.argmax(x, dim=-1)
else:
return torch.argmax(x, dim=-1)
Expand Down