2323from model2vec .train .base import FinetunableStaticModel , TextDataset
2424
2525logger = logging .getLogger (__name__ )
26- _RANDOM_SEED = 42
26+ _DEFAULT_RANDOM_SEED = 42
2727
2828LabelType = TypeVar ("LabelType" , list [str ], list [list [str ]])
2929
@@ -158,6 +158,7 @@ def fit( # noqa: C901 # Complexity is bad.
158158 X_val : list [str ] | None = None ,
159159 y_val : LabelType | None = None ,
160160 class_weight : torch .Tensor | None = None ,
161+ random_seed : int = _DEFAULT_RANDOM_SEED ,
161162 ) -> StaticModelForClassification :
162163 """
163164 Fit a model.
@@ -187,14 +188,14 @@ def fit( # noqa: C901 # Complexity is bad.
187188 :param y_val: The labels to be used for validation.
188189 :param class_weight: The weight of the classes. If None, all classes are weighted equally. Must
189190 have the same length as the number of classes.
191+ :param random_seed: The random seed to use. Defaults to 42.
190192 :return: The fitted model.
191193 :raises ValueError: If either X_val or y_val are provided, but not both.
192194 """
193- pl .seed_everything (_RANDOM_SEED )
195+ pl .seed_everything (random_seed )
194196 logger .info ("Re-initializing model." )
195197
196198 # Determine whether the task is multilabel based on the type of y.
197-
198199 self ._initialize (y )
199200
200201 if (X_val is not None ) != (y_val is not None ):
@@ -380,14 +381,13 @@ def to_pipeline(self) -> StaticModelPipeline:
380381 """Convert the model to an sklearn pipeline."""
381382 static_model = self .to_static_model ()
382383
383- random_state = np .random .RandomState (_RANDOM_SEED )
384+ random_state = np .random .RandomState (_DEFAULT_RANDOM_SEED )
384385 n_items = len (self .classes )
385386 X = random_state .randn (n_items , static_model .dim )
386387 y = self .classes
387388
388- converted = make_pipeline (MLPClassifier (hidden_layer_sizes = (self .hidden_dim ,) * self .n_layers ))
389- converted .fit (X , y )
390- mlp_head : MLPClassifier = converted [- 1 ]
389+ mlp_head = MLPClassifier (hidden_layer_sizes = (self .hidden_dim ,) * self .n_layers )
390+ mlp_head .fit (X , y )
391391
392392 for index , layer in enumerate ([module for module in self .head if isinstance (module , nn .Linear )]):
393393 mlp_head .coefs_ [index ] = layer .weight .detach ().cpu ().numpy ().T
@@ -401,7 +401,8 @@ def to_pipeline(self) -> StaticModelPipeline:
401401 # Set to softmax or sigmoid
402402 mlp_head .out_activation_ = "logistic" if self .multilabel else "softmax"
403403
404- return StaticModelPipeline (static_model , converted )
404+ pipeline = make_pipeline (mlp_head )
405+ return StaticModelPipeline (static_model , pipeline )
405406
406407
407408class _ClassifierLightningModule (pl .LightningModule ):
0 commit comments