Skip to content

Commit 72a9449

Browse files
committed
Update lora def
1 parent d0820e1 commit 72a9449

3 files changed

Lines changed: 20 additions & 34 deletions

File tree

backends/xnnpack/operators/node_visitor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -625,7 +625,7 @@ def get_serialized_buffer_index(
625625
f"Serializing constant data node {tensor} but tensor value has no bytes",
626626
)
627627
sha256_hash = hashlib.sha256(bytes(array))
628-
named_key = tensor.name + "_" + sha256_hash.hexdigest()
628+
named_key = sha256_hash.hexdigest()
629629

630630
size = const_val.untyped_storage().nbytes()
631631
xnn_graph.constant_data.append(

examples/models/llama/lora.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,20 +28,25 @@ def __init__(
2828
self.use_bias = use_bias
2929
self.dropout = dropout
3030

31-
linear = nn.Linear(in_dim, out_dim, bias=use_bias)
32-
weight = linear.weight
33-
bias = linear.bias if self.use_bias else None
34-
self.register_parameter("weight", nn.Parameter(weight))
35-
self.register_parameter(
36-
"bias", nn.Parameter(bias) if bias is not None else None
37-
)
38-
31+
self.linear = nn.Linear(in_dim, out_dim, bias=use_bias)
3932
self.dropout = nn.Dropout(p=dropout) if dropout > 0.0 else nn.Identity()
4033
self.lora_a = nn.Linear(in_features=in_dim, out_features=rank, bias=False)
4134
self.lora_b = nn.Linear(in_features=rank, out_features=out_dim, bias=False)
4235

36+
@property
37+
def weight(self):
38+
return self.linear.weight
39+
40+
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
41+
# Remap old-style "weight" key to "linear.weight" for backward compat
42+
old_key = prefix + "weight"
43+
new_key = prefix + "linear.weight"
44+
if old_key in state_dict and new_key not in state_dict:
45+
state_dict[new_key] = state_dict.pop(old_key)
46+
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
47+
4348
def forward(self, x: torch.Tensor) -> torch.Tensor:
44-
out = torch.nn.functional.linear(x, self.weight, self.bias)
49+
out = self.linear(x)
4550
lora_out = self.lora_a(self.dropout(x))
4651
lora_out = (self.alpha / self.rank) * self.lora_b(lora_out)
4752

examples/models/llama/source_transformation/quantize.py

Lines changed: 5 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -144,30 +144,11 @@ def quantize( # noqa C901
144144
from torchao.utils import unwrap_tensor_subclass
145145

146146
def filter_fn(m, fqn):
147-
# Check if it's a regular nn.Linear
148-
is_linear = isinstance(m, nn.Linear)
149-
150-
# Check if it's a LoRALinear (which has a base weight parameter to quantize)
151-
is_lora_linear = False
152-
try:
153-
from executorch.examples.models.llama.lora import LoRALinear
154-
155-
is_lora_linear = isinstance(m, LoRALinear)
156-
except ImportError:
157-
pass
158-
159-
# Check if the weight shape is compatible with group size
160-
has_shape_compatible_with_group_size = False
161-
if is_linear or is_lora_linear:
162-
if group_size == 0:
163-
has_shape_compatible_with_group_size = True
164-
else:
165-
has_shape_compatible_with_group_size = (
166-
m.weight.shape[1] % group_size == 0
167-
)
168-
return (
169-
is_linear or is_lora_linear
170-
) and has_shape_compatible_with_group_size
147+
if not isinstance(m, nn.Linear):
148+
return False
149+
if group_size == 0:
150+
return True
151+
return m.weight.shape[1] % group_size == 0
171152

172153
weight_dtype = torch.int4 if qmode == "8da4w" else torch.int8
173154
quantize_(

0 commit comments

Comments
 (0)