File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -468,6 +468,7 @@ class MeridianForCausalLM(PreTrainedModel):
468468 config_class = MeridianConfig
469469 base_model_prefix = "model"
470470 supports_gradient_checkpointing = True
471+ _tied_weights_keys = {"lm_head.weight" : "model.embed_tokens.weight" }
471472
472473 def __init__ (self , config : MeridianConfig ):
473474 super ().__init__ (config )
@@ -478,11 +479,6 @@ def __init__(self, config: MeridianConfig):
478479 # Initialize weights and apply final processing
479480 self .post_init ()
480481
481- def tie_weights (self ):
482- """Tie input and output embeddings if configured."""
483- if self .config .tie_word_embeddings :
484- self .lm_head .weight = self .model .embed_tokens .weight
485-
486482 def _init_weights (self , module : nn .Module ) -> None :
487483 std = self .config .initializer_range
488484 if isinstance (module , nn .Linear ):
You can’t perform that action at this time.
0 commit comments