Skip to content

Commit 29ffe8d

Browse files
sxufacebook-github-bot
authored andcommitted
Support LoRALinear in StaticAttention transform when split_mha=False (#18074)
Summary: Enable it when split_mha=False which is the case for ANE. To support split_mha=True, we'd need to split both the base linear and the LoRA adaptor which is more tricky, throw an error for now. Reviewed By: billmguo Differential Revision: D95991428
1 parent 096f10c commit 29ffe8d

2 files changed

Lines changed: 153 additions & 1 deletion

File tree

examples/models/llama/static_attention.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
ForwardOptions,
1414
register_attention,
1515
)
16+
from executorch.examples.models.llama.lora import LoRALinear
1617
from executorch.examples.models.llama.model_args import ModelArgs
1718
from executorch.examples.models.llama.rope import Rope
1819

@@ -861,6 +862,17 @@ def from_attention_mha(
861862
rms_norm_class=torch.nn.RMSNorm,
862863
**kwargs: Any,
863864
) -> "StaticAttention":
865+
has_lora = any(
866+
isinstance(proj, LoRALinear)
867+
for proj in [other.wq, other.wk, other.wv, other.wo]
868+
)
869+
870+
if has_lora and split_mha:
871+
raise ValueError(
872+
"split_mha=True is not supported when the source AttentionMHA "
873+
"contains LoRALinear modules. Use split_mha=False instead."
874+
)
875+
864876
config = ModelArgs(
865877
dim=other.dim,
866878
n_layers=1, # Not used in attention layer
@@ -882,6 +894,49 @@ def from_attention_mha(
882894
split_mha=split_mha,
883895
**kwargs,
884896
)
897+
898+
# Replace nn.Linear with LoRALinear where the source uses LoRA.
899+
if has_lora:
900+
for attr, proj, in_dim, out_dim, bias in [
901+
(
902+
"wqs",
903+
other.wq,
904+
other.dim,
905+
other.n_heads * other.head_dim,
906+
other.attention_qkv_bias
907+
),
908+
(
909+
"wks",
910+
other.wk,
911+
other.dim,
912+
other.n_kv_heads * other.head_dim,
913+
other.attention_qkv_bias,
914+
),
915+
(
916+
"wvs",
917+
other.wv,
918+
other.dim,
919+
other.n_kv_heads * other.head_dim,
920+
other.attention_qkv_bias
921+
),
922+
]:
923+
if isinstance(proj, LoRALinear):
924+
getattr(instance, attr)[0] = LoRALinear(
925+
in_dim=in_dim,
926+
out_dim=out_dim,
927+
rank=proj.rank,
928+
alpha=proj.alpha,
929+
use_bias=bias,
930+
)
931+
if isinstance(other.wo, LoRALinear):
932+
instance.wo = LoRALinear(
933+
in_dim=other.n_heads * other.head_dim,
934+
out_dim=other.dim,
935+
rank=other.wo.rank,
936+
alpha=other.wo.alpha,
937+
use_bias=other.wo.use_bias,
938+
)
939+
885940
instance.load_weights_from_attention_mha(other, rms_norm_class=rms_norm_class)
886941

887942
return instance
@@ -1120,7 +1175,7 @@ def load_weights_from_attention_mha(
11201175
self.wks[0].load_state_dict(other.wk.state_dict())
11211176
self.wvs[0].load_state_dict(other.wv.state_dict())
11221177

1123-
self.wo.weight.data.copy_(other.wo.weight) # pyre-ignore[6]
1178+
self.wo.load_state_dict(other.wo.state_dict())
11241179

11251180
if other.use_qk_norm:
11261181
self.use_qk_norm = True

examples/models/llama/tests/test_static_attention.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import torch
77
from executorch.examples.models.llama.attention import AttentionMHA
88
from executorch.examples.models.llama.llama_transformer import construct_transformer
9+
from executorch.examples.models.llama.lora import LoRALinear
910
from executorch.examples.models.llama.model_args import ModelArgs
1011
from executorch.examples.models.llama.rope import Rope
1112
from executorch.examples.models.llama.static_attention import (
@@ -361,3 +362,99 @@ def test_batched_export_with_backprop(self):
361362
static_transformer, example_inputs
362363
).module()
363364
non_batched_gm.load_state_dict(batched_gm.state_dict())
365+
366+
def test_lora_split_mha_raises(self):
367+
config = ModelArgs(
368+
dim=64,
369+
n_heads=4,
370+
n_kv_heads=2,
371+
max_seq_len=8,
372+
r=4,
373+
lora_alpha=8,
374+
target_modules=["q_proj"],
375+
)
376+
layer_id = 0
377+
rope = Rope(config)
378+
attn_mha = AttentionMHA(config, layer_id, rope)
379+
with self.assertRaises(ValueError):
380+
StaticAttention.from_attention_mha(attn_mha, split_mha=True)
381+
382+
def test_lora_without_cache(self):
383+
torch.manual_seed(42)
384+
config = ModelArgs(
385+
dim=64,
386+
n_heads=4,
387+
n_kv_heads=2,
388+
max_seq_len=8,
389+
r=4,
390+
lora_alpha=8,
391+
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
392+
)
393+
layer_id = 0
394+
rope = Rope(config)
395+
attn_mha = AttentionMHA(config, layer_id, rope).eval()
396+
397+
self.assertIsInstance(attn_mha.wq, LoRALinear)
398+
self.assertIsInstance(attn_mha.wk, LoRALinear)
399+
self.assertIsInstance(attn_mha.wv, LoRALinear)
400+
self.assertIsInstance(attn_mha.wo, LoRALinear)
401+
402+
static_attn = StaticAttention.from_attention_mha(
403+
attn_mha, split_mha=False
404+
).eval()
405+
406+
self.assertIsInstance(static_attn.wqs[0], LoRALinear)
407+
self.assertIsInstance(static_attn.wks[0], LoRALinear)
408+
self.assertIsInstance(static_attn.wvs[0], LoRALinear)
409+
self.assertIsInstance(static_attn.wo, LoRALinear)
410+
411+
x = torch.rand(1, config.max_seq_len, config.dim)
412+
freqs_cos, freqs_sin = rope.get_freqs(None, config.max_seq_len)
413+
expected, _ = attn_mha(x, freqs_cos, freqs_sin)
414+
415+
mask = torch.triu(
416+
torch.full((1, config.max_seq_len, config.max_seq_len), float("-inf")),
417+
diagonal=1,
418+
)
419+
y, _ = static_attn(x, freqs_cos, freqs_sin, masks={0: mask})
420+
self.assertTrue(torch.isclose(y, expected, rtol=1e-3).all())
421+
422+
def test_lora_partial_projections(self):
423+
torch.manual_seed(42)
424+
config = ModelArgs(
425+
dim=64,
426+
n_heads=4,
427+
n_kv_heads=2,
428+
max_seq_len=8,
429+
r=4,
430+
lora_alpha=8,
431+
target_modules=["q_proj", "v_proj"],
432+
)
433+
layer_id = 0
434+
rope = Rope(config)
435+
attn_mha = AttentionMHA(config, layer_id, rope).eval()
436+
437+
self.assertIsInstance(attn_mha.wq, LoRALinear)
438+
self.assertIsInstance(attn_mha.wk, torch.nn.Linear)
439+
self.assertIsInstance(attn_mha.wv, LoRALinear)
440+
self.assertIsInstance(attn_mha.wo, torch.nn.Linear)
441+
442+
static_attn = StaticAttention.from_attention_mha(
443+
attn_mha, split_mha=False
444+
).eval()
445+
446+
self.assertIsInstance(static_attn.wqs[0], LoRALinear)
447+
self.assertIsInstance(static_attn.wks[0], torch.nn.Linear)
448+
self.assertIsInstance(static_attn.wvs[0], LoRALinear)
449+
self.assertIsInstance(static_attn.wo, torch.nn.Linear)
450+
451+
x = torch.rand(1, config.max_seq_len, config.dim)
452+
freqs_cos, freqs_sin = rope.get_freqs(None, config.max_seq_len)
453+
expected, _ = attn_mha(x, freqs_cos, freqs_sin)
454+
455+
mask = torch.triu(
456+
torch.full((1, config.max_seq_len, config.max_seq_len), float("-inf")),
457+
diagonal=1,
458+
)
459+
y, _ = static_attn(x, freqs_cos, freqs_sin, masks={0: mask})
460+
self.assertTrue(torch.isclose(y, expected, rtol=1e-3).all())

0 commit comments

Comments
 (0)