-
Notifications
You must be signed in to change notification settings - Fork 55
Expand file tree
/
Copy pathbase.py
More file actions
709 lines (631 loc) · 27.7 KB
/
base.py
File metadata and controls
709 lines (631 loc) · 27.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
# Copyright 2025 MTS (Mobile Telesystems)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import typing as tp
from collections.abc import Callable
from copy import deepcopy
from pathlib import Path
from tempfile import NamedTemporaryFile
import numpy as np
import torch
import typing_extensions as tpe
from pydantic import BeforeValidator, PlainSerializer
from pytorch_lightning import Trainer
from rectools import ExternalIds
from rectools.dataset.dataset import Dataset, DatasetSchema, DatasetSchemaDict, IdMap
from rectools.models.base import ErrorBehaviour, InternalRecoTriplet, ModelBase, ModelConfig
from rectools.types import InternalIdsArray
from rectools.utils.misc import get_class_or_function_full_path, import_object, make_dict_flat, unflatten_dict
from ..item_net import (
CatFeaturesItemNet,
IdEmbeddingsItemNet,
ItemNetBase,
ItemNetConstructorBase,
SumOfEmbeddingsConstructor,
)
from .context_net import CatFeaturesContextNet, ContextNetBase
from .data_preparator import InitKwargs, TransformerDataPreparatorBase
from .lightning import TransformerLightningModule, TransformerLightningModuleBase
from .negative_sampler import CatalogUniformSampler, TransformerNegativeSamplerBase
from .net_blocks import (
LearnableInversePositionalEncoding,
PositionalEncodingBase,
PreLNTransformerLayers,
TransformerLayersBase,
)
from .similarity import DistanceSimilarityModule, SimilarityModuleBase
from .torch_backbone import TransformerBackboneBase, TransformerTorchBackbone
# #### -------------- Transformer Model Config -------------- #### #
def _get_class_obj(spec: tp.Any) -> tp.Any:
if not isinstance(spec, str):
return spec
return import_object(spec)
def _get_class_obj_sequence(spec: tp.Sequence[tp.Any]) -> tp.Tuple[tp.Any, ...]:
return tuple(map(_get_class_obj, spec))
def _serialize_type_sequence(obj: tp.Sequence[tp.Type]) -> tp.Tuple[str, ...]:
return tuple(map(get_class_or_function_full_path, obj))
PositionalEncodingType = tpe.Annotated[
tp.Type[PositionalEncodingBase],
BeforeValidator(_get_class_obj),
PlainSerializer(
func=get_class_or_function_full_path,
return_type=str,
when_used="json",
),
]
TransformerLayersType = tpe.Annotated[
tp.Type[TransformerLayersBase],
BeforeValidator(_get_class_obj),
PlainSerializer(
func=get_class_or_function_full_path,
return_type=str,
when_used="json",
),
]
TransformerLightningModuleType = tpe.Annotated[
tp.Type[TransformerLightningModuleBase],
BeforeValidator(_get_class_obj),
PlainSerializer(
func=get_class_or_function_full_path,
return_type=str,
when_used="json",
),
]
SimilarityModuleType = tpe.Annotated[
tp.Type[SimilarityModuleBase],
BeforeValidator(_get_class_obj),
PlainSerializer(
func=get_class_or_function_full_path,
return_type=str,
when_used="json",
),
]
TransformerBackboneType = tpe.Annotated[
tp.Type[TransformerBackboneBase],
BeforeValidator(_get_class_obj),
PlainSerializer(
func=get_class_or_function_full_path,
return_type=str,
when_used="json",
),
]
ContextNetType = tpe.Annotated[
tp.Type[ContextNetBase],
BeforeValidator(_get_class_obj),
PlainSerializer(
func=get_class_or_function_full_path,
return_type=str,
when_used="json",
),
]
TransformerDataPreparatorType = tpe.Annotated[
tp.Type[TransformerDataPreparatorBase],
BeforeValidator(_get_class_obj),
PlainSerializer(
func=get_class_or_function_full_path,
return_type=str,
when_used="json",
),
]
TransformerNegativeSamplerType = tpe.Annotated[
tp.Type[TransformerNegativeSamplerBase],
BeforeValidator(_get_class_obj),
PlainSerializer(
func=get_class_or_function_full_path,
return_type=str,
when_used="json",
),
]
ItemNetConstructorType = tpe.Annotated[
tp.Type[ItemNetConstructorBase],
BeforeValidator(_get_class_obj),
PlainSerializer(
func=get_class_or_function_full_path,
return_type=str,
when_used="json",
),
]
ItemNetBlockTypes = tpe.Annotated[
tp.Sequence[tp.Type[ItemNetBase]],
BeforeValidator(_get_class_obj_sequence),
PlainSerializer(
func=_serialize_type_sequence,
return_type=str,
when_used="json",
),
]
ValMaskCallable = Callable[..., np.ndarray]
ValMaskCallableSerialized = tpe.Annotated[
ValMaskCallable,
BeforeValidator(_get_class_obj),
PlainSerializer(
func=get_class_or_function_full_path,
return_type=str,
when_used="json",
),
]
TrainerCallable = Callable[..., Trainer]
TrainerCallableSerialized = tpe.Annotated[
TrainerCallable,
BeforeValidator(_get_class_obj),
PlainSerializer(
func=get_class_or_function_full_path,
return_type=str,
when_used="json",
),
]
class TransformerModelConfig(ModelConfig):
"""Transformer model base config."""
data_preparator_type: TransformerDataPreparatorType
n_blocks: int = 2
n_heads: int = 4
n_factors: int = 256
use_pos_emb: bool = True
use_causal_attn: bool = False
use_key_padding_mask: bool = False
dropout_rate: float = 0.2
session_max_len: int = 100
dataloader_num_workers: int = 0
batch_size: int = 128
loss: str = "softmax"
n_negatives: int = 1
gbce_t: float = 0.2
lr: float = 0.001
epochs: int = 3
verbose: int = 0
deterministic: bool = False
recommend_batch_size: int = 256
recommend_torch_device: tp.Optional[str] = None
train_min_user_interactions: int = 2
item_net_block_types: ItemNetBlockTypes = (IdEmbeddingsItemNet, CatFeaturesItemNet)
item_net_constructor_type: ItemNetConstructorType = SumOfEmbeddingsConstructor
pos_encoding_type: PositionalEncodingType = LearnableInversePositionalEncoding
transformer_layers_type: TransformerLayersType = PreLNTransformerLayers
lightning_module_type: TransformerLightningModuleType = TransformerLightningModule
negative_sampler_type: TransformerNegativeSamplerType = CatalogUniformSampler
similarity_module_type: SimilarityModuleType = DistanceSimilarityModule
backbone_type: TransformerBackboneType = TransformerTorchBackbone
context_net_type: ContextNetType = CatFeaturesContextNet
get_val_mask_func: tp.Optional[ValMaskCallableSerialized] = None
get_trainer_func: tp.Optional[TrainerCallableSerialized] = None
get_val_mask_func_kwargs: tp.Optional[InitKwargs] = None
get_trainer_func_kwargs: tp.Optional[InitKwargs] = None
data_preparator_kwargs: tp.Optional[InitKwargs] = None
transformer_layers_kwargs: tp.Optional[InitKwargs] = None
item_net_constructor_kwargs: tp.Optional[InitKwargs] = None
pos_encoding_kwargs: tp.Optional[InitKwargs] = None
lightning_module_kwargs: tp.Optional[InitKwargs] = None
negative_sampler_kwargs: tp.Optional[InitKwargs] = None
similarity_module_kwargs: tp.Optional[InitKwargs] = None
backbone_kwargs: tp.Optional[InitKwargs] = None
context_net_kwargs: tp.Optional[InitKwargs] = None
TransformerModelConfig_T = tp.TypeVar("TransformerModelConfig_T", bound=TransformerModelConfig)
# #### -------------- Transformer Model Base -------------- #### #
class TransformerModelBase(ModelBase[TransformerModelConfig_T]): # pylint: disable=too-many-instance-attributes
"""
Base model for all recommender algorithms that work on transformer architecture (e.g. SASRec, Bert4Rec).
To create a custom transformer model it is necessary to inherit from this class
and write self.data_preparator initialization logic.
"""
config_class: tp.Type[TransformerModelConfig_T]
train_loss_name: str = "train_loss"
val_loss_name: str = "val_loss"
def __init__( # pylint: disable=too-many-arguments, too-many-locals
self,
data_preparator_type: tp.Type[TransformerDataPreparatorBase],
transformer_layers_type: tp.Type[TransformerLayersBase] = PreLNTransformerLayers,
n_blocks: int = 2,
n_heads: int = 4,
n_factors: int = 256,
use_pos_emb: bool = True,
use_causal_attn: bool = False,
use_key_padding_mask: bool = False,
dropout_rate: float = 0.2,
session_max_len: int = 100,
dataloader_num_workers: int = 0,
batch_size: int = 128,
loss: str = "softmax",
n_negatives: int = 1,
gbce_t: float = 0.2,
lr: float = 0.001,
epochs: int = 3,
verbose: int = 0,
deterministic: bool = False,
recommend_batch_size: int = 256,
recommend_torch_device: tp.Optional[str] = None,
train_min_user_interactions: int = 2,
item_net_block_types: tp.Sequence[tp.Type[ItemNetBase]] = (IdEmbeddingsItemNet, CatFeaturesItemNet),
item_net_constructor_type: tp.Type[ItemNetConstructorBase] = SumOfEmbeddingsConstructor,
pos_encoding_type: tp.Type[PositionalEncodingBase] = LearnableInversePositionalEncoding,
lightning_module_type: tp.Type[TransformerLightningModuleBase] = TransformerLightningModule,
negative_sampler_type: tp.Type[TransformerNegativeSamplerBase] = CatalogUniformSampler,
similarity_module_type: tp.Type[SimilarityModuleBase] = DistanceSimilarityModule,
backbone_type: tp.Type[TransformerBackboneBase] = TransformerTorchBackbone,
context_net_type: tp.Type[ContextNetBase] = CatFeaturesContextNet,
get_val_mask_func: tp.Optional[ValMaskCallable] = None,
get_trainer_func: tp.Optional[TrainerCallable] = None,
get_val_mask_func_kwargs: tp.Optional[InitKwargs] = None,
get_trainer_func_kwargs: tp.Optional[InitKwargs] = None,
data_preparator_kwargs: tp.Optional[InitKwargs] = None,
transformer_layers_kwargs: tp.Optional[InitKwargs] = None,
item_net_constructor_kwargs: tp.Optional[InitKwargs] = None,
pos_encoding_kwargs: tp.Optional[InitKwargs] = None,
lightning_module_kwargs: tp.Optional[InitKwargs] = None,
negative_sampler_kwargs: tp.Optional[InitKwargs] = None,
similarity_module_kwargs: tp.Optional[InitKwargs] = None,
backbone_kwargs: tp.Optional[InitKwargs] = None,
context_net_kwargs: tp.Optional[InitKwargs] = None,
**kwargs: tp.Any,
) -> None:
super().__init__(verbose=verbose)
self.transformer_layers_type = transformer_layers_type
self.data_preparator_type = data_preparator_type
self.n_blocks = n_blocks
self.n_heads = n_heads
self.n_factors = n_factors
self.use_pos_emb = use_pos_emb
self.use_causal_attn = use_causal_attn
self.use_key_padding_mask = use_key_padding_mask
self.dropout_rate = dropout_rate
self.session_max_len = session_max_len
self.dataloader_num_workers = dataloader_num_workers
self.batch_size = batch_size
self.loss = loss
self.n_negatives = n_negatives
self.gbce_t = gbce_t
self.lr = lr
self.epochs = epochs
self.deterministic = deterministic
self.recommend_batch_size = recommend_batch_size
self.recommend_torch_device = recommend_torch_device
self.train_min_user_interactions = train_min_user_interactions
self.similarity_module_type = similarity_module_type
self.item_net_block_types = item_net_block_types
self.item_net_constructor_type = item_net_constructor_type
self.pos_encoding_type = pos_encoding_type
self.lightning_module_type = lightning_module_type
self.negative_sampler_type = negative_sampler_type
self.backbone_type = backbone_type
self.context_net_type = context_net_type
self.get_val_mask_func = get_val_mask_func
self.get_trainer_func = get_trainer_func
self.get_val_mask_func_kwargs = get_val_mask_func_kwargs
self.get_trainer_func_kwargs = get_trainer_func_kwargs
self.data_preparator_kwargs = data_preparator_kwargs
self.transformer_layers_kwargs = transformer_layers_kwargs
self.item_net_constructor_kwargs = item_net_constructor_kwargs
self.pos_encoding_kwargs = pos_encoding_kwargs
self.lightning_module_kwargs = lightning_module_kwargs
self.negative_sampler_kwargs = negative_sampler_kwargs
self.similarity_module_kwargs = similarity_module_kwargs
self.backbone_kwargs = backbone_kwargs
self.context_net_kwargs = context_net_kwargs
self._init_data_preparator()
self._init_trainer()
self.lightning_model: TransformerLightningModuleBase
self.data_preparator: TransformerDataPreparatorBase
self.fit_trainer: tp.Optional[Trainer] = None
@staticmethod
def _get_kwargs(actual_kwargs: tp.Optional[InitKwargs]) -> InitKwargs:
kwargs = {}
if actual_kwargs is not None:
kwargs = actual_kwargs
return kwargs
def _init_data_preparator(self) -> None:
requires_negatives = self.lightning_module_type.requires_negatives(self.loss)
self.data_preparator = self.data_preparator_type(
session_max_len=self.session_max_len,
batch_size=self.batch_size,
dataloader_num_workers=self.dataloader_num_workers,
train_min_user_interactions=self.train_min_user_interactions,
negative_sampler=self._init_negative_sampler() if requires_negatives else None,
n_negatives=self.n_negatives if requires_negatives else None,
get_val_mask_func=self.get_val_mask_func,
get_val_mask_func_kwargs=self.get_val_mask_func_kwargs,
**self._get_kwargs(self.data_preparator_kwargs),
)
def _init_trainer(self) -> None:
if self.get_trainer_func is None:
self._trainer = Trainer(
max_epochs=self.epochs,
min_epochs=self.epochs,
deterministic=self.deterministic,
enable_progress_bar=self.verbose > 0,
enable_model_summary=self.verbose > 0,
logger=self.verbose > 0,
enable_checkpointing=False,
devices=1,
)
else:
self._trainer = self.get_trainer_func(**self._get_kwargs(self.get_trainer_func_kwargs))
def _init_negative_sampler(self) -> TransformerNegativeSamplerBase:
return self.negative_sampler_type(
n_negatives=self.n_negatives,
**self._get_kwargs(self.negative_sampler_kwargs),
)
def _construct_item_net(self, dataset: Dataset) -> ItemNetBase:
return self.item_net_constructor_type.from_dataset(
dataset,
self.n_factors,
self.dropout_rate,
self.item_net_block_types,
**self._get_kwargs(self.item_net_constructor_kwargs),
)
def _construct_context_net(self, dataset_schema: DatasetSchema) -> tp.Optional[ContextNetBase]:
if dataset_schema.interactions is None:
return None
return self.context_net_type.from_dataset_schema(
dataset_schema,
self.n_factors,
self.dropout_rate,
**self._get_kwargs(self.context_net_kwargs),
)
def _construct_item_net_from_dataset_schema(self, dataset_schema: DatasetSchema) -> ItemNetBase:
return self.item_net_constructor_type.from_dataset_schema(
dataset_schema,
self.n_factors,
self.dropout_rate,
self.item_net_block_types,
**self._get_kwargs(self.item_net_constructor_kwargs),
)
def _init_pos_encoding_layer(self) -> PositionalEncodingBase:
return self.pos_encoding_type(
self.use_pos_emb,
self.session_max_len,
self.n_factors,
**self._get_kwargs(self.pos_encoding_kwargs),
)
def _init_transformer_layers(self) -> TransformerLayersBase:
return self.transformer_layers_type(
n_blocks=self.n_blocks,
n_factors=self.n_factors,
n_heads=self.n_heads,
dropout_rate=self.dropout_rate,
**self._get_kwargs(self.transformer_layers_kwargs),
)
def _init_similarity_module(self) -> SimilarityModuleBase:
return self.similarity_module_type(**self._get_kwargs(self.similarity_module_kwargs))
def _init_torch_model(
self, item_model: ItemNetBase, context_net: tp.Optional[ContextNetBase]
) -> TransformerBackboneBase:
pos_encoding_layer = self._init_pos_encoding_layer()
transformer_layers = self._init_transformer_layers()
similarity_module = self._init_similarity_module()
return self.backbone_type(
n_heads=self.n_heads,
dropout_rate=self.dropout_rate,
item_model=item_model,
context_net=context_net,
pos_encoding_layer=pos_encoding_layer,
transformer_layers=transformer_layers,
similarity_module=similarity_module,
use_causal_attn=self.use_causal_attn,
use_key_padding_mask=self.use_key_padding_mask,
**self._get_kwargs(self.backbone_kwargs),
)
def _init_lightning_model(
self,
torch_model: TransformerBackboneBase,
dataset_schema: DatasetSchemaDict,
item_external_ids: ExternalIds,
model_config: tp.Dict[str, tp.Any],
) -> None:
self.lightning_model = self.lightning_module_type(
torch_model=torch_model,
dataset_schema=dataset_schema,
item_external_ids=item_external_ids,
item_extra_tokens=self.data_preparator.item_extra_tokens,
data_preparator=self.data_preparator,
model_config=model_config,
lr=self.lr,
loss=self.loss,
gbce_t=self.gbce_t,
verbose=self.verbose,
train_loss_name=self.train_loss_name,
val_loss_name=self.val_loss_name,
adam_betas=(0.9, 0.98),
**self._get_kwargs(self.lightning_module_kwargs),
)
def _build_model_from_dataset(self, dataset: Dataset) -> None:
self.data_preparator.process_dataset_train(dataset)
item_model = self._construct_item_net(self.data_preparator.train_dataset)
context_net = self._construct_context_net(
DatasetSchema.model_validate(self.data_preparator.train_dataset.get_schema())
)
torch_model = self._init_torch_model(item_model, context_net)
dataset_schema = self.data_preparator.train_dataset.get_schema()
item_external_ids = self.data_preparator.train_dataset.item_id_map.external_ids
model_config = self.get_config(simple_types=True)
self._init_lightning_model(
torch_model=torch_model,
dataset_schema=dataset_schema,
item_external_ids=item_external_ids,
model_config=model_config,
)
def _fit(
self,
dataset: Dataset,
) -> None:
self._build_model_from_dataset(dataset)
train_dataloader = self.data_preparator.get_dataloader_train()
val_dataloader = self.data_preparator.get_dataloader_val()
self.fit_trainer = deepcopy(self._trainer)
self.fit_trainer.fit(self.lightning_model, train_dataloader, val_dataloader)
def _custom_transform_dataset_u2i(
self, dataset: Dataset, users: ExternalIds, on_unsupported_targets: ErrorBehaviour
) -> Dataset:
return self.data_preparator.transform_dataset_u2i(dataset, users)
def _custom_transform_dataset_i2i(
self, dataset: Dataset, target_items: ExternalIds, on_unsupported_targets: ErrorBehaviour
) -> Dataset:
return self.data_preparator.transform_dataset_i2i(dataset)
def _fit_partial(
self,
dataset: Dataset,
min_epochs: int,
max_epochs: int,
) -> None:
if not self.is_fitted:
self._build_model_from_dataset(dataset)
self.fit_trainer = deepcopy(self._trainer)
elif self.fit_trainer is None:
self.data_preparator.process_dataset_train(dataset)
self.fit_trainer = deepcopy(self._trainer)
train_dataloader = self.data_preparator.get_dataloader_train()
val_dataloader = self.data_preparator.get_dataloader_val()
self.lightning_model.train()
self.fit_trainer.fit_loop.max_epochs = self.fit_trainer.current_epoch + max_epochs
self.fit_trainer.fit_loop.min_epochs = self.fit_trainer.current_epoch + min_epochs
self.fit_trainer.fit(self.lightning_model, train_dataloader, val_dataloader)
def _recommend_u2i(
self,
user_ids: InternalIdsArray,
dataset: Dataset, # [n_rec_users x n_items + n_item_extra_tokens]
k: int,
filter_viewed: bool,
sorted_item_ids_to_recommend: tp.Optional[InternalIdsArray], # model_internal
) -> InternalRecoTriplet:
if sorted_item_ids_to_recommend is None:
sorted_item_ids_to_recommend = self.data_preparator.get_known_items_sorted_internal_ids() # model internal
recommend_dataloader = self.data_preparator.get_dataloader_recommend(dataset, self.recommend_batch_size)
return self.lightning_model._recommend_u2i( # pylint: disable=protected-access
user_ids=user_ids,
recommend_dataloader=recommend_dataloader,
sorted_item_ids_to_recommend=sorted_item_ids_to_recommend,
k=k,
filter_viewed=filter_viewed,
dataset=dataset,
torch_device=self.recommend_torch_device,
)
def _recommend_i2i(
self,
target_ids: InternalIdsArray, # model internal
dataset: Dataset,
k: int,
sorted_item_ids_to_recommend: tp.Optional[InternalIdsArray],
) -> InternalRecoTriplet:
if sorted_item_ids_to_recommend is None:
sorted_item_ids_to_recommend = self.data_preparator.get_known_items_sorted_internal_ids()
return self.lightning_model._recommend_i2i( # pylint: disable=protected-access
target_ids=target_ids,
sorted_item_ids_to_recommend=sorted_item_ids_to_recommend,
k=k,
torch_device=self.recommend_torch_device,
)
@property
def torch_model(self) -> TransformerBackboneBase:
"""Pytorch model."""
return self.lightning_model.torch_model
@classmethod
def _from_config(cls, config: TransformerModelConfig_T) -> tpe.Self:
params = config.model_dump()
params.pop("cls")
return cls(**params)
def _get_config(self) -> TransformerModelConfig_T:
attrs = self.config_class.model_json_schema(mode="serialization")["properties"].keys()
params = {attr: getattr(self, attr) for attr in attrs if attr != "cls"}
params["cls"] = self.__class__
return self.config_class(**params)
@classmethod
def _model_from_checkpoint(cls, checkpoint: tp.Dict[str, tp.Any]) -> tpe.Self:
"""Create model from loaded Lightning checkpoint."""
model_config = checkpoint["hyper_parameters"]["model_config"]
loaded = cls.from_config(model_config)
loaded.is_fitted = True
dataset_schema = checkpoint["hyper_parameters"]["dataset_schema"]
dataset_schema = DatasetSchema.model_validate(dataset_schema)
# Update data preparator
item_external_ids = checkpoint["hyper_parameters"]["item_external_ids"]
loaded.data_preparator.item_id_map = IdMap(item_external_ids)
loaded.data_preparator._init_extra_token_ids() # pylint: disable=protected-access
# Init and update torch model and lightning model
item_model = loaded._construct_item_net_from_dataset_schema(dataset_schema)
context_net = loaded._construct_context_net(dataset_schema)
torch_model = loaded._init_torch_model(item_model, context_net)
loaded._init_lightning_model(
torch_model=torch_model,
dataset_schema=dataset_schema,
item_external_ids=item_external_ids,
model_config=model_config,
)
loaded.lightning_model.is_fitted = True
loaded.lightning_model.load_state_dict(checkpoint["state_dict"])
return loaded
def __getstate__(self) -> object:
if self.is_fitted:
if self.fit_trainer is None:
explanation = """
Model is fitted but has no `fit_trainer`. Most likely it was just loaded from the
checkpoint. Model that was loaded from checkpoint cannot be saved without being
fitted again.
"""
raise RuntimeError(explanation)
with NamedTemporaryFile() as f:
self.fit_trainer.save_checkpoint(f.name)
checkpoint = Path(f.name).read_bytes()
state: tp.Dict[str, tp.Any] = {"fitted_checkpoint": checkpoint}
return state
state = {"model_config": self.get_config(simple_types=True)}
return state
def __setstate__(self, state: tp.Dict[str, tp.Any]) -> None:
if "fitted_checkpoint" in state:
checkpoint = torch.load(io.BytesIO(state["fitted_checkpoint"]), weights_only=False)
loaded = self._model_from_checkpoint(checkpoint)
else:
loaded = self.from_config(state["model_config"])
self.__dict__.update(loaded.__dict__)
@classmethod
def load_from_checkpoint(
cls,
checkpoint_path: tp.Union[str, Path],
map_location: tp.Optional[tp.Union[str, torch.device]] = None,
model_params_update: tp.Optional[tp.Dict[str, tp.Any]] = None,
) -> tpe.Self:
"""Load model from Lightning checkpoint path.
Parameters
----------
checkpoint_path: Union[str, Path]
Path to checkpoint location.
map_location: Union[str, torch.device], optional
Target device to load the checkpoint (e.g., 'cpu', 'cuda:0').
If None, will use the device the checkpoint was saved on.
model_params_update: Dict[str, tp.Any], optional
Contains custom values for checkpoint['hyper_parameters']['model_config'].
Has to be flattened with 'dot' reducer, before passed.
You can use this argument to remove training-specific parameters that are not needed anymore.
e.g. 'get_trainer_func'
Returns
-------
Model instance.
"""
checkpoint = torch.load(checkpoint_path, map_location=map_location, weights_only=False)
if model_params_update:
prev_model_config = checkpoint["hyper_parameters"]["model_config"]
prev_config_flatten = make_dict_flat(prev_model_config)
prev_config_flatten.update(model_params_update)
checkpoint["hyper_parameters"]["model_config"] = unflatten_dict(prev_config_flatten)
loaded = cls._model_from_checkpoint(checkpoint)
return loaded
def load_weights_from_checkpoint(self, checkpoint_path: tp.Union[str, Path]) -> None:
"""
Load model weights from Lightning checkpoint path.
Parameters
----------
checkpoint_path: Union[str, Path]
Path to checkpoint location.
"""
if self.fit_trainer is None:
raise RuntimeError("Model weights cannot be loaded from checkpoint into unfitted model")
checkpoint = torch.load(checkpoint_path, weights_only=False)
self.lightning_model.load_state_dict(checkpoint["state_dict"])