|
5 | 5 | import webbrowser |
6 | 6 | from datetime import datetime |
7 | 7 | from pathlib import Path |
8 | | -from typing import Any, Iterable, Literal, Optional, Sequence, Union |
| 8 | +from typing import Any, Callable, Iterable, Literal, Optional, Sequence, Union |
9 | 9 |
|
10 | 10 | import numpy as np |
11 | 11 | from rich.console import Console |
|
30 | 30 | npmi, |
31 | 31 | soft_ctf_idf, |
32 | 32 | ) |
| 33 | +from turftopic.late import LateSentenceTransformer, LateWrapper |
33 | 34 | from turftopic.models._hierarchical_clusters import ( |
34 | 35 | VALID_LINKAGE_METHODS, |
35 | 36 | ClusterNode, |
|
43 | 44 | ) |
44 | 45 | from turftopic.types import VALID_DISTANCE_METRICS, DistanceMetric |
45 | 46 | from turftopic.utils import safe_binarize |
| 47 | +from turftopic.vectorizers import PhraseVectorizer |
46 | 48 | from turftopic.vectorizers.default import default_vectorizer |
47 | 49 |
|
48 | 50 | integer_message = """ |
@@ -865,3 +867,96 @@ def __init__( |
865 | 867 | reduction_distance_metric=reduction_distance_metric, |
866 | 868 | reduction_topic_representation=reduction_topic_representation, |
867 | 869 | ) |
| 870 | + |
| 871 | + |
| 872 | +class CTop2Vec(LateWrapper): |
| 873 | + """Convenience function to construct a CTop2Vec model in Turftopic. |
| 874 | + The model is essentially the same as ClusteringTopicModel in a Late Wrapper |
| 875 | + with defaults that resemble CTop2Vec. This includes: |
| 876 | +
|
| 877 | + 1. A late interaction embedding model, with windowed aggregation |
| 878 | + 2. UMAP reduction |
| 879 | + 3. HDBSCAN clustering |
| 880 | + 4. Centroid term importance |
| 881 | + 5. Phrase vectorizer |
| 882 | +
|
| 883 | + ```bash |
| 884 | + pip install turftopic[umap-learn] |
| 885 | + ``` |
| 886 | +
|
| 887 | + ```python |
| 888 | + from turftopic import CTop2Vec |
| 889 | +
|
| 890 | + corpus: list[str] = ["some text", "more text", ...] |
| 891 | +
|
| 892 | + model = CTop2Vec().fit(corpus) |
| 893 | + model.print_topics() |
| 894 | + ``` |
| 895 | + """ |
| 896 | + |
| 897 | + def __init__( |
| 898 | + self, |
| 899 | + encoder: Union[ |
| 900 | + Encoder, str, MultimodalEncoder |
| 901 | + ] = "sentence-transformers/all-MiniLM-L6-v2", |
| 902 | + vectorizer: Optional[CountVectorizer] = None, |
| 903 | + dimensionality_reduction: Optional[TransformerMixin] = None, |
| 904 | + clustering: Optional[ClusterMixin] = None, |
| 905 | + feature_importance: WordImportance = "centroid", |
| 906 | + n_reduce_to: Optional[int] = None, |
| 907 | + reduction_method: LinkageMethod = "smallest", |
| 908 | + reduction_distance_metric: DistanceMetric = "cosine", |
| 909 | + reduction_topic_representation: TopicRepresentation = "centroid", |
| 910 | + window_size: Optional[int] = 50, |
| 911 | + step_size: Optional[int] = 40, |
| 912 | + pooling: Optional[Callable] = np.mean, |
| 913 | + random_state: Optional[int] = None, |
| 914 | + ): |
| 915 | + if dimensionality_reduction is None: |
| 916 | + try: |
| 917 | + from umap import UMAP |
| 918 | + except ModuleNotFoundError as e: |
| 919 | + raise ModuleNotFoundError( |
| 920 | + "UMAP is not installed in your environment, but Top2Vec requires it." |
| 921 | + ) from e |
| 922 | + dimensionality_reduction = UMAP( |
| 923 | + n_neighbors=15, |
| 924 | + n_components=5, |
| 925 | + min_dist=0.0, |
| 926 | + metric="cosine", |
| 927 | + random_state=random_state, |
| 928 | + ) |
| 929 | + if clustering is None: |
| 930 | + clustering = HDBSCAN( |
| 931 | + min_cluster_size=15, |
| 932 | + metric="euclidean", |
| 933 | + cluster_selection_method="eom", |
| 934 | + ) |
| 935 | + self.encoder = encoder |
| 936 | + self.vectorizer = vectorizer |
| 937 | + self.dimensionality_reduction = dimensionality_reduction |
| 938 | + self.clustering = clustering |
| 939 | + self.feature_importance = feature_importance |
| 940 | + self.n_reduce_to = n_reduce_to |
| 941 | + self.reduction_method = reduction_method |
| 942 | + self.reduction_distance_metric = reduction_distance_metric |
| 943 | + self.reduction_topic_representation = reduction_topic_representation |
| 944 | + self.random_state = random_state |
| 945 | + self.model = ClusteringTopicModel( |
| 946 | + encoder=encoder, |
| 947 | + vectorizer=vectorizer, |
| 948 | + dimensionality_reduction=dimensionality_reduction, |
| 949 | + clustering=clustering, |
| 950 | + n_reduce_to=n_reduce_to, |
| 951 | + random_state=random_state, |
| 952 | + feature_importance=feature_importance, |
| 953 | + reduction_method=reduction_method, |
| 954 | + reduction_distance_metric=reduction_distance_metric, |
| 955 | + reduction_topic_representation=reduction_topic_representation, |
| 956 | + ) |
| 957 | + super().__init__( |
| 958 | + self.model, |
| 959 | + window_size=self.window_size, |
| 960 | + step_size=self.step_size, |
| 961 | + pooling=self.pooling, |
| 962 | + ) |
0 commit comments