|
18 | 18 | from torch.distributed.fsdp import StateDictType |
19 | 19 |
|
20 | 20 | from modalities.__main__ import Main |
21 | | -from modalities.config.config import ProcessGroupBackendType |
22 | | -from modalities.config.pydantic_if_types import PydanticFSDP1ModuleType, PydanticFSDP2ModuleType |
| 21 | +from modalities.config.component_factory import ComponentFactory |
| 22 | +from modalities.config.config import ProcessGroupBackendType, load_app_config_dict |
| 23 | +from modalities.config.pydantic_if_types import ( |
| 24 | + PydanticFSDP1ModuleType, |
| 25 | + PydanticFSDP2ModuleType, |
| 26 | + PydanticPytorchModuleType, |
| 27 | +) |
| 28 | +from modalities.models.gpt2.gpt2_model import GPT2LLM, GPT2Block |
| 29 | +from modalities.registry.components import COMPONENTS |
| 30 | +from modalities.registry.registry import Registry |
23 | 31 | from tests.end2end_tests.custom_components import MultiProcessingCudaEnv |
24 | 32 |
|
25 | 33 |
|
@@ -493,3 +501,83 @@ def _get_fdsp2_state_dict(model: FSDP2) -> dict[str, Any]: |
493 | 501 | model=model, optimizers=[], options=StateDictOptions(full_state_dict=True, broadcast_from_rank0=True) |
494 | 502 | )[0] |
495 | 503 | return model_state |
| 504 | + |
| 505 | + |
| 506 | +class TestLlama3LikeInitialization: |
| 507 | + @pytest.mark.parametrize("depth_init", [True, False]) |
| 508 | + def test_llama3_like_initialization(self, depth_init: bool): |
| 509 | + config_file_path = Path(__file__).parent / "test_yaml_configs/llama3_config_initalization.yaml" |
| 510 | + n_layer = 4 |
| 511 | + n_embd = 256 |
| 512 | + model = self._get_components(config_file_path=config_file_path, depth_init=depth_init) |
| 513 | + self._test_wte(model=model) |
| 514 | + self._test_lm_head(model=model, n_embd=n_embd) |
| 515 | + |
| 516 | + for layer_id, (_, block) in enumerate(model.transformer["h"].items()): |
| 517 | + self._test_qkv_proj(gpt2_block=block) |
| 518 | + self._test_c_proj(gpt2_block=block, depth_init=depth_init, n_layer=n_layer, layer_id=layer_id) |
| 519 | + self._test_swiglu_proj(gpt2_block=block, depth_init=depth_init, n_layer=n_layer, layer_id=layer_id) |
| 520 | + |
| 521 | + def _get_components(self, config_file_path: Path, depth_init: bool) -> GPT2LLM: |
| 522 | + config_dict = load_app_config_dict( |
| 523 | + config_file_path=config_file_path, |
| 524 | + ) |
| 525 | + config_dict["initialized_model"]["config"]["model_initializer"]["config"]["depth_init"] = depth_init |
| 526 | + registry = Registry(COMPONENTS) |
| 527 | + component_factory = ComponentFactory(registry=registry) |
| 528 | + |
| 529 | + class ComponentsInstantiationModel(BaseModel): |
| 530 | + initialized_model: PydanticPytorchModuleType |
| 531 | + |
| 532 | + components: ComponentsInstantiationModel = component_factory.build_components( |
| 533 | + config_dict=config_dict, components_model_type=ComponentsInstantiationModel |
| 534 | + ) |
| 535 | + return components.initialized_model |
| 536 | + |
| 537 | + def _test_wte(self, model: GPT2LLM): |
| 538 | + assert model.transformer.wte.weight.std().detach().cpu() == pytest.approx(1, abs=1e-2) |
| 539 | + assert model.transformer.wte.weight.mean().detach().cpu() == pytest.approx(0, abs=1e-2) |
| 540 | + |
| 541 | + def _test_lm_head(self, model: GPT2LLM, n_embd: int): |
| 542 | + assert model.transformer.lm_head.weight.std().detach().cpu() == pytest.approx(1 / math.sqrt(n_embd), abs=1e-3) |
| 543 | + assert model.transformer.lm_head.weight.max().detach().cpu() <= 3 / math.sqrt(n_embd) |
| 544 | + assert model.transformer.lm_head.weight.min().detach().cpu() >= -3 / math.sqrt(n_embd) |
| 545 | + assert model.transformer.lm_head.weight.mean().detach().cpu() == pytest.approx(0, abs=1e-3) |
| 546 | + |
| 547 | + def _test_qkv_proj(self, gpt2_block: GPT2Block): |
| 548 | + layers = (gpt2_block.attn.q_attn, gpt2_block.attn.k_attn, gpt2_block.attn.v_attn) |
| 549 | + for layer in layers: |
| 550 | + assert layer.weight.std().detach().cpu() == pytest.approx(0.02, abs=1e-3) |
| 551 | + assert layer.weight.max().detach().cpu() <= 2 |
| 552 | + assert layer.weight.min().detach().cpu() >= -2 |
| 553 | + assert layer.weight.mean().detach().cpu() == pytest.approx(0, abs=1e-3) |
| 554 | + |
| 555 | + def _test_c_proj(self, gpt2_block: GPT2Block, depth_init: bool, n_layer: int, layer_id: int): |
| 556 | + layer = gpt2_block.attn.c_proj |
| 557 | + if depth_init: |
| 558 | + assert layer.weight.std().detach().cpu() == pytest.approx(0.02 / math.sqrt(2 * (layer_id + 1)), abs=1e-3) |
| 559 | + else: |
| 560 | + assert layer.weight.std().detach().cpu() == pytest.approx(0.02 / math.sqrt(2 * n_layer), abs=1e-3) |
| 561 | + |
| 562 | + assert layer.weight.max().detach().cpu() <= 2 |
| 563 | + assert layer.weight.min().detach().cpu() >= -2 |
| 564 | + assert layer.weight.mean().detach().cpu() == pytest.approx(0, abs=1e-3) |
| 565 | + |
| 566 | + def _test_swiglu_proj(self, gpt2_block: GPT2Block, depth_init: bool, n_layer: int, layer_id: int): |
| 567 | + layers = (gpt2_block.mlp.V, gpt2_block.mlp.W_2) |
| 568 | + for layer in layers: |
| 569 | + if depth_init: |
| 570 | + assert layer.weight.std().detach().cpu() == pytest.approx( |
| 571 | + 0.02 / math.sqrt(2 * (layer_id + 1)), abs=1e-3 |
| 572 | + ) |
| 573 | + else: |
| 574 | + assert layer.weight.std().detach().cpu() == pytest.approx(0.02 / math.sqrt(2 * n_layer), abs=1e-3) |
| 575 | + assert layer.weight.max().detach().cpu() <= 2 |
| 576 | + assert layer.weight.min().detach().cpu() >= -2 |
| 577 | + assert layer.weight.mean().detach().cpu() == pytest.approx(0, abs=1e-3) |
| 578 | + |
| 579 | + layer = gpt2_block.mlp.W |
| 580 | + assert layer.weight.std().detach().cpu() == pytest.approx(0.02, abs=1e-3) |
| 581 | + assert layer.weight.max().detach().cpu() <= 2 |
| 582 | + assert layer.weight.min().detach().cpu() >= -2 |
| 583 | + assert layer.weight.mean().detach().cpu() == pytest.approx(0, abs=1e-3) |
0 commit comments