Skip to content

Commit 2146f68

Browse files
Added C-Top2Vec
1 parent aaa0e17 commit 2146f68

2 files changed

Lines changed: 103 additions & 2 deletions

File tree

turftopic/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,12 @@
22
from turftopic._datamapplot import build_datamapplot
33
from turftopic.base import ContextualModel
44
from turftopic.error import NotInstalled
5-
from turftopic.models.cluster import BERTopic, ClusteringTopicModel, Top2Vec
5+
from turftopic.models.cluster import (
6+
BERTopic,
7+
ClusteringTopicModel,
8+
CTop2Vec,
9+
Top2Vec,
10+
)
611
from turftopic.models.cvp import ConceptVectorProjection
712
from turftopic.models.decomp import S3, SemanticSignalSeparation
813
from turftopic.models.fastopic import FASTopic
@@ -29,6 +34,7 @@
2934
"ContextualModel",
3035
"FASTopic",
3136
"Top2Vec",
37+
"CTop2Vec",
3238
"BERTopic",
3339
"load_model",
3440
"build_datamapplot",

turftopic/models/cluster.py

Lines changed: 96 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import webbrowser
66
from datetime import datetime
77
from pathlib import Path
8-
from typing import Any, Iterable, Literal, Optional, Sequence, Union
8+
from typing import Any, Callable, Iterable, Literal, Optional, Sequence, Union
99

1010
import numpy as np
1111
from rich.console import Console
@@ -30,6 +30,7 @@
3030
npmi,
3131
soft_ctf_idf,
3232
)
33+
from turftopic.late import LateSentenceTransformer, LateWrapper
3334
from turftopic.models._hierarchical_clusters import (
3435
VALID_LINKAGE_METHODS,
3536
ClusterNode,
@@ -43,6 +44,7 @@
4344
)
4445
from turftopic.types import VALID_DISTANCE_METRICS, DistanceMetric
4546
from turftopic.utils import safe_binarize
47+
from turftopic.vectorizers import PhraseVectorizer
4648
from turftopic.vectorizers.default import default_vectorizer
4749

4850
integer_message = """
@@ -865,3 +867,96 @@ def __init__(
865867
reduction_distance_metric=reduction_distance_metric,
866868
reduction_topic_representation=reduction_topic_representation,
867869
)
870+
871+
872+
class CTop2Vec(LateWrapper):
873+
"""Convenience function to construct a CTop2Vec model in Turftopic.
874+
The model is essentially the same as ClusteringTopicModel in a Late Wrapper
875+
with defaults that resemble CTop2Vec. This includes:
876+
877+
1. A late interaction embedding model, with windowed aggregation
878+
2. UMAP reduction
879+
3. HDBSCAN clustering
880+
4. Centroid term importance
881+
5. Phrase vectorizer
882+
883+
```bash
884+
pip install turftopic[umap-learn]
885+
```
886+
887+
```python
888+
from turftopic import CTop2Vec
889+
890+
corpus: list[str] = ["some text", "more text", ...]
891+
892+
model = CTop2Vec().fit(corpus)
893+
model.print_topics()
894+
```
895+
"""
896+
897+
def __init__(
898+
self,
899+
encoder: Union[
900+
Encoder, str, MultimodalEncoder
901+
] = "sentence-transformers/all-MiniLM-L6-v2",
902+
vectorizer: Optional[CountVectorizer] = None,
903+
dimensionality_reduction: Optional[TransformerMixin] = None,
904+
clustering: Optional[ClusterMixin] = None,
905+
feature_importance: WordImportance = "centroid",
906+
n_reduce_to: Optional[int] = None,
907+
reduction_method: LinkageMethod = "smallest",
908+
reduction_distance_metric: DistanceMetric = "cosine",
909+
reduction_topic_representation: TopicRepresentation = "centroid",
910+
window_size: Optional[int] = 50,
911+
step_size: Optional[int] = 40,
912+
pooling: Optional[Callable] = np.mean,
913+
random_state: Optional[int] = None,
914+
):
915+
if dimensionality_reduction is None:
916+
try:
917+
from umap import UMAP
918+
except ModuleNotFoundError as e:
919+
raise ModuleNotFoundError(
920+
"UMAP is not installed in your environment, but Top2Vec requires it."
921+
) from e
922+
dimensionality_reduction = UMAP(
923+
n_neighbors=15,
924+
n_components=5,
925+
min_dist=0.0,
926+
metric="cosine",
927+
random_state=random_state,
928+
)
929+
if clustering is None:
930+
clustering = HDBSCAN(
931+
min_cluster_size=15,
932+
metric="euclidean",
933+
cluster_selection_method="eom",
934+
)
935+
self.encoder = encoder
936+
self.vectorizer = vectorizer
937+
self.dimensionality_reduction = dimensionality_reduction
938+
self.clustering = clustering
939+
self.feature_importance = feature_importance
940+
self.n_reduce_to = n_reduce_to
941+
self.reduction_method = reduction_method
942+
self.reduction_distance_metric = reduction_distance_metric
943+
self.reduction_topic_representation = reduction_topic_representation
944+
self.random_state = random_state
945+
self.model = ClusteringTopicModel(
946+
encoder=encoder,
947+
vectorizer=vectorizer,
948+
dimensionality_reduction=dimensionality_reduction,
949+
clustering=clustering,
950+
n_reduce_to=n_reduce_to,
951+
random_state=random_state,
952+
feature_importance=feature_importance,
953+
reduction_method=reduction_method,
954+
reduction_distance_metric=reduction_distance_metric,
955+
reduction_topic_representation=reduction_topic_representation,
956+
)
957+
super().__init__(
958+
self.model,
959+
window_size=self.window_size,
960+
step_size=self.step_size,
961+
pooling=self.pooling,
962+
)

0 commit comments

Comments
 (0)