Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 56 additions & 1 deletion examples/models/llama/static_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
ForwardOptions,
register_attention,
)
from executorch.examples.models.llama.lora import LoRALinear
from executorch.examples.models.llama.model_args import ModelArgs
from executorch.examples.models.llama.rope import Rope

Expand Down Expand Up @@ -861,6 +862,17 @@ def from_attention_mha(
rms_norm_class=torch.nn.RMSNorm,
**kwargs: Any,
) -> "StaticAttention":
has_lora = any(
isinstance(proj, LoRALinear)
for proj in [other.wq, other.wk, other.wv, other.wo]
)

if has_lora and split_mha:
raise ValueError(
"split_mha=True is not supported when the source AttentionMHA "
"contains LoRALinear modules. Use split_mha=False instead."
)

config = ModelArgs(
dim=other.dim,
n_layers=1, # Not used in attention layer
Expand All @@ -882,6 +894,49 @@ def from_attention_mha(
split_mha=split_mha,
**kwargs,
)

# Replace nn.Linear with LoRALinear where the source uses LoRA.
if has_lora:
for attr, proj, in_dim, out_dim, bias in [
(
"wqs",
other.wq,
other.dim,
other.n_heads * other.head_dim,
other.attention_qkv_bias,
),
(
"wks",
other.wk,
other.dim,
other.n_kv_heads * other.head_dim,
other.attention_qkv_bias,
),
(
"wvs",
other.wv,
other.dim,
other.n_kv_heads * other.head_dim,
other.attention_qkv_bias,
),
]:
if isinstance(proj, LoRALinear):
getattr(instance, attr)[0] = LoRALinear(
in_dim=in_dim,
out_dim=out_dim,
rank=proj.rank,
alpha=proj.alpha,
use_bias=bias,
)
if isinstance(other.wo, LoRALinear):
instance.wo = LoRALinear(
in_dim=other.n_heads * other.head_dim,
out_dim=other.dim,
rank=other.wo.rank,
alpha=other.wo.alpha,
use_bias=other.wo.use_bias,
)

instance.load_weights_from_attention_mha(other, rms_norm_class=rms_norm_class)

return instance
Expand Down Expand Up @@ -1120,7 +1175,7 @@ def load_weights_from_attention_mha(
self.wks[0].load_state_dict(other.wk.state_dict())
self.wvs[0].load_state_dict(other.wv.state_dict())

self.wo.weight.data.copy_(other.wo.weight) # pyre-ignore[6]
self.wo.load_state_dict(other.wo.state_dict())

if other.use_qk_norm:
self.use_qk_norm = True
Expand Down
97 changes: 97 additions & 0 deletions examples/models/llama/tests/test_static_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
from executorch.examples.models.llama.attention import AttentionMHA
from executorch.examples.models.llama.llama_transformer import construct_transformer
from executorch.examples.models.llama.lora import LoRALinear
from executorch.examples.models.llama.model_args import ModelArgs
from executorch.examples.models.llama.rope import Rope
from executorch.examples.models.llama.static_attention import (
Expand Down Expand Up @@ -361,3 +362,99 @@ def test_batched_export_with_backprop(self):
static_transformer, example_inputs
).module()
non_batched_gm.load_state_dict(batched_gm.state_dict())

def test_lora_split_mha_raises(self):
config = ModelArgs(
dim=64,
n_heads=4,
n_kv_heads=2,
max_seq_len=8,
r=4,
lora_alpha=8,
target_modules=["q_proj"],
)
layer_id = 0
rope = Rope(config)
attn_mha = AttentionMHA(config, layer_id, rope)
with self.assertRaises(ValueError):
StaticAttention.from_attention_mha(attn_mha, split_mha=True)

def test_lora_without_cache(self):
torch.manual_seed(42)
config = ModelArgs(
dim=64,
n_heads=4,
n_kv_heads=2,
max_seq_len=8,
r=4,
lora_alpha=8,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
)
layer_id = 0
rope = Rope(config)
attn_mha = AttentionMHA(config, layer_id, rope).eval()

self.assertIsInstance(attn_mha.wq, LoRALinear)
self.assertIsInstance(attn_mha.wk, LoRALinear)
self.assertIsInstance(attn_mha.wv, LoRALinear)
self.assertIsInstance(attn_mha.wo, LoRALinear)

static_attn = StaticAttention.from_attention_mha(
attn_mha, split_mha=False
).eval()

self.assertIsInstance(static_attn.wqs[0], LoRALinear)
self.assertIsInstance(static_attn.wks[0], LoRALinear)
self.assertIsInstance(static_attn.wvs[0], LoRALinear)
self.assertIsInstance(static_attn.wo, LoRALinear)

x = torch.rand(1, config.max_seq_len, config.dim)
freqs_cos, freqs_sin = rope.get_freqs(None, config.max_seq_len)
expected, _ = attn_mha(x, freqs_cos, freqs_sin)

mask = torch.triu(
torch.full((1, config.max_seq_len, config.max_seq_len), float("-inf")),
diagonal=1,
)
y, _ = static_attn(x, freqs_cos, freqs_sin, masks={0: mask})
self.assertTrue(torch.isclose(y, expected, rtol=1e-3).all())

def test_lora_partial_projections(self):
torch.manual_seed(42)
config = ModelArgs(
dim=64,
n_heads=4,
n_kv_heads=2,
max_seq_len=8,
r=4,
lora_alpha=8,
target_modules=["q_proj", "v_proj"],
)
layer_id = 0
rope = Rope(config)
attn_mha = AttentionMHA(config, layer_id, rope).eval()

self.assertIsInstance(attn_mha.wq, LoRALinear)
self.assertIsInstance(attn_mha.wk, torch.nn.Linear)
self.assertIsInstance(attn_mha.wv, LoRALinear)
self.assertIsInstance(attn_mha.wo, torch.nn.Linear)

static_attn = StaticAttention.from_attention_mha(
attn_mha, split_mha=False
).eval()

self.assertIsInstance(static_attn.wqs[0], LoRALinear)
self.assertIsInstance(static_attn.wks[0], torch.nn.Linear)
self.assertIsInstance(static_attn.wvs[0], LoRALinear)
self.assertIsInstance(static_attn.wo, torch.nn.Linear)

x = torch.rand(1, config.max_seq_len, config.dim)
freqs_cos, freqs_sin = rope.get_freqs(None, config.max_seq_len)
expected, _ = attn_mha(x, freqs_cos, freqs_sin)

mask = torch.triu(
torch.full((1, config.max_seq_len, config.max_seq_len), float("-inf")),
diagonal=1,
)
y, _ = static_attn(x, freqs_cos, freqs_sin, masks={0: mask})
self.assertTrue(torch.isclose(y, expected, rtol=1e-3).all())
Loading