Distance-Based Classification (DBC) for multilabel text classification using sentence embeddings.
This repository implements a distance-based classification pipeline for text data using sentence-transformers embeddings and similarity-based label scoring. It supports:
- Cosine and Euclidean similarity computation
- Threshold-based multilabel classification
- Per-label and global threshold optimization
- Keyword-enhanced label representations
- Prototype centroid generation for label vectors
- Data sampling and evaluation logging
CUDA 12.8:
uv venv && uv sync --extra cu128CUDA 12.1:
uv venv && uv sync --extra cu121Install PyTorch first for your CUDA version, then the rest:
# CUDA 12.8
pip install torch>=2.8.0 --index-url https://download.pytorch.org/whl/cu128
# CUDA 12.1
pip install torch==2.5.1 --index-url https://download.pytorch.org/whl/cu121
pip install -r requirements.txtrequirements.txt- Python dependencies for running the project.
This directory contains all scripts necessary to run the distance-based classification pipeline.
dbc.py- Main script and command-line entrypoint for running the DBC pipeline.data_utils.py- Data loading, conversion, sampling, and dataset utilities.model_utils.py- Sentence transformer loading and encoding utilities.similarity_utils.py- Similarity computation and normalization logic.threshold_utils.py- Threshold search and optimization utilities.evaluation_utils.py- Evaluation metrics and reporting utilities.losses.py- Loss functions (if applicable for training/experiments).upper_bound/- Additional code and utilities related to upper bound experiments.models.txt- Default model list file used bydbc.py.
Usage:
Run the main script (dbc.py) from the repository root:
python dbc/dbc.py \
--data_path path/to/data.json \
--data_split test \ # train; validation; test
--report_path results \
--model_names_file dbc/models.txt \
--batch_size 8 \
--threshold 0.5
To use calibrated label-specific thresholds, first run the following command:
python dbc/dbc.py \
--data_path path/to/data.json \
--data_split val \ # train; validation; test
--report_path results \
--model_names_file dbc/models.txt \
--batch_size 8 \
--iterate_over_thresholds_per_label # or: --iterate_over_thresholds
# or: --threshold_per_label --tpl_strategy medianThis creates a file with the optimal thresholds for each label in the results/ directory.
Then, run the following command to load the label-specific thresholds:
python dbc/dbc.py \
--data_path path/to/data.json \
--data_split val \ # train; validation; test
--report_path results \
--model_names_file dbc/models.txt \
--batch_size 8 \
--load_thresholds_per_label # or: --load_thresholds
--threshold_file_path path/to/file_with_thresholdsThis directory contains all scripts for fine-tuning BERT-like models for multi-label classification.
- main.py - Main script for fine-tuning.
- model.py - Contains the model class.
- test.py - Contains the inference pipeline.
- trainer.py - Contains the training loop.
- utils.py - Contains various utility functions.
Usage:
Quick test (1 epoch):
python upper_bound/main.py \
--data_path data/reuters21578.json \
--epochs 1 \
--seeds 0BERT (default):
python upper_bound/main.py \
--data_path path/to/data.json \
--seeds 0RoBERTa:
python upper_bound/main.py \
--data_path path/to/data.json \
--model_name roberta-base \
--epochs 1 \
--seeds 0python dbc/finetune_st.py \
--data_path path/to/data.json \
--data_split validation \ # train; validation; test
--model_name BAAI/bge-base-en-v1.5 \ # avsolatorio/GIST-large-Embedding-v0
--cl standard \ # standard; pairwise
--n_epochs 1 \
--output_dir outputs/models/Input data is expected as a JSON file with the following structure:
{
"meta": {
"name": "dataset-name",
"label_set": ["label1", "label2", "..."]
},
"data": {
"train": [
{
"id": "1",
"text": "Example text...",
"labels": ["label1", "label2"]
}
],
"validation": [ ... ],
"test": [ ... ]
}
}The --data_split argument selects which split to use (train, validation, test).
--data_path- path to the input JSON dataset.--data_split- dataset split name (train,validation,test).--model_path- model name or local model path.--model_names_file- file containing model paths, one per line. Exampledbc/models.txt:BAAI/bge-base-en-v1.5 avsolatorio/GIST-large-Embedding-v0--report_path- output directory for results and reports.--batch_size- batch size for encoding.--similarity_metric-cosineoreuclidean.--normalize_sims- normalize similarity scores using min-max normalization.--threshold- float as a fixed similarity threshold for classification.--threshold_per_label- compute per-label thresholds using unsupervised statistical methods (mean,median...).--iterate_over_thresholds- search for the best global threshold.--iterate_over_thresholds_per_label- search for the best label-specific thresholds.--do_bayesian_optimization- use Bayesian optimization for threshold search.--average_keywords- average label name embeddings with keyword embeddings--log_file- file path to save logs.
--project_name- W&B project name (default:DBC-upperbound-baseline).--group_name- W&B group name for grouping runs (default:None).
--data_path- path to the JSON dataset file.--text_col- name of the text column in the dataset (default:text).--label_col- name of the label column in the dataset (default:labels).--sample_data- sample a smaller training subset using stratified label sampling.--training_data_sample- number of training examples to keep, or 0 to use the full set (default:0).--sample_seed- random seed for data sampling (default:42).
--model_name- pretrained model name or path (default:bert-base-cased).--train_batch_size- batch size for training and evaluation (default:8).--seeds- list of random seeds for repeated experiments (default:[0, 1, 2, 3, 4]).--learning_rate- initial learning rate for AdamW optimizer (default:5e-5).--epochs- number of training epochs (default:10).--max_length- maximum token length for tokenizer inputs (default:512).--dropout- dropout probability for the classifier head (default:0.1).--accumulate_grad- number of gradient accumulation steps (default:1).--remove_tanh- disable the tanh activation after the dense layer.--early_stopping_patience- early stopping patience on the validation metric (default:10).--best_metric_name- evaluation metric name used for model checkpointing (default:eval/loss).--max_grad_norm- gradient clipping norm (default:1.0).--threshold- classification threshold for sigmoid outputs (default:0.5).--device- device used for model training and evaluation (default:cuda).--output_dir- base directory for training outputs and checkpoints (default:outputs/).
Provide a JSON keywords file and enable keyword averaging:
python dbc/dbc.py \
--data_path path/to/data.json \
--keywords_path path/to/keywords.json \
--average_keywordsUse top similar texts to refine label centroids:
python dbc/dbc.py \
--use_prototype_centroids \
--n_sentences_for_centroid 5Results and reports are stored under the --report_path directory in a dataset-specific subfolder.
dbc/models.txtshould list one model path or name per line.- If you use Hugging Face models that require authentication, ensure
use_auth_token=Trueor provide credentials. - The script automatically selects
cudawhen available, otherwise it falls back tocpu.