-
Notifications
You must be signed in to change notification settings - Fork 0
DO NOT MERGE - for comparison purposes #2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…then doing bash toploc-scripts/analyze_activations.py
…lock may be unnecessary. will revisit.
…ing's implementation of capturing hidden states
…n, even though they are being generated
…r all the tokens instead of just the last one. will investigate / refine later.
…he ModelWorkerBatch
… runner .forward method
…is messy and will clean up shortly.
| if logits_metadata.toploc_verification: | ||
| toploc_verification_hidden_states_to_store = ( | ||
| pruned_states[sample_indices] if sample_indices else pruned_states | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logits_processor contains information at the batch level, so basically the first dimension has all the inferences in the current batch concatenated together. each inference in a batch is called a sequence, "seq".
Thus if seq 1 has k tokens, and seq 2 has m tokens, then pruned_states has the hidden states for indices (k-1) and (k+m-1), which are the last tokens for seq 1 and seq 2.
In some cases there's only 1 sequence in a batch, but not in all cases.
Slicing out the hidden states for the last token in each sequence is what pruned_states is. There's also this sample_indices but that has to do with more exotic usages of sglang that I don't currently think are relevant.
| origin_input_ids: Optional[List[List[int]]] = None | ||
| # Output token ids (for return_output_ids=True) | ||
| output_token_ids: Optional[List[List[int]]] = None | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
GenerateReqInput represents either a batch or a single request, depending on the context it's used in. I'm following the pattern here as is used for all the other fields.
|
|
||
| # Ensure CAPTURE_HIDDEN_MODE is *at least* LAST if toploc verification is enabled | ||
| if self.toploc_verification and capture_hidden_mode == CaptureHiddenMode.NULL: | ||
| capture_hidden_mode = CaptureHiddenMode.LAST |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CaptureHiddenMode has to be at least LAST in order for toploc verification to work, because otherwise we can't capture the last layer's activations in order to generate or validate a fingerprint. CaptureHiddenMode signals to pytorch to retain these values so they can be cloned to the CPU.
| toploc_verification_hidden_state, | ||
| req.toploc_verification_fingerprint_to_validate, | ||
| ) | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You'll see a very similar block of code relating to hidden_states here as well. The difference is that I'm invoking the fingerprint and/or fingerprint verification methods depending on what is being requested.
|
|
||
| if req.grammar is not None: | ||
| req.grammar.accept_token(next_token_id) | ||
| req.grammar.finished = req.finished() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this shouldn't have been dropped? this may have been an oversight or a side effect of syncing with the main branch.
| logger.error( | ||
| f"Error processing toploc verification fingerprint validation results: {e}" | ||
| ) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks like a lot, but it's all just what's required to pass things through in the API layer.
| default=128, | ||
| help="Top-k for TopLoc verification", | ||
| ) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is where we specify the flags.
| # No need to generate a fingerprint until the last decode step of the sequence | ||
| req.toploc_verification_hidden_states.append(None) | ||
| req.toploc_verification_fingerprints.append(None) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
process_batch_result_decode is where "post processing" on a "decode" happens. A "decode" is the production of a single new token. We use the req.finished flag to only do toploc stuff on the last token.
…as another setting to test
Motivation
Modifications
Checklist