-
Notifications
You must be signed in to change notification settings - Fork 974
Expand file tree
/
Copy pathlora.py
More file actions
57 lines (47 loc) · 1.84 KB
/
lora.py
File metadata and controls
57 lines (47 loc) · 1.84 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import torch
from torch import nn
class LoRALinear(nn.Module):
"""LoRA linear layer as introduced in `LoRA: Low-Rank Adaptation of Large Language Models <https://arxiv.org/abs/2106.09685>`."""
def __init__(
self,
in_dim: int,
out_dim: int,
rank: int,
alpha: float,
dropout: float = 0.0,
use_bias: bool = False,
):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.rank = rank
self.alpha = alpha
self.use_bias = use_bias
self.linear = nn.Linear(in_dim, out_dim, bias=use_bias)
self.dropout = nn.Dropout(p=dropout) if dropout > 0.0 else nn.Identity()
self.lora_a = nn.Linear(in_features=in_dim, out_features=rank, bias=False)
self.lora_b = nn.Linear(in_features=rank, out_features=out_dim, bias=False)
@property
def weight(self):
return self.linear.weight
@property
def bias(self):
return self.linear.bias
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
# Remap keys to "linear.*"
for attr in ("weight", "bias"):
old_key = prefix + attr
new_key = prefix + "linear." + attr
if old_key in state_dict and new_key not in state_dict:
state_dict[new_key] = state_dict.pop(old_key)
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
def forward(self, x: torch.Tensor) -> torch.Tensor:
out = self.linear(x)
lora_out = self.lora_a(self.dropout(x))
lora_out = (self.alpha / self.rank) * self.lora_b(lora_out)
return out + lora_out