From b1074b5777bd1de1e002cd894bf70e959349e166 Mon Sep 17 00:00:00 2001 From: Julian Assmann Date: Mon, 14 Jul 2025 12:05:26 +0200 Subject: [PATCH 1/4] Add optional interactive token visualization to function `attention_heads` --- python/Demonstration.ipynb | 14011 +++++++++++++++++++---- python/circuitsvis/attention.py | 29 + react/src/attention/AttentionHeads.tsx | 75 +- 3 files changed, 12042 insertions(+), 2073 deletions(-) diff --git a/python/Demonstration.ipynb b/python/Demonstration.ipynb index 52ac7c12..a39d648d 100644 --- a/python/Demonstration.ipynb +++ b/python/Demonstration.ipynb @@ -4,7 +4,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# CircuitsVis Demonstration" + "# CircuitsVis Demonstration\n" ] }, { @@ -15,7 +15,7 @@ "source": [ "## Setup/Imports\n", "\n", - "__Note:__ To run Jupyter directly within this repo, you may need to run `poetry run pip install jupyter`." + "**Note:** To run Jupyter directly within this repo, you may need to run `poetry run pip install jupyter`.\n" ] }, { @@ -27,8 +27,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "The autoreload extension is already loaded. To reload it, use:\n", - " %reload_ext autoreload\n" + "In dev mode: True\n" ] } ], @@ -39,96 +38,9984 @@ "\n", "# Imports\n", "import numpy as np\n", - "from circuitsvis.attention import attention_patterns, attention_pattern\n", + "from circuitsvis.attention import attention_patterns, attention_pattern, attention_heads\n", "from circuitsvis.activations import text_neuron_activations\n", "from circuitsvis.examples import hello\n", "from circuitsvis.tokens import colored_tokens\n", "from circuitsvis.topk_tokens import topk_tokens\n", - "from circuitsvis.topk_samples import topk_samples" + "from circuitsvis.topk_samples import topk_samples\n", + "\n", + "from circuitsvis.utils.render import is_in_dev_mode\n", + "print(\"In dev mode:\", is_in_dev_mode())" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "tags": [] + }, + "source": [ + "## Built In Visualizations\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Activations\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Text Neuron Activations (single sample)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + " " + ], + "text/plain": [ + "" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tokens = ['Hi', ' and', ' welcome', ' to', ' the', ' Attention', ' Patterns', ' example']\n", + "n_layers = 3\n", + "n_neurons_per_layer = 4\n", + "activations = np.random.normal(size=(len(tokens), n_layers, n_neurons_per_layer))\n", + "activations = np.exp(activations) / np.exp(activations).sum(axis=0, keepdims=True)\n", + "text_neuron_activations(tokens=tokens, activations=activations)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Text Neuron Activations (multiple samples)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + " " + ], + "text/plain": [ + "" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tokens = [['Hi', ' and', ' welcome', ' to', ' the', ' Attention', ' Patterns', ' example'], ['This', ' is', ' another', ' example', ' of', ' colored', ' tokens'], ['And', ' here', ' another', ' example', ' of', ' colored', ' tokens', ' with', ' more', ' words.'], ['This', ' is', ' another', ' example', ' of', ' tokens.']]\n", + "n_layers = 3\n", + "n_neurons_per_layer = 4\n", + "activations = []\n", + "for sample in tokens:\n", + " sample_activations = np.random.normal(size=(len(sample), n_layers, n_neurons_per_layer)) * 5\n", + " activations.append(sample_activations)\n", + "text_neuron_activations(tokens=tokens, activations=activations)" ] }, { "cell_type": "markdown", - "metadata": { - "tags": [] - }, + "metadata": {}, "source": [ - "## Built In Visualizations" + "### Attention\n" ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": 25, "metadata": {}, + "outputs": [], "source": [ - "### Activations" + "def normalize_attention(attention):\n", + " \"\"\"Apply causal mask and normalize attention to realistic values\"\"\"\n", + " # Apply causal mask (lower triangular)\n", + " attention = np.tril(attention)\n", + "\n", + " # Add small amount of realistic noise\n", + " noise = np.random.normal(0, 0.02, attention.shape)\n", + " attention += noise\n", + " attention = np.maximum(attention, 0) # Ensure non-negative\n", + "\n", + " # Normalize rows to sum to 1 (like softmax)\n", + " row_sums = attention.sum(axis=1, keepdims=True)\n", + " row_sums[row_sums == 0] = 1 # Avoid division by zero\n", + " attention = attention / row_sums\n", + "\n", + " return attention\n", + "\n", + "def create_previous_token_head(n_tokens):\n", + " \"\"\"Previous token head: attends to the immediately previous token\"\"\"\n", + " attention = np.zeros((n_tokens, n_tokens))\n", + " for i in range(1, n_tokens):\n", + " attention[i, i-1] = 0.8 # Strong attention to previous token\n", + " if i >= 2:\n", + " attention[i, i-2] = 0.2 # Weak attention to token before that\n", + " return normalize_attention(attention)\n", + "\n", + "def create_first_token_head(n_tokens):\n", + " \"\"\"First token head: always attends to the first token (BOS-like behavior)\"\"\"\n", + " attention = np.zeros((n_tokens, n_tokens))\n", + " attention[:, 0] = 1.0 # All positions attend to first token\n", + " return attention\n", + "\n", + "def create_induction_head(n_tokens):\n", + " \"\"\"Simulated induction head: attends to tokens that came after similar contexts\"\"\"\n", + " attention = np.zeros((n_tokens, n_tokens))\n", + " # Simulate pattern where positions attend to what came after previous similar contexts\n", + " for i in range(2, n_tokens):\n", + " if i >= 3:\n", + " # Create induction-like pattern: attend to tokens 2-3 positions back\n", + " attention[i, max(0, i-3):i-1] = np.linspace(0.3, 0.7, min(2, i-1))\n", + " elif i == 2:\n", + " attention[i, 0] = 1.0\n", + " return normalize_attention(attention)\n", + "\n", + "def create_local_attention_head(n_tokens, window=2):\n", + " \"\"\"Local attention head: attends to nearby tokens within a window\"\"\"\n", + " attention = np.zeros((n_tokens, n_tokens))\n", + " for i in range(n_tokens):\n", + " start = max(0, i - window)\n", + " end = min(n_tokens, i + 1) # Causal: can't attend to future\n", + " # Uniform attention within the window\n", + " attention[i, start:end] = 1.0 / (end - start)\n", + " return normalize_attention(attention)" ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ - "#### Text Neuron Activations (single sample)" + "#### Attention Heads\n" ] }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 30, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
\n", + "
\n", " " ], "text/plain": [ - "" + "" ] }, - "execution_count": 2, + "execution_count": 30, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tokens = ['Hi', ' and', ' welcome', ' to', ' the', ' Attention', ' Patterns', ' example']\n", - "n_layers = 3\n", - "n_neurons_per_layer = 4\n", - "activations = np.random.normal(size=(len(tokens), n_layers, n_neurons_per_layer))\n", - "activations = np.exp(activations) / np.exp(activations).sum(axis=0, keepdims=True) \n", - "text_neuron_activations(tokens=tokens, activations=activations)" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Text Neuron Activations (multiple samples)" + "n_tokens = len(tokens)\n", + "single_head_attention = create_previous_token_head(n_tokens)\n", + "attention_heads(tokens=tokens, attention=single_head_attention, attention_head_names=['Previous Token Head'], show_tokens=True)" ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
\n", + "
\n", " " ], "text/plain": [ - "" + "" ] }, - "execution_count": 3, + "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "tokens = [['Hi', ' and', ' welcome', ' to', ' the', ' Attention', ' Patterns', ' example'], ['This', ' is', ' another', ' example', ' of', ' colored', ' tokens'], ['And', ' here', ' another', ' example', ' of', ' colored', ' tokens', ' with', ' more', ' words.'], ['This', ' is', ' another', ' example', ' of', ' tokens.']]\n", - "n_layers = 3\n", - "n_neurons_per_layer = 4\n", - "activations = []\n", - "for sample in tokens:\n", - " sample_activations = np.random.normal(size=(len(sample), n_layers, n_neurons_per_layer)) * 5\n", - " activations.append(sample_activations)\n", - "text_neuron_activations(tokens=tokens, activations=activations)\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Attention" + "heads_info = [\n", + " (\"Previous Token\", create_previous_token_head(n_tokens)),\n", + " (\"Local Context (n=4)\", create_local_attention_head(n_tokens, window=4)),\n", + " (\"First Token\", create_first_token_head(n_tokens)),\n", + " (\"Induction\", create_induction_head(n_tokens)),\n", + " (\"Local Context (n=2)\", create_local_attention_head(n_tokens, window=2)),\n", + "]\n", + "\n", + "multi_head_attention = np.stack([head[1] for head in heads_info], axis=0)\n", + "head_names = [head[0] for head in heads_info]\n", + "\n", + "attention_heads(tokens=tokens, attention=multi_head_attention, attention_head_names=head_names, show_tokens=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "#### Attention Pattern (single head)" + "#### Attention Pattern (single head) [deprecated]\n" ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
\n", + "
\n", " " ], "text/plain": [ - "" + "" ] }, - "execution_count": 4, + "execution_count": 22, "metadata": {}, "output_type": "execute_result" } @@ -14797,67 +24670,67 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "#### Attention Patterns" + "#### Attention Patterns [deprecated]\n" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
\n", + "
\n", " " ], "text/plain": [ - "" + "" ] }, - "execution_count": 5, + "execution_count": 23, "metadata": {}, "output_type": "execute_result" } @@ -19701,74 +29574,74 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Tokens" + "### Tokens\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "#### Colored Tokens" + "#### Colored Tokens\n" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
\n", + "
\n", " " ], "text/plain": [ - "" + "" ] }, - "execution_count": 6, + "execution_count": 24, "metadata": {}, "output_type": "execute_result" } @@ -24613,67 +34486,67 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Topk Tokens Table" + "### Topk Tokens Table\n" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
\n", + "
\n", " " ], "text/plain": [ - "" + "" ] }, - "execution_count": 7, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -29527,67 +39400,67 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Topk Samples" + "### Topk Samples\n" ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
\n", + "
\n", " " ], "text/plain": [ - "" + "" ] }, - "execution_count": 8, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -34437,25 +44310,24 @@ "activations = []\n", "for neuron in range(len(tokens)):\n", " neuron_acts = []\n", - " \n", + "\n", " for k in range(len(tokens[0])):\n", " acts = (np.random.normal(size=(len(tokens[neuron][k]))) * 5).tolist()\n", " neuron_acts.append(acts)\n", " activations.append(neuron_acts)\n", - " \n", + "\n", "# Assume we have an arbitrary selection of neurons\n", "neuron_labels = [2, 7, 9]\n", "# Wrap tokens and activations in an outer list to represent the single layer\n", - "topk_samples(tokens=[tokens], activations=[activations], zeroth_dimension_name=\"Layer\", first_dimension_name=\"Neuron\", first_dimension_labels=neuron_labels)\n", - "\n" + "topk_samples(tokens=[tokens], activations=[activations], zeroth_dimension_name=\"Layer\", first_dimension_name=\"Neuron\", first_dimension_labels=neuron_labels)" ] } ], "metadata": { "kernelspec": { - "display_name": "circuitsvis-env", + "display_name": ".venv", "language": "python", - "name": "circuitsvis-env" + "name": "python3" }, "language_info": { "codemirror_mode": { @@ -34467,12 +44339,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.8" - }, - "vscode": { - "interpreter": { - "hash": "ada5ea967828749ea6c7f5c93ea14cd73d82db7939f837b7070fa8806da132ee" - } + "version": "3.12.3" } }, "nbformat": 4, diff --git a/python/circuitsvis/attention.py b/python/circuitsvis/attention.py index 3da78d4e..ed26b36e 100644 --- a/python/circuitsvis/attention.py +++ b/python/circuitsvis/attention.py @@ -3,6 +3,7 @@ import numpy as np import torch + from circuitsvis.utils.render import RenderedHTML, render @@ -15,6 +16,7 @@ def attention_heads( negative_color: Optional[str] = None, positive_color: Optional[str] = None, mask_upper_tri: Optional[bool] = None, + show_tokens: Optional[bool] = True, ) -> RenderedHTML: """Attention Heads @@ -41,10 +43,36 @@ def attention_heads( mask_upper_tri: Whether or not to mask the upper triangular portion of the attention patterns. Should be true for causal attention, false for bidirectional attention. + attention: Attention head activations of the shape [heads x dest_tokens x src_tokens] + or [dest_tokens x src_tokens] (will be expanded to single head) Returns: Html: Attention pattern visualization """ + + # Convert attention to numpy array + if isinstance(attention, torch.Tensor): + attention = attention.detach().cpu().numpy() + elif not isinstance(attention, np.ndarray): + attention = np.array(attention) + + # Ensure attention is 3D (num_heads, dest_len, src_len) + if attention.ndim == 2: + attention = attention[np.newaxis, :, :] + elif attention.ndim != 3: + raise ValueError( + f"Attention tensor must be 2D or 3D, got {attention.ndim}D tensor." + ) + + num_heads, dest_len, src_len = attention.shape + + # Validate token count matches attention dimensions + if len(tokens) != dest_len or len(tokens) != src_len: + raise ValueError( + f"Token count ({len(tokens)}) doesn't match attention dimensions " + f"(dest: {dest_len}, src: {src_len}). For causal attention, these should all be equal." + ) + kwargs = { "attention": attention, "attentionHeadNames": attention_head_names, @@ -54,6 +82,7 @@ def attention_heads( "positiveColor": positive_color, "tokens": tokens, "maskUpperTri": mask_upper_tri, + "showTokens": show_tokens, } return render( diff --git a/react/src/attention/AttentionHeads.tsx b/react/src/attention/AttentionHeads.tsx index 67251047..f9a0df2a 100644 --- a/react/src/attention/AttentionHeads.tsx +++ b/react/src/attention/AttentionHeads.tsx @@ -1,6 +1,8 @@ -import React from "react"; +import React, { useMemo, useState } from "react"; import { Col, Container, Row } from "react-grid-system"; import { AttentionPattern } from "./AttentionPattern"; +import { colorAttentionTensors } from "./AttentionPatterns"; +import { Tokens, TokensView } from "./components/AttentionTokens"; import { useHoverLock, UseHoverLockState } from "./components/useHoverLock"; /** @@ -115,14 +117,40 @@ export function AttentionHeads({ negativeColor, positiveColor, maskUpperTri = true, + showTokens = true, tokens }: AttentionHeadsProps) { // Attention head focussed state const { focused, onClick, onMouseEnter, onMouseLeave } = useHoverLock(0); + // State for the token view type + const [tokensView, setTokensView] = useState( + TokensView.DESTINATION_TO_SOURCE + ); + + // State for which token is focussed + const { + focused: focussedToken, + onClick: onClickToken, + onMouseEnter: onMouseEnterToken, + onMouseLeave: onMouseLeaveToken + } = useHoverLock(); + const headNames = attentionHeadNames || attention.map((_, idx) => `Head ${idx}`); + // Color the attention values (by head) for interactive tokens + const coloredAttention = useMemo(() => { + if (!showTokens || !attention || attention.length === 0) return null; + const numHeads = attention.length; + const numDestTokens = attention[0]?.length || 0; + const numSrcTokens = attention[0]?.[0]?.length || 0; + + if (numDestTokens === 0 || numSrcTokens === 0 || numHeads === 0) + return null; + return colorAttentionTensors(attention); + }, [attention, showTokens]); + return (

@@ -176,6 +204,42 @@ export function AttentionHeads({ + {showTokens && coloredAttention && ( + + +
+

+ Tokens + (click to focus) +

+ +
+ +
+
+ +
+ )} + ); @@ -262,6 +326,15 @@ export interface AttentionHeadsProps { */ showAxisLabels?: boolean; + /** + * Show interactive tokens + * + * Whether to show interactive token visualization where hovering over tokens shows attention strength to other tokens. + * + * @default true + */ + showTokens?: boolean; + /** * List of tokens * From 594560b1b083daeb3d7681099c5236228355efd6 Mon Sep 17 00:00:00 2001 From: Julian Assmann Date: Mon, 14 Jul 2025 12:22:47 +0200 Subject: [PATCH 2/4] Updated attention_heads function to include match_color parameter for visual consistency in attention patterns and token visualization --- python/Demonstration.ipynb | 2455 ++++++++++++------------ python/circuitsvis/attention.py | 14 +- react/src/attention/AttentionHeads.tsx | 28 +- 3 files changed, 1261 insertions(+), 1236 deletions(-) diff --git a/python/Demonstration.ipynb b/python/Demonstration.ipynb index a39d648d..b3a49bc7 100644 --- a/python/Demonstration.ipynb +++ b/python/Demonstration.ipynb @@ -20,13 +20,15 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ + "The autoreload extension is already loaded. To reload it, use:\n", + " %reload_ext autoreload\n", "In dev mode: True\n" ] } @@ -40,7 +42,6 @@ "import numpy as np\n", "from circuitsvis.attention import attention_patterns, attention_pattern, attention_heads\n", "from circuitsvis.activations import text_neuron_activations\n", - "from circuitsvis.examples import hello\n", "from circuitsvis.tokens import colored_tokens\n", "from circuitsvis.topk_tokens import topk_tokens\n", "from circuitsvis.topk_samples import topk_samples\n", @@ -75,62 +76,62 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
\n", + "
\n", " " ], "text/plain": [ - "" + "" ] }, - "execution_count": 2, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } @@ -4983,62 +4984,62 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
\n", + "
\n", " " ], "text/plain": [ - "" + "" ] }, - "execution_count": 3, + "execution_count": 14, "metadata": {}, "output_type": "execute_result" } @@ -9892,7 +9893,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 15, "metadata": {}, "outputs": [], "source": [ @@ -9960,62 +9961,62 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
\n", + "
\n", " " ], "text/plain": [ - "" + "" ] }, - "execution_count": 30, + "execution_count": 16, "metadata": {}, "output_type": "execute_result" } @@ -14858,62 +14859,62 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
\n", + "
\n", " " ], "text/plain": [ - "" + "" ] }, - "execution_count": 28, + "execution_count": 17, "metadata": {}, "output_type": "execute_result" } @@ -19771,62 +19772,62 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
\n", + "
\n", " " ], "text/plain": [ - "" + "" ] }, - "execution_count": 22, + "execution_count": 18, "metadata": {}, "output_type": "execute_result" } @@ -24675,62 +24676,62 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
\n", + "
\n", " " ], "text/plain": [ - "" + "" ] }, - "execution_count": 23, + "execution_count": 19, "metadata": {}, "output_type": "execute_result" } @@ -29586,62 +29587,62 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
\n", + "
\n", " " ], "text/plain": [ - "" + "" ] }, - "execution_count": 24, + "execution_count": 20, "metadata": {}, "output_type": "execute_result" } @@ -34491,62 +34492,62 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
\n", + "
\n", " " ], "text/plain": [ - "" + "" ] }, - "execution_count": 9, + "execution_count": 21, "metadata": {}, "output_type": "execute_result" } @@ -39405,62 +39406,62 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
\n", + "
\n", " " ], "text/plain": [ - "" + "" ] }, - "execution_count": 10, + "execution_count": 22, "metadata": {}, "output_type": "execute_result" } diff --git a/python/circuitsvis/attention.py b/python/circuitsvis/attention.py index ed26b36e..34b68d37 100644 --- a/python/circuitsvis/attention.py +++ b/python/circuitsvis/attention.py @@ -16,7 +16,8 @@ def attention_heads( negative_color: Optional[str] = None, positive_color: Optional[str] = None, mask_upper_tri: Optional[bool] = None, - show_tokens: Optional[bool] = True, + show_tokens: Optional[bool] = None, + match_color: Optional[bool] = None, ) -> RenderedHTML: """Attention Heads @@ -26,8 +27,8 @@ def attention_heads( is then shown in full size. Args: - attention: Attention head activations of the shape [dest_tokens x - src_tokens] + attention: Attention head activations of the shape [heads x dest_tokens x src_tokens] + or [dest_tokens x src_tokens] (will be expanded to single head) tokens: List of tokens (e.g. `["A", "person"]`). Must be the same length as the list of values. max_value: Maximum value. Used to determine how dark the token color is @@ -43,8 +44,10 @@ def attention_heads( mask_upper_tri: Whether or not to mask the upper triangular portion of the attention patterns. Should be true for causal attention, false for bidirectional attention. - attention: Attention head activations of the shape [heads x dest_tokens x src_tokens] - or [dest_tokens x src_tokens] (will be expanded to single head) + show_tokens: Whether to show interactive token visualization where + hovering over tokens shows attention strength to other tokens. + match_color: Whether to match colors between attention patterns, token + visualization, and head headers for visual consistency. Returns: Html: Attention pattern visualization @@ -83,6 +86,7 @@ def attention_heads( "tokens": tokens, "maskUpperTri": mask_upper_tri, "showTokens": show_tokens, + "matchColor": match_color, } return render( diff --git a/react/src/attention/AttentionHeads.tsx b/react/src/attention/AttentionHeads.tsx index f9a0df2a..abc43893 100644 --- a/react/src/attention/AttentionHeads.tsx +++ b/react/src/attention/AttentionHeads.tsx @@ -37,6 +37,7 @@ export function AttentionHeadsSelector({ onMouseLeave, positiveColor, maskUpperTri, + matchColor, tokens }: AttentionHeadsProps & { attentionHeadNames: string[]; @@ -90,8 +91,12 @@ export function AttentionHeadsSelector({ showAxisLabels={false} maxValue={maxValue} minValue={minValue} - negativeColor={negativeColor} - positiveColor={positiveColor} + negativeColor={matchColor ? undefined : negativeColor} + positiveColor={ + matchColor + ? attentionHeadColor(idx, attention.length) + : positiveColor + } maskUpperTri={maskUpperTri} />
@@ -118,6 +123,7 @@ export function AttentionHeads({ positiveColor, maskUpperTri = true, showTokens = true, + matchColor = false, tokens }: AttentionHeadsProps) { // Attention head focussed state @@ -169,6 +175,7 @@ export function AttentionHeads({ onMouseLeave={onMouseLeave} positiveColor={positiveColor} maskUpperTri={maskUpperTri} + matchColor={matchColor} tokens={tokens} /> @@ -194,8 +201,12 @@ export function AttentionHeads({ attention={attention[focused]} maxValue={maxValue} minValue={minValue} - negativeColor={negativeColor} - positiveColor={positiveColor} + negativeColor={matchColor ? undefined : negativeColor} + positiveColor={ + matchColor + ? attentionHeadColor(focused, attention.length) + : positiveColor + } zoomed={true} maskUpperTri={maskUpperTri} tokens={tokens} @@ -335,6 +346,15 @@ export interface AttentionHeadsProps { */ showTokens?: boolean; + /** + * Match colors + * + * Whether to match colors between attention patterns, token visualization, and head headers for visual consistency. + * + * @default true + */ + matchColor?: boolean; + /** * List of tokens * From 3d5c5927222e9d6868cc83962e1fe61b47e0a73e Mon Sep 17 00:00:00 2001 From: Julian Assmann Date: Mon, 14 Jul 2025 12:30:16 +0200 Subject: [PATCH 3/4] Update Demonstration.ipynb to showcase color matching --- python/Demonstration.ipynb | 213 +++++++++++++------------------------ 1 file changed, 72 insertions(+), 141 deletions(-) diff --git a/python/Demonstration.ipynb b/python/Demonstration.ipynb index b3a49bc7..84e9cd18 100644 --- a/python/Demonstration.ipynb +++ b/python/Demonstration.ipynb @@ -20,15 +20,13 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "The autoreload extension is already loaded. To reload it, use:\n", - " %reload_ext autoreload\n", "In dev mode: True\n" ] } @@ -76,13 +74,13 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
\n", + "
\n", " " ], "text/plain": [ - "" + "" ] }, - "execution_count": 13, + "execution_count": 2, "metadata": {}, "output_type": "execute_result" } @@ -4984,13 +4982,13 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
\n", + "
\n", " " ], "text/plain": [ - "" + "" ] }, - "execution_count": 14, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" } @@ -9891,67 +9889,6 @@ "### Attention\n" ] }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": {}, - "outputs": [], - "source": [ - "def normalize_attention(attention):\n", - " \"\"\"Apply causal mask and normalize attention to realistic values\"\"\"\n", - " # Apply causal mask (lower triangular)\n", - " attention = np.tril(attention)\n", - "\n", - " # Add small amount of realistic noise\n", - " noise = np.random.normal(0, 0.02, attention.shape)\n", - " attention += noise\n", - " attention = np.maximum(attention, 0) # Ensure non-negative\n", - "\n", - " # Normalize rows to sum to 1 (like softmax)\n", - " row_sums = attention.sum(axis=1, keepdims=True)\n", - " row_sums[row_sums == 0] = 1 # Avoid division by zero\n", - " attention = attention / row_sums\n", - "\n", - " return attention\n", - "\n", - "def create_previous_token_head(n_tokens):\n", - " \"\"\"Previous token head: attends to the immediately previous token\"\"\"\n", - " attention = np.zeros((n_tokens, n_tokens))\n", - " for i in range(1, n_tokens):\n", - " attention[i, i-1] = 0.8 # Strong attention to previous token\n", - " if i >= 2:\n", - " attention[i, i-2] = 0.2 # Weak attention to token before that\n", - " return normalize_attention(attention)\n", - "\n", - "def create_first_token_head(n_tokens):\n", - " \"\"\"First token head: always attends to the first token (BOS-like behavior)\"\"\"\n", - " attention = np.zeros((n_tokens, n_tokens))\n", - " attention[:, 0] = 1.0 # All positions attend to first token\n", - " return attention\n", - "\n", - "def create_induction_head(n_tokens):\n", - " \"\"\"Simulated induction head: attends to tokens that came after similar contexts\"\"\"\n", - " attention = np.zeros((n_tokens, n_tokens))\n", - " # Simulate pattern where positions attend to what came after previous similar contexts\n", - " for i in range(2, n_tokens):\n", - " if i >= 3:\n", - " # Create induction-like pattern: attend to tokens 2-3 positions back\n", - " attention[i, max(0, i-3):i-1] = np.linspace(0.3, 0.7, min(2, i-1))\n", - " elif i == 2:\n", - " attention[i, 0] = 1.0\n", - " return normalize_attention(attention)\n", - "\n", - "def create_local_attention_head(n_tokens, window=2):\n", - " \"\"\"Local attention head: attends to nearby tokens within a window\"\"\"\n", - " attention = np.zeros((n_tokens, n_tokens))\n", - " for i in range(n_tokens):\n", - " start = max(0, i - window)\n", - " end = min(n_tokens, i + 1) # Causal: can't attend to future\n", - " # Uniform attention within the window\n", - " attention[i, start:end] = 1.0 / (end - start)\n", - " return normalize_attention(attention)" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -9961,13 +9898,13 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
\n", + "
\n", " " ], "text/plain": [ - "" + "" ] }, - "execution_count": 16, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } @@ -14853,19 +14790,19 @@ "source": [ "tokens = ['Hi', ' and', ' welcome', ' to', ' the', ' Attention', ' Patterns', ' example']\n", "n_tokens = len(tokens)\n", - "single_head_attention = create_previous_token_head(n_tokens)\n", - "attention_heads(tokens=tokens, attention=single_head_attention, attention_head_names=['Previous Token Head'], show_tokens=True)" + "single_head_attention = np.tril(np.random.normal(loc=0.3, scale=0.2, size=(n_tokens, n_tokens)))\n", + "attention_heads(tokens=tokens, attention=single_head_attention, show_tokens=True)" ] }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
\n", + "
\n", " " ], "text/plain": [ - "" + "" ] }, - "execution_count": 17, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "heads_info = [\n", - " (\"Previous Token\", create_previous_token_head(n_tokens)),\n", - " (\"Local Context (n=4)\", create_local_attention_head(n_tokens, window=4)),\n", - " (\"First Token\", create_first_token_head(n_tokens)),\n", - " (\"Induction\", create_induction_head(n_tokens)),\n", - " (\"Local Context (n=2)\", create_local_attention_head(n_tokens, window=2)),\n", - "]\n", - "\n", - "multi_head_attention = np.stack([head[1] for head in heads_info], axis=0)\n", - "head_names = [head[0] for head in heads_info]\n", + "n_layers = 3\n", + "n_heads = 4\n", "\n", - "attention_heads(tokens=tokens, attention=multi_head_attention, attention_head_names=head_names, show_tokens=True)" + "multi_head_attention = np.tril(np.random.normal(loc=0.3, scale=0.2, size=(n_layers * n_heads, n_tokens, n_tokens)))\n", + "head_names = [f\"L{layer}H{head}\" for layer in range(n_layers) for head in range(n_heads)]\n", + "attention_heads(tokens=tokens, attention=multi_head_attention, attention_head_names=head_names, show_tokens=True, match_color=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "#### Attention Pattern (single head) [deprecated]\n" + "#### Attention Pattern (single head)\n" ] }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
\n", + "
\n", " " ], "text/plain": [ - "" + "" ] }, - "execution_count": 18, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -24676,13 +24607,13 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
\n", + "
\n", " " ], "text/plain": [ - "" + "" ] }, - "execution_count": 19, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } @@ -29587,13 +29518,13 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
\n", + "
\n", " " ], "text/plain": [ - "" + "" ] }, - "execution_count": 20, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } @@ -34492,13 +34423,13 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
\n", + "
\n", " " ], "text/plain": [ - "" + "" ] }, - "execution_count": 21, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -39406,13 +39337,13 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
\n", + "
\n", " " ], "text/plain": [ - "" + "" ] }, - "execution_count": 22, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } From 6668e569dddcdfadfbc8218af43a7a57a1aeb9b3 Mon Sep 17 00:00:00 2001 From: Julian Assmann Date: Mon, 14 Jul 2025 12:52:14 +0200 Subject: [PATCH 4/4] Fix mypy error --- python/circuitsvis/attention.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/python/circuitsvis/attention.py b/python/circuitsvis/attention.py index 34b68d37..317902b9 100644 --- a/python/circuitsvis/attention.py +++ b/python/circuitsvis/attention.py @@ -53,21 +53,24 @@ def attention_heads( Html: Attention pattern visualization """ - # Convert attention to numpy array + # Convert attention to numpy array if it's not already + attention_np: np.ndarray if isinstance(attention, torch.Tensor): - attention = attention.detach().cpu().numpy() - elif not isinstance(attention, np.ndarray): - attention = np.array(attention) + attention_np = attention.detach().cpu().numpy() + elif isinstance(attention, np.ndarray): + attention_np = attention + else: + attention_np = np.array(attention) # Ensure attention is 3D (num_heads, dest_len, src_len) - if attention.ndim == 2: - attention = attention[np.newaxis, :, :] - elif attention.ndim != 3: + if attention_np.ndim == 2: + attention_np = attention_np[np.newaxis, :, :] + elif attention_np.ndim != 3: raise ValueError( - f"Attention tensor must be 2D or 3D, got {attention.ndim}D tensor." + f"Attention tensor must be 2D or 3D, got {attention_np.ndim}D tensor." ) - num_heads, dest_len, src_len = attention.shape + num_heads, dest_len, src_len = attention_np.shape # Validate token count matches attention dimensions if len(tokens) != dest_len or len(tokens) != src_len: @@ -77,7 +80,7 @@ def attention_heads( ) kwargs = { - "attention": attention, + "attention": attention_np, "attentionHeadNames": attention_head_names, "maxValue": max_value, "minValue": min_value,