Skip to content

Commit 8e756f2

Browse files
committed
lora layers
1 parent d1e8daa commit 8e756f2

File tree

1 file changed

+68
-0
lines changed

1 file changed

+68
-0
lines changed

docs/transformers/LoRA/__init__.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
5+
class Linear(nn.Module):
6+
def __init__(
7+
self,
8+
in_features: int,
9+
out_features: int,
10+
bias: bool,
11+
r: int,
12+
alpha: int = None):
13+
if alpha is None:
14+
alpha = r
15+
super().__init__()
16+
self.weight = nn.Parameter(torch.empty((out_features, in_features)))
17+
self.weight.requires_grad = False
18+
19+
if bias:
20+
self.bias = nn.Parameter(torch.empty(out_features))
21+
self.bias.requires_grad = False
22+
else:
23+
self.bias = None
24+
25+
self.scaling = alpha / r
26+
self.lora_a = nn.Parameter(torch.empty((in_features, r)))
27+
self.lora_b = nn.Parameter(torch.empty((r, out_features)))
28+
29+
with torch.no_grad():
30+
nn.init.kaiming_uniform_(self.lora_a, a=5 ** 0.5)
31+
nn.init.zeros_(self.lora_b)
32+
33+
def forward(self, x: torch.Tensor):
34+
result = nn.functional.linear(x, self.weight, bias=self.bias)
35+
36+
result += (x @ self.lora_a @ self.lora_b) * self.scaling
37+
38+
return result
39+
40+
41+
class Embedding(nn.Module):
42+
def __init__(
43+
self,
44+
num_embeddings: int,
45+
embedding_dim: int,
46+
r: int,
47+
alpha: int = None,
48+
):
49+
if alpha is None:
50+
alpha = r
51+
super().__init__()
52+
53+
self.weight = nn.Parameter(torch.empty((num_embeddings, embedding_dim)))
54+
self.weight.requires_grad = False
55+
56+
self.scaling = alpha / self.r
57+
self.lora_a = nn.Parameter(torch.empty((num_embeddings, r)))
58+
self.lora_b = nn.Parameter(torch.empty((r, embedding_dim)))
59+
60+
with torch.no_grad():
61+
nn.init.normal_(self.lora_a)
62+
nn.init.zeros_(self.lora_b)
63+
64+
def forward(self, x: torch.Tensor):
65+
result = nn.functional.embedding(x, self.weight)
66+
result += (nn.functional.embedding(x, self.lora_a) @ self.lora_b) * self.scaling
67+
68+
return result

0 commit comments

Comments
 (0)