|
| 1 | +import torch |
| 2 | +import torch.nn as nn |
| 3 | + |
| 4 | + |
| 5 | +class Linear(nn.Module): |
| 6 | + def __init__( |
| 7 | + self, |
| 8 | + in_features: int, |
| 9 | + out_features: int, |
| 10 | + bias: bool, |
| 11 | + r: int, |
| 12 | + alpha: int = None): |
| 13 | + if alpha is None: |
| 14 | + alpha = r |
| 15 | + super().__init__() |
| 16 | + self.weight = nn.Parameter(torch.empty((out_features, in_features))) |
| 17 | + self.weight.requires_grad = False |
| 18 | + |
| 19 | + if bias: |
| 20 | + self.bias = nn.Parameter(torch.empty(out_features)) |
| 21 | + self.bias.requires_grad = False |
| 22 | + else: |
| 23 | + self.bias = None |
| 24 | + |
| 25 | + self.scaling = alpha / r |
| 26 | + self.lora_a = nn.Parameter(torch.empty((in_features, r))) |
| 27 | + self.lora_b = nn.Parameter(torch.empty((r, out_features))) |
| 28 | + |
| 29 | + with torch.no_grad(): |
| 30 | + nn.init.kaiming_uniform_(self.lora_a, a=5 ** 0.5) |
| 31 | + nn.init.zeros_(self.lora_b) |
| 32 | + |
| 33 | + def forward(self, x: torch.Tensor): |
| 34 | + result = nn.functional.linear(x, self.weight, bias=self.bias) |
| 35 | + |
| 36 | + result += (x @ self.lora_a @ self.lora_b) * self.scaling |
| 37 | + |
| 38 | + return result |
| 39 | + |
| 40 | + |
| 41 | +class Embedding(nn.Module): |
| 42 | + def __init__( |
| 43 | + self, |
| 44 | + num_embeddings: int, |
| 45 | + embedding_dim: int, |
| 46 | + r: int, |
| 47 | + alpha: int = None, |
| 48 | + ): |
| 49 | + if alpha is None: |
| 50 | + alpha = r |
| 51 | + super().__init__() |
| 52 | + |
| 53 | + self.weight = nn.Parameter(torch.empty((num_embeddings, embedding_dim))) |
| 54 | + self.weight.requires_grad = False |
| 55 | + |
| 56 | + self.scaling = alpha / self.r |
| 57 | + self.lora_a = nn.Parameter(torch.empty((num_embeddings, r))) |
| 58 | + self.lora_b = nn.Parameter(torch.empty((r, embedding_dim))) |
| 59 | + |
| 60 | + with torch.no_grad(): |
| 61 | + nn.init.normal_(self.lora_a) |
| 62 | + nn.init.zeros_(self.lora_b) |
| 63 | + |
| 64 | + def forward(self, x: torch.Tensor): |
| 65 | + result = nn.functional.embedding(x, self.weight) |
| 66 | + result += (nn.functional.embedding(x, self.lora_a) @ self.lora_b) * self.scaling |
| 67 | + |
| 68 | + return result |
0 commit comments