@@ -92,18 +92,22 @@ def test_distill_removal_pattern_all_tokens(
9292 mock_model_info .return_value = type ("ModelInfo" , (object ,), {"cardData" : {"language" : "en" }})
9393 mock_auto_model .return_value = mock_transformer
9494
95- with pytest .raises (ValueError ):
96- distill_from_model (
97- model = mock_transformer ,
98- tokenizer = mock_berttokenizer ,
99- vocabulary = None ,
100- device = "cpu" ,
101- token_remove_pattern = r".*" ,
102- )
95+ # Even if we remove all tokens, we can't remove the [UNK] token
96+ model = distill_from_model (
97+ model = mock_transformer ,
98+ tokenizer = mock_berttokenizer ,
99+ vocabulary = None ,
100+ device = "cpu" ,
101+ token_remove_pattern = r".*" ,
102+ )
103+
104+ # So the only token left is the [UNK] token.
105+ assert model .tokens == ("[UNK]" ,)
103106
104107
105108@patch .object (import_module ("model2vec.distill.distillation" ), "model_info" )
106109@patch ("transformers.AutoModel.from_pretrained" )
110+ @pytest .mark .parametrize ("mock_transformer" , [{"vocab_size" : 35022 }], indirect = True )
107111def test_distill_removal_pattern (
108112 mock_auto_model : MagicMock ,
109113 mock_model_info : MagicMock ,
@@ -114,7 +118,8 @@ def test_distill_removal_pattern(
114118 mock_model_info .return_value = type ("ModelInfo" , (object ,), {"cardData" : {"language" : "en" }})
115119 mock_auto_model .return_value = mock_transformer
116120
117- expected_vocab_size = mock_berttokenizer .vocab_size
121+ # Because the added [MASK], [CLS] and [SEP] get removed
122+ expected_vocab_size = mock_berttokenizer .vocab_size - 3
118123
119124 static_model = distill_from_model (
120125 model = mock_transformer ,
@@ -159,18 +164,19 @@ def test_distill_removal_pattern(
159164@pytest .mark .parametrize (
160165 "vocabulary, pca_dims, sif_coefficient, expected_shape" ,
161166 [
162- (None , 256 , None , (30522 , 256 )), # PCA applied, SIF off
163- (None , "auto" , None , (30522 , 768 )), # PCA 'auto', SIF off
164- (None , "auto" , 1e-4 , (30522 , 768 )), # PCA 'auto', SIF on
167+ (None , 256 , None , (30519 , 256 )), # PCA applied, SIF off
168+ (None , "auto" , None , (30519 , 768 )), # PCA 'auto', SIF off
169+ (None , "auto" , 1e-4 , (30519 , 768 )), # PCA 'auto', SIF on
165170 (None , "auto" , 0 , None ), # invalid SIF (too low) -> raises
166171 (None , "auto" , 1 , None ), # invalid SIF (too high) -> raises
167- (None , 1024 , None , (30522 , 768 )), # PCA set high (no reduction)
168- (["wordA" , "wordB" ], 4 , None , (30524 , 4 )), # Custom vocab, PCA applied
169- (None , None , None , (30522 , 768 )), # No PCA, SIF off
172+ (None , 1024 , None , (30519 , 768 )), # PCA set high (no reduction)
173+ (["wordA" , "wordB" ], 4 , None , (30521 , 4 )), # Custom vocab, PCA applied
174+ (None , None , None , (30519 , 768 )), # No PCA, SIF off
170175 ],
171176)
172177@patch .object (import_module ("model2vec.distill.distillation" ), "model_info" )
173178@patch ("transformers.AutoModel.from_pretrained" )
179+ @pytest .mark .parametrize ("mock_transformer" , [{"vocab_size" : 30522 }], indirect = True )
174180def test_distill (
175181 mock_auto_model : MagicMock ,
176182 mock_model_info : MagicMock ,
0 commit comments