diff --git a/pyproject.toml b/pyproject.toml index 3a0d1fc..069aadb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ profile = "black" [project] name = "turftopic" -version = "0.25.0" +version = "0.25.1" description = "Topic modeling with contextual representations from sentence transformers." authors = [ { name = "Márton Kardos ", email = "martonkardos@cas.au.dk" } diff --git a/turftopic/late.py b/turftopic/late.py index cb140d6..40e13c5 100644 --- a/turftopic/late.py +++ b/turftopic/late.py @@ -63,6 +63,9 @@ def _encode_tokens( ): batch = texts[start_index : start_index + batch_size] features = self.tokenize(batch) + features = { + key: value.to(self.device) for key, value in features.items() + } with torch.no_grad(): output_features = self.forward(features) n_tokens = output_features["attention_mask"].sum(axis=1)