Skip to content

clips/push_and_pull

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

40 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

push_and_pull

Distance-Based Classification (DBC) for multilabel text classification using sentence embeddings.

Overview

This repository implements a distance-based classification pipeline for text data using sentence-transformers embeddings and similarity-based label scoring. It supports:

  • Cosine and Euclidean similarity computation
  • Threshold-based multilabel classification
  • Per-label and global threshold optimization
  • Keyword-enhanced label representations
  • Prototype centroid generation for label vectors
  • Data sampling and evaluation logging

Requirements

Setup with uv (recommended)

CUDA 12.8:

uv venv && uv sync --extra cu128

CUDA 12.1:

uv venv && uv sync --extra cu121

Setup with pip

Install PyTorch first for your CUDA version, then the rest:

# CUDA 12.8
pip install torch>=2.8.0 --index-url https://download.pytorch.org/whl/cu128
# CUDA 12.1
pip install torch==2.5.1 --index-url https://download.pytorch.org/whl/cu121

pip install -r requirements.txt

Repository structure

  • requirements.txt - Python dependencies for running the project.

DBC - Distance-Based Classification

This directory contains all scripts necessary to run the distance-based classification pipeline.

  • dbc.py - Main script and command-line entrypoint for running the DBC pipeline.
  • data_utils.py - Data loading, conversion, sampling, and dataset utilities.
  • model_utils.py - Sentence transformer loading and encoding utilities.
  • similarity_utils.py - Similarity computation and normalization logic.
  • threshold_utils.py - Threshold search and optimization utilities.
  • evaluation_utils.py - Evaluation metrics and reporting utilities.
  • losses.py - Loss functions (if applicable for training/experiments).
  • upper_bound/ - Additional code and utilities related to upper bound experiments.
  • models.txt - Default model list file used by dbc.py.

Usage:

Run the main script (dbc.py) from the repository root:

python dbc/dbc.py \
  --data_path path/to/data.json \
  --data_split test \                    # train; validation; test
  --report_path results \
  --model_names_file dbc/models.txt \
  --batch_size 8 \
  --threshold 0.5                        
                                       

To use calibrated label-specific thresholds, first run the following command:

python dbc/dbc.py \
  --data_path path/to/data.json \
  --data_split val \                    # train; validation; test
  --report_path results \
  --model_names_file dbc/models.txt \
  --batch_size 8 \
  --iterate_over_thresholds_per_label   # or: --iterate_over_thresholds
                                        # or: --threshold_per_label --tpl_strategy median

This creates a file with the optimal thresholds for each label in the results/ directory. Then, run the following command to load the label-specific thresholds:

python dbc/dbc.py \
  --data_path path/to/data.json \
  --data_split val \                    # train; validation; test
  --report_path results \
  --model_names_file dbc/models.txt \
  --batch_size 8 \
  --load_thresholds_per_label           # or: --load_thresholds
  --threshold_file_path path/to/file_with_thresholds

Upper-bound - Fine-tuning SLMs for Multi-Label Classification

This directory contains all scripts for fine-tuning BERT-like models for multi-label classification.

  • main.py - Main script for fine-tuning.
  • model.py - Contains the model class.
  • test.py - Contains the inference pipeline.
  • trainer.py - Contains the training loop.
  • utils.py - Contains various utility functions.

Usage:

Quick test (1 epoch):

python upper_bound/main.py \
  --data_path data/reuters21578.json \
  --epochs 1 \
  --seeds 0

BERT (default):

python upper_bound/main.py \
  --data_path path/to/data.json \
  --seeds 0

RoBERTa:

python upper_bound/main.py \
  --data_path path/to/data.json \
  --model_name roberta-base \
  --epochs 1 \
  --seeds 0

Fine-tuning sentence transformers

python dbc/finetune_st.py \
  --data_path path/to/data.json \
  --data_split validation \              # train; validation; test
  --model_name BAAI/bge-base-en-v1.5 \   # avsolatorio/GIST-large-Embedding-v0
  --cl standard \                        # standard; pairwise
  --n_epochs 1 \
  --output_dir outputs/models/

Data format

Input data is expected as a JSON file with the following structure:

{
  "meta": {
    "name": "dataset-name",
    "label_set": ["label1", "label2", "..."]
  },
  "data": {
    "train": [
      {
        "id": "1",
        "text": "Example text...",
        "labels": ["label1", "label2"]
      }
    ],
    "validation": [ ... ],
    "test": [ ... ]
  }
}

The --data_split argument selects which split to use (train, validation, test).

Options for dbc/dbc.py

  • --data_path - path to the input JSON dataset.
  • --data_split - dataset split name (train, validation, test).
  • --model_path - model name or local model path.
  • --model_names_file - file containing model paths, one per line. Example dbc/models.txt:
    BAAI/bge-base-en-v1.5
    avsolatorio/GIST-large-Embedding-v0
    
  • --report_path - output directory for results and reports.
  • --batch_size - batch size for encoding.
  • --similarity_metric - cosine or euclidean.
  • --normalize_sims - normalize similarity scores using min-max normalization.
  • --threshold - float as a fixed similarity threshold for classification.
  • --threshold_per_label - compute per-label thresholds using unsupervised statistical methods (mean, median ...).
  • --iterate_over_thresholds - search for the best global threshold.
  • --iterate_over_thresholds_per_label - search for the best label-specific thresholds.
  • --do_bayesian_optimization - use Bayesian optimization for threshold search.
  • --average_keywords - average label name embeddings with keyword embeddings
  • --log_file - file path to save logs.

Options for upper_bound/main.py

Weights & Biases settings

  • --project_name - W&B project name (default: DBC-upperbound-baseline).
  • --group_name - W&B group name for grouping runs (default: None).

Data settings

  • --data_path - path to the JSON dataset file.
  • --text_col - name of the text column in the dataset (default: text).
  • --label_col - name of the label column in the dataset (default: labels).
  • --sample_data - sample a smaller training subset using stratified label sampling.
  • --training_data_sample - number of training examples to keep, or 0 to use the full set (default: 0).
  • --sample_seed - random seed for data sampling (default: 42).

Model and training settings

  • --model_name - pretrained model name or path (default: bert-base-cased).
  • --train_batch_size - batch size for training and evaluation (default: 8).
  • --seeds - list of random seeds for repeated experiments (default: [0, 1, 2, 3, 4]).
  • --learning_rate - initial learning rate for AdamW optimizer (default: 5e-5).
  • --epochs - number of training epochs (default: 10).
  • --max_length - maximum token length for tokenizer inputs (default: 512).
  • --dropout - dropout probability for the classifier head (default: 0.1).
  • --accumulate_grad - number of gradient accumulation steps (default: 1).
  • --remove_tanh - disable the tanh activation after the dense layer.
  • --early_stopping_patience - early stopping patience on the validation metric (default: 10).
  • --best_metric_name - evaluation metric name used for model checkpointing (default: eval/loss).
  • --max_grad_norm - gradient clipping norm (default: 1.0).
  • --threshold - classification threshold for sigmoid outputs (default: 0.5).
  • --device - device used for model training and evaluation (default: cuda).
  • --output_dir - base directory for training outputs and checkpoints (default: outputs/).

Keyword-assisted label encoding

Provide a JSON keywords file and enable keyword averaging:

python dbc/dbc.py \
  --data_path path/to/data.json \
  --keywords_path path/to/keywords.json \
  --average_keywords

Prototype centroids

Use top similar texts to refine label centroids:

python dbc/dbc.py \
  --use_prototype_centroids \
  --n_sentences_for_centroid 5

Output

Results and reports are stored under the --report_path directory in a dataset-specific subfolder.

Notes

  • dbc/models.txt should list one model path or name per line.
  • If you use Hugging Face models that require authentication, ensure use_auth_token=True or provide credentials.
  • The script automatically selects cuda when available, otherwise it falls back to cpu.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages