Skip to content

Commit 23b7e2e

Browse files
committed
create experiment notebook and refactoring
1 parent c82529c commit 23b7e2e

File tree

3 files changed

+129
-34
lines changed

3 files changed

+129
-34
lines changed

labml_nn/transformers/LoRA/GPT2.py

Lines changed: 4 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
}
1515

1616

17-
class HeadFFN(nn.Module): # todo rename
17+
class FFN(nn.Module):
1818
def __init__(self, dim):
1919
super().__init__()
2020
self.c_fc = nn.Linear(config['n_embd'], dim)
@@ -28,7 +28,7 @@ def forward(self, hidden_states):
2828
return hidden_states
2929

3030

31-
class MultiHead(nn.Module):
31+
class MultiHeadAttention(nn.Module):
3232
def __init__(self):
3333
super().__init__()
3434
self.embed_dim = config['n_embd']
@@ -65,7 +65,6 @@ def forward(self, hidden_states):
6565
is_causal=True, # for the triangular mask
6666
)
6767

68-
# todo why this?
6968
attn_output = attn_output.transpose(1, 2).contiguous()
7069
attn_output = attn_output.view(batch_size, seq_length, self.embed_dim)
7170

@@ -78,9 +77,9 @@ class Block(nn.Module):
7877
def __init__(self):
7978
super().__init__()
8079
self.pre_norm = nn.LayerNorm(config['n_embd'], eps=config['layer_norm_epsilon'])
81-
self.attn = MultiHead()
80+
self.attn = MultiHeadAttention()
8281
self.post_norm = nn.LayerNorm(config['n_embd'], eps=config['layer_norm_epsilon'])
83-
self.ffn = HeadFFN(config['n_embd'] * 4)
82+
self.ffn = FFN(config['n_embd'] * 4)
8483

8584
def forward(self, hidden_states):
8685
residual = hidden_states
@@ -98,7 +97,6 @@ def forward(self, hidden_states):
9897

9998

10099
class GPTModel(nn.Module):
101-
# todo ignored token type embeds, past key values
102100
def __init__(self):
103101
super().__init__()
104102

@@ -128,31 +126,3 @@ def forward(self, input_ids):
128126
logits = self.lm_head(hidden_states)
129127

130128
return logits
131-
132-
133-
model = GPTModel()
134-
135-
state_dict = torch.load('transformed.pth')
136-
137-
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
138-
if missing_keys:
139-
print(f"Missing keys: {missing_keys}")
140-
if unexpected_keys:
141-
print(f"Unexpected keys: {unexpected_keys}")
142-
143-
prompt = "hello how are you"
144-
tokenized = tokenizer(prompt, return_tensors="pt")
145-
146-
with torch.no_grad():
147-
model.eval()
148-
res = model(tokenized['input_ids'])
149-
150-
print(res)
151-
152-
output_ids = torch.argmax(res, dim=-1)
153-
154-
# Decode the token indices back to text
155-
output_text = tokenizer.decode(output_ids[0])
156-
157-
# Print the tokens of the output
158-
print(output_text)
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
{
2+
"cells": [
3+
{
4+
"metadata": {
5+
"ExecuteTime": {
6+
"end_time": "2024-07-29T07:14:27.781097Z",
7+
"start_time": "2024-07-29T07:14:24.819976Z"
8+
}
9+
},
10+
"cell_type": "code",
11+
"source": [
12+
"from labml_nn.transformers.LoRA.GPT2 import GPTModel\n",
13+
"import torch"
14+
],
15+
"id": "cffa3ec341b4905a",
16+
"outputs": [],
17+
"execution_count": 1
18+
},
19+
{
20+
"metadata": {
21+
"ExecuteTime": {
22+
"end_time": "2024-07-29T07:14:28.183960Z",
23+
"start_time": "2024-07-29T07:14:27.782683Z"
24+
}
25+
},
26+
"cell_type": "code",
27+
"source": [
28+
"from transformers import AutoTokenizer\n",
29+
"\n",
30+
"tokenizer = AutoTokenizer.from_pretrained(\"gpt2\")"
31+
],
32+
"id": "c2b0b7e18394ea9e",
33+
"outputs": [],
34+
"execution_count": 2
35+
},
36+
{
37+
"cell_type": "code",
38+
"id": "initial_id",
39+
"metadata": {
40+
"collapsed": true,
41+
"ExecuteTime": {
42+
"end_time": "2024-07-29T07:14:29.840925Z",
43+
"start_time": "2024-07-29T07:14:28.185080Z"
44+
}
45+
},
46+
"source": [
47+
"model = GPTModel()\n",
48+
"\n",
49+
"state_dict = torch.load('transformed.pth')\n",
50+
"\n",
51+
"missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)\n",
52+
"if missing_keys:\n",
53+
" print(f\"Missing keys: {missing_keys}\")\n",
54+
"if unexpected_keys:\n",
55+
" print(f\"Unexpected keys: {unexpected_keys}\")"
56+
],
57+
"outputs": [],
58+
"execution_count": 3
59+
},
60+
{
61+
"metadata": {
62+
"ExecuteTime": {
63+
"end_time": "2024-07-29T07:22:30.408855Z",
64+
"start_time": "2024-07-29T07:22:30.168376Z"
65+
}
66+
},
67+
"cell_type": "code",
68+
"source": [
69+
"prompt = \"hello how are you\"\n",
70+
"tokenized = tokenizer(prompt, return_tensors=\"pt\")\n",
71+
"\n",
72+
"with torch.no_grad():\n",
73+
" model.eval()\n",
74+
" res = model(tokenized['input_ids'])\n",
75+
"\n",
76+
"output_ids = torch.argmax(res, dim=-1)\n",
77+
"for id in output_ids[0]:\n",
78+
" print(tokenizer.decode(id))"
79+
],
80+
"id": "f4f7826ec3729b66",
81+
"outputs": [
82+
{
83+
"name": "stdout",
84+
"output_type": "stream",
85+
"text": [
86+
",\n",
87+
" to\n",
88+
" you\n",
89+
" doing\n"
90+
]
91+
}
92+
],
93+
"execution_count": 17
94+
},
95+
{
96+
"metadata": {},
97+
"cell_type": "code",
98+
"outputs": [],
99+
"execution_count": null,
100+
"source": "",
101+
"id": "c12776360008a974"
102+
}
103+
],
104+
"metadata": {
105+
"kernelspec": {
106+
"display_name": "Python (ml)",
107+
"language": "python",
108+
"name": "ml"
109+
},
110+
"language_info": {
111+
"codemirror_mode": {
112+
"name": "ipython",
113+
"version": 2
114+
},
115+
"file_extension": ".py",
116+
"mimetype": "text/x-python",
117+
"name": "python",
118+
"nbconvert_exporter": "python",
119+
"pygments_lexer": "ipython2",
120+
"version": "2.7.6"
121+
}
122+
},
123+
"nbformat": 4,
124+
"nbformat_minor": 5
125+
}

0 commit comments

Comments
 (0)