Skip to content

Conversation

@kyle-pena-kuzco
Copy link

Motivation

Modifications

Checklist

Kyle Pena and others added 30 commits March 27, 2025 18:55
…then doing bash toploc-scripts/analyze_activations.py
…ing's implementation of capturing hidden states
…r all the tokens instead of just the last one. will investigate / refine later.
if logits_metadata.toploc_verification:
toploc_verification_hidden_states_to_store = (
pruned_states[sample_indices] if sample_indices else pruned_states
)
Copy link
Author

@kyle-pena-kuzco kyle-pena-kuzco Apr 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logits_processor contains information at the batch level, so basically the first dimension has all the inferences in the current batch concatenated together. each inference in a batch is called a sequence, "seq".

Thus if seq 1 has k tokens, and seq 2 has m tokens, then pruned_states has the hidden states for indices (k-1) and (k+m-1), which are the last tokens for seq 1 and seq 2.

In some cases there's only 1 sequence in a batch, but not in all cases.

Slicing out the hidden states for the last token in each sequence is what pruned_states is. There's also this sample_indices but that has to do with more exotic usages of sglang that I don't currently think are relevant.

origin_input_ids: Optional[List[List[int]]] = None
# Output token ids (for return_output_ids=True)
output_token_ids: Optional[List[List[int]]] = None

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GenerateReqInput represents either a batch or a single request, depending on the context it's used in. I'm following the pattern here as is used for all the other fields.


# Ensure CAPTURE_HIDDEN_MODE is *at least* LAST if toploc verification is enabled
if self.toploc_verification and capture_hidden_mode == CaptureHiddenMode.NULL:
capture_hidden_mode = CaptureHiddenMode.LAST
Copy link
Author

@kyle-pena-kuzco kyle-pena-kuzco Apr 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CaptureHiddenMode has to be at least LAST in order for toploc verification to work, because otherwise we can't capture the last layer's activations in order to generate or validate a fingerprint. CaptureHiddenMode signals to pytorch to retain these values so they can be cloned to the CPU.

toploc_verification_hidden_state,
req.toploc_verification_fingerprint_to_validate,
)
)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You'll see a very similar block of code relating to hidden_states here as well. The difference is that I'm invoking the fingerprint and/or fingerprint verification methods depending on what is being requested.


if req.grammar is not None:
req.grammar.accept_token(next_token_id)
req.grammar.finished = req.finished()
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this shouldn't have been dropped? this may have been an oversight or a side effect of syncing with the main branch.

logger.error(
f"Error processing toploc verification fingerprint validation results: {e}"
)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like a lot, but it's all just what's required to pass things through in the API layer.

default=128,
help="Top-k for TopLoc verification",
)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is where we specify the flags.

# No need to generate a fingerprint until the last decode step of the sequence
req.toploc_verification_hidden_states.append(None)
req.toploc_verification_fingerprints.append(None)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

process_batch_result_decode is where "post processing" on a "decode" happens. A "decode" is the production of a single new token. We use the req.finished flag to only do toploc stuff on the last token.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants