|
6 | 6 | import torch |
7 | 7 | from executorch.examples.models.llama.attention import AttentionMHA |
8 | 8 | from executorch.examples.models.llama.llama_transformer import construct_transformer |
| 9 | +from executorch.examples.models.llama.lora import LoRALinear |
9 | 10 | from executorch.examples.models.llama.model_args import ModelArgs |
10 | 11 | from executorch.examples.models.llama.rope import Rope |
11 | 12 | from executorch.examples.models.llama.static_attention import ( |
@@ -361,3 +362,99 @@ def test_batched_export_with_backprop(self): |
361 | 362 | static_transformer, example_inputs |
362 | 363 | ).module() |
363 | 364 | non_batched_gm.load_state_dict(batched_gm.state_dict()) |
| 365 | + |
| 366 | + def test_lora_split_mha_raises(self): |
| 367 | + config = ModelArgs( |
| 368 | + dim=64, |
| 369 | + n_heads=4, |
| 370 | + n_kv_heads=2, |
| 371 | + max_seq_len=8, |
| 372 | + r=4, |
| 373 | + lora_alpha=8, |
| 374 | + target_modules=["q_proj"], |
| 375 | + ) |
| 376 | + layer_id = 0 |
| 377 | + rope = Rope(config) |
| 378 | + attn_mha = AttentionMHA(config, layer_id, rope) |
| 379 | + with self.assertRaises(ValueError): |
| 380 | + StaticAttention.from_attention_mha(attn_mha, split_mha=True) |
| 381 | + |
| 382 | + def test_lora_without_cache(self): |
| 383 | + torch.manual_seed(42) |
| 384 | + config = ModelArgs( |
| 385 | + dim=64, |
| 386 | + n_heads=4, |
| 387 | + n_kv_heads=2, |
| 388 | + max_seq_len=8, |
| 389 | + r=4, |
| 390 | + lora_alpha=8, |
| 391 | + target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], |
| 392 | + ) |
| 393 | + layer_id = 0 |
| 394 | + rope = Rope(config) |
| 395 | + attn_mha = AttentionMHA(config, layer_id, rope).eval() |
| 396 | + |
| 397 | + self.assertIsInstance(attn_mha.wq, LoRALinear) |
| 398 | + self.assertIsInstance(attn_mha.wk, LoRALinear) |
| 399 | + self.assertIsInstance(attn_mha.wv, LoRALinear) |
| 400 | + self.assertIsInstance(attn_mha.wo, LoRALinear) |
| 401 | + |
| 402 | + static_attn = StaticAttention.from_attention_mha( |
| 403 | + attn_mha, split_mha=False |
| 404 | + ).eval() |
| 405 | + |
| 406 | + self.assertIsInstance(static_attn.wqs[0], LoRALinear) |
| 407 | + self.assertIsInstance(static_attn.wks[0], LoRALinear) |
| 408 | + self.assertIsInstance(static_attn.wvs[0], LoRALinear) |
| 409 | + self.assertIsInstance(static_attn.wo, LoRALinear) |
| 410 | + |
| 411 | + x = torch.rand(1, config.max_seq_len, config.dim) |
| 412 | + freqs_cos, freqs_sin = rope.get_freqs(None, config.max_seq_len) |
| 413 | + expected, _ = attn_mha(x, freqs_cos, freqs_sin) |
| 414 | + |
| 415 | + mask = torch.triu( |
| 416 | + torch.full((1, config.max_seq_len, config.max_seq_len), float("-inf")), |
| 417 | + diagonal=1, |
| 418 | + ) |
| 419 | + y, _ = static_attn(x, freqs_cos, freqs_sin, masks={0: mask}) |
| 420 | + self.assertTrue(torch.isclose(y, expected, rtol=1e-3).all()) |
| 421 | + |
| 422 | + def test_lora_partial_projections(self): |
| 423 | + torch.manual_seed(42) |
| 424 | + config = ModelArgs( |
| 425 | + dim=64, |
| 426 | + n_heads=4, |
| 427 | + n_kv_heads=2, |
| 428 | + max_seq_len=8, |
| 429 | + r=4, |
| 430 | + lora_alpha=8, |
| 431 | + target_modules=["q_proj", "v_proj"], |
| 432 | + ) |
| 433 | + layer_id = 0 |
| 434 | + rope = Rope(config) |
| 435 | + attn_mha = AttentionMHA(config, layer_id, rope).eval() |
| 436 | + |
| 437 | + self.assertIsInstance(attn_mha.wq, LoRALinear) |
| 438 | + self.assertIsInstance(attn_mha.wk, torch.nn.Linear) |
| 439 | + self.assertIsInstance(attn_mha.wv, LoRALinear) |
| 440 | + self.assertIsInstance(attn_mha.wo, torch.nn.Linear) |
| 441 | + |
| 442 | + static_attn = StaticAttention.from_attention_mha( |
| 443 | + attn_mha, split_mha=False |
| 444 | + ).eval() |
| 445 | + |
| 446 | + self.assertIsInstance(static_attn.wqs[0], LoRALinear) |
| 447 | + self.assertIsInstance(static_attn.wks[0], torch.nn.Linear) |
| 448 | + self.assertIsInstance(static_attn.wvs[0], LoRALinear) |
| 449 | + self.assertIsInstance(static_attn.wo, torch.nn.Linear) |
| 450 | + |
| 451 | + x = torch.rand(1, config.max_seq_len, config.dim) |
| 452 | + freqs_cos, freqs_sin = rope.get_freqs(None, config.max_seq_len) |
| 453 | + expected, _ = attn_mha(x, freqs_cos, freqs_sin) |
| 454 | + |
| 455 | + mask = torch.triu( |
| 456 | + torch.full((1, config.max_seq_len, config.max_seq_len), float("-inf")), |
| 457 | + diagonal=1, |
| 458 | + ) |
| 459 | + y, _ = static_attn(x, freqs_cos, freqs_sin, masks={0: mask}) |
| 460 | + self.assertTrue(torch.isclose(y, expected, rtol=1e-3).all()) |
0 commit comments