Skip to content

Commit 831eb33

Browse files
committed
collected activations now return as dict
1 parent 9ee5d73 commit 831eb33

File tree

3 files changed

+48
-15
lines changed

3 files changed

+48
-15
lines changed

pyvene/models/intervenable_base.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1101,8 +1101,11 @@ def _wait_for_forward_with_serial_intervention(
11011101
unit_locations_base = unit_locations[group_key][1]
11021102

11031103
if activations_sources != None:
1104-
for key in keys:
1105-
self.activations[key] = activations_sources[key]
1104+
for passed_in_key, v in activations_sources.items():
1105+
assert (
1106+
passed_in_key in self.sorted_keys
1107+
), f"{passed_in_key} not in {self.sorted_keys}, {unit_locations}"
1108+
self.activations[passed_in_key] = torch.clone(v)
11061109
else:
11071110
keys_with_source = [
11081111
k for i, k in enumerate(keys) if unit_locations_source[i] != None

pyvene_101.ipynb

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -126,10 +126,27 @@
126126
},
127127
{
128128
"cell_type": "code",
129-
"execution_count": 2,
129+
"execution_count": 1,
130130
"id": "17c7f2f6-b0d3-4fe2-8e4f-c044b93f3ef0",
131-
"metadata": {},
132-
"outputs": [],
131+
"metadata": {
132+
"metadata": {}
133+
},
134+
"outputs": [
135+
{
136+
"data": {
137+
"application/vnd.jupyter.widget-view+json": {
138+
"model_id": "fce745d6f2ca453b98f7b10868b1ab7d",
139+
"version_major": 2,
140+
"version_minor": 0
141+
},
142+
"text/plain": [
143+
"generation_config.json: 0%| | 0.00/124 [00:00<?, ?B/s]"
144+
]
145+
},
146+
"metadata": {},
147+
"output_type": "display_data"
148+
}
149+
],
133150
"source": [
134151
"import pyvene as pv\n",
135152
"from transformers import AutoTokenizer, AutoModelForCausalLM\n",
@@ -144,10 +161,10 @@
144161
" \"intervention_type\": pv.CollectIntervention}, model=gpt2)\n",
145162
"\n",
146163
"base = \"When John and Mary went to the shops, Mary gave the bag to\"\n",
147-
"collected_attn_w = pv_gpt2(\n",
164+
"(_, collected_attn_w), _ = pv_gpt2(\n",
148165
" base = tokenizer(base, return_tensors=\"pt\"\n",
149166
" ), unit_locations={\"base\": [h for h in range(12)]}\n",
150-
")[0][-1][0]"
167+
")"
151168
]
152169
},
153170
{
@@ -160,7 +177,7 @@
160177
},
161178
{
162179
"cell_type": "code",
163-
"execution_count": 5,
180+
"execution_count": 4,
164181
"id": "128be2dd-f089-4291-bfc5-7002d031b1e9",
165182
"metadata": {
166183
"metadata": {}
@@ -171,7 +188,7 @@
171188
"output_type": "stream",
172189
"text": [
173190
"loaded GPT2 model gpt2\n",
174-
"torch.Size([12, 14, 14])\n"
191+
"torch.Size([1, 12, 14, 14])\n"
175192
]
176193
}
177194
],
@@ -193,6 +210,7 @@
193210
" base = tokenizer(base, return_tensors=\"pt\"\n",
194211
" ), unit_locations={\"base\": [h for h in range(12)]}\n",
195212
")\n",
213+
"collected_attn_w = torch.stack(list(collected_attn_w.values()))\n",
196214
"print(collected_attn_w[0].shape)"
197215
]
198216
},
@@ -206,16 +224,28 @@
206224
},
207225
{
208226
"cell_type": "code",
209-
"execution_count": 22,
227+
"execution_count": 5,
210228
"id": "678dc46f",
211-
"metadata": {},
229+
"metadata": {
230+
"metadata": {}
231+
},
212232
"outputs": [
213233
{
214234
"name": "stdout",
215235
"output_type": "stream",
216236
"text": [
217-
"loaded model\n"
237+
"loaded GPT2 model gpt2\n"
218238
]
239+
},
240+
{
241+
"data": {
242+
"text/plain": [
243+
"True"
244+
]
245+
},
246+
"execution_count": 5,
247+
"metadata": {},
248+
"output_type": "execute_result"
219249
}
220250
],
221251
"source": [

tests/integration_tests/IntervenableBasicTestCase.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -591,10 +591,10 @@ def test_customized_intervention_function_get(self):
591591
)
592592

593593
base = "When John and Mary went to the shops, Mary gave the bag to"
594-
collected_attn_w = pv_gpt2(
594+
(_, collected_attn_w), _ = pv_gpt2(
595595
base=tokenizer(base, return_tensors="pt"),
596596
unit_locations={"base": [h for h in range(12)]},
597-
)[0][-1][0]
597+
)
598598

599599
cached_w = {}
600600

@@ -608,7 +608,7 @@ def pv_patcher(b, s):
608608

609609
base = "When John and Mary went to the shops, Mary gave the bag to"
610610
_ = pv_gpt2(tokenizer(base, return_tensors="pt"))
611-
torch.allclose(collected_attn_w, cached_w["attn_w"].unsqueeze(dim=0))
611+
torch.allclose(list(collected_attn_w.values())[0], cached_w["attn_w"].unsqueeze(dim=0))
612612

613613
def test_customized_intervention_function_zeroout(self):
614614

0 commit comments

Comments
 (0)