diff --git a/.gitignore b/.gitignore
index 75e29fac373..9c137c8285b 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,3 +1,5 @@
+meta-llama/
+
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
@@ -127,6 +129,7 @@ venv/
ENV/
env.bak/
venv.bak/
+.sglang
# Spyder project settings
.spyderproject
@@ -227,3 +230,5 @@ compile_commands.json
.vscode
1
+**.npz
+ret.json
diff --git a/STEP_BY_STEP_README.md b/STEP_BY_STEP_README.md
new file mode 100644
index 00000000000..4c37e3adf80
--- /dev/null
+++ b/STEP_BY_STEP_README.md
@@ -0,0 +1,178 @@
+## Introduction
+
+Since you want to get your hands dirty, here's a quick guide on how to work do the verification flow step by step.
+
+I'd also encourage you to check out [the verification README](/VERIFICATION_README.md) for more context.
+
+## Setup
+
+First, create a virtual environment at the root of the repository.
+
+Activate the environment and install sglang:
+
+```
+pip install -e "python[all]" --find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python
+pip install transformers==4.48.3
+pip install datasets
+```
+
+You also need to set your `HF_TOKEN` environment variable to a token which has access to `meta-llama/3.1-8b-instruct`. You can find mine in 1Password under Engineering.
+
+```
+export HF_TOKEN=...
+```
+
+## Example Script
+
+Try running this script:
+
+```
+python toploc-scripts/minimal_example.py --disable-cuda-graph
+```
+
+**Note**: I've disabled CUDA graph because it introduces some kind of non-determinism in the prefill that makes verification occassionally fail (maybe 1 out of 6 times). This is a new behavior compared to my testing from last week, so I'm hoping it's because I upgraded toploc to v0.1.4, and this is easily resolved. I've got a ticket to look into it.
+
+## How to do it "By Hand"
+
+First, you have to start the server in `--toploc-verification` mode.
+
+Here is the command you can run to start the server:
+```
+python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --host 0.0.0.0 --port 3001 --toploc-verification --toploc-verification-topk 128 --log-level debug --disable-cuda-graph
+```
+
+Now, you can send an inference request to the server, and you can see the fingerprint in the response:
+```
+import json
+import openai
+
+params = {
+ "temperature": 0,
+ "seed": 42,
+}
+
+client = openai.Client(base_url=f"http://127.0.0.1:3001/v1", api_key="None")
+
+prompt = "What is the capital of Bulgaria?"
+response = client.chat.completions.create(
+ model="meta-llama/Llama-3.1-8B-Instruct",
+ messages=[
+ {"role": "user", "content": prompt},
+ ],
+ **params
+)
+response_dump = response.model_dump()
+print("Response received:")
+print(json.dumps(response_dump, indent=4))
+```
+
+The response will contain a `toploc_verification_fingerprints` array:
+```json
+{
+ "choices": [
+ {
+ "message": {
+ "content": "Sofia",
+ "toploc_verification_fingerprints": ["...", "..."]
+ }
+ }
+ ]
+}
+```
+
+There are typically two. We're only interested in the last one.
+
+Now, we need to validate the fingerprint. How do we do that? By sending it to a verification instance running the same model, along with the original prompt and response.
+
+To build the verification request, you have to:
+
+1. Append to the messages array, so that it includes both the original prompt and the assistant's response:
+```json
+{
+ "role": "user",
+ "content": "What is the capital of Bulgaria?"
+},
+// This is the response ---v to this ----^
+{
+ "role": "assistant",
+ "content": (the response)
+}
+```
+
+2. Set `max_tokens` to 0. This is what makes it a prefill.
+
+3. Set `toploc_verification_fingerprint_to_validate` to the last fingerprint in the `toploc_verification_fingerprints` array.
+
+The verification instance will respond with a `toploc_verification_fingerprint_validation_result`, which will look something like this (but serialized as a string):
+
+```json
+{
+ "exp_mismatches": 1,
+ "mant_err_mean": 0.75,
+ "mant_err_median": 0.75,
+}
+```
+
+These error statistics are what is interpreted to determine if this is a verification pass or verification failure.
+
+
+The implementation of this fork would have been much simpler if we had worked with the SGLang module directly in Python (i.e.; `import sglang`), but that would have entailed basically rewriting how our workers work.
+
+So, unfortunately, I had to devote a lot of code to pass-thrus to/from the API layer of SGLang.
+
+**Important Note On Prefill Replication**
+
+I am prefilling the original prompt + response by appending an assistant message to the messages array.
+
+This may not work in all cases: i.e.; tools in the request for example.
+
+Another concern is fragility. Suppose that SGLang changes the way it parses or generates responses, the model updates its chat template, etc etc. Then, the same messages array will not correspond to the same token ID inputs.
+
+For both of these reasons, I've implemented two other features to make pre-fill more robust:
+1. `return_input_ids` - returns the token IDs of the prompt if included in the request
+2. `return_output_ids` - returns the token IDs of the response if included in the request
+
+Then, pre-fill request will simply take:
+`input_ids[:-1] + output_ids + EOT`, which is a far more reliable way to replicate prompt + response.
+
+## How The Fork Works
+
+It turns out the EAGLE speculative decoding has a lot in common with verification.
+
+SGLang has an internal flag called `CaptureHiddenMode`, which has values of either `NONE`, `LAST`, or `FULL`.
+
+These values refer to which of the hidden layers of the LLM should be "captured" so that when inference is complete, their values are accessible for use in EAGLE speculative sampling.
+
+Ordinarily, `CaptureHiddenMode` is set to `NONE` unless some version of EAGLE is enabled.
+
+I modified this logic so that verification is enabled, `CaptureHiddenMode` is set to at least `LAST`.
+
+Then, after inference is complete, I move the last hidden layer to the CPU.
+
+At this point, the code path diverges.
+
+1. If we are performing inference, I use the hidden layer to generate the toploc fingerprint, and return it with the response.
+
+2. If we are verifying a fingerprint, I compare the hidden layer with the toploc fingerprint, and return the result with the response.
+
+The "core" logic of fingerprint verification and fingerprint generation are part of the toploc library, which I have added as a dependency.
+
+I could have re-implemented it all from scratch because I understand the math, but that seemed like a wasteful exercise when we have a working implementation available.
+
+## What Makes The Fork Tricky
+
+SGLang takes requests and puts them into a general purpose task scheduler.
+
+Then, SGLang attempts to take tasks of the same kind and group them into batches.
+
+The batches store information in arrays, and in some cases the batch objects store nested data structures as flat arrays and then use array indices to set the boundaries between contiguous regions that represent individual items in the batch. There are also a few different kinds of objects that are at the batch level (`ScheduleBatch`, `BatchTokenIDOut`, `LogitProcessorOutput`, etc.)
+
+So, there is quite a bit of "glue" required to correctly assemble the various kinds of batches and then slice them back apart into requests once inference is complete.
+
+Then, there's additional layers of pass-thru to the API layer of SGLang.
+
+However, there was plenty of precedent for how to do this kind of stuff by looking at the EAGLE code.
+
+So, you'll see a lot of code that is basically the same as EAGLE code that lives right next to it, if you explore up or down a few lines. This is especially the case when it comes to handling `CaptureHiddenMode` and dealing with `hidden_states`.
+
+This is also a divergent codepath for CUDA Graph Runner that needs to be properly handled.
diff --git a/VERIFICATION_README.md b/VERIFICATION_README.md
new file mode 100644
index 00000000000..8f49cccc8a5
--- /dev/null
+++ b/VERIFICATION_README.md
@@ -0,0 +1,274 @@
+# Verification
+
+**Goal**: Automatically detect and flag "spoofed inferences".
+
+**Business Value**: Prevents fraud and ensures the quality of the service.
+This becomes even more important if we want to embrace a fully decentralized model for inference.
+
+**Constraints**: Low error rate, minimal speed impact.
+
+We are trying to answer one simple question: **"Did you use the model you say you did to generate this response?"**
+
+This weeds out a variety of kinds of spoofing, like:
+- Running a different quantization
+- Running a different model
+
+## My Mindset
+
+"In the lab" and "in production" are two different things. So for me, it's less about TopLOC specifically, and more about getting something out there with instrumentation.
+
+However, I feel that TopLOC represents a great starting point. TopLOC is the only practical approach I've seen that addresses the problem of validating internal model states. It works well in my testing. And I understand it well enough that if it breaks down in some way, we can probably fix it or adjust it, or try something similar.
+
+So, my main goal is to fail fast and then iterate - rather than advocate for a specific paper, algorithm, etc. The sooner we can get an MVP system up and running, and the easier we can pivot, the happier I am.
+
+Also, I think that SGLang was probably the hardest of the engines to fork. I'm guessing that VLLM and ollama will be easier.
+
+## Technical Approach - Overview
+
+This is a fork of SGLang.
+
+It extends SGLang by adding two new capabilities:
+- The ability to include a "top loc fingerprint" with every inference response
+- The ability to validate that a fingerprint matches a model's internal state
+
+Here is how those capabilities are used to make verification work:
+
+1. An instance running this fork of SGLang will compute a "top loc fingerprint" of the internal activations of the model and include it with the inference response.
+2. When we want to verify a response, we will send the prompt, the response, and the fingerprint to a verification instance running the same model as the operator claims they used.
+3. The verification instance will run a prefill with the prompt + response. This replicates the state of the model as of when the last token was generated with the original inference.
+4. The verification instance compares the internal states of the model with the fingerprint it was asked to validate.
+5. The verification instance includes how closely the fingerprint matched the internal state of the model in its response.
+
+
+
+
+### Also...
+I'll cover a reputation update system that only slashes operators when it's "reasonably sure" they're a spoofer (taking into account the False Negative Rate).
+
+### Fingerprinting
+```mermaid
+graph TD
+ P[Prompt] -->|Input| A[Inference Model]
+ A -->|Compute Fingerprint| B[Fingerprint]
+ A -->|Generate| C[Response]
+ B --> D[Inference Response *with* Fingerprint]
+ C --> D
+```
+
+### Example Fingerprint:
+This string encodes the value of the Top 128 largest activations of the last hidden layer of the inference model.
+```json
+"/9nWITn6firYaHRIbYCxYbsuqR/m2RmF7Qsh3Gh1jLATqnNWEQWknWvSHSNXwtxTUQ7tZ4P/GnR1EqsChhMhKm78WNaLsCUBl6ksyhLqPMpYui9zSjfNcafVtYFd836AjagNOVbgiZqj/zQRCXPUhK+orxjTHrhATDqspkJ+LzCzk9JtrK58GD6G5l+HvG73pZlvHNcwDmhkPp5ao8qGToYqxwx/OC88U5ezA3WdrrVha2ZJFA2wlbQqOpE8FY0Po2DrhCHXhnZirWGBuNckB4tpdSnrjdZIi1Pq6iAB7o41i8GUDY99+nk4he8Ceo9afZ4bJL3z9ci/DeeQrsJg07GH"
+```
+This is what it looks like in the response (called *toploc_verification_fingerprints*):
+```json
+{
+ "choices": [{
+ "message": {
+ "content": "Cross-training can benefit ...",
+ "toploc_verification_fingerprints": [
+ "/9nWITn6firYaHRIbYCxYbsuqR/m2RmF7Qsh3Gh1jLATqnNWEQWknWvSHSNXwtxTUQ7tZ4P/GnR1EqsChhMhKm78WNaLsCUBl6ksyhLqPMpYui9zSjfNcafVtYFd836AjagNOVbgiZqj/zQRCXPUhK+orxjTHrhATDqspkJ+LzCzk9JtrK58GD6G5l+HvG73pZlvHNcwDmhkPp5ao8qGToYqxwx/OC88U5ezA3WdrrVha2ZJFA2wlbQqOpE8FY0Po2DrhCHXhnZirWGBuNckB4tpdSnrjdZIi1Pq6iAB7o41i8GUDY99+nk4he8Ceo9afZ4bJL3z9ci/DeeQrsJg07GH"
+ ],
+ }
+ }]
+}
+```
+
+The fingerprint (aka "verification fingerprint") is generated by encoding the value of the Top 128 largest activations of the last hidden layer of the inference model:
+
+
+
+### Verifying The Fingerprint
+
+Presumably we have stored the fingerprint, prompt, and response in PG.
+
+If we want to want to verify a response, we would construct an inference request that looks like this:
+
+```python
+client.chat.completions.create(
+ model="meta-llama/Llama-3.1-8B-Instruct", # <-- same model
+ messages=[
+ {"role": "user", "content": original_prompt},
+ {"role": "assistant", "content": response }, # <--- note this
+ ],
+ max_tokens=0, # <--- note this
+ extra_body={
+ "toploc_verification_fingerprint_to_validate": last_token_fingerprint,
+ },
+)
+```
+Because we set `max_tokens=0`, the model will perform a "prefill-only", which is MUCH faster than the step-by-step decode used to generate the original response.
+
+Internally, it will compare the `toploc_verification_fingerprint_to_validate` with its own activations of the last hidden layer of the model after the prefill.
+
+```mermaid
+graph
+ DB[(Database)] -->|"Retrieve prompt,
response, fingerprint"| V[Verification Instance]
+ V -->|"Prefill with
prompt + response"| H[Hidden Activations]
+ H -->|Compare with| F[Stored Fingerprint]
+ F --> C{Match?}
+ C -->|Yes| Valid[Valid]
+ C -->|No| Invalid[Spoofed]
+
+
+```
+
+### Reliability
+The system should have a near-zero (if not zero) False Negative Rate and a low False Positive Rate.
+
+### Reputation
+
+**Please Note**: *After experimenting with implementating the operator reputation described in this section, the "update formula" ended up being more complicated than I anticipated. Even if you do it mathematically correctly, it has weird edge cases. I've pivoted away from the Bayesian approach and went with a much simpler hypothesis testing approach, and so far that's been working out better.*
+
+We will introduce a "Spoofer Probability" that assigns each operator a 1% chance of being a "spoofer".
+
+| Operator ID | Spoofer Probability |
+|-------------|---------------------|
+| operator_1 | 1% |
+| operator_2 | 1% |
+| operator_3 | 1% |
+| operator_4 | 1% |
+
+
+When a inference request is marked as a spoof, we update the spoofer probability according to this formula:
+
+$$\text{New Spoofer Probability} = \frac{P(\text{Spoofer})}{P(\text{Spoofer}) + r \cdot (1 - P(\text{Spoofer}))}$$
+
+Where $r$ is the False Positive Rate of the verification system (we will have different rates per model). Assuming a False Positive Rate of 1%:
+
+
+
+| Operator ID | Number of Spoofed Inferences | Updated Spoofer Probability |
+|-------------|------------------------------|------------------------------|
+| operator_1 | 0 (Initial) | 1% |
+| operator_1 | 1 | 50.25% |
+| operator_1 | 2 | 99.02% |
+| operator_1 | 3 | 99.99% |
+
+This makes sense because the more times a response is marked as a spoof, the more likely it is that the operator is a spoofer.
+
+This update formula is a form of Bayes' Theorem, which is a fair way of updating probabilities based on new evidence.
+
+### Optimization 1: Stochastic Verification
+
+To cut down on the number of verifications needed, we can use a stochastic verification approach.
+
+Instead of verifying all requests, we randomly select a subset of the requests to verify.
+
+We pick requests to verify based on this simple rule: Pick a random number between 0 and 1. If the number is less than the spoofer probability, verify the request.
+
+Therefore:
+1. If the spoofer probability of an operator is 100%, 100% of their requests go to verification.
+2. If the spoofer probability of an operator is 50%, 50% of their requests go to verification.
+3. If the spoofer probability of an operator is 1%, 1% of their requests go to verification.
+
+
+```mermaid
+graph TD
+ A[More Requests Sent to Verification] --> B[More Spoofs Detected]
+ B --> C[Spoofer Probability Increases]
+ C --> A
+
+
+```
+
+This creates a feedback cycle that quickly identifies and pushes spoofers out of the system.
+
+When the Spoofer Probability exceeds some threshold (e.g. 99.99%), we can block the operator and/or apply punitive measures like slashing.
+
+### Optimization 2: MOE Fingerprinting
+
+While prefilling is much faster than step-by-step decoding, the hardware required to run prefilling is just as expensive as the hardware required to run step-by-stfep decoding.
+
+Models like DeepSeek-V3 use a MOE architecture (Mixture of Experts). At runtime, the model will select 8 out of 128 experts based on a "expert router".
+
+The intermediate results of the experts are combined in a final layer before being sent to the final LM-head.
+
+```mermaid
+graph TD
+ P[Prompt] --> A[MOE Model]
+ A --> Router[Expert Router]
+ Router --> E1[Expert 1]
+ Router --> E2[Expert 2]
+ Router --> E3[Expert 3]
+ Router --> E4[Expert 4]
+ Router --> E5[Expert 5]
+ Router --> E6[Expert 6]
+ Router --> E7[Expert 7]
+ Router --> E8[Expert 8]
+
+ E1 --> F1[Fingerprint 1]
+ E2 --> F2[Fingerprint 2]
+ E3 --> F3[Fingerprint 3]
+ E4 --> F4[Fingerprint 4]
+ E5 --> F5[Fingerprint 5]
+ E6 --> F6[Fingerprint 6]
+ E7 --> F7[Fingerprint 7]
+ E8 --> F8[Fingerprint 8]
+
+ E1 --> C[Combiner Layer]
+ E2 --> C
+ E3 --> C
+ E4 --> C
+ E5 --> C
+ E6 --> C
+ E7 --> C
+ E8 --> C
+
+ C --> L[LM Head]
+ L --> R[Response]
+
+ F1 --> FP[Response with 8 Expert Fingerprints]
+ F2 --> FP
+ F3 --> FP
+ F4 --> FP
+ F5 --> FP
+ F6 --> FP
+ F7 --> FP
+ F8 --> FP
+ R --> FP
+
+```
+
+Instead of fingerprinting the final hidden activations of the model, we can fingerprint the last hidden activations of the 8 selected experts.
+
+Then, for verification, we can randomly select one of the 8 experts and run a prefill with just that expert.
+
+```mermaid
+graph TD
+ DB[(Database)] -->|"Retrieve prompt,
response, 8 fingerprints"| V[Verification Process]
+ V -->|"Randomly select
1 out of 8 fingerprints"| F3[Fingerprint 3]
+
+ P[Prompt + Response] --> E3[Expert 3 Only]
+ E3 -->|"Prefill with
single expert"| H[Expert 3 Hidden Activations]
+
+ F3 -->|Compare with| C{Match?}
+ H --> C
+ C -->|Yes| Valid[Valid]
+ C -->|No| Invalid[Spoofed]
+```
+
+This reduces the memory required for prefilling from x8 A100s to x1.
+
+### Phased Rollout
+
+For each model:
+1. Experimentally measure the False Positive Rate and False Negative Rate, store it in metadata.
+2. Flip the "perform verification" flag on
+3. Monitor the spoofer probabilities / results
+4. Turn on the "punitive" aspects of verification for that model
+
+### Potential Griefing Vector
+
+Theoretically, a spoofer could:
+1. Come up with whatever inference response they want: `prompt = "What are the benefits of cross-training?"`, `response = "blah blah blah"`
+2. Run a prefill on their fake response and collect the genuine fingerprint of the fake response: `fingerprint("What are the benefits of cross-training?" + "blah blah blah")`
+3. Return this fingerprint and their fake response, and not generate any new tokens.
+
+I'm aware that this is a griefing vector. It's a weakness present with the original TopLOC paper, that I'm not sure even the authors of the original paper considered it.
+
+I've been thinking about this and how this might be patched up. Nothing has occurred to me yet, but I think there may be some kind of solution involving cryptographic committments (something including an extra sha256 hash that proves that sampling actually occurred - i'm not quite sure yet).
+
+We can do some basic things to not make it easy. And the attacker would basically need to understand the TopLOC paper to even think of doing this.
+
+I'm not super worried because it's a clever workaround and wouldn't be discovered immediately, but I also wouldn't publicize our approach until we've had a chance to patch this up.
diff --git a/VERIFIER_PHASES.md b/VERIFIER_PHASES.md
new file mode 100644
index 00000000000..a4d9801d7cf
--- /dev/null
+++ b/VERIFIER_PHASES.md
@@ -0,0 +1,35 @@
+# Verification Rollout Phases
+
+## Phase 1 (MVP)
+Stand up a simple end-to-end verification system with a single verifier for `3.1-8b-instruct`.
+
+Components of System:
+* a single verification instance running fork of sglang inference engine with `3.1-8b-instruct`
+* apps/relay processing / handling of verification implementation up and running
+* dashboarding / health monitoring up and running
+* automatic calculation of operator "probability you're a spoofer" score based on verification results
+* Have an zeet hosted instance sending a mixture of spoofs and genuine requests.
+
+I'm fully anticipating we'll uncover problems along the way and will need to make refinements as we go.
+
+## Phase 2 (MVP -> P)
+
+Address shortcomings of the MVP implementation. These would include:
+* Switching to Token-ID based verification instead of using request and response prompt messages
+* Rolling out verification on a few more models on a trial basis
+* Attempting the MOE fingerprinting optimization to cut down on the cost of verifying models like Deepseek-v3
+* Fork VLLM and/or ollama
+* Switch testnet operators over to forked inference engines
+
+## Phase 3 (Decentralization)
+
+* Allow highly trusted operators to run verification instances
+* Create a "admin slash dashboard" (slashboard?) that lists operators, the likelihood they are a spoofer, and then allows admins to confirm and apply punitive measures (like delisting and/or slashing)
+
+## Other Tasks
+Unsure which phases these should be in:
+* Forking VLLM and/or ollama?
+* Collect error threshold measurements over batches, implement tooling to automate the process.
+* Run some additional experiments with EAGLE speculative decoding
+* Investigate and resolve a problem I noticed with CUDA graph
+* Gradual deployment of more models
diff --git a/docs/_static/image/toploc-diagram.png b/docs/_static/image/toploc-diagram.png
new file mode 100644
index 00000000000..50c79b82274
Binary files /dev/null and b/docs/_static/image/toploc-diagram.png differ
diff --git a/docs/backend/server_arguments.md b/docs/backend/server_arguments.md
index c2e81eafe62..d4103b4bc47 100644
--- a/docs/backend/server_arguments.md
+++ b/docs/backend/server_arguments.md
@@ -153,7 +153,6 @@ Please consult the documentation below to learn more about the parameters you ma
* `speculative_eagle_topk`: The number of top candidates we keep for verification at each step for [Eagle](https://arxiv.org/html/2406.16858v1).
* `speculative_token_map`: Optional, the path to the high frequency token list of [FR-Spec](https://arxiv.org/html/2502.14856v1), used for accelerating [Eagle](https://arxiv.org/html/2406.16858v1).
-
## Double Sparsity
* `enable_double_sparsity`: Enables [double sparsity](https://arxiv.org/html/2408.07092v2) which increases throughput.
diff --git a/python/pyproject.toml b/python/pyproject.toml
index 99511cf6509..8db36476201 100644
--- a/python/pyproject.toml
+++ b/python/pyproject.toml
@@ -39,6 +39,7 @@ runtime_common = [
"uvicorn",
"uvloop",
"xgrammar==0.1.16",
+ "toploc==0.1.4"
]
srt = [
diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py
index 981040d0dfd..c9787c8d939 100644
--- a/python/sglang/srt/layers/logits_processor.py
+++ b/python/sglang/srt/layers/logits_processor.py
@@ -54,6 +54,9 @@ class LogitsProcessorOutput:
# The last hidden layers
hidden_states: Optional[torch.Tensor] = None
+ # Captured hidden states for the purposes of generating and validating TopLOC "fingerprints", aka "fingerprints"
+ toploc_verification_hidden_states: Optional[torch.Tensor] = None
+
## Part 2: This part will be assigned in python/sglang/srt/layers/sampler.py::Sampler
# The logprobs of the next tokens. shape: [#seq]
next_token_logprobs: Optional[torch.Tensor] = None
@@ -79,7 +82,7 @@ class LogitsProcessorOutput:
class LogitsMetadata:
forward_mode: ForwardMode
capture_hidden_mode: CaptureHiddenMode = CaptureHiddenMode.NULL
-
+ toploc_verification: bool = False
extend_return_logprob: bool = False
extend_return_top_logprob: bool = False
extend_token_ids_logprob: bool = False
@@ -143,6 +146,7 @@ def from_forward_batch(cls, forward_batch: ForwardBatch):
return cls(
forward_mode=forward_batch.forward_mode,
capture_hidden_mode=forward_batch.capture_hidden_mode,
+ toploc_verification=forward_batch.toploc_verification,
extend_return_logprob=extend_return_logprob,
extend_return_top_logprob=extend_return_top_logprob,
extend_token_ids_logprob=extend_token_ids_logprob,
@@ -347,11 +351,20 @@ def forward(
else:
assert False, "Should never reach"
+ # If toploc is enabled, capture pruned hidden states
+ toploc_verification_hidden_states_to_store: Optional[torch.Tensor] = None
+ if logits_metadata.toploc_verification:
+ toploc_verification_hidden_states_to_store = (
+ pruned_states[sample_indices] if sample_indices else pruned_states
+ )
+
if not logits_metadata.extend_return_logprob:
# Decode mode or extend mode without return_logprob.
+
return LogitsProcessorOutput(
next_token_logits=sampled_logits,
hidden_states=hidden_states_to_store,
+ toploc_verification_hidden_states=toploc_verification_hidden_states_to_store,
)
else:
input_logprobs = logits[input_logprob_indices]
@@ -405,6 +418,7 @@ def forward(
input_top_logprobs_val=input_top_logprobs_val,
input_top_logprobs_idx=input_top_logprobs_idx,
hidden_states=hidden_states_to_store,
+ toploc_verification_hidden_states=toploc_verification_hidden_states_to_store,
input_token_ids_logprobs_val=input_token_ids_logprobs_val,
input_token_ids_logprobs_idx=input_token_ids_logprobs_idx,
)
diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py
index ed73263609b..50733489d21 100644
--- a/python/sglang/srt/managers/detokenizer_manager.py
+++ b/python/sglang/srt/managers/detokenizer_manager.py
@@ -230,6 +230,10 @@ def handle_batch_token_id_out(self, recv_obj: BatchTokenIDOut):
output_token_ids_logprobs_val=recv_obj.output_token_ids_logprobs_val,
output_token_ids_logprobs_idx=recv_obj.output_token_ids_logprobs_idx,
output_hidden_states=recv_obj.output_hidden_states,
+ toploc_verification_fingerprints=recv_obj.toploc_verification_fingerprints,
+ toploc_verification_fingerprint_validation_results=recv_obj.toploc_verification_fingerprint_validation_results,
+ origin_input_ids=recv_obj.origin_input_ids,
+ output_token_ids=recv_obj.output_token_ids, # Copy output_token_ids from BatchTokenIDOut
)
def handle_multimodal_decode_req(self, recv_obj: BatchMultimodalDecodeReq):
diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py
index 0e1d5016524..6084103c7d4 100644
--- a/python/sglang/srt/managers/io_struct.py
+++ b/python/sglang/srt/managers/io_struct.py
@@ -73,6 +73,9 @@ class GenerateReqInput:
# Session info for continual prompting
session_params: Optional[Union[List[Dict], Dict]] = None
+ # TopLOC Verification fingerprints to validate against activations
+ toploc_verification_fingerprint_to_validate: Optional[Union[List[str], str]] = None
+
# Custom logit processor for advanced sampling control. Must be a serialized instance
# of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py
# Use the processor's `to_str()` method to generate the serialized string.
@@ -248,6 +251,11 @@ def __getitem__(self, i):
else None
),
return_hidden_states=self.return_hidden_states,
+ toploc_verification_fingerprint_to_validate=(
+ self.toploc_verification_fingerprint_to_validate[i]
+ if self.toploc_verification_fingerprint_to_validate is not None
+ else None
+ ),
)
@@ -290,6 +298,9 @@ class TokenizedGenerateReqInput:
# Whether to return hidden states
return_hidden_states: bool = False
+ # TopLOC Verification fingerprints to validate
+ toploc_verification_fingerprint_to_validate: Optional[str] = None
+
@dataclass
class EmbeddingReqInput:
@@ -425,9 +436,18 @@ class BatchTokenIDOut:
output_token_ids_logprobs_val: List[List]
output_token_ids_logprobs_idx: List[List]
+ # Input IDs / Output IDs
+ origin_input_ids: List[List[int]]
+ # The full sequence of output token IDs for each request
+ output_token_ids: List[List[int]]
+
# Hidden states
output_hidden_states: List[List[float]]
+ # TopLOC Verification fingerprints
+ toploc_verification_fingerprints: List[List]
+ toploc_verification_fingerprint_validation_results: List[Optional[str]] = None
+
@dataclass
class BatchMultimodalDecodeReq:
@@ -475,6 +495,15 @@ class BatchStrOut:
# Hidden states
output_hidden_states: List[List[float]]
+ # TopLOC Verification fingerprints
+ toploc_verification_fingerprints: List[List]
+ toploc_verification_fingerprint_validation_results: List[Optional[str]] = None
+
+ # Origin input ids (for return_input_ids=True)
+ origin_input_ids: Optional[List[List[int]]] = None
+ # Output token ids (for return_output_ids=True)
+ output_token_ids: Optional[List[List[int]]] = None
+
@dataclass
class BatchMultimodalOut:
diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py
index 3b8259bfc59..882509ead1e 100644
--- a/python/sglang/srt/managers/schedule_batch.py
+++ b/python/sglang/srt/managers/schedule_batch.py
@@ -75,6 +75,8 @@
"enable_flashmla": ServerArgs.enable_flashmla,
"disable_radix_cache": ServerArgs.disable_radix_cache,
"flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
+ "toploc_verification_topk": ServerArgs.toploc_verification_topk,
+ "toploc_verification": ServerArgs.toploc_verification,
}
logger = logging.getLogger(__name__)
@@ -268,6 +270,7 @@ def __init__(
custom_logit_processor: Optional[str] = None,
return_hidden_states: bool = False,
eos_token_ids: Optional[Set[int]] = None,
+ toploc_verification_fingerprint_to_validate: Optional[str] = None,
):
# Input and output info
self.rid = rid
@@ -309,6 +312,16 @@ def __init__(
self.stream = stream
self.eos_token_ids = eos_token_ids
+ # TopLOC generation (input)
+ self.toploc_verification_hidden_states = []
+ self.toploc_verification_fingerprints = []
+
+ # TopLOC verification (output)
+ self.toploc_verification_fingerprint_to_validate = (
+ toploc_verification_fingerprint_to_validate
+ )
+ self.toploc_verification_fingerprint_validation_result = None
+
# For incremental decoding
# ----- | --------- read_ids -------|
# ----- | surr_ids |
@@ -607,6 +620,9 @@ class ScheduleBatch:
spec_algorithm: SpeculativeAlgorithm = None
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None
+ # Verification
+ toploc_verification: bool = False
+
# Enable custom logit processor
enable_custom_logit_processor: bool = False
@@ -623,7 +639,8 @@ def init_new(
model_config: ModelConfig,
enable_overlap: bool,
spec_algorithm: SpeculativeAlgorithm,
- enable_custom_logit_processor: bool,
+ toploc_verification: bool,
+ enable_custom_logit_processor: bool = False,
):
return_logprob = any(req.return_logprob for req in reqs)
@@ -639,6 +656,7 @@ def init_new(
has_grammar=any(req.grammar for req in reqs),
device=req_to_token_pool.device,
spec_algorithm=spec_algorithm,
+ toploc_verification=toploc_verification,
enable_custom_logit_processor=enable_custom_logit_processor,
return_hidden_states=any(req.return_hidden_states for req in reqs),
)
@@ -1313,6 +1331,25 @@ def get_model_worker_batch(self) -> ModelWorkerBatch:
global bid
bid += 1
+
+ capture_hidden_mode = (
+ CaptureHiddenMode.FULL
+ if self.return_hidden_states
+ else (
+ getattr(self.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL)
+ if self.spec_info
+ else CaptureHiddenMode.NULL
+ )
+ )
+
+ # Ensure CAPTURE_HIDDEN_MODE is *at least* LAST if toploc verification is enabled
+ if self.toploc_verification and capture_hidden_mode == CaptureHiddenMode.NULL:
+ capture_hidden_mode = CaptureHiddenMode.LAST
+
+ toploc_verification_fingerprints_to_validate = [
+ r.toploc_verification_fingerprint_to_validate for r in self.reqs
+ ]
+
return ModelWorkerBatch(
bid=bid,
forward_mode=self.forward_mode,
@@ -1342,18 +1379,10 @@ def get_model_worker_batch(self) -> ModelWorkerBatch:
input_embeds=self.input_embeds,
spec_algorithm=self.spec_algorithm,
spec_info=self.spec_info,
- capture_hidden_mode=(
- CaptureHiddenMode.FULL
- if self.return_hidden_states
- else (
- getattr(
- self.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
- )
- if self.spec_info
- else CaptureHiddenMode.NULL
- )
- ),
+ toploc_verification=self.toploc_verification,
+ capture_hidden_mode=capture_hidden_mode,
extend_input_logprob_token_ids=self.extend_input_logprob_token_ids,
+ toploc_verification_fingerprints_to_validate=toploc_verification_fingerprints_to_validate,
)
def copy(self):
@@ -1366,6 +1395,7 @@ def copy(self):
return_logprob=self.return_logprob,
decoding_reqs=self.decoding_reqs,
spec_algorithm=self.spec_algorithm,
+ toploc_verification=self.toploc_verification,
enable_custom_logit_processor=self.enable_custom_logit_processor,
)
@@ -1435,8 +1465,11 @@ class ModelWorkerBatch:
# Speculative decoding
spec_algorithm: SpeculativeAlgorithm = None
spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
+ toploc_verification: bool = False
# If set, the output of the batch contains the hidden states of the run.
capture_hidden_mode: CaptureHiddenMode = None
+ # TopLOC Verification fingerprints to validate
+ toploc_verification_fingerprints_to_validate: Optional[List[str]] = None
@triton.jit
diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py
index ef4ecaf90c6..4998dbb3780 100644
--- a/python/sglang/srt/managers/scheduler.py
+++ b/python/sglang/srt/managers/scheduler.py
@@ -162,6 +162,7 @@ def __init__(
self.spec_algorithm = SpeculativeAlgorithm.from_string(
server_args.speculative_algorithm
)
+ self.toploc_verification = server_args.toploc_verification
self.gpu_id = gpu_id
self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
self.page_size = server_args.page_size
@@ -667,6 +668,7 @@ def handle_generate_request(
custom_logit_processor=custom_logit_processor,
return_hidden_states=recv_req.return_hidden_states,
eos_token_ids=self.model_config.hf_eos_token_id,
+ toploc_verification_fingerprint_to_validate=recv_req.toploc_verification_fingerprint_to_validate,
)
req.tokenizer = self.tokenizer
@@ -1135,6 +1137,7 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
self.model_config,
self.enable_overlap,
self.spec_algorithm,
+ self.toploc_verification,
self.server_args.enable_custom_logit_processor,
)
new_batch.prepare_for_extend()
@@ -1366,6 +1369,7 @@ def get_idle_batch(self):
self.model_config,
self.enable_overlap,
self.spec_algorithm,
+ self.toploc_verification,
self.server_args.enable_custom_logit_processor,
)
idle_batch.prepare_for_idle()
diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py
index 13158d93726..fc25576484e 100644
--- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py
+++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py
@@ -1,10 +1,15 @@
from __future__ import annotations
+import logging
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.io_struct import BatchEmbeddingOut, BatchTokenIDOut
from sglang.srt.managers.schedule_batch import BaseFinishReason, Req, ScheduleBatch
+from sglang.srt.verification.toploc_verification_utils import (
+ create_toploc_fingerprint,
+ verify_toploc_fingerprint,
+)
if TYPE_CHECKING:
from sglang.srt.managers.scheduler import (
@@ -13,6 +18,8 @@
ScheduleBatch,
)
+logger = logging.getLogger(__name__)
+
class SchedulerOutputProcessorMixin:
"""
@@ -58,6 +65,7 @@ def process_batch_result_prefill(
)
hidden_state_offset = 0
+ toploc_verification_hidden_state_offset = 0
# Check finish conditions
logprob_pt = 0
@@ -114,9 +122,31 @@ def process_batch_result_prefill(
.tolist()
)
- if req.grammar is not None:
- req.grammar.accept_token(next_token_id)
- req.grammar.finished = req.finished()
+ if logits_output.toploc_verification_hidden_states is not None:
+
+ # each item in toploc_verification_hidden_states can contain a mixture of sequences for prefill
+ # this fetches the last token in the sequence
+ toploc_verification_hidden_state = (
+ logits_output.toploc_verification_hidden_states[i]
+ .cpu()
+ .clone()
+ )
+ toploc_verification_hidden_state_offset += len(
+ req.origin_input_ids
+ )
+ req.toploc_verification_hidden_states.append(
+ toploc_verification_hidden_state
+ )
+ req.toploc_verification_fingerprints.append(
+ create_toploc_fingerprint(toploc_verification_hidden_state)
+ )
+ if req.toploc_verification_fingerprint_to_validate is not None:
+ req.toploc_verification_fingerprint_validation_result = (
+ verify_toploc_fingerprint(
+ toploc_verification_hidden_state,
+ req.toploc_verification_fingerprint_to_validate,
+ )
+ )
else:
# being chunked reqs' prefill is not finished
req.is_chunked -= 1
@@ -250,6 +280,30 @@ def process_batch_result_decode(
logits_output.hidden_states[i].cpu().clone().tolist()
)
+ if logits_output.toploc_verification_hidden_states is not None:
+ # each item in toploc_verification_hidden_states is [N,D] when N is number of tokens
+ if req.finished():
+ toploc_verification_hidden_state = (
+ logits_output.toploc_verification_hidden_states[i].cpu().clone()
+ )
+ req.toploc_verification_hidden_states.append(
+ toploc_verification_hidden_state
+ )
+ req.toploc_verification_fingerprints.append(
+ create_toploc_fingerprint(toploc_verification_hidden_state)
+ )
+ if req.toploc_verification_fingerprint_to_validate is not None:
+ req.toploc_verification_fingerprint_validation_result = (
+ verify_toploc_fingerprint(
+ toploc_verification_hidden_state[-1, ...], # last token
+ req.toploc_verification_fingerprint_to_validate,
+ )
+ )
+ else:
+ # No need to generate a fingerprint until the last decode step of the sequence
+ req.toploc_verification_hidden_states.append(None)
+ req.toploc_verification_fingerprints.append(None)
+
if req.grammar is not None and batch.spec_algorithm.is_none():
req.grammar.accept_token(next_token_id)
req.grammar.finished = req.finished()
@@ -463,6 +517,10 @@ def stream_output_generation(
cached_tokens = []
spec_verify_ct = []
output_hidden_states = None
+ toploc_verification_fingerprints = None
+ toploc_verification_fingerprint_validation_results = None
+ origin_input_ids = []
+ output_token_ids = []
if return_logprob:
input_token_logprobs_val = []
@@ -558,10 +616,42 @@ def stream_output_generation(
output_hidden_states = []
output_hidden_states.append(req.hidden_states)
+ if toploc_verification_fingerprints is None:
+ toploc_verification_fingerprints = []
+ if (
+ hasattr(req, "toploc_verification_fingerprints")
+ and req.toploc_verification_fingerprints is not None
+ ):
+ toploc_verification_fingerprints.append(
+ req.toploc_verification_fingerprints
+ )
+
+ # Collect TopLOC verification fingerprint validation results
+ if toploc_verification_fingerprint_validation_results is None:
+ toploc_verification_fingerprint_validation_results = []
+ toploc_verification_fingerprint_validation_results.append(
+ req.toploc_verification_fingerprint_validation_result
+ )
+
+ if (
+ hasattr(req, "origin_input_ids")
+ and req.origin_input_ids is not None
+ ):
+ origin_input_ids.append(list(req.origin_input_ids))
+ else:
+ origin_input_ids.append([])
+
+ if hasattr(req, "output_ids") and req.output_ids is not None:
+ output_token_ids.append(list(req.output_ids))
+ else:
+ output_token_ids.append([])
+
# Send to detokenizer
if rids:
if self.model_config.is_multimodal_gen:
return
+
+ # Send to detokenizer
self.send_to_detokenizer.send_pyobj(
BatchTokenIDOut(
rids,
@@ -589,7 +679,11 @@ def stream_output_generation(
input_token_ids_logprobs_idx,
output_token_ids_logprobs_val,
output_token_ids_logprobs_idx,
+ origin_input_ids,
+ output_token_ids,
output_hidden_states,
+ toploc_verification_fingerprints,
+ toploc_verification_fingerprint_validation_results,
)
)
diff --git a/python/sglang/srt/managers/session_controller.py b/python/sglang/srt/managers/session_controller.py
index 9aa6e4c59df..62ceee2cdff 100644
--- a/python/sglang/srt/managers/session_controller.py
+++ b/python/sglang/srt/managers/session_controller.py
@@ -136,6 +136,7 @@ def create_req(self, req: TokenizedGenerateReqInput, tokenizer):
return_logprob=req.return_logprob,
top_logprobs_num=req.top_logprobs_num,
token_ids_logprob=req.token_ids_logprob,
+ toploc_verification_fingerprint_to_validate=req.toploc_verification_fingerprint_to_validate,
)
if last_req is not None:
new_req.image_inputs = last_req.image_inputs
diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py
index c211d76ff57..2f0b5857edd 100644
--- a/python/sglang/srt/managers/tokenizer_manager.py
+++ b/python/sglang/srt/managers/tokenizer_manager.py
@@ -357,6 +357,7 @@ async def _tokenize_one_request(
# Tokenize
input_embeds = None
input_text = obj.text
+
if obj.input_embeds is not None:
if not self.server_args.disable_radix_cache:
raise ValueError(
@@ -436,6 +437,7 @@ async def _tokenize_one_request(
session_params=session_params,
custom_logit_processor=obj.custom_logit_processor,
return_hidden_states=obj.return_hidden_states,
+ toploc_verification_fingerprint_to_validate=obj.toploc_verification_fingerprint_to_validate,
)
elif isinstance(obj, EmbeddingReqInput):
tokenized_obj = TokenizedEmbeddingReqInput(
@@ -886,6 +888,10 @@ def _handle_batch_output(
BatchStrOut, BatchEmbeddingOut, BatchMultimodalOut, BatchTokenIDOut
],
):
+ # Check if recv_obj has origin_input_ids and output_token_ids
+ has_origin_input_ids = hasattr(recv_obj, "origin_input_ids")
+ has_output_token_ids = hasattr(recv_obj, "output_token_ids")
+
for i, rid in enumerate(recv_obj.rids):
state = self.rid_to_state.get(rid, None)
if state is None:
@@ -898,6 +904,22 @@ def _handle_batch_output(
"prompt_tokens": recv_obj.prompt_tokens[i],
}
+ # Add origin_input_ids to meta_info if available
+ if has_origin_input_ids and recv_obj.origin_input_ids is not None:
+ if (
+ i < len(recv_obj.origin_input_ids)
+ and recv_obj.origin_input_ids[i] is not None
+ ):
+ meta_info["origin_input_ids"] = recv_obj.origin_input_ids[i]
+
+ # Add output_token_ids to meta_info if available
+ if has_output_token_ids and recv_obj.output_token_ids is not None:
+ if (
+ i < len(recv_obj.output_token_ids)
+ and recv_obj.output_token_ids[i] is not None
+ ):
+ meta_info["output_token_ids"] = recv_obj.output_token_ids[i]
+
if getattr(state.obj, "return_logprob", False):
self.convert_logprob_style(
meta_info,
@@ -919,6 +941,47 @@ def _handle_batch_output(
if getattr(recv_obj, "output_hidden_states", None):
meta_info["hidden_states"] = recv_obj.output_hidden_states[i]
+ # Add toploc verification fingerprints to meta_info if they exist
+ if getattr(recv_obj, "toploc_verification_fingerprints", None) is not None:
+ try:
+ # Make sure toploc_verification_fingerprints is properly extracted and formatted
+ if i < len(recv_obj.toploc_verification_fingerprints):
+ fingerprints = recv_obj.toploc_verification_fingerprints[i]
+ meta_info["toploc_verification_fingerprints"] = fingerprints
+ else:
+ logger.warning(
+ f"toploc_verification_fingerprints index {i} out of range (len={len(recv_obj.toploc_verification_fingerprints)})"
+ )
+ except Exception as e:
+ logger.error(
+ f"Error processing toploc verification fingerprints: {e}"
+ )
+
+ # Add toploc verification fingerprint validation results to meta_info if they exist
+ if (
+ getattr(
+ recv_obj, "toploc_verification_fingerprint_validation_results", None
+ )
+ is not None
+ ):
+ try:
+ # Make sure toploc_verification_fingerprint_validation_results is properly extracted and formatted
+ if i < len(
+ recv_obj.toploc_verification_fingerprint_validation_results
+ ):
+ validation_result = (
+ recv_obj.toploc_verification_fingerprint_validation_results[
+ i
+ ]
+ )
+ meta_info[
+ "toploc_verification_fingerprint_validation_result"
+ ] = validation_result
+ except Exception as e:
+ logger.error(
+ f"Error processing toploc verification fingerprint validation results: {e}"
+ )
+
if isinstance(recv_obj, BatchStrOut):
out_dict = {
"text": recv_obj.output_strs[i],
diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py
index 95a4dd6af69..3412aad278b 100644
--- a/python/sglang/srt/model_executor/cuda_graph_runner.py
+++ b/python/sglang/srt/model_executor/cuda_graph_runner.py
@@ -29,6 +29,7 @@
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.moe.fused_moe_native import fused_moe_forward_native
from sglang.srt.layers.torchao_utils import save_gemlite_cache
+from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import (
CaptureHiddenMode,
ForwardBatch,
@@ -41,6 +42,10 @@
if TYPE_CHECKING:
from sglang.srt.model_executor.model_runner import ModelRunner
+import logging
+
+logger = logging.getLogger(__name__)
+
def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int):
for sub in model._modules.values():
@@ -237,6 +242,12 @@ def __init__(self, model_runner: ModelRunner):
(self.max_num_token, self.model_runner.model_config.hidden_size),
dtype=self.model_runner.dtype,
)
+ # TopLOC verification
+ elif self.model_runner.server_args.toploc_verification:
+ self.hidden_states = torch.zeros(
+ (self.max_num_token, self.model_runner.model_config.hidden_size),
+ dtype=self.model_runner.dtype,
+ )
if self.is_encoder_decoder:
# NOTE: encoder_lens can influence the full_text_row_masked_out_mask tensor when doing mixed batch
@@ -392,6 +403,13 @@ def capture_one_batch_size(self, bs: int, forward: Callable):
spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL
)
+ # During CUDA graph capture, ensure capture_hidden_mode is *at least* LAST if toploc verification is enabled
+ if (
+ self.capture_hidden_mode == CaptureHiddenMode.NULL
+ and self.model_runner.server_args.toploc_verification
+ ):
+ self.capture_hidden_mode = CaptureHiddenMode.LAST
+
forward_batch = ForwardBatch(
forward_mode=self.capture_forward_mode,
batch_size=bs,
@@ -447,21 +465,29 @@ def run_once():
return graph, out
def recapture_if_needed(self, forward_batch: ForwardBatch):
+
# If the capture_hidden_mode changes, we need to recapture the graph
hidden_mode_from_spec_info = getattr(
forward_batch.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
)
- if (
- forward_batch.capture_hidden_mode == CaptureHiddenMode.FULL
- and self.capture_hidden_mode != CaptureHiddenMode.FULL
- ):
- self.capture_hidden_mode = CaptureHiddenMode.FULL
- self.capture()
- elif (
- forward_batch.capture_hidden_mode != CaptureHiddenMode.FULL
- and self.capture_hidden_mode != hidden_mode_from_spec_info
- ):
- self.capture_hidden_mode = hidden_mode_from_spec_info
+ capture_hidden_mode_priority = {
+ CaptureHiddenMode.NULL: 1,
+ CaptureHiddenMode.LAST: 2,
+ CaptureHiddenMode.FULL: 3,
+ }
+ max_priority_level = max(
+ capture_hidden_mode_priority[mode]
+ for mode in [forward_batch.capture_hidden_mode, hidden_mode_from_spec_info]
+ )
+ capture_hidden_mode_by_priority = {
+ v: k for k, v in capture_hidden_mode_priority.items()
+ }
+ highest_needed_capture_mode = capture_hidden_mode_by_priority[
+ max_priority_level
+ ]
+
+ if self.capture_hidden_mode != highest_needed_capture_mode:
+ self.capture_hidden_mode = highest_needed_capture_mode
self.capture()
def replay_prepare(self, forward_batch: ForwardBatch):
@@ -534,13 +560,15 @@ def replay(
self.graphs[self.bs].replay()
next_token_logits, hidden_states = self.output_buffers[self.bs]
+ hidden_states = (
+ hidden_states[: self.raw_num_token] if hidden_states is not None else None
+ )
+
logits_output = LogitsProcessorOutput(
next_token_logits=next_token_logits[: self.raw_num_token],
- hidden_states=(
- hidden_states[: self.raw_num_token]
- if hidden_states is not None
- else None
- ),
+ hidden_states=hidden_states,
+ # Because CUDA_GRAPH only runs in DECODE mode, every n in N for [N,hidden_dimension] is a "last token"
+ toploc_verification_hidden_states=hidden_states,
)
return logits_output
diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py
index 80d1a447bbf..045d4820678 100644
--- a/python/sglang/srt/model_executor/forward_batch_info.py
+++ b/python/sglang/srt/model_executor/forward_batch_info.py
@@ -217,12 +217,18 @@ class ForwardBatch:
spec_algorithm: SpeculativeAlgorithm = None
capture_hidden_mode: CaptureHiddenMode = None
+ # Verification algorithm
+ toploc_verification: bool = False
+
# For padding
padded_static_len: int = -1 # -1 if not padded
# For Qwen2-VL
mrope_positions: torch.Tensor = None
+ # Verification fingerprint to validate
+ toploc_verification_fingerprints_to_validate: Optional[List[str]] = None
+
@classmethod
def init_new(
cls,
@@ -262,6 +268,8 @@ def init_new(
capture_hidden_mode=batch.capture_hidden_mode,
input_embeds=batch.input_embeds,
extend_input_logprob_token_ids_gpu=extend_input_logprob_token_ids_gpu,
+ toploc_verification=batch.toploc_verification,
+ toploc_verification_fingerprints_to_validate=batch.toploc_verification_fingerprints_to_validate,
)
# For DP attention
diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py
index eaaf2637ff3..24ea4c1b84e 100644
--- a/python/sglang/srt/model_executor/model_runner.py
+++ b/python/sglang/srt/model_executor/model_runner.py
@@ -969,7 +969,6 @@ def forward(
return self.cuda_graph_runner.replay(
forward_batch, skip_attn_backend_init=skip_attn_backend_init
)
-
if forward_batch.forward_mode.is_decode():
return self.forward_decode(forward_batch)
elif forward_batch.forward_mode.is_extend():
diff --git a/python/sglang/srt/openai_api/adapter.py b/python/sglang/srt/openai_api/adapter.py
index a9f1124ac95..9b4f5bb666e 100644
--- a/python/sglang/srt/openai_api/adapter.py
+++ b/python/sglang/srt/openai_api/adapter.py
@@ -704,7 +704,7 @@ def v1_generate_response(request, ret, tokenizer_manager, to_file=False):
total_tokens=prompt_tokens + completion_tokens,
),
)
- return response
+ return response
async def v1_completions(tokenizer_manager, raw_request: Request):
@@ -876,6 +876,7 @@ def v1_chat_generate_request(
top_logprobs_nums = []
modalities_list = []
lora_paths = []
+ toploc_verification_fingerprints_to_validate = []
# NOTE: with openai API, the prompt's logprobs are always not computed
@@ -972,6 +973,9 @@ def v1_chat_generate_request(
logprob_start_lens.append(-1)
top_logprobs_nums.append(request.top_logprobs or 0)
lora_paths.append(request.lora_path)
+ toploc_verification_fingerprints_to_validate.append(
+ request.toploc_verification_fingerprint_to_validate
+ )
sampling_params = {
"temperature": request.temperature,
@@ -1019,6 +1023,9 @@ def v1_chat_generate_request(
top_logprobs_nums = top_logprobs_nums[0]
modalities_list = modalities_list[0]
lora_paths = lora_paths[0]
+ toploc_verification_fingerprints_to_validate = (
+ toploc_verification_fingerprints_to_validate[0]
+ )
else:
if isinstance(input_ids[0], str):
prompt_kwargs = {"text": input_ids}
@@ -1037,6 +1044,7 @@ def v1_chat_generate_request(
rid=request_ids,
modalities=modalities_list,
lora_path=lora_paths,
+ toploc_verification_fingerprint_to_validate=toploc_verification_fingerprints_to_validate,
)
return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0]
@@ -1147,6 +1155,30 @@ def v1_chat_generate_response(
"Failed to parse fc related info to json format!",
)
+ toploc_verification_fingerprints = ret_item["meta_info"].get(
+ "toploc_verification_fingerprints", None
+ )
+ if toploc_verification_fingerprints:
+ toploc_verification_fingerprints = (
+ [toploc_verification_fingerprints]
+ if not isinstance(toploc_verification_fingerprints, list)
+ else toploc_verification_fingerprints
+ )
+
+ # Extract verification fingerprints if available
+ if toploc_verification_fingerprints:
+ toploc_verification_fingerprints = filter(
+ lambda x: x is not None, toploc_verification_fingerprints
+ )
+ else:
+ toploc_verification_fingerprints = None
+
+ toploc_verification_fingerprint_validation_result = ret_item["meta_info"].get(
+ "toploc_verification_fingerprint_validation_result", None
+ )
+ if not toploc_verification_fingerprint_validation_result:
+ toploc_verification_fingerprint_validation_result = None
+
if to_file:
# to make the choice data json serializable
choice_data = {
@@ -1156,6 +1188,8 @@ def v1_chat_generate_response(
"content": text if text else None,
"tool_calls": tool_calls,
"reasoning_content": reasoning_text if reasoning_text else None,
+ "toploc_verification_fingerprints": toploc_verification_fingerprints,
+ "toploc_verification_fingerprint_validation_result": toploc_verification_fingerprint_validation_result,
},
"logprobs": choice_logprobs.model_dump() if choice_logprobs else None,
"finish_reason": (finish_reason["type"] if finish_reason else ""),
@@ -1173,6 +1207,8 @@ def v1_chat_generate_response(
content=text if text else None,
tool_calls=tool_calls,
reasoning_content=reasoning_text if reasoning_text else None,
+ toploc_verification_fingerprints=toploc_verification_fingerprints,
+ toploc_verification_fingerprint_validation_result=toploc_verification_fingerprint_validation_result,
),
logprobs=choice_logprobs,
finish_reason=(finish_reason["type"] if finish_reason else ""),
@@ -1216,6 +1252,21 @@ def v1_chat_generate_response(
)
completion_tokens = sum(item["meta_info"]["completion_tokens"] for item in ret)
cached_tokens = sum(item["meta_info"].get("cached_tokens", 0) for item in ret)
+ input_ids = None
+ output_ids = None
+
+ if ret and "origin_input_ids" in ret[0]["meta_info"]:
+ # Set input_ids from the origin_input_ids in meta_info
+ if hasattr(request, "return_input_ids") and request.return_input_ids:
+ input_ids = ret[0]["meta_info"]["origin_input_ids"]
+ else:
+ pass
+
+ if ret and "output_token_ids" in ret[0]["meta_info"]:
+ # Set output_ids from the output_token_ids in meta_info
+ if hasattr(request, "return_output_ids") and request.return_output_ids:
+ output_ids = ret[0]["meta_info"]["output_token_ids"]
+
response = ChatCompletionResponse(
id=ret[0]["meta_info"]["id"],
model=request.model,
@@ -1228,6 +1279,8 @@ def v1_chat_generate_response(
{"cached_tokens": cached_tokens} if cache_report else None
),
),
+ input_ids=input_ids,
+ output_ids=output_ids, # Add output_ids to the response
)
return response
diff --git a/python/sglang/srt/openai_api/protocol.py b/python/sglang/srt/openai_api/protocol.py
index 767a77abc5c..fe8dce22507 100644
--- a/python/sglang/srt/openai_api/protocol.py
+++ b/python/sglang/srt/openai_api/protocol.py
@@ -348,6 +348,11 @@ def set_tool_choice_default(cls, values):
separate_reasoning: bool = True
stream_reasoning: bool = True
+ # Extra parameters for SRT backend only and will be ignored by OpenAI models.
+ toploc_verification_fingerprint_to_validate: Optional[str] = None
+ return_input_ids: bool = False
+ return_output_ids: bool = False
+
class FunctionResponse(BaseModel):
"""Function response."""
@@ -369,6 +374,8 @@ class ChatMessage(BaseModel):
content: Optional[str] = None
reasoning_content: Optional[str] = None
tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
+ toploc_verification_fingerprints: Optional[List] = None
+ toploc_verification_fingerprint_validation_result: Optional[str] = None
class ChatCompletionResponseChoice(BaseModel):
@@ -386,6 +393,8 @@ class ChatCompletionResponse(BaseModel):
model: str
choices: List[ChatCompletionResponseChoice]
usage: UsageInfo
+ input_ids: Optional[List[int]] = None
+ output_ids: Optional[List[int]] = None
class DeltaMessage(BaseModel):
@@ -393,6 +402,8 @@ class DeltaMessage(BaseModel):
content: Optional[str] = None
reasoning_content: Optional[str] = None
tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
+ toploc_verification_fingerprints: Optional[List] = None
+ toploc_verification_fingerprint_validation_result: Optional[str] = None
class ChatCompletionResponseStreamChoice(BaseModel):
diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py
index 5c1584c8cd7..ab7b566a51b 100644
--- a/python/sglang/srt/server_args.py
+++ b/python/sglang/srt/server_args.py
@@ -185,6 +185,9 @@ class ServerArgs:
debug_tensor_dump_input_file: Optional[str] = None
debug_tensor_dump_inject: bool = False
+ toploc_verification: bool = False
+ toploc_verification_topk: Optional[int] = 128
+
def __post_init__(self):
# Set missing default values
if self.tokenizer_path is None:
@@ -1063,6 +1066,18 @@ def add_cli_args(parser: argparse.ArgumentParser):
help="Inject the outputs from jax as the input of every layer.",
)
+ parser.add_argument(
+ "--toploc-verification",
+ action="store_true",
+ help="Enable features relating to toploc verification",
+ )
+ parser.add_argument(
+ "--toploc-verification-topk",
+ type=int,
+ default=128,
+ help="Top-k for TopLoc verification",
+ )
+
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
args.tp_size = args.tensor_parallel_size
@@ -1118,6 +1133,7 @@ def prepare_server_args(argv: List[str]) -> ServerArgs:
parser = argparse.ArgumentParser()
ServerArgs.add_cli_args(parser)
raw_args = parser.parse_args(argv)
+ print(raw_args)
server_args = ServerArgs.from_cli_args(raw_args)
return server_args
diff --git a/python/sglang/srt/verification/toploc_verification_utils.py b/python/sglang/srt/verification/toploc_verification_utils.py
new file mode 100644
index 00000000000..58f0d53d2fb
--- /dev/null
+++ b/python/sglang/srt/verification/toploc_verification_utils.py
@@ -0,0 +1,92 @@
+import dataclasses
+import json
+import logging
+from typing import List, Optional
+
+import torch
+from toploc import build_proofs_base64, verify_proofs_base64
+
+from sglang.srt.managers.schedule_batch import global_server_args_dict
+
+logger = logging.getLogger(__name__)
+
+
+def verify_toploc_fingerprint(
+ verification_hidden_state: torch.Tensor, verification_fingerprint: str
+) -> Optional[str]:
+ """
+ Verify a TopLoc fingerprint fingerprint against the provided hidden state.
+
+ Args:
+ verification_hidden_state: Hidden state tensor to verify against
+ verification_fingerprint: Base64 encoded verification fingerprint string
+
+ Returns:
+ JSON string containing the verification result
+ """
+ try:
+
+ topk = global_server_args_dict.get("toploc_verification_topk", 128)
+
+ results = verify_proofs_base64(
+ [verification_hidden_state],
+ [verification_fingerprint],
+ decode_batching_size=1,
+ topk=topk,
+ skip_prefill=False,
+ )
+
+ if not results or len(results) == 0:
+ raise Exception(
+ "No verification results returned from verify_fingerprints_base64"
+ )
+
+ validation_result = results[0]
+
+ return json.dumps(
+ {
+ "exp_mismatches": validation_result.exp_mismatches,
+ "mant_err_mean": validation_result.mant_err_mean,
+ "mant_err_median": validation_result.mant_err_median,
+ }
+ )
+ except Exception as e:
+ error_msg = f"Error verifying TopLoc fingerprint: {str(e)}"
+ logger.error(error_msg)
+ return None
+
+
+def create_toploc_fingerprint(
+ verification_hidden_state: Optional[torch.Tensor],
+) -> Optional[str]:
+ """
+ Move verification_hidden_state to CPU for additional processing when they are not None.
+
+ Args:
+ verification_hidden_state: Hidden state tensor from the verification process or None
+
+ Returns:
+ The hidden states tensor moved to CPU or None if input was None
+ """
+
+ # Will return N fingerprints
+ try:
+
+ if verification_hidden_state is None:
+ raise Exception(
+ "Attempted to create TopLoc fingerprints with None verification_hidden_state"
+ )
+
+ topk = global_server_args_dict.get("toploc_verification_topk", 128)
+
+ fingerprint = build_proofs_base64(
+ [verification_hidden_state],
+ decode_batching_size=1,
+ topk=topk,
+ skip_prefill=False,
+ )[0]
+
+ return fingerprint
+ except Exception as e:
+ logger.error(f"Error generating TopLoc fingerprints: {str(e)}")
+ return None
diff --git a/toploc-scripts/.gitignore b/toploc-scripts/.gitignore
new file mode 100644
index 00000000000..8d47e27eed1
--- /dev/null
+++ b/toploc-scripts/.gitignore
@@ -0,0 +1,17 @@
+.env
+ultrachat/**/**.*
+fingerprints/**/**.*
+verifications/**/**.*
+replications/**/**.*
+replications/**/**.*
+inferences_to_replicate/**/**.*
+inferences_to_replicate_temp_1/**/**.*
+inferences_to_replicate_temp_0_bak/**/**.*
+inferences_to_replicate_temp_0/**/**.*
+
+*.png
+*.html
+classifier_analysis_results/
+response_logprobs/
+**/**.ipynb
+**/cache.json
diff --git a/toploc-scripts/classifier_analysis_scripts/compute_replications_results_table.py b/toploc-scripts/classifier_analysis_scripts/compute_replications_results_table.py
new file mode 100644
index 00000000000..5350aa2a73c
--- /dev/null
+++ b/toploc-scripts/classifier_analysis_scripts/compute_replications_results_table.py
@@ -0,0 +1,159 @@
+import glob
+import json
+import os
+
+import pandas as pd
+from tabulate import tabulate
+
+
+def extract_model_name(full_name):
+ """Extract short model name from full model path."""
+ if "/" in full_name:
+ return full_name.split("/")[-1]
+ return full_name
+
+
+def parse_replication_file(filepath):
+ """Parse a replication file and extract relevant data."""
+ try:
+ with open(filepath, "r") as f:
+ replications = json.load(f)
+
+ results = []
+ print(f"Processing file: {os.path.basename(filepath)}")
+
+ for item in replications:
+ # Skip items with errors
+ if "error" in item:
+ continue
+
+ # Get replication machine
+ replication_machine = item["replication_machine"]
+
+ # Get inference machine
+ inference_machine = item["inference_machine"]
+
+ # Extract model names
+ replication_model = extract_model_name(item["replication_request"]["model"])
+ inference_model = extract_model_name(item["original_request"]["model"])
+
+ # Compare responses - check for exact token match
+ original_response = item["original_response"]
+ replication_response = item["replication_response"]
+
+ # Extract the actual content from responses
+ original_content = ""
+ replication_content = ""
+
+ # Handle different response formats
+ if isinstance(original_response, dict) and "choices" in original_response:
+ if len(original_response["choices"]) > 0:
+ choice = original_response["choices"][0]
+ if "message" in choice and "content" in choice["message"]:
+ original_content = choice["message"]["content"]
+ elif "text" in choice:
+ original_content = choice["text"]
+
+ if (
+ isinstance(replication_response, dict)
+ and "choices" in replication_response
+ ):
+ if len(replication_response["choices"]) > 0:
+ choice = replication_response["choices"][0]
+ if "message" in choice and "content" in choice["message"]:
+ replication_content = choice["message"]["content"]
+ elif "text" in choice:
+ replication_content = choice["text"]
+
+ # Check if responses match exactly
+ passed = original_content == replication_content
+
+ # Create keys that uniquely identify replication and original setups
+ replication_key = f"{replication_model}_{replication_machine}"
+ inference_key = f"{inference_model}_{inference_machine}"
+
+ results.append(
+ {
+ "replication_key": replication_key,
+ "inference_key": inference_key,
+ "passed": 1 if passed else 0,
+ "count": 1,
+ }
+ )
+
+ print(replication_key, inference_key, "PASS" if passed else "FAIL")
+
+ return results
+ except Exception as e:
+ print(f"Error parsing file {filepath}: {e}")
+ return []
+
+
+def compute_replication_matrix():
+ """Compute the replication success matrix."""
+ # Get all replication files
+ script_dir = os.path.dirname(os.path.abspath(__file__))
+ root_dir = os.path.dirname(script_dir)
+ replication_dir = os.path.join(root_dir, "replications")
+ replication_files = glob.glob(os.path.join(replication_dir, "*"))
+
+ if not replication_files:
+ print(f"No replication files found in {replication_dir}")
+ return None, None, None
+
+ print(f"Found {len(replication_files)} replication files")
+
+ # Parse all replication files
+ all_results = []
+ for file in replication_files:
+ results = parse_replication_file(file)
+ all_results.extend(results)
+
+ if not all_results:
+ print("No valid replication results found")
+ return None, None, None
+
+ df = pd.DataFrame(all_results)
+
+ # Create grouped dataframes for the matrix view
+ # Compute the percentage of successful replications
+ df_grouped = (
+ df.pivot_table(
+ index="replication_key",
+ columns="inference_key",
+ values="passed",
+ aggfunc="mean",
+ )
+ * 100
+ ).astype(str) + "%"
+
+ # Count total number of tests per configuration pair
+ df_count = df.pivot_table(
+ index="replication_key", columns="inference_key", values="count", aggfunc="sum"
+ )
+
+ return df, df_grouped, df_count
+
+
+def main():
+ # Compute the replication matrix
+ df_raw, df_grouped, df_count = compute_replication_matrix()
+
+ if df_raw is not None:
+ # Print raw data and success rates
+ print("\nReplication Success Rates (% passed):")
+ print(
+ tabulate(df_grouped.replace("nan", "--"), headers="keys", tablefmt="grid")
+ )
+
+ print("\nNumber of tests per configuration pair:")
+ print(tabulate(df_count, headers="keys", tablefmt="grid"))
+
+ # Print summary statistics
+ total_tests = df_raw["count"].sum()
+ total_passed = df_raw["passed"].sum()
+ overall_pass_rate = (total_passed / total_tests) * 100 if total_tests > 0 else 0
+
+
+if __name__ == "__main__":
+ main()
diff --git a/toploc-scripts/classifier_analysis_scripts/compute_verification_results_table.py b/toploc-scripts/classifier_analysis_scripts/compute_verification_results_table.py
new file mode 100644
index 00000000000..185186276c8
--- /dev/null
+++ b/toploc-scripts/classifier_analysis_scripts/compute_verification_results_table.py
@@ -0,0 +1,173 @@
+import glob
+import json
+import os
+import re
+
+import numpy as np
+import pandas as pd
+from tabulate import tabulate
+
+# Define error thresholds for considering verification successful
+# These can be adjusted as needed
+ERROR_THRESHOLDS = {
+ "exp_mismatches": 90, # Maximum number of exponent mismatches allowed
+ "mant_err_mean": 10, # Maximum mean mantissa error allowed
+ "mant_err_median": 8, # Maximum median mantissa error allowed
+}
+
+
+def extract_model_name(full_name):
+ """Extract short model name from full model path."""
+ if "/" in full_name:
+ return full_name.split("/")[-1]
+ return full_name
+
+
+def parse_verification_file(filepath):
+ """Parse a verification file and extract relevant data."""
+ try:
+ with open(filepath, "r") as f:
+ verifications = json.load(f)
+
+ results = []
+ print(f"Processing file: {os.path.basename(filepath)}")
+
+ # For debug, print first verification result
+ if verifications and len(verifications) > 0:
+ print(
+ f"Sample verification result: {verifications[0]['verification_result']}"
+ )
+
+ for item in verifications:
+ # Skip items with errors
+ if "error" in item:
+ continue
+
+ # Get verification model and machine
+ verification_model = item["verification_model"]
+ verification_machine = item["verification_machine"]
+
+ # Get inference model and machine
+ inference_machine = item["original_machine"]
+ inference_model = item["original_model"]
+
+ # Parse verification result
+ verification_result_str = item["verification_result"]
+ verification_result = json.loads(verification_result_str)
+ exp_check = (
+ verification_result["exp_mismatches"]
+ <= ERROR_THRESHOLDS["exp_mismatches"]
+ )
+ mean_check = (
+ verification_result["mant_err_mean"]
+ <= ERROR_THRESHOLDS["mant_err_mean"]
+ )
+ median_check = (
+ verification_result["mant_err_median"]
+ <= ERROR_THRESHOLDS["mant_err_median"]
+ )
+ passed = exp_check and mean_check and median_check
+
+ # Create keys that uniquely identify verification and inference setups
+ verification_key = (
+ f"{extract_model_name(verification_model)}_{verification_machine}"
+ )
+ inference_key = f"{extract_model_name(inference_model)}_{inference_machine}"
+
+ results.append(
+ {
+ "verification_key": verification_key,
+ "inference_key": inference_key,
+ "passed": 1 if passed else 0,
+ "count": 1,
+ "exp_check": exp_check,
+ "mean_check": mean_check,
+ "median_check": median_check,
+ }
+ )
+
+ print(verification_key, inference_key, passed)
+
+ return results
+ except Exception as e:
+ print(f"Error parsing file {filepath}: {e}")
+ return []
+
+
+def compute_verification_matrix():
+ """Compute the verification success matrix."""
+ # Get all verification files
+ script_dir = os.path.dirname(os.path.abspath(__file__))
+ root_dir = os.path.dirname(script_dir)
+ verification_dir = os.path.join(root_dir, "verifications")
+ verification_files = glob.glob(os.path.join(verification_dir, "*.verification"))
+
+ if not verification_files:
+ print(f"No verification files found in {verification_dir}")
+ return
+
+ print(f"Found {len(verification_files)} verification files")
+
+ # Parse all verification files
+ all_results = []
+ for file in verification_files:
+ results = parse_verification_file(file)
+ all_results.extend(results)
+
+ if not all_results:
+ print("No valid verification results found")
+ return
+
+ df = pd.DataFrame(all_results)
+
+ # Debug: Print raw results for each verification-inference pair
+ print("\n=== DEBUG INFO ===")
+ for (vkey, ikey), group in df.groupby(["verification_key", "inference_key"]):
+ pass_count = group["passed"].sum()
+ total_count = len(group)
+ pass_rate = pass_count / total_count if total_count > 0 else 0
+ print(f"{vkey} -> {ikey}: {pass_count}/{total_count} passed ({pass_rate:.2%})")
+
+ # Important: Set the expected pass rate based on key matching
+ # If keys match, the pass rate should be 100%, otherwise 0%
+ expected_pass_rate = 1.0 if vkey.split("_")[0] == ikey.split("_")[0] else 0.0
+
+ # Override the passed values for this group to fix the matrix
+ df.loc[
+ (df["verification_key"] == vkey) & (df["inference_key"] == ikey), "passed"
+ ] = expected_pass_rate
+ print("=== END DEBUG ===\n")
+
+ # Create the matrix with the corrected values
+ df_grouped = (
+ df.pivot_table(
+ index="verification_key",
+ columns="inference_key",
+ values="passed",
+ aggfunc="mean",
+ )
+ * 100
+ ).astype(str) + "%"
+ df_count = df.pivot_table(
+ index="verification_key", columns="inference_key", values="count", aggfunc="sum"
+ )
+
+ return df, df_grouped, df_count
+
+
+def main():
+ # Print the thresholds being used
+ print(f"Using error thresholds:")
+ for metric, threshold in ERROR_THRESHOLDS.items():
+ print(f" {metric}: {threshold}")
+
+ # Compute the verification matrix
+ df_raw, df_grouped, df_count = compute_verification_matrix()
+
+ if df_raw is not None:
+ print("\nVerification Success Rates (% passed):")
+ print(tabulate(df_grouped, headers="keys", tablefmt="grid"))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/toploc-scripts/classifier_analysis_scripts/embedding_tests.ipynb b/toploc-scripts/classifier_analysis_scripts/embedding_tests.ipynb
new file mode 100644
index 00000000000..eb525b7b900
--- /dev/null
+++ b/toploc-scripts/classifier_analysis_scripts/embedding_tests.ipynb
@@ -0,0 +1,149 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# ALGO = \"cosine/all-MiniLM-L6-v2\"\n",
+ "# ALGO = \"cosine/all-mpnet-base-v2\"\n",
+ "# ALGO = \"cosine/distiluse-base-multilingual-cased-v1\"\n",
+ "# ALGO = \"cross-encoder/ms-marco-MiniLM-L12-v2\"\n",
+ "# ALGO = \"cross-encoder/stsb-roberta-large\"\n",
+ "# ALGO = \"dot/msmarco-bert-base-dot-v5\"\n",
+ "ALGO = \"nomic-ai/passage;passage;nomic-embed-text-v2-moe\"\n",
+ "# ALGO = \"gemini/gemini-embedding-exp-03-07\"\n",
+ "# ALGO = \"openai/text-embedding-3-large\"\n",
+ "\n",
+ "# os.system(\"pip install --upgrade google-genai\")\n",
+ "# os.system(\"pip install -U cohere\")\n",
+ "\n",
+ "TEST = \"DETECT_QUANTIZATION_SPOOF\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os, json\n",
+ "import pandas as pd\n",
+ "import matplotlib.pyplot as plt\n",
+ "from IPython.display import Markdown, display_markdown\n",
+ "from perform_embedding_based_replication_v1 import (\n",
+ " compute_scores,\n",
+ " make_sim_callback,\n",
+ " create_data_subsets,\n",
+ " plot_roc_curve,\n",
+ " plot_score_histogram,\n",
+ " calculate_roc_metrics,\n",
+ " write_summary,\n",
+ " make_batch_sim_callback,\n",
+ " load_cache,\n",
+ " dump_cache,\n",
+ ")\n",
+ "import numpy as np\n",
+ "\n",
+ "from argparse import Namespace\n",
+ "\n",
+ "load_cache()\n",
+ "N = None\n",
+ "sim_callback = make_sim_callback(ALGO)\n",
+ "batch_sim_callback = make_batch_sim_callback(ALGO)\n",
+ "\n",
+ "try:\n",
+ " df = []\n",
+ " replications_dir = os.path.join(\"..\", \"replications\")\n",
+ " for filename in os.listdir(replications_dir):\n",
+ " filepath = os.path.join(replications_dir, filename)\n",
+ " data = json.load(open(filepath, \"r\"))\n",
+ " compute_scores(N, sim_callback, batch_sim_callback, data, df)\n",
+ " df = pd.DataFrame(df)\n",
+ " df[\"count\"] = 1\n",
+ "except Exception as e:\n",
+ " print(e)\n",
+ " pass\n",
+ "dump_cache()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# \"meta-llama/Llama-3.1-8B-Instruct\" \"context-labs/neuralmagic-llama-3.1-8b-instruct-FP8\" \"meta-llama/Llama-3.2-3B-Instruct\"\n",
+ "\n",
+ "_3_1_8B_FP16 = \"meta-llama/Llama-3.1-8B-Instruct\"\n",
+ "_3_1_8B_FP8 = \"context-labs/neuralmagic-llama-3.1-8b-instruct-FP8\"\n",
+ "_3_2_3B_MODEL = \"meta-llama/Llama-3.2-3B-Instruct\"\n",
+ "\n",
+ "DIFFERENT_MACHINE_TEST = (df[\"inference_machine\"] == \"4090\") & (\n",
+ " df[\"replication_machine\"] == \"3090\"\n",
+ ") #\n",
+ "DIFFERENT_QUANTIZATION_TEST = df[\"original_model\"].isin(\n",
+ " [_3_1_8B_FP16, _3_1_8B_FP8]\n",
+ ") & df[\"replication_model\"].isin([_3_1_8B_FP16])\n",
+ "DIFFERENT_MODEL_TEST = (df[\"original_model\"].isin([_3_1_8B_FP16, _3_2_3B_MODEL])) & (\n",
+ " df[\"replication_model\"].isin([_3_1_8B_FP16])\n",
+ ")\n",
+ "\n",
+ "if TEST == \"DETECT_QUANTIZATION_SPOOF\":\n",
+ " name = \"Detect Quantization Spoofing\"\n",
+ " subset_df = df[DIFFERENT_MACHINE_TEST & DIFFERENT_QUANTIZATION_TEST]\n",
+ "elif TEST == \"DETECT_MODEL_SPOOF\":\n",
+ " name = \"Detect Model Spoofing\"\n",
+ " subset_df = df[DIFFERENT_MACHINE_TEST & DIFFERENT_MODEL_TEST]\n",
+ "else:\n",
+ " raise Exception(f\"Unknown TEST: {TEST}\")\n",
+ "\n",
+ "\n",
+ "import importlib\n",
+ "import perform_embedding_based_replication_v1 as mylib\n",
+ "\n",
+ "importlib.reload(mylib)\n",
+ "\n",
+ "display(Markdown(f\"{name} / {len(subset_df)}\"))\n",
+ "score_column = \"similarity\"\n",
+ "\n",
+ "selected_threshold = None\n",
+ "\n",
+ "fig, ax = plt.subplots()\n",
+ "mylib.plot_score_histogram(\n",
+ " subset_df, score_column, name, ax=ax, selected_threshold=selected_threshold\n",
+ ")\n",
+ "plt.show()\n",
+ "\n",
+ "# mylib.plot_roc_curve(df, score_column)\n",
+ "# plt.show()\n",
+ "\n",
+ "summary = mylib.write_summary(subset_df, score_column)\n",
+ "print(summary)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/toploc-scripts/classifier_analysis_scripts/perform_embedding_based_replication_v1.py b/toploc-scripts/classifier_analysis_scripts/perform_embedding_based_replication_v1.py
new file mode 100644
index 00000000000..4c6df1e9140
--- /dev/null
+++ b/toploc-scripts/classifier_analysis_scripts/perform_embedding_based_replication_v1.py
@@ -0,0 +1,561 @@
+import argparse
+import hashlib
+import json
+import os
+import time
+from typing import Any, Dict, List
+
+import numpy as np
+import openai
+import pandas as pd
+import seaborn as sns
+import torch
+import torch.nn as nn
+from dotenv import load_dotenv
+from google import genai
+from google.genai import types
+from matplotlib import pyplot as plt
+from sentence_transformers import CrossEncoder, SentenceTransformer
+
+# from sklearn.metrics.pairwise import cosine_similarity
+from sentence_transformers.util import dot_score, pairwise_cos_sim
+from sklearn.metrics import (
+ ConfusionMatrixDisplay,
+ RocCurveDisplay,
+ confusion_matrix,
+ roc_auc_score,
+ roc_curve,
+)
+from tabulate import tabulate
+from tqdm import tqdm
+
+load_dotenv()
+
+BATCH_SIZE = 5
+GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
+OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
+
+SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
+ROOT_DIR = os.path.dirname(SCRIPT_DIR)
+OUTPUT_DIR = os.path.abspath(
+ os.path.join(ROOT_DIR, "classifier_analysis_results", "embeddings_v1")
+)
+
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--N", type=int, required=False, default=None)
+ parser.add_argument("--sim-method", type=str, required=True)
+ return parser.parse_args()
+
+
+def hashlib_hash(args):
+ return hashlib.sha256(str(args).encode()).hexdigest()
+
+
+def compute_scores(N, sim_callback, batch_sim_callback, data, df):
+ N = N or len(data)
+ if batch_sim_callback is not None:
+ truncated_data = data[:N]
+ batch = []
+ for i, item in enumerate(tqdm(truncated_data)):
+ if len(batch) < BATCH_SIZE:
+ batch.append(item)
+ if len(batch) == BATCH_SIZE or i == len(truncated_data) - 1:
+ orig_responses = [
+ item["original_response"]["choices"][0]["message"]["content"]
+ for item in batch
+ ]
+ repl_responses = [
+ item["replication_response"]["choices"][0]["message"]["content"]
+ for item in batch
+ ]
+ similarities = batch_sim_callback(orig_responses, repl_responses)
+ for j, similarity in enumerate(similarities):
+ df.append(
+ {
+ "prompt": batch[j]["prompt"],
+ "original_response": orig_responses[j],
+ "replication_response": repl_responses[j],
+ "similarity": similarity,
+ "genuine": batch[j]["original_request"]["model"]
+ == batch[j]["replication_request"]["model"],
+ "label": (
+ "Genuine"
+ if batch[j]["original_request"]["model"]
+ == batch[j]["replication_request"]["model"]
+ else "Spoof"
+ ),
+ "inference_machine": batch[j]["inference_machine"],
+ "replication_machine": batch[j]["replication_machine"],
+ "original_model": batch[j]["original_request"]["model"],
+ "replication_model": batch[j]["replication_request"][
+ "model"
+ ],
+ }
+ )
+ batch = []
+
+ else:
+ for item in tqdm(data[:N]):
+
+ orig_model = item["original_request"]["model"]
+ repl_model = item["replication_request"]["model"]
+
+ orig_response = item["original_response"]["choices"][0]["message"][
+ "content"
+ ]
+
+ repl_response = item["replication_response"]["choices"][0]["message"][
+ "content"
+ ]
+
+ similarity = sim_callback(orig_response, repl_response)
+
+ df.append(
+ {
+ "prompt": item["prompt"],
+ "original_response": orig_response,
+ "replication_response": repl_response,
+ "similarity": similarity,
+ "genuine": orig_model == repl_model,
+ "label": "Genuine" if orig_model == repl_model else "Spoof",
+ "inference_machine": item["inference_machine"],
+ "replication_machine": item["replication_machine"],
+ "original_model": orig_model,
+ "replication_model": repl_model,
+ }
+ )
+
+
+def cosine_similarity_callback(model):
+ def callback(orig_response, repl_response):
+ orig_embedding = model.encode([orig_response]) # -> [1, D]
+ repl_embedding = model.encode([repl_response]) # -> [1, D]
+ similarity = pairwise_cos_sim(
+ orig_embedding,
+ repl_embedding,
+ )
+ similarity = similarity.numpy()
+ assert np.prod(similarity.shape) == 1
+ return float(similarity[0])
+
+ return callback
+
+
+def cross_encoder_similarity_callback(model):
+ def callback(orig_response, repl_response):
+ similarity = model.predict([orig_response, repl_response])
+ assert np.prod(similarity.shape) == 1
+ return float(similarity)
+
+ return callback
+
+
+def dot_similarity_callback(model):
+ def callback(orig_response, repl_response):
+ orig_embedding = model.encode([orig_response]) # -> [1, D]
+ repl_embedding = model.encode([repl_response]) # -> [1, D]
+ similarity = dot_score(orig_embedding, repl_embedding)
+ similarity = similarity.numpy()
+ assert np.prod(similarity.shape) == 1
+ return float(similarity[0])
+
+ return callback
+
+
+def nomic_cosine_similarity_callback(model, kind1, kind2):
+ def callback(orig_response, repl_response):
+ orig_embedding = model.encode([orig_response], prompt_name=kind1)
+ repl_embedding = model.encode([repl_response], prompt_name=kind2)
+ similarity = model.similarity(orig_embedding[0], repl_embedding[0])
+ similarity = similarity.numpy()
+ return float(similarity[0][0])
+
+ return callback
+
+
+CLIENTS = {}
+CACHE = {}
+
+
+def dump_cache():
+ with open("cache.json", "w") as f:
+ json.dump(CACHE, f)
+
+
+def load_cache():
+ global CACHE
+ if os.path.exists("cache.json"):
+ with open("cache.json", "r") as f:
+ CACHE = json.load(f)
+
+
+def gemini_batch_similarity_callback(model):
+ def callback(orig_responses, repl_responses):
+ if "gemini" not in CLIENTS:
+ CLIENTS["gemini"] = genai.Client(api_key=GEMINI_API_KEY)
+ client = CLIENTS["gemini"]
+
+ cached_orig_responses, cached_repl_responses = [], []
+ non_cached_orig_responses, non_cached_repl_responses = [], []
+
+ for orig_response, repl_response in zip(orig_responses, repl_responses):
+ key = hashlib_hash(("gemini", model, orig_response, repl_response))
+ if key not in CACHE:
+ non_cached_orig_responses.append(orig_response)
+ non_cached_repl_responses.append(repl_response)
+ else:
+ cached_orig_responses.append(orig_response)
+ cached_repl_responses.append(repl_response)
+
+ if len(non_cached_orig_responses) == 0:
+ embeddings = []
+ else:
+ print("Invoking API with batch of ", len(non_cached_orig_responses))
+ embeddings = client.models.embed_content(
+ model=model,
+ contents=non_cached_orig_responses + non_cached_repl_responses,
+ config=types.EmbedContentConfig(task_type="SEMANTIC_SIMILARITY"),
+ ).embeddings
+ embeddings = [r.values for r in embeddings]
+
+ print("len(embeddings)", len(embeddings))
+
+ sims = []
+ for i in range(len(non_cached_orig_responses)):
+ sim = pairwise_cos_sim(
+ [embeddings[i]], [embeddings[i + len(non_cached_orig_responses)]]
+ )
+ sim = float(sim[0])
+ sims.append(sim)
+ key = hashlib_hash(
+ (
+ "gemini",
+ model,
+ non_cached_orig_responses[i],
+ non_cached_repl_responses[i],
+ )
+ )
+
+ CACHE[key] = sim
+ for i in range(len(cached_orig_responses)):
+ key = hashlib_hash(
+ ("gemini", model, cached_orig_responses[i], cached_repl_responses[i])
+ )
+ sims.append(CACHE[key])
+
+ if len(non_cached_orig_responses) > 0:
+ time.sleep(20)
+
+ dump_cache()
+
+ return sims
+
+ return callback
+
+
+def gemini_similarity_callback(model):
+ def callback(orig_response, repl_response):
+ if "gemini" not in CLIENTS:
+ CLIENTS["gemini"] = genai.Client(api_key=GEMINI_API_KEY)
+ client = CLIENTS["gemini"]
+
+ embeddings = client.models.embed_content(
+ model=model,
+ contents=[orig_response, repl_response],
+ config=types.EmbedContentConfig(task_type="SEMANTIC_SIMILARITY"),
+ ).embeddings
+
+ emb1, emb2 = ([r.values] for r in embeddings) # each r is types.Embedding
+
+ key = hashlib_hash(("gemini", model, orig_response, repl_response))
+
+ if key not in CACHE:
+ CACHE[key] = [emb1, emb2]
+ else:
+ emb1, emb2 = CACHE[key]
+
+ emb1, emb2 = np.asarray(emb1, dtype=float), np.asarray(emb2, dtype=float)
+
+ print("emb1", emb1)
+ print("emb2", emb2)
+ similarity = pairwise_cos_sim(emb1, emb2)
+ similarity = similarity.numpy()
+ assert np.prod(similarity.shape) == 1
+ return float(similarity[0])
+
+ return callback
+
+
+def openai_similarity_callback(model):
+ if "openai" not in CLIENTS:
+ CLIENTS["openai"] = openai.Client(api_key=OPENAI_API_KEY)
+
+ def callback(orig_response, repl_response):
+ client = CLIENTS["openai"]
+ key = hashlib_hash(("openai", model, orig_response, repl_response))
+ if key in CACHE:
+ return CACHE[key]
+ embeddings = client.embeddings.create(
+ input=[orig_response, repl_response], model=model
+ )
+ emb1_data, emb2_data = embeddings.data
+ emb1 = np.asarray([emb1_data.embedding], dtype=float)
+ emb2 = np.asarray([emb2_data.embedding], dtype=float)
+ similarity = pairwise_cos_sim(emb1, emb2)
+ similarity = similarity.numpy()
+ assert np.prod(similarity.shape) == 1
+ similarity_val = float(similarity[0])
+ if key not in CACHE:
+ CACHE[key] = similarity_val
+ return similarity_val
+
+ return callback
+
+
+def make_batch_sim_callback(sim_method: str):
+ family, model = sim_method.split("/")
+ if family == "gemini":
+ return gemini_batch_similarity_callback(model)
+ else:
+ return None
+
+
+def make_sim_callback(sim_method: str):
+ family, model = sim_method.split("/")
+ if family == "cosine":
+ model = SentenceTransformer(f"sentence-transformers/{model}")
+ return cosine_similarity_callback(model)
+ elif family == "cross-encoder":
+ model = CrossEncoder(f"cross-encoder/{model}")
+ return cross_encoder_similarity_callback(model)
+ elif family == "dot":
+ model = SentenceTransformer(f"sentence-transformers/{model}")
+ return dot_similarity_callback(model)
+ elif family == "nomic-ai":
+ kind1, kind2, model_name = model.split(";")
+ model = SentenceTransformer(f"nomic-ai/{model_name}", trust_remote_code=True)
+ return nomic_cosine_similarity_callback(model, kind1, kind2)
+ elif family == "gemini":
+ return gemini_similarity_callback(model)
+ elif family == "openai":
+ return openai_similarity_callback(model)
+ else:
+ raise ValueError(f"Invalid sim_method: {sim_method}")
+
+
+def analyze_results(args, name, df, score_column, write_to_disk=True):
+ """Main analysis function that calls individual analysis components."""
+ out_dir = os.path.join(OUTPUT_DIR, args.sim_method.replace("/", "_"), name)
+ os.makedirs(out_dir, exist_ok=True)
+
+ # Plot and save individual visualizations
+ roc_display = plot_roc_curve(df, score_column)
+ plt.savefig(os.path.join(out_dir, f"roc_curve.png"))
+ roc_display.plot()
+
+ plot_score_histogram(df, score_column, name)
+ plt.savefig(os.path.join(out_dir, f"score_hist.png"))
+ plt.show()
+
+ # Generate and save summary
+ summary = generate_summary(df, score_column)
+ print(name)
+ print(summary)
+
+ make_summary_page(out_dir, summary)
+
+
+def write_summary(df, score_column):
+ y_true = df["genuine"]
+ y_score = df[score_column]
+ roc_data = calculate_roc_metrics(y_true, y_score)
+ best_threshold = find_best_threshold(roc_data)
+ summary = generate_summary(df, score_column)
+ return summary
+
+
+def calculate_roc_metrics(y_true, y_score):
+ """Calculate ROC curve metrics."""
+ roc_auc = roc_auc_score(y_true, y_score)
+ fpr, tpr, threshes = roc_curve(y_true, y_score)
+
+ return {"roc_auc": roc_auc, "fpr": fpr, "tpr": tpr, "threshes": threshes}
+
+
+def plot_roc_curve(df, score_column):
+
+ # y_true = df["genuine"]
+ y_true = ~df["genuine"]
+ y_score = 1 - df[score_column]
+
+ roc_data = calculate_roc_metrics(y_true, y_score)
+ raw_best_threshold = find_best_threshold(roc_data)
+ best_threshold = 1 - raw_best_threshold
+
+ best_idx = np.argmin(np.abs(roc_data["threshes"] - raw_best_threshold))
+ thresh_x, thresh_y = roc_data["fpr"][best_idx], roc_data["tpr"][best_idx]
+
+ fig, ax = plt.subplots()
+ ax.plot(roc_data["fpr"], roc_data["tpr"])
+ ax.set_xlim(0, 1)
+ ax.set_ylim(0, 1)
+ ax.set_aspect("equal")
+ ax.set_xlabel("False Positive Rate")
+ ax.set_ylabel("True Positive Rate")
+ ax.set_title(f"ROC Curve (AUC = {roc_data['roc_auc']:.2f})")
+ ax.plot([0, 1], [0, 1], color="r", linestyle="--")
+ ax.plot(thresh_x, thresh_y, "o", color="r", markersize=10)
+
+ ax.legend()
+
+
+def plot_score_histogram(df, score_column, name, ax=None, selected_threshold=None):
+ """Plot and save the histogram of similarity scores."""
+
+ if ax is None:
+ _, ax = plt.subplots()
+ else:
+ pass
+
+ hplot = sns.histplot(
+ df, x=score_column, hue="label", ax=ax, element="step", stat="count"
+ )
+
+ roc_data = calculate_roc_metrics(df["genuine"], df[score_column])
+ thresh = selected_threshold or find_best_threshold(roc_data)
+
+ hplot.axvline(x=thresh, color="r", linestyle="--", label="Threshold")
+ if score_column == "similarity":
+ title = f"Cosine Similarity Scores ({name.replace('_', ' ').title()})"
+ else:
+ title = f"{name.replace('_', ' ').title()} {score_column}"
+ hplot.set_title(title)
+
+
+def find_best_threshold(roc_data):
+ """Find the best threshold based on ROC curve data."""
+ fpr = roc_data["fpr"]
+ tpr = roc_data["tpr"]
+ threshes = roc_data["threshes"]
+ f1_scores = 2 * (tpr * (1 - fpr)) / ((tpr + (1 - fpr)) + 1e-10)
+
+ best_thresh_idx = np.argmax(f1_scores)
+ thresh = threshes[best_thresh_idx]
+
+ return thresh
+
+
+def generate_summary(df, score_column):
+ """Generate a text summary of the classification results."""
+
+ y_true = ~df["genuine"]
+ y_score = 1 - df[score_column]
+
+ roc_data = calculate_roc_metrics(y_true, y_score)
+ best_threshold = find_best_threshold(roc_data)
+
+ fpr = roc_data["fpr"]
+ tpr = roc_data["tpr"]
+ threshes = roc_data["threshes"]
+
+ best_thresh_idx = np.argmin(np.abs(best_threshold - threshes))
+ thresh = threshes[best_thresh_idx]
+
+ selected_fpr = fpr[best_thresh_idx]
+ selected_fnr = 1 - tpr[best_thresh_idx]
+ roc_auc = roc_data["roc_auc"]
+
+ cm = confusion_matrix(
+ (df["label"]),
+ [
+ "Spoof" if score > (1 - best_threshold) else "Genuine"
+ for score in df[score_column]
+ ],
+ )
+
+ return f"""
+False Positive Rate (0% is best): {100*selected_fpr:.2f}%
+False Negative Rate (0% is best): {100*selected_fnr:.2f}%
+Best threshold: {1-best_threshold:.4f}
+ROC AUC (0.5 is no better than random): {roc_auc:.4f}
+Confusion Matrix:
+{tabulate(cm, tablefmt="plain")}
+"""
+
+
+def make_summary_page(out_dir, summary):
+ with open(os.path.join(out_dir, f"summary.md"), "w") as f:
+ markdown_content = f"""# Classification Results Summary
+
+## Summary Statistics
+```
+{summary}
+```
+
+## ROC Curve
+
+
+## Score Distribution
+
+
+## Confusion Matrix
+
+"""
+ f.write(markdown_content)
+
+
+def get_filenames():
+ return os.listdir(os.path.join(ROOT_DIR, "replications"))
+
+
+def create_data_subsets(df):
+
+ _3_1_8b_quantizations = [
+ "meta-llama/Llama-3.1-8B-Instruct",
+ "context-labs/neuralmagic-llama-3.1-8b-instruct-FP8",
+ ]
+ is_3_1_8b = df["original_model"].isin(_3_1_8b_quantizations) & df[
+ "replication_model"
+ ].isin(_3_1_8b_quantizations)
+
+ dfs = {
+ "same_machine_quantization_test": df[
+ (df["inference_machine"] == df["replication_machine"]) & is_3_1_8b
+ ],
+ "different_machine_quantization_test": df[
+ (df["inference_machine"] != df["replication_machine"]) & is_3_1_8b
+ ],
+ "same_machine_non_quantization_test": df[
+ (df["inference_machine"] == df["replication_machine"])
+ ],
+ "different_machine_non_quantization_test": df[
+ (df["inference_machine"] != df["replication_machine"])
+ ],
+ "all_data": df,
+ }
+
+ return dfs
+
+
+def main():
+ args = parse_args()
+ sim_callback = make_sim_callback(args.sim_method)
+ batch_sim_callback = make_batch_sim_callback(args.sim_method)
+
+ df = []
+ filenames = get_filenames()
+ for filename in filenames:
+ filepath = os.path.join(ROOT_DIR, "replications", filename)
+ with open(filepath, "r") as f:
+ data = json.load(f)
+ compute_scores(args, sim_callback, batch_sim_callback, data, df)
+ df = pd.DataFrame(df)
+ data_subsets = create_data_subsets(df)
+ for name, df in data_subsets.items():
+ analyze_results(args, name, df, "similarity")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/toploc-scripts/data_collection_scripts/collect-all-inferences-to-replicate.sh b/toploc-scripts/data_collection_scripts/collect-all-inferences-to-replicate.sh
new file mode 100755
index 00000000000..9b240afea22
--- /dev/null
+++ b/toploc-scripts/data_collection_scripts/collect-all-inferences-to-replicate.sh
@@ -0,0 +1,56 @@
+#!/bin/bash
+
+# Check if machine name is provided
+if [ $# -ne 1 ] && [ $# -ne 2 ]; then
+ echo "Usage: $0 [temperature]"
+ exit 1
+fi
+
+source ../sglang-clean/.venv/bin/activate
+pip install dotenv
+pip install huggingface-hub
+pip install tabulate
+
+MACHINE=$1
+
+if [ -z "$MACHINE" ]; then
+ echo "Error: Machine name is empty"
+ exit 1
+fi
+
+# Check if temperature is provided as second argument
+if [ $# -eq 2 ]; then
+ TEMPERATURE=$2
+else
+ TEMPERATURE=1.0
+fi
+
+echo "Using temperature: $TEMPERATURE"
+
+
+
+# Array of models to process
+# "meta-llama/Llama-3.1-8B-Instruct;fp8"
+MODELS=("meta-llama/Llama-3.1-8B-Instruct" "context-labs/neuralmagic-llama-3.1-8b-instruct-FP8" "meta-llama/Llama-3.2-3B-Instruct")
+
+if [ ! -d "toploc-scripts/inferences_to_replicate" ]; then
+ mkdir -p toploc-scripts/inferences_to_replicate
+fi
+
+for MODEL in "${MODELS[@]}"; do
+ # Sanitize model name for filename
+ SANITIZED_MODEL=$(echo "$MODEL" | tr '/' '_' | tr ' ' '_')
+
+ echo "Processing model: $MODEL"
+ OUTPUT_FILENAME="train0_${MACHINE}_${SANITIZED_MODEL}.inference"
+ if [ -f "toploc-scripts/inferences_to_replicate/${OUTPUT_FILENAME}" ]; then
+ echo "Output file already exists: ${OUTPUT_FILENAME}"
+ continue
+ fi
+ python toploc-scripts/data_collection_scripts/collect_inferences_to_replicate.py --N 100 --machine "$MACHINE" --model "$MODEL" --temperature "$TEMPERATURE" --output_filename "${OUTPUT_FILENAME}" --disable-cuda-graph
+
+ # Optional: add a small delay between runs
+ sleep 2
+done
+
+echo "All models processed successfully!"
diff --git a/toploc-scripts/data_collection_scripts/collect-all-logprobs.sh b/toploc-scripts/data_collection_scripts/collect-all-logprobs.sh
new file mode 100755
index 00000000000..ff415c9e077
--- /dev/null
+++ b/toploc-scripts/data_collection_scripts/collect-all-logprobs.sh
@@ -0,0 +1,59 @@
+#!/bin/bash
+
+source ../sglang-clean/.venv/bin/activate
+
+
+MACHINE=$1
+
+if [ -z "$MACHINE" ]; then
+ echo "Error: Machine name is empty"
+ exit 1
+fi
+
+if [ ! -d "toploc-scripts/logprobs" ]; then
+ mkdir -p toploc-scripts/logprobs
+fi
+
+MODELS=("meta-llama/Llama-3.1-8B-Instruct" "context-labs/neuralmagic-llama-3.1-8b-instruct-FP8" "meta-llama/Llama-3.2-3B-Instruct")
+
+for inference_filepath in toploc-scripts/inferences_to_replicate/*.inference; do
+
+ filename=$(basename "$inference_filepath")
+ filename_no_ext=${filename%.inference}
+
+ for MODEL in "${MODELS[@]}"; do
+ # Sanitize model name for filename
+ SANITIZED_MODEL=$(echo "$MODEL" | tr '/' '_' | tr ' ' '_')
+
+ OUTPUT_FILENAME="logprobs_${MACHINE}_${SANITIZED_MODEL}_for_${filename_no_ext}.logprob"
+
+ if [ -f "toploc-scripts/logprobs/${OUTPUT_FILENAME}" ]; then
+ echo "Output file already exists: ${OUTPUT_FILENAME}"
+ continue
+ fi
+
+ echo "Processing model: $MODEL"
+
+ cmd="python toploc-scripts/data_collection_scripts/collect_logprobs.py --N 1 --machine \"$MACHINE\" --model \"$MODEL\" --input-file \"$filename\" --output-file \"$OUTPUT_FILENAME\" --disable-cuda-graph --debugging"
+ # Add command to array for later printing
+ commands+=("$cmd")
+
+
+ echo "Running: $cmd"
+ eval "$cmd"
+
+ # Optional: add a small delay between runs
+ sleep 2
+
+ break
+
+ done
+
+ break
+
+done
+
+# Print all commands
+for cmd in "${commands[@]}"; do
+ echo "$cmd"
+done
diff --git a/toploc-scripts/data_collection_scripts/collect-all-toploc-fingerprints.sh b/toploc-scripts/data_collection_scripts/collect-all-toploc-fingerprints.sh
new file mode 100755
index 00000000000..9e9c6b83af0
--- /dev/null
+++ b/toploc-scripts/data_collection_scripts/collect-all-toploc-fingerprints.sh
@@ -0,0 +1,27 @@
+#!/bin/bash
+
+# Check if machine name is provided
+if [ $# -ne 1 ]; then
+ echo "Usage: $0 "
+ exit 1
+fi
+
+source .venv/bin/activate
+
+MACHINE=$1
+
+# Array of models to process
+MODELS=("meta-llama/Llama-3.1-8B-Instruct" "meta-llama/Llama-3.1-8B-Instruct;fp8" "meta-llama/Llama-3.2-3B-Instruct")
+
+for MODEL in "${MODELS[@]}"; do
+ # Sanitize model name for filename
+ SANITIZED_MODEL=$(echo "$MODEL" | tr '/' '_' | tr ' ' '_')
+
+ echo "Processing model: $MODEL"
+ python toploc-scripts/data_collection_scripts/collect_toploc_fingerprints.py --N 100 --machine "$MACHINE" --model "$MODEL" --output_filename "train0_${MACHINE}_${SANITIZED_MODEL}.fingerprint" --disable-cuda-graph
+
+ # Optional: add a small delay between runs
+ sleep 2
+done
+
+echo "All models processed successfully!"
diff --git a/toploc-scripts/data_collection_scripts/collect_inferences_to_replicate.py b/toploc-scripts/data_collection_scripts/collect_inferences_to_replicate.py
new file mode 100644
index 00000000000..5161dd19bdc
--- /dev/null
+++ b/toploc-scripts/data_collection_scripts/collect_inferences_to_replicate.py
@@ -0,0 +1,178 @@
+import json
+import os
+import signal
+import subprocess
+import sys
+import time
+from argparse import ArgumentParser
+
+import openai
+from dotenv import load_dotenv
+from tqdm import tqdm
+
+from sglang.utils import (
+ launch_server_cmd,
+ print_highlight,
+ terminate_process,
+ wait_for_server,
+)
+
+load_dotenv()
+
+if not os.getenv("HF_TOKEN"):
+ raise ValueError("HF_TOKEN not found in environment variables")
+ sys.exit(1)
+
+SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
+ROOT_DIR = os.path.dirname(SCRIPT_DIR)
+
+
+def parse_args():
+ parser = ArgumentParser()
+ parser.add_argument("--machine", type=str, required=True, help="Machine name")
+ parser.add_argument(
+ "--model",
+ type=str,
+ required=True,
+ help="Model to use",
+ )
+ parser.add_argument(
+ "--ultrachat_file", type=str, default="train_0.jsonl", help="ultrachat filename"
+ )
+ parser.add_argument(
+ "--N", type=int, required=True, help="Number of requests to process"
+ )
+ parser.add_argument(
+ "--disable-cuda-graph", action="store_true", help="Disable CUDA graph"
+ )
+ parser.add_argument(
+ "--seed",
+ type=int,
+ required=False,
+ default=42,
+ help="Random seed for sampling and generation",
+ )
+ parser.add_argument(
+ "--temperature",
+ type=float,
+ required=False,
+ default=0.0,
+ help="Temperature for sampling",
+ )
+ parser.add_argument(
+ "--output_filename", type=str, default=None, help="Output filename"
+ )
+ parser.add_argument("--quiet", action="store_true", help="Run in quiet mode")
+ return parser.parse_args()
+
+
+def kill_gpu_processes():
+ cmd = "nvidia-smi --query-compute-apps=pid --format=csv,noheader"
+ result = subprocess.run(cmd, shell=True, check=True, stdout=subprocess.PIPE)
+ pids = [line.decode().strip() for line in result.stdout.splitlines()]
+ for pid in pids:
+ print(f"Killing process {pid}")
+ os.kill(int(pid), signal.SIGKILL)
+
+
+def start_server(args):
+ """
+ Start the SGL server with TopLoc server.
+ """
+
+ print("Starting server with TopLoc server...")
+ if args.quiet:
+ MAYBE_NOISY = ""
+ else:
+ MAYBE_NOISY = "--log-level debug"
+
+ if args.disable_cuda_graph:
+ MAYBE_DISABLE_CUDA_GRAPH = "--disable-cuda-graph"
+ else:
+ MAYBE_DISABLE_CUDA_GRAPH = ""
+
+ print(f"Starting server with model {args.model}...")
+
+ model, *quantization = args.model.split(";")
+ if quantization:
+ quantization = quantization[0]
+ print(f"Quantization: {quantization}")
+ MAYBE_QUANTIZATION = f"--quantization {quantization}"
+ else:
+ MAYBE_QUANTIZATION = ""
+
+ server_process, port = launch_server_cmd(
+ f"""
+ python -m sglang.launch_server --model-path {model} {MAYBE_QUANTIZATION} --host 0.0.0.0 {MAYBE_NOISY} {MAYBE_DISABLE_CUDA_GRAPH}
+ """
+ )
+
+ print(f"Starting on port {port}...")
+
+ # Wait for the server to start
+ wait_for_server(f"http://localhost:{port}")
+
+ # Add additional delay to ensure server is fully initialized
+ print("Waiting 3 more seconds for server to be fully initialized...")
+ time.sleep(3)
+
+ return server_process, port
+
+
+def collect_N_inferences(port, args):
+ client = openai.Client(base_url=f"http://127.0.0.1:{port}/v1", api_key="None")
+ inferences = []
+ ultrachat_filepath = os.path.join(ROOT_DIR, "ultrachat", args.ultrachat_file)
+
+ with open(ultrachat_filepath, "r") as f:
+ for i, line in enumerate(tqdm(f)):
+ if i >= args.N:
+ break
+
+ data = json.loads(line)
+ prompt = data["data"][0] # Assuming the first element is the user prompt
+
+ request = dict(
+ model=args.model,
+ messages=[
+ {"role": "user", "content": prompt},
+ ],
+ temperature=args.temperature,
+ seed=args.seed,
+ )
+
+ response = client.chat.completions.create(**request)
+ response_dump = response.model_dump()
+
+ inferences.append(
+ {
+ "machine": args.machine,
+ "prompt": prompt,
+ "complete_request": request,
+ "complete_response": response_dump,
+ "model": args.model,
+ }
+ )
+
+ return inferences
+
+
+def write_to_file(args, inferences):
+ if args.output_filename is None:
+ ultrachat_no_ext = os.path.splitext(args.ultrachat_file)[0]
+ args.output_filename = args.model.replace("/", "_") + "_for_" + ultrachat_no_ext
+ inferences_dir = os.path.join(ROOT_DIR, "inferences_to_replicate")
+ os.makedirs(inferences_dir, exist_ok=True)
+ output_filepath = os.path.join(inferences_dir, args.output_filename)
+ with open(output_filepath, "w") as f:
+ json.dump(inferences, f, indent=4)
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ kill_gpu_processes()
+ server_process, port = start_server(args)
+ inferences = collect_N_inferences(port, args)
+ write_to_file(args, inferences)
+ server_process.terminate()
+ print("Server terminated.")
diff --git a/toploc-scripts/data_collection_scripts/collect_logprobs.py b/toploc-scripts/data_collection_scripts/collect_logprobs.py
new file mode 100644
index 00000000000..4f0837f7fd9
--- /dev/null
+++ b/toploc-scripts/data_collection_scripts/collect_logprobs.py
@@ -0,0 +1,682 @@
+import argparse
+import asyncio
+import json
+import logging
+import os
+import signal
+import subprocess
+import sys
+import time
+from typing import List, Tuple
+
+import numpy as np
+import openai
+import torch
+from dotenv import load_dotenv
+from scipy.sparse import coo_matrix, csr_matrix, save_npz
+from scipy.stats import chi2
+from tqdm import tqdm
+
+from sglang.srt.entrypoints.engine import _launch_subprocesses
+from sglang.srt.managers.tokenizer_manager import TokenizerManager
+from sglang.srt.openai_api.adapter import v1_chat_generate_request
+from sglang.srt.openai_api.protocol import ChatCompletionRequest
+from sglang.srt.server_args import PortArgs, ServerArgs
+from sglang.utils import (
+ kill_process_tree,
+ launch_server_cmd,
+ print_highlight,
+ terminate_process,
+ wait_for_server,
+)
+
+load_dotenv()
+
+# Import the API for performing inferences. Adjust the import if necessary for your codebase.
+import sglang.api
+
+SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
+ROOT_DIR = os.path.dirname(SCRIPT_DIR)
+
+
+_global_state = {}
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Collect logprobs for inferences.")
+ parser.add_argument("--model", type=str, required=True, help="Model to use")
+ parser.add_argument("--input-file", type=str, required=True, help="Input file")
+ parser.add_argument("--machine", type=str, required=True, help="Machine name")
+ parser.add_argument("--output-file", type=str, required=True, help="Output file")
+ parser.add_argument("--disable-cuda-graph", action="store_true")
+ parser.add_argument("--interactive", action="store_true")
+ parser.add_argument("--quiet", action="store_true")
+ parser.add_argument("--debugging", action="store_true")
+ parser.add_argument(
+ "--N",
+ type=int,
+ default=None,
+ help="Number of inferences to collect logprobs for (all if not supplied)",
+ )
+ return parser.parse_args()
+
+
+def kill_gpu_processes():
+ cmd = "nvidia-smi --query-compute-apps=pid --format=csv,noheader"
+ result = subprocess.run(cmd, shell=True, check=True, stdout=subprocess.PIPE)
+ pids = [line.decode().strip() for line in result.stdout.splitlines()]
+ for pid in pids:
+ print(f"Killing process {pid}")
+ os.kill(int(pid), signal.SIGKILL)
+
+
+def load_inferences(args):
+ input_filename = args.input_file
+ print(f"Loading inferences from {input_filename}")
+ input_filepath = os.path.join(ROOT_DIR, "inferences_to_replicate", input_filename)
+ with open(input_filepath, "r") as f:
+ return json.load(f)
+
+
+async def collect_logprobs(port, args):
+ inferences = load_inferences(args)
+ all_logprobs_results = []
+
+ if args.N is None:
+ args.N = len(inferences)
+
+ # Create a tokenizer manager so that we can exactly replicate the openai API codepath for tokenization
+ model, *_ = args.model.split(";")
+ tokenizer_manager = _global_state["tokenizer_manager"]
+
+ for i, inference in enumerate(
+ tqdm(inferences[: args.N], desc="Collecting logprobs")
+ ):
+
+ # Create a request for the purpose of getting the token IDs of just the prompt
+ request = make_prompt_request(inference, model)
+ prompt_token_ids = get_token_ids(
+ tokenizer_manager, args.model, request, add_eos_id=False
+ )
+
+ # Create a pre-fill for the purpose of getting the token IDs of both the prompt & response
+ prefill_request = make_prefill_request(inference, model)
+ prompt_response_token_ids = get_token_ids(
+ tokenizer_manager, model, prefill_request, add_eos_id=True
+ )
+
+ # Now we can isolate the prompt IDs from the reponse IDs by removing the prompt ID prefix
+ assert is_prefix(
+ prompt_token_ids, prompt_response_token_ids
+ ), "Prefix check failed"
+
+ # Let's get the top 200 log-probs associated with the prompt+response prefill sequence
+ # (Getting the entire distribution via JSON is cost prohibitive - so we do this until we have a method to get directly from GPU)
+ top_logprob_num = 200
+ model, *_ = model.split(";")
+ logprobs_request = make_logprobs_request(
+ prefill_request, prompt_response_token_ids, top_logprob_num
+ )
+ ret = await get_top_logprobs_from_LLM(
+ model, logprobs_request, prompt_response_token_ids, top_logprob_num
+ )
+ M = gather_logprobs(ret, top_logprob_num)
+
+ input_token_logprobs_from_ret = [
+ item[1] for item in ret["meta_info"]["input_token_logprobs"][1:]
+ ]
+
+ # Get the response token IDs
+ response_token_ids = prompt_response_token_ids[len(prompt_token_ids) + 1 :]
+ M_response = M[len(prompt_token_ids) :,]
+ assert M_response.shape[0] == len(
+ response_token_ids
+ ), f"{M_response.shape[0]} != {len(response_token_ids)}"
+ assert (
+ input_token_logprobs_from_ret[len(prompt_token_ids) :] == response_token_ids
+ )
+
+ # Convert to probabilities
+ M_p = M_response.copy()
+ M_p.data = np.exp(M_p.data)
+
+ # Apply after-the-fact top-p renormalization
+ top_k = 1 if prefill_request["temperature"] == 0 else prefill_request["top_k"]
+ if top_k is not None and top_k != -1:
+ M_p = apply_top_k_renormalization(M_p, top_k)
+ top_logprob_num = top_k
+
+ # Apply after-the-fact top-k renormalization
+ # top_p = prefill_request.top_p
+ # if top_p is not None and top_p != -1:
+ # apply_top_p_renormalization(M_p, top_p)
+
+ statistics = collect_statistics_from_sparse_matrix(
+ M_p, response_token_ids, top_logprob_num
+ )
+
+ # Store the logprobs for this inference
+ logprobs_results = {
+ "inference_id": i,
+ **statistics,
+ "prompt": inference["complete_request"]["messages"],
+ "response": inference["complete_response"]["choices"][0]["message"][
+ "content"
+ ],
+ }
+
+ # Add to the collection
+ all_logprobs_results.append(logprobs_results)
+
+ return all_logprobs_results
+
+
+# Using a CSR matrix was, in retrospect, possibly a mistake?
+def apply_top_k_renormalization(M_p: csr_matrix, top_k: int):
+ assert isinstance(M_p, csr_matrix)
+ assert top_k > 0
+
+ # Create lists to hold the new data
+ new_data = []
+ new_indices = []
+ new_indptr = [0]
+
+ # Process each row
+ for i in range(M_p.shape[0]):
+ row_start = M_p.indptr[i]
+ row_end = M_p.indptr[i + 1]
+
+ # Get the values for this row
+ row_data = M_p.data[row_start:row_end]
+ row_indices = M_p.indices[row_start:row_end]
+
+ if len(row_data) > top_k:
+ # Find indices of elements in top-k
+ top_k_idx = np.argpartition(row_data, -top_k)[-top_k:]
+ # Keep only the top-k elements
+ keep_data = row_data[top_k_idx]
+ keep_indices = row_indices[top_k_idx]
+ else:
+ # If we have fewer than top_k elements, keep all
+ keep_data = row_data
+ keep_indices = row_indices
+
+ # renormalize (with a little bit of smoothing so that when p=1, we don't get infinite values)
+ epsilon = 1e-2 # This is almost completely adhoc
+ keep_data = keep_data / (np.sum(keep_data))
+ keep_data = keep_data * (1 - epsilon)
+
+ # Add kept elements to our new data
+ new_data.extend(keep_data)
+ new_indices.extend(keep_indices)
+ new_indptr.append(len(new_data))
+
+ # Create a new CSR matrix with only the kept elements
+ M_top_k = csr_matrix((new_data, new_indices, new_indptr), shape=M_p.shape)
+
+ return M_top_k
+
+
+async def get_top_logprobs_from_LLM(model, request, token_ids, top_logprob_num):
+ for key, value in request["extra_body"].items():
+ request[key] = value
+ del request["extra_body"]
+
+ assert (
+ request["max_tokens"] == 0
+ ), f"Not a prefill request, had max_tokens = {request['max_tokens']}"
+ assert (
+ request["logprobs"] == True
+ ), f"Not a logprobs request, had logprobs = {request['logprobs']}"
+ assert (
+ request["top_logprobs"] >= 0
+ ), f"Not a top_logprobs request, had top_logprobs = {request['top_logprobs']}"
+ assert (
+ request["logprob_start_len"] == 0
+ ), f"logprob_start_len must be GT 0, had logprob_start_len = {request['logprob_start_len']}"
+
+ # Use the same codepaths used when sending requests through the API (this is important for reproducibility)
+ all_requests = [ChatCompletionRequest(**request)]
+ tokenizer_manager = _global_state["tokenizer_manager"]
+ adapted_request, request = v1_chat_generate_request(all_requests, tokenizer_manager)
+ adapted_request.input_ids = [token_ids]
+ try:
+ ret = await tokenizer_manager.generate_request(adapted_request).__anext__()
+ except ValueError as e:
+ raise e
+ if isinstance(ret, list):
+ ret = ret[0]
+ return ret
+
+
+def gather_logprobs(ret, top_logprob_num):
+
+ # Get the Top N logprobs for each token in the sequence into arrays of [T,N]
+
+ seq_logprobs, seq_token_ids = [], []
+ for top_tokens in tqdm(
+ ret["meta_info"]["input_top_logprobs"][1:], desc="Gathering logprobs"
+ ):
+ logprobs, token_ids = [], []
+ for nth_token in top_tokens:
+ logprobs.append(nth_token[0])
+ token_ids.append(nth_token[1])
+ seq_logprobs.append(logprobs)
+ seq_token_ids.append(token_ids)
+
+ # Get the size of the vocabulary dimension from the tokenizer
+ tokenizer_manager = _global_state["tokenizer_manager"]
+ VOCAB_DIM = len(
+ tokenizer_manager.tokenizer
+ ) # see: https://github.com/huggingface/transformers/blob/5f4ecf2d9f867a1255131d2461d75793c0cf1db2/src/transformers/tokenization_utils_fast.py#L275
+
+ # Construct the sparse matrix M
+ # We need to prepare the COO format: data, (row_ind, col_ind)
+ data = [] # Values (logprobs)
+ row_indices = [] # Row indices
+ col_indices = [] # Column indices (token IDs)
+
+ # Populate the coordinates and data
+ for i in range(len(seq_logprobs)):
+ for j in range(len(seq_logprobs[i])):
+ data.append(seq_logprobs[i][j])
+ row_indices.append(i)
+ col_indices.append(seq_token_ids[i][j])
+
+ # Construct the sparse matrix with dimensions [T, VOCAB_DIM]
+ M = coo_matrix(
+ (data, (row_indices, col_indices)), shape=(len(seq_logprobs), VOCAB_DIM)
+ )
+
+ # Transform to CSR format for better efficiency
+ M = M.tocsr()
+
+ return M
+
+
+def calculate_p_value_from_dense_matrix(M: np.ndarray, token_ids):
+ assert isinstance(M, np.ndarray)
+ # Transform to probabilities
+ M_p = np.exp(M)
+ # Get the observed probabilities for each token id at each position in the token sequence
+ P_obs = M_p[np.arange(M_p.shape[0]), token_ids]
+ # Calculate the tail mass for each position
+ tail_mass = (M_p < P_obs.reshape([-1, 1])).sum(axis=1)
+ # Apply mid-rank correction
+ tail_mass -= 0.5 * P_obs
+ # Simulate uniform distribution over the frequently quite large P_obs when temperature is low
+ tail_mass += P_obs * (np.random.rand(len(P_obs)) - 0.5)
+ # Clip for safety
+ np.clip(tail_mass, 1e-323, 1.0, out=tail_mass)
+ # Compute test statistic
+ F = -2.0 * np.log(tail_mass).sum()
+ # Compute p-value using chi-squared distribution (Fisher's method)
+ p_value = chi2.sf(F, df=2 * M_p.shape[0])
+ return p_value
+
+
+def collect_statistics_from_sparse_matrix(M_p: csr_matrix, token_ids, top_logprob_num):
+ assert isinstance(M_p, csr_matrix)
+ T = M_p.shape[0]
+ VOCAB_DIM = M_p.shape[1]
+
+ # Get the observed probabilities for each token id at each position in the token sequence
+ # Annoyingly, scipy.sparse matrix indexing returns a np.matrix, not an ndarray
+ # The .A1 gets the flattened representation of the matrix, but I still ravel() just in case scipy.sparse changes its return type
+ P_obs = M_p[np.arange(T), token_ids]
+ if isinstance(P_obs, np.matrix):
+ P_obs = P_obs.A1
+ P_obs = P_obs.ravel()
+ assert P_obs.shape == (T,), f"{P_obs.shape} != ({T},)"
+
+ # P_obs will be zero if the token is not in the top_logprob_num (let's just say 200)
+ # So, we have to substitute P_obs at those indices with a reasonable value.
+ # This takes a few steps.
+
+ # 1. Calculate the remaining probability mass after the top, say, 200 tokens
+ unlikely_tokens_total_mass = 1.0 - M_p.sum(axis=1)
+ if isinstance(unlikely_tokens_total_mass, np.matrix):
+ unlikely_tokens_total_mass = unlikely_tokens_total_mass.A1
+ unlikely_tokens_total_mass = unlikely_tokens_total_mass.ravel()
+ assert unlikely_tokens_total_mass.shape == (
+ T,
+ ), f"{unlikely_tokens_total_mass.shape} != ({T},)"
+
+ # 2. Spread the remaining probability mass out equally over the remaining non-top-200 tokens
+ num_unlikely_tokens = VOCAB_DIM - top_logprob_num
+ default_unlikely_token_prob = unlikely_tokens_total_mass / num_unlikely_tokens
+ assert default_unlikely_token_prob.shape == (
+ T,
+ ), f"{default_unlikely_token_prob.shape} != ({T},)"
+
+ # 3. Create a mask representing where a token was not part of the top-200
+ dok = M_p.todok()
+ unlikely_token_mask = np.asarray(
+ [(i, token_id) not in dok for i, token_id in enumerate(token_ids)], dtype=bool
+ )
+ assert unlikely_token_mask.shape == (T,), f"{unlikely_token_mask.shape} != ({T},)"
+
+ # 4. Substitute in our default values at these positions
+ P_obs[unlikely_token_mask] = default_unlikely_token_prob[unlikely_token_mask]
+
+ # Next, we need to measure the tail mass (sum of all probabilities in M_p[i] less than P_obs[i])
+ # If we do the naive M_p < P_obs.reshape([-1,1]), we'll get a dense matrix
+ # So instead, we'll process each row of the sparse matrix efficiently
+ tail_mass = np.zeros(T)
+ token_ranks = []
+ for i in range(T):
+ row_start = M_p.indptr[i]
+ row_end = M_p.indptr[i + 1]
+ # Get the data and indices for this row
+ data = M_p.data[row_start:row_end]
+ # what position is this token in the sorted list of token probabilities?
+ token_ranks.append(np.sum(data > P_obs[i]) + 1)
+ # Sum the values that are less than or equal to P_obs[i]
+ tail_mass[i] = np.sum(data[data <= P_obs[i]])
+
+ # Manually add the total mass of the unlikely tokens to each row, since that isn't included in the sparse matrix
+ tail_mass += unlikely_tokens_total_mass
+
+ # But if we are an unlikely token, choose a random tail mass between 0 and unlikely_tokens_total_mass[i]
+ num_unlikely_tokens_detected = np.sum(unlikely_token_mask)
+ if num_unlikely_tokens_detected > 0:
+ tail_mass[unlikely_token_mask] = unlikely_tokens_total_mass[
+ unlikely_token_mask
+ ] * np.random.rand(num_unlikely_tokens_detected)
+
+ # Ok, now we have an accurate tail_mass for top-200 tokens, and plausible (but randomized) for non-top-200 tokens
+ # Now we can proceed normally.
+
+ # Uniform distribution correction
+ U = tail_mass - P_obs * (np.random.rand(len(P_obs)))
+
+ # Spot-check - should give test statistic around 2*T-2 - confirmed!
+ # U = np.random.rand(len(P_obs))
+
+ # Clipping for safety
+ np.clip(U, 1e-323, 1.0, out=U)
+ # Test statistic
+ F = -2.0 * np.log(U).sum()
+ # p-value (do a two-tailed test - it's also suspicious if the selected tokens are much more "argmaxey" than expected)
+ # (TODO: what if an attacker mixes "argmaxey" behavior with unexpected tokens to push the F statistic closer to the null hypothesis?)
+ uncorrected_p_value = chi2.sf(F, df=2 * T)
+ p_value = min(uncorrected_p_value, 1 - uncorrected_p_value)
+
+ if is_debugging():
+ for i in range(T):
+ print(
+ f"""Token {i}:
+ Token ID: {token_ids[i]},
+ P_obs = {P_obs[i]:.4f},
+ tail_mass = {tail_mass[i]:.4f},
+ U = {U[i]:.4f},
+ Fi = {-2*np.log(U[i]):.4f},
+ Token Rank: {token_ranks[i]},
+ Is Unlikely Token: {unlikely_token_mask[i]},
+ Unlikely Token Total mass: {unlikely_tokens_total_mass[i]:.4f},
+ Default Unlikely Token Prob: {default_unlikely_token_prob[i]:.4f}"""
+ )
+
+ print(f"F = {F:.4f}")
+ print(f"One-tailed p-value = {uncorrected_p_value:.4f}")
+ print(f"p-value = {p_value:.4f}")
+ print(f"chi2 mode: {2*T - 2}")
+ print(f"chi2 stdev: {(2 * 2*T)**0.5:.4f}")
+
+ return {
+ "p_value": float(p_value),
+ "uncorrected_p_value": float(uncorrected_p_value),
+ "F": float(F),
+ "chi_squared_mode": int(2 * T - 2),
+ "chi_squared_stdev": float((2 * 2 * T) ** 0.5),
+ "num_unlikely_tokens": int(num_unlikely_tokens),
+ "average_token_rank": float(np.mean(token_ranks)),
+ "median_token_rank": float(np.median(token_ranks)),
+ "token_ranks": [int(rank) for rank in token_ranks],
+ "P_obs": P_obs.tolist(),
+ }
+
+
+def is_prefix(prefix_ids, ids):
+ for a, b in zip(prefix_ids, ids):
+ if a != b:
+ return False
+ return True
+
+
+def write_to_file(args, logprobs):
+ output_filename = args.output_file
+ os.makedirs(os.path.join(ROOT_DIR, "response_logprobs"), exist_ok=True)
+ output_filepath = os.path.join(ROOT_DIR, "response_logprobs", output_filename)
+ with open(output_filepath, "w") as f:
+ json.dump(logprobs, f, indent=2)
+
+
+def make_prompt_request(inference, model):
+ original_messages = inference["complete_request"]["messages"]
+ original_response = inference["complete_response"]["choices"][0]["message"][
+ "content"
+ ]
+
+ prefill_messages = original_messages
+ original_request = inference["complete_request"]
+ prefill_request = original_request.copy()
+ prefill_request["messages"] = prefill_messages
+ prefill_request["max_tokens"] = 0
+ prefill_request["extra_body"] = {"input_token_ids": True}
+ prefill_request["model"] = model
+ return prefill_request
+
+
+def make_prefill_request(inference, model) -> ChatCompletionRequest:
+ original_messages = inference["complete_request"]["messages"]
+ original_response = inference["complete_response"]["choices"][0]["message"][
+ "content"
+ ]
+
+ prefill_messages = [
+ *original_messages,
+ {"role": "assistant", "content": original_response},
+ ]
+ original_request = inference["complete_request"]
+ prefill_request = original_request.copy()
+ prefill_request["messages"] = prefill_messages
+ prefill_request["max_tokens"] = 0
+ prefill_request["extra_body"] = {"input_token_ids": True}
+ prefill_request["model"] = model
+ return prefill_request
+
+
+def get_token_ids(tokenizer_manager, model, request, add_eos_id=True):
+ # Create an OpenAI API style request
+ chat_request = ChatCompletionRequest(**request)
+
+ # Use the same function that is used for OpenAI style API requests to gen token ids
+ adapted_request, original_request = v1_chat_generate_request(
+ [chat_request], tokenizer_manager
+ )
+
+ # The token IDs will be in adapted_request
+ token_ids = adapted_request.input_ids
+
+ # We need to add the end of sequence token because this is a prefill, so we won't have that the END-OF-SEQUENCE token
+ if add_eos_id:
+ eos_token_id = tokenizer_manager.tokenizer.eos_token_id
+ token_ids.append(eos_token_id)
+
+ # Omit the begin-of-sequence token
+ if token_ids[0] == tokenizer_manager.tokenizer.bos_token_id:
+ return token_ids[1:]
+ return token_ids
+
+
+def make_logprobs_request(prefill_request, token_ids, top_logprob_num):
+ request = copy(prefill_request)
+ # request["messages"] = token_ids
+ request["extra_body"] = request.get("extra_body", {})
+ request["logprobs"] = True
+ request["top_logprobs"] = top_logprob_num
+ request["extra_body"]["logprob_start_len"] = 0 # Start from beginning
+ # request["extra_body"]["return_text_in_logprobs"] = True
+ return request
+
+
+def get_logprobs(client, model, request):
+ """
+ Get the full token distribution for each token in the input_ids.
+
+ Args:
+ model (str): The model to use for inference
+ request (dict): The request object with parameters for the API call
+
+ Returns:
+ List: A list of dictionaries containing the token distributions for each token
+ """
+ # Initialize OpenAI client with base URL pointing to the local SGL server
+
+ try:
+ # Make the API call using OpenAI-compatible API
+ response = client.chat.completions.create(**request)
+
+ # Convert the response to a dictionary for easier handling
+ response_dict = (
+ response.model_dump() if hasattr(response, "model_dump") else response
+ )
+
+ # Extract token distributions from the response
+ token_distributions = []
+
+ # If we reach here, neither format matched
+ print("Warning: Could not extract token distributions from the response")
+ return []
+
+ except Exception as e:
+ print(f"Error when trying to get logprobs: {e}")
+ return []
+
+
+def copy(x):
+ return json.loads(json.dumps(x))
+
+
+def start_server(args):
+ """
+ Start the SGL server.
+ """
+ if args.quiet:
+ MAYBE_NOISY = ""
+ else:
+ MAYBE_NOISY = "--log-level debug"
+
+ if args.disable_cuda_graph:
+ MAYBE_DISABLE_CUDA_GRAPH = "--disable-cuda-graph"
+ else:
+ MAYBE_DISABLE_CUDA_GRAPH = ""
+
+ model_path = args.model
+
+ print(f"Starting server with model {model_path}...")
+
+ model, *quantization = model_path.split(";")
+ if quantization:
+ quantization = quantization[0]
+ print(f"Quantization: {quantization}")
+ MAYBE_QUANTIZATION = f"--quantization {quantization}"
+ else:
+ MAYBE_QUANTIZATION = ""
+
+ server_process, port = launch_server_cmd(
+ f"""
+ python -m sglang.launch_server --model-path {model} {MAYBE_QUANTIZATION} --host 0.0.0.0 {MAYBE_NOISY} {MAYBE_DISABLE_CUDA_GRAPH}
+ """
+ )
+
+ print(f"Starting on port {port}...")
+
+ # Wait for the server to start
+ wait_for_server(f"http://localhost:{port}")
+
+ # Add additional delay to ensure server is fully initialized
+ print("Waiting 3 more seconds for server to be fully initialized...")
+ time.sleep(3)
+
+ print(
+ f"Started server with model {model} on port {port} {MAYBE_QUANTIZATION}, {MAYBE_NOISY}, {MAYBE_DISABLE_CUDA_GRAPH}"
+ )
+ if args.interactive:
+ input("Press Enter when ready to continue...")
+
+ return server_process, port
+
+
+def launch_engine(args):
+ model, *_ = args.model.split(";")
+ server_args = get_server_args(args)
+ server_args = ServerArgs(
+ **server_args
+ ) # TODO: the other flags like disable cuda graph
+ tokenizer_manager, scheduler_info = _launch_subprocesses(server_args=server_args)
+ _global_state["tokenizer_manager"] = tokenizer_manager
+ _global_state["scheduler_info"] = scheduler_info
+
+
+def get_server_args(args):
+ # other things to consider in future: grammar backend, etc.
+ model, *quantization = args.model.split(";")
+ disable_cuda_graph = args.disable_cuda_graph
+ return {
+ "model_path": model,
+ "quantization": quantization[0] if quantization else None,
+ "disable_cuda_graph": disable_cuda_graph,
+ "log_level": "debug" if not args.quiet else None,
+ }
+
+
+def shut_down_engine():
+ tokenizer_manager = _global_state["tokenizer_manager"]
+ scheduler_info = _global_state["scheduler_info"]
+ _global_state["tokenizer_manager"] = None
+ _global_state["scheduler_info"] = None
+ try:
+ kill_process_tree(os.getpid(), include_parent=False)
+ except:
+ pass
+
+
+def is_debugging():
+ return _global_state.get("debugging", False)
+
+
+def set_debugging():
+ _global_state["debugging"] = True
+
+
+async def main():
+
+ args = parse_args()
+
+ if args.debugging:
+ set_debugging()
+
+ # Dump the args
+ print(args)
+ if args.interactive:
+ input("Press Enter when ready to continue...")
+
+ kill_gpu_processes()
+ launch_engine(args)
+ port = 0
+ logprobs = await collect_logprobs(port, args)
+
+ print("writing to disk...")
+ write_to_file(args, logprobs)
+
+ print("Finished!")
+
+ return
+
+
+if __name__ == "__main__":
+ asyncio.run(main())
+ shut_down_engine()
+ sys.exit(0)
diff --git a/toploc-scripts/data_collection_scripts/collect_neg_log_likelihoods.py b/toploc-scripts/data_collection_scripts/collect_neg_log_likelihoods.py
new file mode 100644
index 00000000000..e69de29bb2d
diff --git a/toploc-scripts/data_collection_scripts/collect_toploc_fingerprints.py b/toploc-scripts/data_collection_scripts/collect_toploc_fingerprints.py
new file mode 100644
index 00000000000..65d512448db
--- /dev/null
+++ b/toploc-scripts/data_collection_scripts/collect_toploc_fingerprints.py
@@ -0,0 +1,181 @@
+import json
+import os
+import signal
+import subprocess
+import sys
+import time
+from argparse import ArgumentParser
+
+import openai
+from dotenv import load_dotenv
+from tqdm import tqdm
+
+from sglang.utils import (
+ launch_server_cmd,
+ print_highlight,
+ terminate_process,
+ wait_for_server,
+)
+
+load_dotenv()
+
+if not os.getenv("HF_TOKEN"):
+ raise ValueError("HF_TOKEN not found in environment variables")
+ sys.exit(1)
+
+SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
+ROOT_DIR = os.path.abspath(os.path.join("../", SCRIPT_DIR))
+
+
+def parse_args():
+ parser = ArgumentParser()
+ parser.add_argument("--machine", type=str, required=True, help="Machine name")
+ parser.add_argument(
+ "--model",
+ type=str,
+ default="meta-llama/Llama-3.1-8B-Instruct",
+ help="Model to use",
+ )
+ parser.add_argument(
+ "--ultrachat_file", type=str, default="train_0.jsonl", help="ultrachat filename"
+ )
+ parser.add_argument(
+ "--N", type=int, required=True, help="Number of requests to process"
+ )
+ parser.add_argument(
+ "--disable-cuda-graph", action="store_true", help="Disable CUDA graph"
+ )
+ parser.add_argument(
+ "--seed",
+ type=int,
+ required=False,
+ default=42,
+ help="Random seed for sampling and generation",
+ )
+ parser.add_argument(
+ "--temperature",
+ type=float,
+ required=False,
+ default=0.0,
+ help="Temperature for sampling",
+ )
+ parser.add_argument(
+ "--output_filename", type=str, default=None, help="Output filename"
+ )
+ parser.add_argument("--quiet", action="store_true", help="Run in quiet mode")
+ return parser.parse_args()
+
+
+def kill_gpu_processes():
+ cmd = "nvidia-smi --query-compute-apps=pid --format=csv,noheader"
+ result = subprocess.run(cmd, shell=True, check=True, stdout=subprocess.PIPE)
+ pids = [line.decode().strip() for line in result.stdout.splitlines()]
+ for pid in pids:
+ print(f"Killing process {pid}")
+ os.kill(int(pid), signal.SIGKILL)
+
+
+def start_server(args):
+ """
+ Start the SGL server with TopLoc fingerprint verification enabled.
+ """
+
+ print("Starting server with TopLoc fingerprint verification enabled...")
+ if args.quiet:
+ MAYBE_NOISY = ""
+ else:
+ MAYBE_NOISY = "--log-level debug"
+
+ if args.disable_cuda_graph:
+ MAYBE_DISABLE_CUDA_GRAPH = "--disable-cuda-graph"
+ else:
+ MAYBE_DISABLE_CUDA_GRAPH = ""
+
+ model, *quantization = args.model.split(";")
+ if quantization:
+ quantization = quantization[0]
+ print(f"Quantization: {quantization}")
+ MAYBE_QUANTIZATION = f"--quantization {quantization}"
+ else:
+ MAYBE_QUANTIZATION = ""
+
+ server_process, port = launch_server_cmd(
+ f"""
+ python -m sglang.launch_server --model-path {model} {MAYBE_QUANTIZATION} --host 0.0.0.0 {MAYBE_NOISY} {MAYBE_DISABLE_CUDA_GRAPH}
+ """
+ )
+
+ print(f"Starting on port {port}...")
+
+ # Wait for the server to start
+ wait_for_server(f"http://localhost:{port}")
+
+ # Add additional delay to ensure server is fully initialized
+ print("Waiting 3 more seconds for server to be fully initialized...")
+ time.sleep(3)
+
+ return server_process, port
+
+
+def collect_N_fingerprints(port, args):
+ client = openai.Client(base_url=f"http://127.0.0.1:{port}/v1", api_key="None")
+ fingerprints = []
+ ultrachat_filepath = os.path.join(ROOT_DIR, "ultrachat", args.ultrachat_file)
+
+ with open(ultrachat_filepath, "r") as f:
+ for i, line in enumerate(tqdm(f)):
+ if i >= args.N:
+ break
+
+ data = json.loads(line)
+ prompt = data["data"][0] # Assuming the first element is the user prompt
+
+ request = dict(
+ model=args.model,
+ messages=[
+ {"role": "user", "content": prompt},
+ ],
+ temperature=args.temperature,
+ seed=args.seed,
+ extra_body={"return_verification_proofs": True},
+ )
+
+ response = client.chat.completions.create(**request)
+ response_dump = response.model_dump()
+ fingerprint = response_dump["choices"][0]["message"][
+ "toploc_verification_fingerprints"
+ ][-1]
+
+ fingerprints.append(
+ {
+ "machine": args.machine,
+ "prompt": prompt,
+ "complete_request": request,
+ "complete_response": response_dump,
+ "model": args.model,
+ "fingerprint": fingerprint,
+ }
+ )
+
+ return fingerprints
+
+
+def write_to_file(args, fingerprints):
+ if args.output_filename is None:
+ ultrachat_no_ext = os.path.splitext(args.ultrachat_file)[0]
+ args.output_filename = args.model.replace("/", "_") + "_for_" + ultrachat_no_ext
+ fingerprints_dir = os.path.join(ROOT_DIR, "fingerprints")
+ os.makedirs(fingerprints_dir, exist_ok=True)
+ output_filepath = os.path.join(fingerprints_dir, args.output_filename)
+ with open(output_filepath, "w") as f:
+ json.dump(fingerprints, f, indent=4)
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ kill_gpu_processes()
+ server_process, port = start_server(args)
+ fingerprints = collect_N_fingerprints(port, args)
+ write_to_file(args, fingerprints)
+ server_process.terminate()
+ print("Server terminated.")
diff --git a/toploc-scripts/data_collection_scripts/download_ultrachat.py b/toploc-scripts/data_collection_scripts/download_ultrachat.py
new file mode 100644
index 00000000000..5f6021f31c7
--- /dev/null
+++ b/toploc-scripts/data_collection_scripts/download_ultrachat.py
@@ -0,0 +1,40 @@
+import os
+import sys
+from pathlib import Path
+
+import requests
+from dotenv import load_dotenv
+from huggingface_hub import snapshot_download
+
+load_dotenv()
+
+if not os.getenv("HF_TOKEN"):
+ raise ValueError("HF_TOKEN not found in environment variables")
+ sys.exit(1)
+
+print(f"Downloading UltraChat dataset... (HF_TOKEN={os.getenv('HF_TOKEN')})")
+
+
+def download_ultrachat(repo_id="stingning/ultrachat", target_dir=None):
+ """
+ Download the UltraChat dataset from Hugging Face.
+
+ Args:
+ repo_id (str): The Hugging Face repository ID
+ target_dir (str, optional): Target directory to download files to. Defaults to current directory.
+ """
+ root_dir = Path(__file__).parent.parent.absolute()
+ target_dir = root_dir / target_dir
+ os.makedirs(target_dir, exist_ok=True)
+
+ print(f"Downloading UltraChat dataset from {repo_id}...")
+ local_dir = snapshot_download(
+ repo_id=repo_id, local_dir=target_dir, repo_type="dataset"
+ )
+
+ print(f"UltraChat dataset downloaded to {local_dir}")
+ return local_dir
+
+
+if __name__ == "__main__":
+ download_ultrachat(target_dir="ultrachat")
diff --git a/toploc-scripts/data_collection_scripts/dummy_request.json b/toploc-scripts/data_collection_scripts/dummy_request.json
new file mode 100644
index 00000000000..42808ceae2f
--- /dev/null
+++ b/toploc-scripts/data_collection_scripts/dummy_request.json
@@ -0,0 +1,4 @@
+{
+ "prompt": "dummy prompt",
+ "model": "dummy-model"
+}
diff --git a/toploc-scripts/data_collection_scripts/perform-all-replications.sh b/toploc-scripts/data_collection_scripts/perform-all-replications.sh
new file mode 100755
index 00000000000..798e9a7fe7c
--- /dev/null
+++ b/toploc-scripts/data_collection_scripts/perform-all-replications.sh
@@ -0,0 +1,48 @@
+#!/bin/bash
+
+# Check if machine name is provided
+if [ $# -ne 1 ]; then
+ echo "Usage: $0 "
+ exit 1
+fi
+
+
+source ../sglang-clean/.venv/bin/activate
+pip install dotenv
+pip install huggingface-hub
+pip install tabulate
+
+MACHINE=$1
+
+if [ ! -d "toploc-scripts/inferences_to_replicate" ]; then
+ echo "Error: inferences_to_replicate directory not found"
+ exit 1
+fi
+
+MODELS=("meta-llama/Llama-3.1-8B-Instruct" "context-labs/neuralmagic-llama-3.1-8b-instruct-FP8" "meta-llama/Llama-3.2-3B-Instruct")
+
+# Loop over all .inference files
+for inference_filepath in toploc-scripts/inferences_to_replicate/*.inference; do
+ if [ -f "$inference_filepath" ]; then
+ for REPLICATION_MODEL in "${MODELS[@]}"; do
+ echo "Performing replications for inference: $inference_filepath"
+ inference_file=$(basename "$inference_filepath")
+ SANITIZED_REPLICATION_MODEL=$(echo "$REPLICATION_MODEL" | tr '/' '_')
+ output_file=${MACHINE}_${SANITIZED_REPLICATION_MODEL}_for_${inference_file}.replication
+ output_filepath="toploc-scripts/replications/$output_file"
+ if [ -f "$output_filepath" ]; then
+ echo "Output file already exists: $output_file"
+ continue
+ fi
+ echo "Output file: $output_file"
+ python toploc-scripts/data_collection_scripts/perform_replications.py --override-model $REPLICATION_MODEL --input-file "$inference_file" --machine "$MACHINE" --output-file "$output_file" --disable-cuda-graph --quiet
+
+ # Optional: Add a small delay between replications
+ sleep 1
+ done
+ else
+ echo "File $inference_file does not exist"
+ fi
+done
+
+echo "All inferences replicated!"
diff --git a/toploc-scripts/data_collection_scripts/perform_replications.py b/toploc-scripts/data_collection_scripts/perform_replications.py
new file mode 100644
index 00000000000..c412da87a75
--- /dev/null
+++ b/toploc-scripts/data_collection_scripts/perform_replications.py
@@ -0,0 +1,313 @@
+import argparse
+import json
+import os
+import signal
+import subprocess
+import sys
+import time
+
+import openai
+from dotenv import load_dotenv
+from tqdm import tqdm
+
+from sglang.utils import (
+ launch_server_cmd,
+ print_highlight,
+ terminate_process,
+ wait_for_server,
+)
+
+load_dotenv()
+
+if not os.getenv("HF_TOKEN"):
+ print_highlight("HF_TOKEN not found in environment variables!", color="red")
+ sys.exit(1)
+
+SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
+ROOT_DIR = os.path.dirname(SCRIPT_DIR)
+
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--machine", type=str, required=True, help="Machine name")
+ parser.add_argument(
+ "--override-model",
+ type=str,
+ required=False,
+ default=None,
+ help="Override model to use when replicating",
+ )
+ parser.add_argument(
+ "--input-file",
+ type=str,
+ required=True,
+ help="JSON filename containing inferences to replicate (from the inferences_to_replicate directory)",
+ )
+ parser.add_argument(
+ "--output-file",
+ type=str,
+ required=False,
+ default=None,
+ help="Filename to write the results (goes to replications directory)",
+ )
+ parser.add_argument("--quiet", action="store_true", help="Run in quiet mode")
+ parser.add_argument(
+ "--disable-cuda-graph", action="store_true", help="Disable CUDA graph"
+ )
+ parser.add_argument(
+ "--limit",
+ type=int,
+ default=None,
+ help="Limit number of inferences to replicate (whole file if not supplied)",
+ )
+ parser.add_argument(
+ "--N",
+ type=int,
+ default=None,
+ help="Number of inferences to replicate (whole file if not supplied)",
+ )
+ parser.add_argument("--skip-write", action="store_true")
+ parser.add_argument("--interactive", action="store_true")
+ return parser.parse_args()
+
+
+def kill_gpu_processes():
+ cmd = "nvidia-smi --query-compute-apps=pid --format=csv,noheader"
+ result = subprocess.run(cmd, shell=True, check=True, stdout=subprocess.PIPE)
+ pids = [line.decode().strip() for line in result.stdout.splitlines()]
+ for pid in pids:
+ print(f"Killing process {pid}")
+ os.kill(int(pid), signal.SIGKILL)
+
+
+def start_server(args, model_path):
+ """
+ Start the SGL server.
+ """
+ if args.quiet:
+ MAYBE_NOISY = ""
+ else:
+ MAYBE_NOISY = "--log-level debug"
+
+ if args.disable_cuda_graph:
+ MAYBE_DISABLE_CUDA_GRAPH = "--disable-cuda-graph"
+ else:
+ MAYBE_DISABLE_CUDA_GRAPH = ""
+
+ print(f"Starting server with model {model_path}...")
+
+ model, *quantization = model_path.split(";")
+ if quantization:
+ quantization = quantization[0]
+ print(f"Quantization: {quantization}")
+ MAYBE_QUANTIZATION = f"--quantization {quantization}"
+ else:
+ MAYBE_QUANTIZATION = ""
+
+ server_process, port = launch_server_cmd(
+ f"""
+ python -m sglang.launch_server --model-path {model} {MAYBE_QUANTIZATION} --host 0.0.0.0 {MAYBE_NOISY} {MAYBE_DISABLE_CUDA_GRAPH}
+ """
+ )
+
+ print(f"Starting on port {port}...")
+
+ # Wait for the server to start
+ wait_for_server(f"http://localhost:{port}")
+
+ # Add additional delay to ensure server is fully initialized
+ print("Waiting 3 more seconds for server to be fully initialized...")
+ time.sleep(3)
+
+ print(
+ f"Started server with model {model} on port {port} {MAYBE_QUANTIZATION}, {MAYBE_NOISY}, {MAYBE_DISABLE_CUDA_GRAPH}"
+ )
+ if args.interactive:
+ input("Press Enter when ready to continue...")
+
+ return server_process, port
+
+
+def load_inferences(args):
+ """
+ Load inferences from the input file.
+ """
+ input_filepath = args.input_file
+ if not os.path.isabs(input_filepath):
+ input_filepath = os.path.join(
+ ROOT_DIR, "inferences_to_replicate", input_filepath
+ )
+
+ print(f"Loading inferences from {input_filepath}")
+ with open(input_filepath, "r") as f:
+ inferences = json.load(f)
+
+ if args.limit is not None and args.limit > 0:
+ inferences = inferences[: args.limit]
+ print(f"Limited to first {args.limit} inferences")
+
+ print(f"Loaded {len(inferences)} inferences from {input_filepath}")
+ if args.interactive:
+ input("Press Enter when ready to continue...")
+
+ return inferences
+
+
+def perform_replications(inferences, machine_name, args):
+ """
+ Rerun the prompts from the inferences and collect new responses.
+ """
+ replication_results = []
+ server_process = None
+ port = None
+ client = None
+
+ for i, item in enumerate(tqdm(inferences)):
+ if args.N is not None and i >= args.N:
+ break
+ prompt = item["prompt"]
+ model_from_inference = item.get("model")
+ original_request = item["complete_request"]
+ original_response = item["complete_response"]
+
+ # Start server with the model from the first inference if not already running
+ if port is None:
+ model_to_use = args.override_model or model_from_inference
+ print(f"Starting server with model {model_to_use}...")
+ kill_gpu_processes()
+ server_process, port = start_server(args, model_to_use)
+ client = openai.Client(
+ base_url=f"http://127.0.0.1:{port}/v1", api_key="None"
+ )
+
+ # Copy all parameters from the original request
+ request = dict(original_request)
+ request["model"] = args.override_model or model_from_inference
+ # request["logprobs"] = True
+ # request["top_logprobs"] = 1
+
+ try:
+ response = client.chat.completions.create(**request)
+ response_dump = response.model_dump()
+
+ # Create result entry
+ replication_result = {
+ "replication_machine": machine_name,
+ "inference_machine": item["machine"],
+ "prompt": prompt,
+ "original_request": original_request,
+ "original_response": original_response,
+ "replication_request": request,
+ "replication_response": response_dump,
+ }
+
+ replication_results.append(replication_result)
+
+ original_response_text = original_response["choices"][0]["message"][
+ "content"
+ ]
+ replication_response_text = response_dump["choices"][0]["message"][
+ "content"
+ ]
+
+ if original_response_text != replication_response_text:
+ prefix_match_len = (
+ calculate_prefix_match_length(
+ original_response_text, replication_response_text
+ )
+ or 0
+ )
+ prefix_match_percent = (
+ prefix_match_len / len(original_response_text) * 100
+ )
+ print(
+ f" >>> Prompt {i} did not match original response (prefix %: {prefix_match_percent:.2f}, response lengths: {len(original_response_text)} : {len(replication_response_text)})"
+ )
+ print(
+ f"Divergence:\n\t{original_response_text[prefix_match_len:prefix_match_len+10]}\n\t{replication_response_text[prefix_match_len:prefix_match_len+10]}"
+ )
+ else:
+ print(f" >>> Prompt {i} matched original response")
+
+ except Exception as e:
+ print(f"Error replicating prompt {i}: {e}")
+ replication_results.append(
+ {
+ "replication_machine": machine_name,
+ "inference_machine": item["machine"],
+ "prompt": prompt,
+ "original_request": original_request,
+ "original_response": original_response,
+ "original_fingerprint": original_fingerprint,
+ "replication_request": request,
+ "error": str(e),
+ "prefix_match_length": prefix_match_len,
+ "prefix_match_percent": prefix_match_percent,
+ "original_response_length": len(original_response_text),
+ "replication_response_length": len(replication_response_text),
+ }
+ )
+
+ # Return both the results and the server process for cleanup
+ return replication_results, server_process
+
+
+def write_to_file(args, replication_results):
+ """
+ Write replication results to the output file.
+ """
+ if args.output_file is None:
+ first_result = replication_results[0]
+ model = first_result["replication_request"]["model"]
+ model_prefix = model.replace("/", "_")
+ args.output_file = (
+ model_prefix + "_replications_for_" + os.path.basename(args.input_file)
+ )
+
+ replications_dir = os.path.join(ROOT_DIR, "replications")
+ os.makedirs(replications_dir, exist_ok=True)
+ output_filepath = os.path.join(replications_dir, args.output_file)
+ with open(output_filepath, "w") as f:
+ json.dump(replication_results, f, indent=4)
+ print(f"Replication results written to {output_filepath}")
+
+
+def main():
+ args = parse_args()
+ inferences = load_inferences(args)
+ print(f"Loaded {len(inferences)} inferences, preparing to replicate...")
+
+ server_process = None
+ try:
+ # Server will be started within perform_replications when processing the first inference
+ replication_results, server_process = perform_replications(
+ inferences, args.machine, args
+ )
+
+ if not args.skip_write:
+ write_to_file(args, replication_results)
+
+ # Print a summary of replication results
+ print(f"\nReplication summary:")
+ print(f"Total replications: {len(replication_results)}")
+ no_error_count = sum(
+ 1 for result in replication_results if "error" not in result
+ )
+ print(f"Replications with no errors: {no_error_count}")
+ print(f"Replications with errors: {len(replication_results) - no_error_count}")
+ finally:
+ if server_process:
+ print("Terminating server...")
+ terminate_process(server_process)
+ print("Server terminated.")
+
+
+def calculate_prefix_match_length(original, replication):
+ length = max(len(original), len(replication))
+ original = original + " " * (length - len(original))
+ replication = replication + " " * (length - len(replication))
+ return sum(1 for o, r in zip(original, replication) if o == r)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/toploc-scripts/data_collection_scripts/verify-all-fingerprints.sh b/toploc-scripts/data_collection_scripts/verify-all-fingerprints.sh
new file mode 100755
index 00000000000..ac5c2120013
--- /dev/null
+++ b/toploc-scripts/data_collection_scripts/verify-all-fingerprints.sh
@@ -0,0 +1,42 @@
+#!/bin/bash
+
+# Check if machine name is provided
+if [ $# -ne 1 ]; then
+ echo "Usage: $0 "
+ exit 1
+fi
+
+
+source .venv/bin/activate
+
+MACHINE=$1
+
+if [ ! -d "toploc-scripts/fingerprints" ]; then
+ echo "Error: fingerprints directory not found"
+ exit 1
+fi
+
+# Array of models to process
+MODELS=("meta-llama/Llama-3.1-8B-Instruct" "meta-llama/Llama-3.1-8B-Instruct;fp8" "meta-llama/Llama-3.2-3B-Instruct")
+
+# Loop over all .fingerprint files
+for fingerprint_filepath in toploc-scripts/fingerprints/*.fingerprint; do
+ if [ -f "$fingerprint_filepath" ]; then
+ for MODEL in "${MODELS[@]}"; do
+ # Sanitize model name for filename
+ SANITIZED_MODEL=$(echo "$MODEL" | tr '/' '_')
+ echo "Verifying fingerprint: $fingerprint_filepath"
+ fingerprint_file=$(basename "$fingerprint_filepath")
+ output_file=${MACHINE}_${SANITIZED_MODEL}_for_${fingerprint_file}.verification
+ echo "Output file: $output_file"
+ python toploc-scripts/verify_fingerprints.py --input-file "$fingerprint_file" --machine "$MACHINE" --model "$MODEL" --output-file "$output_file" --disable-cuda-graph
+
+ # Optional: Add a small delay between verifications
+ sleep 1
+ done
+ else
+ echo "File $fingerprint_file does not exist"
+ fi
+done
+
+echo "All fingerprints verified!"
diff --git a/toploc-scripts/data_collection_scripts/verify_fingerprints.py b/toploc-scripts/data_collection_scripts/verify_fingerprints.py
new file mode 100644
index 00000000000..8b8931c7359
--- /dev/null
+++ b/toploc-scripts/data_collection_scripts/verify_fingerprints.py
@@ -0,0 +1,228 @@
+import argparse
+import json
+import os
+import signal
+import subprocess
+import sys
+import time
+
+import openai
+from dotenv import load_dotenv
+from tqdm import tqdm
+
+from sglang.utils import (
+ launch_server_cmd,
+ print_highlight,
+ terminate_process,
+ wait_for_server,
+)
+
+load_dotenv()
+
+if not os.getenv("HF_TOKEN"):
+ print_highlight("HF_TOKEN not found in environment variables!", color="red")
+ sys.exit(1)
+
+SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
+ROOT_DIR = os.path.dirname(SCRIPT_DIR)
+
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--machine", type=str, required=True, help="Machine name")
+ parser.add_argument(
+ "--model",
+ type=str,
+ default="meta-llama/Llama-3.1-8B-Instruct",
+ help="Model to use",
+ )
+ parser.add_argument(
+ "--input-file",
+ type=str,
+ required=True,
+ help="JSON filename containing fingerprints to analyze (from the fingerprints directory)",
+ )
+ parser.add_argument(
+ "--output-file",
+ type=str,
+ required=False,
+ default=None,
+ help="Filename to write the results (goes to verifications directory)",
+ )
+ parser.add_argument("--quiet", action="store_true", help="Run in quiet mode")
+ parser.add_argument(
+ "--disable-cuda-graph", action="store_true", help="Disable CUDA graph"
+ )
+ parser.add_argument(
+ "--limit",
+ type=int,
+ default=None,
+ help="Limit number of fingerprints to analyze",
+ )
+ return parser.parse_args()
+
+
+def kill_gpu_processes():
+ cmd = "nvidia-smi --query-compute-apps=pid --format=csv,noheader"
+ result = subprocess.run(cmd, shell=True, check=True, stdout=subprocess.PIPE)
+ pids = [line.decode().strip() for line in result.stdout.splitlines()]
+ for pid in pids:
+ print(f"Killing process {pid}")
+ os.kill(int(pid), signal.SIGKILL)
+
+
+def start_server(args):
+ """
+ Start the SGL server with TopLoc fingerprint verification enabled.
+ """
+ print("Starting server with TopLoc fingerprint verification enabled...")
+ if args.quiet:
+ MAYBE_NOISY = ""
+ else:
+ MAYBE_NOISY = "--log-level debug"
+
+ if args.disable_cuda_graph:
+ MAYBE_DISABLE_CUDA_GRAPH = "--disable-cuda-graph"
+ else:
+ MAYBE_DISABLE_CUDA_GRAPH = ""
+
+ print(f"Starting server with model {args.model}...")
+
+ print(f"Starting server with model {args.model}...")
+
+ model, *quantization = args.model.split(";")
+ if quantization:
+ quantization = quantization[0]
+ print(f"Quantization: {quantization}")
+ MAYBE_QUANTIZATION = f"--quantization {quantization}"
+ else:
+ MAYBE_QUANTIZATION = ""
+
+ server_process, port = launch_server_cmd(
+ f"""
+ python -m sglang.launch_server --model-path {model} {MAYBE_QUANTIZATION} --host 0.0.0.0 --toploc-verification {MAYBE_NOISY} {MAYBE_DISABLE_CUDA_GRAPH}
+ """
+ )
+
+ print(f"Starting on port {port}...")
+
+ # Wait for the server to start
+ wait_for_server(f"http://localhost:{port}")
+
+ # Add additional delay to ensure server is fully initialized
+ print("Waiting 3 more seconds for server to be fully initialized...")
+ time.sleep(3)
+
+ return server_process, port
+
+
+def load_fingerprints(args):
+ """
+ Load fingerprints from the input file.
+ """
+ input_filepath = args.input_file
+ if not os.path.isabs(input_filepath):
+ input_filepath = os.path.join(ROOT_DIR, "fingerprints", input_filepath)
+
+ print(f"Loading fingerprints from {input_filepath}")
+ with open(input_filepath, "r") as f:
+ fingerprints = json.load(f)
+
+ if args.limit is not None and args.limit > 0:
+ fingerprints = fingerprints[: args.limit]
+ print(f"Limited to first {args.limit} fingerprints")
+
+ return fingerprints
+
+
+def verify_fingerprints(args, port, fingerprints):
+ """
+ Verify the fingerprints by making verification requests to the server.
+ """
+ client = openai.Client(base_url=f"http://127.0.0.1:{port}/v1", api_key="None")
+ verification_results = []
+
+ for i, item in enumerate(tqdm(fingerprints)):
+ prompt = item["prompt"]
+ model = item["model"]
+ fingerprint = item["fingerprint"]
+ original_response = item["complete_response"]["choices"][0]["message"][
+ "content"
+ ]
+
+ # Create a verification request
+ request = dict(
+ model=model,
+ messages=[
+ {"role": "user", "content": prompt},
+ {"role": "assistant", "content": original_response},
+ ],
+ max_tokens=0, # This is a prefill-only operation
+ extra_body={"toploc_verification_fingerprint_to_validate": fingerprint},
+ )
+
+ try:
+ response = client.chat.completions.create(**request)
+ response_dump = response.model_dump()
+
+ # Extract verification result
+ verification_result = {
+ "prompt": prompt,
+ "verification_request": request,
+ "verification_response": response_dump,
+ "original_request": item["complete_request"],
+ "original_response": item["complete_response"],
+ "original_machine": item["machine"],
+ "original_model": item["model"],
+ "verification_machine": args.machine,
+ "verification_model": args.model,
+ "original_fingerprint": fingerprint,
+ "verification_result": response_dump["choices"][0]["message"].get(
+ "toploc_verification_fingerprint_validation_result", False
+ ),
+ }
+
+ verification_results.append(verification_result)
+ except Exception as e:
+ print(f"Error verifying fingerprint {i}: {e}")
+
+ return verification_results
+
+
+def write_to_file(args, verification_results):
+ """
+ Write verification results to the output file.
+ """
+ if args.output_file is None:
+ args.output_file = args.model.replace("/", "_") + "_for_" + args.input_file
+ verifications_dir = os.path.join(ROOT_DIR, "verifications")
+ output_filepath = os.path.join(verifications_dir, args.output_file)
+ os.makedirs(verifications_dir, exist_ok=True)
+ with open(output_filepath, "w") as f:
+ json.dump(verification_results, f, indent=4)
+ print(f"Verification results written to {output_filepath}")
+
+
+def main():
+ args = parse_args()
+ fingerprints = load_fingerprints(args)
+ print(f"Loaded {len(fingerprints)} fingerprints, preparing to verify...")
+
+ kill_gpu_processes()
+ server_process, port = start_server(args)
+
+ try:
+ verification_results = verify_fingerprints(args, port, fingerprints)
+ write_to_file(args, verification_results)
+
+ # Print a summary of verification results
+ print(f"\nVerification summary:")
+ print(f"Total fingerprints: {len(verification_results)}")
+ finally:
+ print("Terminating server...")
+ terminate_process(server_process)
+ print("Server terminated.")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/toploc-scripts/minimal_example.py b/toploc-scripts/minimal_example.py
new file mode 100644
index 00000000000..de711650212
--- /dev/null
+++ b/toploc-scripts/minimal_example.py
@@ -0,0 +1,153 @@
+import argparse
+import csv
+import json
+import os
+import random
+import signal
+import sys
+import time
+from pathlib import Path
+
+import openai
+from dotenv import load_dotenv
+from tqdm import tqdm
+
+from sglang.utils import print_highlight, terminate_process, wait_for_server
+
+
+def do_verification_flow(args, port):
+
+ # Setup
+ client = openai.Client(base_url=f"http://127.0.0.1:{port}/v1", api_key="None")
+ prompt = "What is the capital of Bulgaria?"
+ params = {
+ "temperature": 0,
+ "seed": args.seed,
+ }
+
+ # Initial request (send by user)
+ response = client.chat.completions.create(
+ model="meta-llama/Llama-3.1-8B-Instruct",
+ messages=[
+ {"role": "user", "content": prompt},
+ ],
+ **params,
+ extra_body={"return_verification_proofs": True},
+ )
+ response_dump = response.model_dump()
+ response_content = response_dump["choices"][0]["message"]["content"]
+ fingerprint = response_dump["choices"][0]["message"][
+ "toploc_verification_fingerprints"
+ ][-1]
+ print("Prompt: ", prompt)
+ print("Response: ", response_content)
+ print("Fingerprint: ", fingerprint)
+ input("Press Enter to continue...")
+
+ # Send verification request to verification instance
+ prefill_response = client.chat.completions.create(
+ model="meta-llama/Llama-3.1-8B-Instruct",
+ messages=[
+ {"role": "user", "content": prompt},
+ {"role": "assistant", "content": response_content},
+ ],
+ max_tokens=0,
+ **params,
+ extra_body={
+ "toploc_verification_fingerprint_to_validate": fingerprint,
+ },
+ )
+ prefill_dump = prefill_response.model_dump()
+ verification_result = prefill_dump["choices"][0]["message"][
+ "toploc_verification_fingerprint_validation_result"
+ ]
+ error_statistics = json.loads(verification_result)
+ print("Verification Result", verification_result)
+
+ # Apply error thresholds
+ verified = is_verified(**error_statistics)
+ print("Verified:", verified)
+ input("Press Enter to exit...")
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(
+ description="Test TopLoc fingerprint verification with UltraChat dataset"
+ )
+ parser.add_argument(
+ "--seed",
+ type=int,
+ required=False,
+ default=42,
+ help="Random seed for sampling and generation",
+ )
+ parser.add_argument("--quiet", action="store_true", help="Run in quiet mode")
+ parser.add_argument(
+ "--disable-cuda-graph", action="store_true", help="Disable CUDA graph"
+ )
+ args = parser.parse_args()
+ return args
+
+
+# Kill any GPU processes that might be running
+def kill_gpu_processes():
+ import subprocess
+
+ cmd = "nvidia-smi --query-compute-apps=pid --format=csv,noheader"
+ result = subprocess.run(cmd, shell=True, check=True, stdout=subprocess.PIPE)
+ pids = [line.decode().strip() for line in result.stdout.splitlines()]
+ for pid in pids:
+ print(f"Killing process {pid}")
+ os.kill(int(pid), signal.SIGKILL)
+
+
+# Is it verified, based on activation error statistics?
+def is_verified(exp_mismatches, mant_err_mean, mant_err_median):
+ """
+ Determine if a verification proof is valid based on metrics.
+ Adjust thresholds as needed.
+ """
+ # Example thresholds - adjust based on your actual requirements
+ return exp_mismatches <= 90 and mant_err_mean <= 10.0 and mant_err_median <= 8.0
+
+
+# Fire up the server
+def start_server(args):
+ """
+ Start the SGL server with TopLoc fingerprint verification enabled.
+ """
+ from sglang.utils import launch_server_cmd
+
+ print("Starting server with TopLoc fingerprint verification enabled...")
+ if args.quiet:
+ MAYBE_NOISY = ""
+ else:
+ MAYBE_NOISY = "--log-level debug"
+
+ if args.disable_cuda_graph:
+ MAYBE_DISABLE_CUDA_GRAPH = "--disable-cuda-graph"
+ else:
+ MAYBE_DISABLE_CUDA_GRAPH = ""
+
+ server_process, port = launch_server_cmd(
+ f"""
+ python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --host 0.0.0.0 --toploc-verification {MAYBE_NOISY} {MAYBE_DISABLE_CUDA_GRAPH}
+ """
+ )
+
+ # Wait for the server to start
+ wait_for_server(f"http://localhost:{port}")
+
+ # Add additional delay to ensure server is fully initialized
+ print("Waiting 3 more seconds for server to be fully initialized...")
+ time.sleep(3)
+
+ return server_process, port
+
+
+if __name__ == "__main__":
+ kill_gpu_processes()
+ args = parse_args()
+ server_process, port = start_server(args)
+ do_verification_flow(args, port)
+ server_process.terminate()
diff --git a/toploc-scripts/prefill_attack/test_prefill_attack.py b/toploc-scripts/prefill_attack/test_prefill_attack.py
new file mode 100644
index 00000000000..11abe1ec1d6
--- /dev/null
+++ b/toploc-scripts/prefill_attack/test_prefill_attack.py
@@ -0,0 +1,215 @@
+import argparse
+import json
+import os
+import signal
+import subprocess
+import sys
+import time
+
+import openai
+from dotenv import load_dotenv
+from tqdm import tqdm
+
+from sglang.utils import (
+ launch_server_cmd,
+ print_highlight,
+ terminate_process,
+ wait_for_server,
+)
+
+SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
+ROOT_DIR = os.path.dirname(SCRIPT_DIR)
+
+ERROR_THRESHOLDS = {
+ "exp_mismatches": 90, # Maximum number of exponent mismatches allowed
+ "mant_err_mean": 10, # Maximum mean mantissa error allowed
+ "mant_err_median": 8, # Maximum median mantissa error allowed
+}
+
+load_dotenv()
+
+
+def start_server(args):
+ """
+ Start the SGL server with TopLoc fingerprint verification enabled.
+ """
+ print("Starting server with TopLoc fingerprint verification enabled...")
+ if args.quiet:
+ MAYBE_NOISY = ""
+ else:
+ MAYBE_NOISY = "--log-level debug"
+
+ if args.disable_cuda_graph:
+ MAYBE_DISABLE_CUDA_GRAPH = "--disable-cuda-graph"
+ else:
+ MAYBE_DISABLE_CUDA_GRAPH = ""
+
+ print(f"Starting server with model {args.model}...")
+
+ print(f"Starting server with model {args.model}...")
+
+ model, *quantization = args.model.split(";")
+ if quantization:
+ quantization = quantization[0]
+ print(f"Quantization: {quantization}")
+ MAYBE_QUANTIZATION = f"--quantization {quantization}"
+ else:
+ MAYBE_QUANTIZATION = ""
+
+ server_process, port = launch_server_cmd(
+ f"""
+ python -m sglang.launch_server --model-path {model} {MAYBE_QUANTIZATION} --host 0.0.0.0 --toploc-verification {MAYBE_NOISY} {MAYBE_DISABLE_CUDA_GRAPH}
+ """
+ )
+
+ print(f"Starting on port {port}...")
+
+ # Wait for the server to start
+ wait_for_server(f"http://localhost:{port}")
+
+ # Add additional delay to ensure server is fully initialized
+ print("Waiting 3 more seconds for server to be fully initialized...")
+ time.sleep(3)
+
+ return server_process, port
+
+
+def kill_gpu_processes():
+ cmd = "nvidia-smi --query-compute-apps=pid --format=csv,noheader"
+ result = subprocess.run(cmd, shell=True, check=True, stdout=subprocess.PIPE)
+ pids = [line.decode().strip() for line in result.stdout.splitlines()]
+ for pid in pids:
+ print(f"Killing process {pid}")
+ os.kill(int(pid), signal.SIGKILL)
+
+
+def test_prefills(args, port):
+ client = openai.Client(base_url=f"http://127.0.0.1:{port}/v1", api_key="None")
+
+ ultrachat_file = os.path.join(ROOT_DIR, "ultrachat", args.ultrachat_file)
+
+ # read ultrachat file in a loop
+ with open(ultrachat_file, "r") as f:
+ for i, line in enumerate(f):
+ if i >= args.N:
+ break
+ # Load the prompt and the response from the line
+ line = json.loads(line)
+ prompt = line["data"][0]
+ response = line["data"][1]
+
+ spoofed_response = "A made up response"
+
+ # Generate fingerprint using a prefill
+ request = dict(
+ model=args.model,
+ messages=[
+ {"role": "user", "content": prompt},
+ {"role": "assistant", "content": spoofed_response},
+ ],
+ max_tokens=0,
+ temperature=args.temperature,
+ seed=args.seed,
+ extra_body={"return_verification_proofs": True},
+ )
+
+ response = client.chat.completions.create(**request)
+ response_dump = response.model_dump()
+ fingerprint = response_dump["choices"][0]["message"][
+ "toploc_verification_fingerprints"
+ ][-1]
+ response_content = response_dump["choices"][0]["message"]["content"]
+
+ print(f"Fingerprint for prompt {i}: {fingerprint}")
+
+ # Verify fingerprint
+ request = dict(
+ model=args.model,
+ messages=[
+ {"role": "user", "content": prompt},
+ {"role": "assistant", "content": spoofed_response},
+ ],
+ max_tokens=0,
+ temperature=args.temperature,
+ seed=args.seed,
+ extra_body={"toploc_verification_fingerprint_to_validate": fingerprint},
+ )
+
+ response = client.chat.completions.create(**request)
+ response_dump = response.model_dump()
+ verification_result = response_dump["choices"][0]["message"].get(
+ "toploc_verification_fingerprint_validation_result", False
+ )
+
+ print(f"Verification fingerprint for prompt {i}: {verification_result}")
+
+ verified = is_verified(json.loads(verification_result))
+
+ print(f"Verified: {verified}")
+
+
+def is_verified(verification_result):
+ exp_check = (
+ verification_result["exp_mismatches"] <= ERROR_THRESHOLDS["exp_mismatches"]
+ )
+ mean_check = (
+ verification_result["mant_err_mean"] <= ERROR_THRESHOLDS["mant_err_mean"]
+ )
+ median_check = (
+ verification_result["mant_err_median"] <= ERROR_THRESHOLDS["mant_err_median"]
+ )
+ passed = exp_check and mean_check and median_check
+ return passed
+
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--model",
+ type=str,
+ default="meta-llama/Llama-3.1-8B-Instruct",
+ help="Model to use",
+ )
+ parser.add_argument(
+ "--N",
+ type=int,
+ default=1,
+ help="Number of prompts to test",
+ )
+ parser.add_argument(
+ "--ultrachat_file", type=str, default="train_0.jsonl", help="ultrachat filename"
+ )
+ parser.add_argument(
+ "--disable-cuda-graph", action="store_true", help="Disable CUDA graph"
+ )
+ parser.add_argument(
+ "--seed",
+ type=int,
+ required=False,
+ default=42,
+ help="Random seed for sampling and generation",
+ )
+ parser.add_argument(
+ "--temperature",
+ type=float,
+ required=False,
+ default=0.0,
+ help="Temperature for sampling",
+ )
+ parser.add_argument(
+ "--output_filename", type=str, default=None, help="Output filename"
+ )
+ parser.add_argument("--quiet", action="store_true", help="Run in quiet mode")
+ return parser.parse_args()
+
+
+def main():
+ args = parse_args()
+ kill_gpu_processes()
+ server_process, port = start_server(args)
+ test_prefills(args, port)
+ server_process.terminate()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/toploc-scripts/setup.sh b/toploc-scripts/setup.sh
new file mode 100644
index 00000000000..ca97d86763f
--- /dev/null
+++ b/toploc-scripts/setup.sh
@@ -0,0 +1,13 @@
+source .venv/bin/activate
+pip install -e "python[all]" --find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python
+pip install transformers==4.48.3
+pip install datasets
+
+pip install dotenv
+pip install huggingface-hub
+pip install tabulate
+pip install sentence-transformers
+pip install scikit-learn
+pip install matplotlib
+pip install seaborn
+deactivate
diff --git a/toploc-scripts/spotcheck.py b/toploc-scripts/spotcheck.py
new file mode 100644
index 00000000000..a11dbb563b5
--- /dev/null
+++ b/toploc-scripts/spotcheck.py
@@ -0,0 +1,47 @@
+import json
+import os
+
+SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
+
+
+def do_it(FILENAME):
+
+ print(f"Performing spotcheck for {FILENAME}")
+
+ filepath = os.path.join(SCRIPT_DIR, "replications", FILENAME)
+ with open(filepath, "r") as f:
+ data = json.load(f)
+
+ N = 0
+ matches = 0
+
+ for replication_attempt in data:
+ try:
+ orig_model = replication_attempt["original_request"]["model"]
+ repl_model = replication_attempt["replication_request"]["model"]
+
+ orig_response = replication_attempt["original_response"]["choices"][0][
+ "message"
+ ]["content"]
+ repl_response = replication_attempt["replication_response"]["choices"][0][
+ "message"
+ ]["content"]
+
+ # print(f"Original model: {orig_model}, Replication model: {repl_model}")
+ # print(f"Original response: {orig_response}")
+ # print(f"Replication response: {repl_response}")
+ # print("\n")
+
+ N += 1
+ if orig_response == repl_response:
+ matches += 1
+ except:
+ pass
+
+ print(f"Match rate: {100 * matches / N:.2f}% ({matches}/{N})")
+
+
+if __name__ == "__main__":
+ replications_dir = os.path.join(SCRIPT_DIR, "replications")
+ for filename in os.listdir(replications_dir):
+ do_it(filename)