Skip to content

Commit 957ade6

Browse files
authored
Merge pull request #266 from lakshith-403/LoRA
2 parents 89a3ae8 + bc32b50 commit 957ade6

File tree

5 files changed

+554
-0
lines changed

5 files changed

+554
-0
lines changed

labml_nn/transformers/LoRA/GPT2.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
import torch
2+
import torch.nn as nn
3+
from transformers import AutoTokenizer
4+
from labml_nn.transformers.LoRA import Linear, Embedding
5+
6+
tokenizer = AutoTokenizer.from_pretrained("gpt2")
7+
8+
config = {
9+
"layer_norm_epsilon": 1e-05,
10+
"n_embd": 768,
11+
"n_head": 12,
12+
"n_layer": 12,
13+
"n_positions": 1024,
14+
"vocab_size": 50257,
15+
"device": "cuda"
16+
}
17+
18+
19+
class FFN(nn.Module):
20+
def __init__(self, dim):
21+
super().__init__()
22+
self.c_fc = Linear(config['n_embd'], dim, r=32, bias=True)
23+
self.c_proj = Linear(dim, config['n_embd'], r=32, bias=True)
24+
self.act = nn.functional.gelu
25+
26+
def forward(self, hidden_states):
27+
hidden_states = self.c_fc(hidden_states)
28+
hidden_states = self.act(hidden_states)
29+
hidden_states = self.c_proj(hidden_states)
30+
return hidden_states
31+
32+
33+
class MultiHeadAttention(nn.Module):
34+
def __init__(self):
35+
super().__init__()
36+
self.embed_dim = config['n_embd']
37+
self.num_heads = config['n_head']
38+
self.head_dim = self.embed_dim // self.num_heads
39+
self.split_size = self.embed_dim
40+
41+
self.c_att = Linear(config['n_embd'], config['n_embd'] * 3, r=32, bias=True)
42+
self.c_proj = Linear(config['n_embd'], config['n_embd'], r=32, bias=True)
43+
44+
def _split_heads(self, tensor, num_heads, attn_head_size):
45+
"""
46+
Splits hidden_size dim into attn_head_size and num_heads
47+
"""
48+
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
49+
tensor = tensor.view(new_shape)
50+
return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
51+
52+
def forward(self, hidden_states):
53+
batch_size, seq_length, _ = hidden_states.size()
54+
55+
query, key, value = self.c_att(hidden_states).split(self.split_size, dim=2)
56+
57+
query = self._split_heads(query, self.num_heads, self.head_dim)
58+
key = self._split_heads(key, self.num_heads, self.head_dim)
59+
value = self._split_heads(value, self.num_heads, self.head_dim)
60+
61+
attn_output = torch.nn.functional.scaled_dot_product_attention(
62+
query,
63+
key,
64+
value,
65+
attn_mask=None,
66+
dropout_p=0.0,
67+
is_causal=True, # for the triangular mask
68+
)
69+
70+
attn_output = attn_output.transpose(1, 2).contiguous()
71+
attn_output = attn_output.view(batch_size, seq_length, self.embed_dim)
72+
73+
attn_output = self.c_proj(attn_output)
74+
75+
return attn_output
76+
77+
78+
class Block(nn.Module):
79+
def __init__(self):
80+
super().__init__()
81+
self.pre_norm = nn.LayerNorm(config['n_embd'], eps=config['layer_norm_epsilon'])
82+
self.attn = MultiHeadAttention()
83+
self.post_norm = nn.LayerNorm(config['n_embd'], eps=config['layer_norm_epsilon'])
84+
self.ffn = FFN(config['n_embd'] * 4)
85+
86+
def forward(self, hidden_states):
87+
residual = hidden_states
88+
hidden_states = self.pre_norm(hidden_states)
89+
90+
attn_output = self.attn(hidden_states)
91+
92+
hidden_states = attn_output + residual
93+
residual = hidden_states
94+
hidden_states = self.post_norm(hidden_states)
95+
feed_forward_output = self.ffn(hidden_states)
96+
hidden_states = feed_forward_output + residual
97+
98+
return hidden_states
99+
100+
101+
class GPTModel(nn.Module):
102+
def __init__(self):
103+
super().__init__()
104+
105+
self.token_embedding = Embedding(config['vocab_size'], config['n_embd'], r=32)
106+
self.position_embedding = Embedding(config['n_positions'], config['n_embd'], r=32)
107+
108+
self.blocks = nn.ModuleList([Block() for _ in range(config['n_layer'])])
109+
110+
self.final_norm = nn.LayerNorm(config['n_embd'], eps=config['layer_norm_epsilon'])
111+
112+
self.lm_head = Linear(config['n_embd'], config['vocab_size'], r=32, bias=False)
113+
114+
def forward(self, input_ids):
115+
batch_size, input_shape = input_ids.size()
116+
117+
token_embeddings = self.token_embedding(input_ids) # B T C
118+
position_ids = torch.arange(input_shape, device=config['device']) # T C
119+
position_embeddings = self.position_embedding(position_ids) # B T C
120+
121+
hidden_states = token_embeddings + position_embeddings
122+
123+
for block in self.blocks:
124+
hidden_states = block(hidden_states)
125+
126+
hidden_states = self.final_norm(hidden_states)
127+
128+
logits = self.lm_head(hidden_states)
129+
130+
return logits
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
5+
class Linear(nn.Module):
6+
def __init__(
7+
self,
8+
in_features: int,
9+
out_features: int,
10+
bias: bool,
11+
r: int,
12+
alpha: int = None):
13+
if alpha is None:
14+
alpha = r
15+
super().__init__()
16+
self.weight = nn.Parameter(torch.empty((out_features, in_features)))
17+
self.weight.requires_grad = False
18+
19+
if bias:
20+
self.bias = nn.Parameter(torch.empty(out_features))
21+
self.bias.requires_grad = False
22+
else:
23+
self.bias = None
24+
25+
self.scaling = alpha / r
26+
self.lora_a = nn.Parameter(torch.empty((in_features, r)))
27+
self.lora_b = nn.Parameter(torch.empty((r, out_features)))
28+
29+
with torch.no_grad():
30+
nn.init.kaiming_uniform_(self.lora_a, a=5 ** 0.5)
31+
nn.init.zeros_(self.lora_b)
32+
33+
def forward(self, x: torch.Tensor):
34+
result = nn.functional.linear(x, self.weight, bias=self.bias)
35+
36+
result += (x @ self.lora_a @ self.lora_b) * self.scaling
37+
38+
return result
39+
40+
41+
class Embedding(nn.Module):
42+
def __init__(
43+
self,
44+
num_embeddings: int,
45+
embedding_dim: int,
46+
r: int,
47+
alpha: int = None,
48+
):
49+
if alpha is None:
50+
alpha = r
51+
super().__init__()
52+
53+
self.weight = nn.Parameter(torch.empty((num_embeddings, embedding_dim)))
54+
self.weight.requires_grad = False
55+
56+
self.scaling = alpha / r
57+
self.lora_a = nn.Parameter(torch.empty((num_embeddings, r)))
58+
self.lora_b = nn.Parameter(torch.empty((r, embedding_dim)))
59+
60+
with torch.no_grad():
61+
nn.init.normal_(self.lora_a)
62+
nn.init.zeros_(self.lora_b)
63+
64+
def forward(self, x: torch.Tensor):
65+
result = nn.functional.embedding(x, self.weight)
66+
result += (nn.functional.embedding(x, self.lora_a) @ self.lora_b) * self.scaling
67+
68+
return result
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
{
2+
"cells": [
3+
{
4+
"metadata": {},
5+
"cell_type": "code",
6+
"source": [
7+
"from labml_nn.transformers.LoRA.GPT2 import GPTModel\n",
8+
"import torch"
9+
],
10+
"id": "cffa3ec341b4905a",
11+
"outputs": [],
12+
"execution_count": null
13+
},
14+
{
15+
"metadata": {},
16+
"cell_type": "code",
17+
"source": [
18+
"from transformers import AutoTokenizer\n",
19+
"\n",
20+
"tokenizer = AutoTokenizer.from_pretrained(\"gpt2\")"
21+
],
22+
"id": "c2b0b7e18394ea9e",
23+
"outputs": [],
24+
"execution_count": null
25+
},
26+
{
27+
"cell_type": "code",
28+
"id": "initial_id",
29+
"metadata": {
30+
"collapsed": true
31+
},
32+
"source": [
33+
"model = GPTModel()\n",
34+
"\n",
35+
"state_dict = torch.load('transformed.pth')\n",
36+
"\n",
37+
"missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)\n",
38+
"if missing_keys:\n",
39+
" print(f\"Missing keys: {missing_keys}\")\n",
40+
"if unexpected_keys:\n",
41+
" print(f\"Unexpected keys: {unexpected_keys}\")"
42+
],
43+
"outputs": [],
44+
"execution_count": null
45+
},
46+
{
47+
"metadata": {},
48+
"cell_type": "code",
49+
"source": [
50+
"prompt = \"hello how are you\"\n",
51+
"tokenized = tokenizer(prompt, return_tensors=\"pt\")\n",
52+
"tokenized['input_ids'] = tokenized['input_ids'].to('cuda')\n",
53+
"model = model.to('cuda')\n",
54+
"\n",
55+
"with torch.no_grad():\n",
56+
" model.eval()\n",
57+
" res = model(tokenized['input_ids'])\n",
58+
"\n",
59+
"output_ids = torch.argmax(res, dim=-1)\n",
60+
"for id in output_ids[0]:\n",
61+
" print(tokenizer.decode(id))"
62+
],
63+
"id": "f4f7826ec3729b66",
64+
"outputs": [],
65+
"execution_count": null
66+
},
67+
{
68+
"metadata": {},
69+
"cell_type": "code",
70+
"source": "",
71+
"id": "c12776360008a974",
72+
"outputs": [],
73+
"execution_count": null
74+
}
75+
],
76+
"metadata": {
77+
"kernelspec": {
78+
"display_name": "Python 3 (ipykernel)",
79+
"language": "python",
80+
"name": "python3"
81+
},
82+
"language_info": {
83+
"codemirror_mode": {
84+
"name": "ipython",
85+
"version": 2
86+
},
87+
"file_extension": ".py",
88+
"mimetype": "text/x-python",
89+
"name": "python",
90+
"nbconvert_exporter": "python",
91+
"pygments_lexer": "ipython2",
92+
"version": "2.7.6"
93+
}
94+
},
95+
"nbformat": 4,
96+
"nbformat_minor": 5
97+
}

labml_nn/transformers/LoRA/load_hf.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import torch
2+
from transformers import AutoModelForCausalLM
3+
4+
model = AutoModelForCausalLM.from_pretrained("gpt2")
5+
6+
state_dict = model.state_dict()
7+
8+
mapping = {
9+
'transformer.wte.weight': 'token_embedding.weight',
10+
'transformer.wpe.weight': 'position_embedding.weight',
11+
'transformer.ln_f.weight': 'final_norm.weight',
12+
'transformer.ln_f.bias': 'final_norm.bias',
13+
'lm_head.weight': 'lm_head.weight'
14+
}
15+
16+
for i in range(12):
17+
mapping[f'transformer.h.{i}.ln_1.weight'] = f'blocks.{i}.pre_norm.weight'
18+
mapping[f'transformer.h.{i}.ln_1.bias'] = f'blocks.{i}.pre_norm.bias'
19+
mapping[f'transformer.h.{i}.attn.c_attn.weight'] = f'blocks.{i}.attn.c_att.weight'
20+
mapping[f'transformer.h.{i}.attn.c_attn.bias'] = f'blocks.{i}.attn.c_att.bias'
21+
mapping[f'transformer.h.{i}.attn.c_proj.weight'] = f'blocks.{i}.attn.c_proj.weight'
22+
mapping[f'transformer.h.{i}.attn.c_proj.bias'] = f'blocks.{i}.attn.c_proj.bias'
23+
mapping[f'transformer.h.{i}.ln_2.weight'] = f'blocks.{i}.post_norm.weight'
24+
mapping[f'transformer.h.{i}.ln_2.bias'] = f'blocks.{i}.post_norm.bias'
25+
mapping[f'transformer.h.{i}.mlp.c_fc.weight'] = f'blocks.{i}.ffn.c_fc.weight'
26+
mapping[f'transformer.h.{i}.mlp.c_fc.bias'] = f'blocks.{i}.ffn.c_fc.bias'
27+
mapping[f'transformer.h.{i}.mlp.c_proj.weight'] = f'blocks.{i}.ffn.c_proj.weight'
28+
mapping[f'transformer.h.{i}.mlp.c_proj.bias'] = f'blocks.{i}.ffn.c_proj.bias'
29+
30+
new_state_dict = {}
31+
for old_key, new_key in mapping.items():
32+
if old_key in state_dict:
33+
new_state_dict[new_key] = state_dict[old_key]
34+
35+
# transpose weight matrices of convo 1d layers to use linear layers instead
36+
convo_layers = ([f'blocks.{i}.ffn.c_fc.weight' for i in range(12)] +
37+
[f'blocks.{i}.ffn.c_proj.weight' for i in range(12)] +
38+
[f'blocks.{i}.attn.c_att.weight' for i in range(12)] +
39+
[f'blocks.{i}.attn.c_proj.weight' for i in range(12)])
40+
41+
for layer in convo_layers:
42+
new_state_dict[layer] = torch.transpose(new_state_dict[layer], 0, 1)
43+
44+
torch.save(new_state_dict, 'transformed.pth')

0 commit comments

Comments
 (0)