Skip to content

Commit b137dc6

Browse files
authored
feat: expose random seed in training
* feat: expose random seed, small typing ix * small reshuffle for typing: * seed -> random_seed
1 parent 8f46692 commit b137dc6

1 file changed

Lines changed: 9 additions & 8 deletions

File tree

model2vec/train/classifier.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from model2vec.train.base import FinetunableStaticModel, TextDataset
2424

2525
logger = logging.getLogger(__name__)
26-
_RANDOM_SEED = 42
26+
_DEFAULT_RANDOM_SEED = 42
2727

2828
LabelType = TypeVar("LabelType", list[str], list[list[str]])
2929

@@ -158,6 +158,7 @@ def fit( # noqa: C901 # Complexity is bad.
158158
X_val: list[str] | None = None,
159159
y_val: LabelType | None = None,
160160
class_weight: torch.Tensor | None = None,
161+
random_seed: int = _DEFAULT_RANDOM_SEED,
161162
) -> StaticModelForClassification:
162163
"""
163164
Fit a model.
@@ -187,14 +188,14 @@ def fit( # noqa: C901 # Complexity is bad.
187188
:param y_val: The labels to be used for validation.
188189
:param class_weight: The weight of the classes. If None, all classes are weighted equally. Must
189190
have the same length as the number of classes.
191+
:param random_seed: The random seed to use. Defaults to 42.
190192
:return: The fitted model.
191193
:raises ValueError: If either X_val or y_val are provided, but not both.
192194
"""
193-
pl.seed_everything(_RANDOM_SEED)
195+
pl.seed_everything(random_seed)
194196
logger.info("Re-initializing model.")
195197

196198
# Determine whether the task is multilabel based on the type of y.
197-
198199
self._initialize(y)
199200

200201
if (X_val is not None) != (y_val is not None):
@@ -380,14 +381,13 @@ def to_pipeline(self) -> StaticModelPipeline:
380381
"""Convert the model to an sklearn pipeline."""
381382
static_model = self.to_static_model()
382383

383-
random_state = np.random.RandomState(_RANDOM_SEED)
384+
random_state = np.random.RandomState(_DEFAULT_RANDOM_SEED)
384385
n_items = len(self.classes)
385386
X = random_state.randn(n_items, static_model.dim)
386387
y = self.classes
387388

388-
converted = make_pipeline(MLPClassifier(hidden_layer_sizes=(self.hidden_dim,) * self.n_layers))
389-
converted.fit(X, y)
390-
mlp_head: MLPClassifier = converted[-1]
389+
mlp_head = MLPClassifier(hidden_layer_sizes=(self.hidden_dim,) * self.n_layers)
390+
mlp_head.fit(X, y)
391391

392392
for index, layer in enumerate([module for module in self.head if isinstance(module, nn.Linear)]):
393393
mlp_head.coefs_[index] = layer.weight.detach().cpu().numpy().T
@@ -401,7 +401,8 @@ def to_pipeline(self) -> StaticModelPipeline:
401401
# Set to softmax or sigmoid
402402
mlp_head.out_activation_ = "logistic" if self.multilabel else "softmax"
403403

404-
return StaticModelPipeline(static_model, converted)
404+
pipeline = make_pipeline(mlp_head)
405+
return StaticModelPipeline(static_model, pipeline)
405406

406407

407408
class _ClassifierLightningModule(pl.LightningModule):

0 commit comments

Comments
 (0)