Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# See https://pre-commit.com/hooks.html for more hooks
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
rev: v6.0.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
Expand All @@ -12,7 +12,7 @@ repos:
hooks:
- id: black
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: 'v0.12.1'
rev: 'v0.12.8'
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
1 change: 1 addition & 0 deletions .python-version
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
3.12
2 changes: 2 additions & 0 deletions delphi/explainers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .contrastive_explainer import ContrastiveExplainer
from .default.default import DefaultExplainer
from .explainer import Explainer, explanation_loader, random_explanation_loader
from .graph_explainer import GraphExplainer
from .no_op_explainer import NoOpExplainer
from .single_token_explainer import SingleTokenExplainer

Expand All @@ -12,4 +13,5 @@
"random_explanation_loader",
"ContrastiveExplainer",
"NoOpExplainer",
"GraphExplainer",
]
30 changes: 28 additions & 2 deletions delphi/explainers/default/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
Guidelines:

You will be given a list of text examples on which special words are selected and between delimiters like <<this>>. If a sequence of consecutive tokens all are important, the entire sequence of tokens will be contained between delimiters <<just like this>>. How important each token is for the behavior is listed after each example in parentheses.

- Try to produce a concise final description. Simply describe the text latents that are common in the examples, and what patterns you found.
- If the examples are uninformative, you don't need to mention them. Don't focus on giving examples of important tokens, but try to summarize the patterns found in the examples.
- Do not mention the marker tokens (<< >>) in your explanation.
Expand Down Expand Up @@ -59,9 +58,36 @@
2. Write down general shared latents of the text examples. This could be related to the full sentence or to the words surrounding the marked words.

3. Formulate an hypothesis and write down the final explanation using [EXPLANATION]:.

"""

SYSTEM_GRAPH = """You are explaining the behavior of a neuron in a neural network. You will be graded on the quality of your explanation in terms of simplicity and accuracy. Do not make mistakes and follow all instructions carefully.
### Instructions
Your response should be a very concise explanation that captures what the neuron detects or predicts by finding patterns in lists:
- The explanation should be specific. For example, "unique words" is not a specific enough pattern, nor is "foreign words"
- The explanation can contain multiple different elements if one pattern is not sufficient to cover all examples. Remember to use only a few words to describe each one
- There will be a few examples in each of the provided lists that are irrelevant to the true explanation and should be discarded when looking for the pattern. You must cut through the noise.

To explain the neuron, try all methods and then go back to a previous method that works best. The methods are listed in order of probability of being correct, but does not mean you should always choose method 1.
- Method 1: Look at MAX_ACTIVATING_TOKENS. If they share something specific in common, or are all the same token or a variation of the same token (like different cases or conjugations), respond with that token
- Method 2: Look at TOKENS_AFTER_MAX_ACTIVATING_TOKEN. Try to find a specific pattern or similarity in all the tokens. A common pattern is that they all start with the same letter.
- Method 3: Look at TOP_POSITIVE_LOGITS for similarities.
- Method 4: Look at TOP_NEGATIVE_LOGITS for similarities.

To further refine your explanation, follow these guidelines:
- Do not add unnecessary phrases like "words related to", "concepts related to", or "variations of the word"
- Do not mention "tokens" or "patterns" or the method used in your explanation
- Look at the GRAPH_PROMPT for additional context. Since there may be multiple possible explanations, use the GRAPH_PROMPT to narrow it down.

Follow these instructions step by step and then produce the response in the format specified below. Please use square brackets as specified in the format to help the grader score your answer.
Format:
Method 1: <Plausible explanation using this method and rationale for whether it is the best or insufficient>
Method 2: <Plausible explanation using this method and rationale for whether it is the best or insufficient>
Method 3: <Plausible explanation using this method and rationale for whether it is the best or insufficient>
Method 4: <Plausible explanation using this method and rationale for whether it is the best or insufficient>
Answer:
[SELECTED METHOD] <The method number you chose>
[EXPLANATION] <Your final refined explanation>
"""

### EXAMPLE 1 ###

Expand Down
11 changes: 10 additions & 1 deletion delphi/explainers/explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,15 @@ class ExplainerResult(NamedTuple):
explanation: str
"""Generated explanation for latent."""

prompt: str = ""
"""Prompt given to explainer model to generate explanation."""

response: str = ""
"""
The full response generated by the explainer model
(includes thinking process/rationale)
"""


@dataclass
class Explainer(ABC):
Expand Down Expand Up @@ -67,7 +76,7 @@ def parse_explanation(self, text: str) -> str:
if match:
return match.group(1).strip()
else:
return "Explanation could not be parsed."
return f"Explanation could not be parsed: {text}"
except Exception as e:
logger.error(f"Explanation parsing regex failed: {repr(e)}")
raise
Expand Down
223 changes: 223 additions & 0 deletions delphi/explainers/graph_explainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
import asyncio
import json
import os
from dataclasses import dataclass
from typing import Optional

import torch

from delphi.explainers.default.prompts import SYSTEM_GRAPH
from delphi.explainers.explainer import Explainer, ExplainerResult, Response
from delphi.latents.latents import (
ActivatingExample,
LatentRecord,
)


@dataclass
class GraphExplainer(Explainer):
activations: bool = True
"""Whether to show activations to the explainer."""
cot: bool = False
"""Whether to use chain of thought reasoning in the prompt."""
max_examples: int = 15
"""Maximum number of activating examples to use."""
graph_info_path: Optional[os.PathLike] = None
"""Path to the graph information file."""
explanations_dir: Optional[os.PathLike] = None
"""Path to the directory where explanations will be saved."""
graph_prompt: str = ""
"""The prompt used to generate the graph."""
max_parent_explanations: int = 1
"""Maximimum number of explanations from parent nodes to use"""
top_logits: bool = True
"""Whether to show the top logits for a feature to the explainer"""
bot_logits: bool = True
"""Whether to show the bottom logits for a feature to the explainer"""
disable_thinking: bool = False
"""Appends /no_think to the user prompt to disable thinking on Qwen3 models"""

async def __call__(self, record: LatentRecord) -> ExplainerResult:
"""
Override the base __call__ method to use
- train examples
- non activating examples
- prompt used to generate graph
- explanations from the parent nodes

Args:
record: The latent record containing both activating and
non-activating examples.

Returns:
ExplainerResult: The explainer result containing the explanation.
"""
# Sample from both activating and non-activating examples
activating_examples = record.train[: self.max_examples]

top_logits = record.top_logits if self.top_logits else []
bot_logits = record.bot_logits if self.bot_logits else []
if self.graph_info_path is None or not os.path.exists(self.graph_info_path):
from delphi import logger

logger.error("Graph info path is not set.")
return ExplainerResult(
record=record,
explanation="Graph info path was not set to a valid file path.",
)

parent_explanations_files = record.parents[
: min(len(record.parents), self.max_parent_explanations)
]
parent_explanations = []
for parent in parent_explanations_files:
parent_path = os.path.join(self.explanations_dir, str(parent))
if not os.path.exists(parent_path):
continue
with open(parent_path, "r") as f:
parent_explanations.append(f.read())

# Build the prompt with both types of examples
messages = self._build_prompt(
activating_examples,
top_logits,
bot_logits,
self.graph_prompt,
parent_explanations,
)
# Generate the explanation
response = await self.client.generate(
messages, temperature=self.temperature, **self.generation_kwargs
)

try:
if isinstance(response, Response):
response_text = response.text
else:
response_text = response
explanation = self._parse_explanation(response_text)
return ExplainerResult(
record=record,
explanation=explanation,
prompt=messages[1]["content"],
response=response_text,
)
except Exception as e:
print(f"Explanation parsing failed: {repr(e)}")
return ExplainerResult(
record=record, explanation=response.text, prompt=messages[1]["content"]
)

def _parse_explanation(self, response):
# Extract explanation from the response text
explanation = ""
if "[EXPLANATION]" in response:
explanation = response.split("[EXPLANATION]")[-1].strip()
else:
explanation = response.strip()

method = 0
if "[SELECTED METHOD]" in response:
try:
method = int(response.split("[SELECTED METHOD]")[-1].strip()[-1])
except Exception as e:
method = 0
print(f"Failed to parse method: {repr(e)} from response: {response}")

prefixes = ["", "[say] ", "", "[promote] ", "[supress] "]
explanation = prefixes[method] + explanation
return explanation

def _parse_information(self, examples: list[ActivatingExample]):
# Extract relevant information from activating examples
activating_tokens = []
text_after_tokens = []
plain_examples = []
for example in examples:
# find non zero activated tokens
activated_idxs = torch.where(example.normalized_activations > 0)[0]
activated_list = activated_idxs.tolist()
activating_tokens.extend([example.str_tokens[i] for i in activated_list])
example.str_tokens.append("") # to avoid index error
text_after_tokens.extend(
[example.str_tokens[i + 1] for i in activated_list]
)

plain_examples.append(" ".join(example.str_tokens))

return activating_tokens, text_after_tokens, plain_examples

def _build_prompt( # type: ignore
self,
examples: list[ActivatingExample],
top_logits: list[str],
bot_logits: list[str],
prompt: str,
parent_explanations: list[str],
) -> list[dict]:
"""
Build a prompt with graph information

Args:
examples: List containing both activating and non-activating examples.
top_logits: List of top logits to include in the prompt.
bot_logits: List of bottom logits to include in the prompt.
prompt: The prompt to include in the message.
parent_explanations: List of explanations from parent nodes.

Returns:
A list of message dictionaries for the prompt.
"""
highlighted_examples = ["### INPUTS:"]
(activating_tokens, text_after_tokens, plain_examples) = (
self._parse_information(examples)
)
highlighted_examples.append("\nMAX_ACTIVATING_TOKENS:")
highlighted_examples.append(", ".join(activating_tokens))

highlighted_examples.append("\nTOKENS_AFTER_MAX_ACTIVATING_TOKEN:")
highlighted_examples.append(", ".join(text_after_tokens))

highlighted_examples.append("\nTOP_POSITIVE_LOGITS:")
highlighted_examples.append(", ".join(top_logits))

highlighted_examples.append("\nTOP_NEGATIVE_LOGITS:")
highlighted_examples.append(", ".join(bot_logits))

highlighted_examples.append("\nTOP_ACTIVATING_TEXT:")
highlighted_examples.append(", ".join(plain_examples))

# If there are parent explanations, add them to the prompt
if parent_explanations:
highlighted_examples.append("\nTOP_PARENT_EXPLANATIONS:")
highlighted_examples.append(", ".join(parent_explanations))

# If a prompt is provided, include it in the messages
if prompt:
highlighted_examples.append(f"\nGRAPH_PROMPT: \n{prompt}")

highlighted_examples.append("\n### OUTPUT:")
highlighted_examples.append("/no_think" if self.disable_thinking else "")
highlighted_examples_str = "\n".join(highlighted_examples)

# Create messages array with the system prompt
return [
{"role": "system", "content": SYSTEM_GRAPH},
{
"role": "user",
"content": highlighted_examples_str,
},
]

def call_sync(self, record):
"""Synchronous wrapper for the asynchronous __call__ method."""
return asyncio.run(self.__call__(record))

def _log_prompt(self, prompt, feature, output):
log_entry = {
"feature": feature,
"prompt": prompt,
"output": output,
}
with open("prompt_log.jsonl", "a") as f:
f.write(json.dumps(log_entry) + "\n")
45 changes: 45 additions & 0 deletions delphi/latents/constructors.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import hashlib
import json
import os
from pathlib import Path
from typing import Optional
Expand Down Expand Up @@ -235,6 +236,7 @@ def constructor(
tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
all_data: Optional[dict[int, ActivationData]] = None,
seed: int = 42,
logits_directory: Optional[os.PathLike] = None,
) -> LatentRecord | None:
cache_ctx_len = tokens.shape[1]
example_ctx_len = constructor_cfg.example_ctx_len
Expand All @@ -260,6 +262,9 @@ def constructor(
# per context frequency
record.per_context_frequency = len(unique_batch_pos) / n_windows

# add top/bottom logits if available
if logits_directory:
record = load_graph_info(record, logits_directory)
# Add activation examples to the record in place
if constructor_cfg.center_examples:
token_windows, act_windows = pool_max_activation_windows(
Expand Down Expand Up @@ -332,6 +337,46 @@ def constructor(
return record


def load_graph_info(record: LatentRecord, logits_directory: str) -> LatentRecord:
"""
Load top and bottom logits from a file based on the latent module name. Also loads
"""
import re

from delphi import logger

def cantor(num1, num2):
return (num1 + num2) * (num1 + num2 + 1) // 2 + num2

match = re.search(r"\d+", record.latent.module_name)
logits_file = ""
if match:
layer = int(match.group(0))
logits_file = (
f"{logits_directory}/{str(cantor(layer,record.latent.latent_index))}.json"
)

else:
logger.warning(
"Module name does not include layer number. Failed to load logits"
)
logits_file = ""

if os.path.exists(logits_file):
with open(logits_file, "r") as file:
data = json.load(file)
record.top_logits = data.get("top_logits", [])
record.bot_logits = data.get("bottom_logits", [])
record.parents = [
(f, i)
for (f, i) in data.get("parent_connections", [])
if abs(i) > 0.00001
]
else:
logger.warning("Could not find graph info file. Failed to load logits/parents")
return record


def create_token_key(tokens_tensor, ctx_len):
"""
Create a file key based on token tensors without detokenization.
Expand Down
Loading