@@ -28,20 +28,25 @@ def __init__(
2828 self .use_bias = use_bias
2929 self .dropout = dropout
3030
31- linear = nn .Linear (in_dim , out_dim , bias = use_bias )
32- weight = linear .weight
33- bias = linear .bias if self .use_bias else None
34- self .register_parameter ("weight" , nn .Parameter (weight ))
35- self .register_parameter (
36- "bias" , nn .Parameter (bias ) if bias is not None else None
37- )
38-
31+ self .linear = nn .Linear (in_dim , out_dim , bias = use_bias )
3932 self .dropout = nn .Dropout (p = dropout ) if dropout > 0.0 else nn .Identity ()
4033 self .lora_a = nn .Linear (in_features = in_dim , out_features = rank , bias = False )
4134 self .lora_b = nn .Linear (in_features = rank , out_features = out_dim , bias = False )
4235
36+ @property
37+ def weight (self ):
38+ return self .linear .weight
39+
40+ def _load_from_state_dict (self , state_dict , prefix , * args , ** kwargs ):
41+ # Remap old-style "weight" key to "linear.weight" for backward compat
42+ old_key = prefix + "weight"
43+ new_key = prefix + "linear.weight"
44+ if old_key in state_dict and new_key not in state_dict :
45+ state_dict [new_key ] = state_dict .pop (old_key )
46+ super ()._load_from_state_dict (state_dict , prefix , * args , ** kwargs )
47+
4348 def forward (self , x : torch .Tensor ) -> torch .Tensor :
44- out = torch . nn . functional . linear (x , self . weight , self . bias )
49+ out = self . linear (x )
4550 lora_out = self .lora_a (self .dropout (x ))
4651 lora_out = (self .alpha / self .rank ) * self .lora_b (lora_out )
4752
0 commit comments