Skip to content

Commit 72b8762

Browse files
fix: correctly define _tied_weights_keys as dict for safetensors compatibility
1 parent 9a76d88 commit 72b8762

1 file changed

Lines changed: 1 addition & 5 deletions

File tree

meridian/model/modeling.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,7 @@ class MeridianForCausalLM(PreTrainedModel):
468468
config_class = MeridianConfig
469469
base_model_prefix = "model"
470470
supports_gradient_checkpointing = True
471+
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
471472

472473
def __init__(self, config: MeridianConfig):
473474
super().__init__(config)
@@ -478,11 +479,6 @@ def __init__(self, config: MeridianConfig):
478479
# Initialize weights and apply final processing
479480
self.post_init()
480481

481-
def tie_weights(self):
482-
"""Tie input and output embeddings if configured."""
483-
if self.config.tie_word_embeddings:
484-
self.lm_head.weight = self.model.embed_tokens.weight
485-
486482
def _init_weights(self, module: nn.Module) -> None:
487483
std = self.config.initializer_range
488484
if isinstance(module, nn.Linear):

0 commit comments

Comments
 (0)