Skip to content

Commit 0f2a9be

Browse files
committed
training loop
1 parent 23b7e2e commit 0f2a9be

File tree

1 file changed

+162
-0
lines changed

1 file changed

+162
-0
lines changed
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"id": "initial_id",
6+
"metadata": {
7+
"collapsed": true
8+
},
9+
"source": "# !wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt",
10+
"outputs": [],
11+
"execution_count": null
12+
},
13+
{
14+
"metadata": {},
15+
"cell_type": "code",
16+
"source": [
17+
"with open('input.txt', 'r', encoding='utf-8') as f:\n",
18+
" text = f.read()"
19+
],
20+
"id": "3b1e507015ba6b81",
21+
"outputs": [],
22+
"execution_count": null
23+
},
24+
{
25+
"metadata": {},
26+
"cell_type": "code",
27+
"source": [
28+
"from transformers import AutoTokenizer\n",
29+
"\n",
30+
"tokenizer = AutoTokenizer.from_pretrained(\"gpt2\")\n",
31+
"\n",
32+
"tokens = tokenizer.encode(text, add_special_tokens=False)"
33+
],
34+
"id": "ac8e51ae5bbfcae7",
35+
"outputs": [],
36+
"execution_count": null
37+
},
38+
{
39+
"metadata": {},
40+
"cell_type": "code",
41+
"source": [
42+
"context_length = 10\n",
43+
"batch_size = 64"
44+
],
45+
"id": "aeefcdf813e427e",
46+
"outputs": [],
47+
"execution_count": null
48+
},
49+
{
50+
"metadata": {},
51+
"cell_type": "code",
52+
"source": [
53+
"num_batches = len(tokens) // (batch_size * context_length)\n",
54+
"tokens = tokens[:num_batches * batch_size * context_length]"
55+
],
56+
"id": "a384b42274f008a2",
57+
"outputs": [],
58+
"execution_count": null
59+
},
60+
{
61+
"metadata": {},
62+
"cell_type": "code",
63+
"source": [
64+
"import torch\n",
65+
"\n",
66+
"input_ids = torch.tensor(tokens).view(-1, context_length)"
67+
],
68+
"id": "5c4cc78ac1a02c1d",
69+
"outputs": [],
70+
"execution_count": null
71+
},
72+
{
73+
"metadata": {},
74+
"cell_type": "code",
75+
"source": [
76+
"from torch.utils.data import DataLoader, TensorDataset\n",
77+
"from torch.optim import Adam\n",
78+
"print(input_ids.shape)\n",
79+
"dataset = TensorDataset(input_ids)\n",
80+
"dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)"
81+
],
82+
"id": "7037fd75e2161382",
83+
"outputs": [],
84+
"execution_count": null
85+
},
86+
{
87+
"metadata": {},
88+
"cell_type": "code",
89+
"source": [
90+
"from labml_nn.transformers.LoRA.GPT2 import GPTModel\n",
91+
"\n",
92+
"model = GPTModel()"
93+
],
94+
"id": "a98b7baa064b8494",
95+
"outputs": [],
96+
"execution_count": null
97+
},
98+
{
99+
"metadata": {},
100+
"cell_type": "code",
101+
"source": [
102+
"optimizer = Adam(model.parameters(), lr=5e-5)\n",
103+
"criterion = torch.nn.CrossEntropyLoss()\n",
104+
"\n",
105+
"model.eval()\n",
106+
"epochs = 3\n",
107+
"for epoch in range(epochs):\n",
108+
" for batch in dataloader:\n",
109+
" inputs = batch[0]\n",
110+
" labels = inputs.clone()\n",
111+
" \n",
112+
" outputs = model(inputs)\n",
113+
" \n",
114+
" shift_logits = outputs[..., :-1, :]\n",
115+
" shift_labels = labels[..., 1:]\n",
116+
" \n",
117+
" loss = criterion(shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.reshape(-1))\n",
118+
" \n",
119+
" optimizer.zero_grad()\n",
120+
" loss.backward()\n",
121+
" optimizer.step()\n",
122+
"\n",
123+
" print(f'Epoch: {epoch + 1}, Loss: {loss.item()}')\n",
124+
" break\n",
125+
"\n",
126+
"print(\"Training complete.\")"
127+
],
128+
"id": "e2f5076894770740",
129+
"outputs": [],
130+
"execution_count": null
131+
},
132+
{
133+
"metadata": {},
134+
"cell_type": "code",
135+
"source": "",
136+
"id": "da2d4023002648dc",
137+
"outputs": [],
138+
"execution_count": null
139+
}
140+
],
141+
"metadata": {
142+
"kernelspec": {
143+
"display_name": "Python (ml)",
144+
"language": "python",
145+
"name": "ml"
146+
},
147+
"language_info": {
148+
"codemirror_mode": {
149+
"name": "ipython",
150+
"version": 2
151+
},
152+
"file_extension": ".py",
153+
"mimetype": "text/x-python",
154+
"name": "python",
155+
"nbconvert_exporter": "python",
156+
"pygments_lexer": "ipython2",
157+
"version": "2.7.6"
158+
}
159+
},
160+
"nbformat": 4,
161+
"nbformat_minor": 5
162+
}

0 commit comments

Comments
 (0)