Skip to content

feat(retrieval): add ColBERT late-interaction engine (PyLate + GTE-ModernColBERT)#5

Open
chicham wants to merge 2 commits into
mainfrom
feat/retrieval-colbert
Open

feat(retrieval): add ColBERT late-interaction engine (PyLate + GTE-ModernColBERT)#5
chicham wants to merge 2 commits into
mainfrom
feat/retrieval-colbert

Conversation

@chicham

@chicham chicham commented Jun 3, 2026

Copy link
Copy Markdown
Collaborator

What

Adds a third retrieval engine — ColBERT late interaction — to retrieval/retrieval.py, alongside the existing BM25 and SPLADE engines. Same frozen-@beartype-dataclass contract (index()/search(), corpus-position → ids.json mapping, shared docid space), so its run.trec scores against the same qrels.

Model

No finance- or KPI-specialized ColBERT exists on the HF Hub (searched finance/financial/10-K/SEC/KPI across models and the PyLate tag; the only domain fine-tune is a legal one). The default is therefore the strongest general-domain late-interaction model: lightonai/GTE-ModernColBERT-v1 — ModernBERT-based with an 8k context, so it indexes whole long OCR'd pages instead of truncating at the 512-token limit of a classic BERT ColBERT.

Implementation

  • ColbertEngine: a PyLate PLAID index + a shared models.ColBERT encoder. JVM-free (faiss/FastPlaid, C++), consistent with the existing engines' "no JVM" ethos.
  • search() reopens the on-disk index (override=False) — it never re-indexes, matching the bm25/splade load() pattern.
  • Registered in ENGINES + the Method literal; added pylate to the PEP-723 deps.
  • Every index/search parameter is a CLI flag under the colbert subgroup: --model --device --doc_length --query_length --nbits --kmeans_niters --batch_size --show_progress --index_name. (Non-knobs are structural: is_query, documents_ids, override, k.)

Verification

Smoke-tested end-to-end on a 2-report / 190-page sample:

  • index → 190 pages, FastPlaid backend, centroids trained
  • search reopens from disk in a fresh process
  • single query, multi-query --queries_file (order preserved), and --report deep-pool filter (correctly returns only that report's pages)
  • all flags exposed via --help under the subgroup

Notes

  • --doc_length defaults to 2048 (well under GTE-ModernColBERT's 8192 ceiling, so no silent clamp). It directly trades recall vs index size/latency, since ColBERT stores one vector per token.
  • PLAID search-time knobs (ndocs, ncells, centroid_score_threshold) are left at library defaults — can be promoted to flags in a follow-up if wanted.
  • Pre-existing (not in this PR): repeated --query keeps only the last value; use --queries_file for multiple queries.

🤖 Generated with Claude Code

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces ColBERT-based late-interaction retrieval using the pylate library, adding a new ColbertEngine alongside the existing BM25 and SPLADE engines. The feedback identifies a critical runtime issue where models.ColBERT is initialized with unsupported arguments (document_length and query_length), which will cause a TypeError. The reviewer suggests removing these arguments from the constructor and instead passing max_length directly to the .encode() method during both indexing and searching.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread retrieval/retrieval.py
Comment thread retrieval/retrieval.py Outdated
Comment thread retrieval/retrieval.py Outdated
Comment on lines +515 to +524
def search(self, index_dir: Path, texts: list[str], k: Positive) -> Ranked:
from pylate import retrieve

embeddings = self._encoder().encode(
texts, batch_size=self.batch_size, is_query=True,
show_progress_bar=self.show_progress)
index = self._index(index_dir, override=False)
ranked = retrieve.ColBERT(index=index).retrieve(queries_embeddings=embeddings, k=k)
return [[(int(h["id"]), float(h["score"])) for h in ranked[q]]
for q in range(len(texts))]

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Pass max_length=self.query_length to the .encode() method to enforce the query token limit during retrieval.

Suggested change
def search(self, index_dir: Path, texts: list[str], k: Positive) -> Ranked:
from pylate import retrieve
embeddings = self._encoder().encode(
texts, batch_size=self.batch_size, is_query=True,
show_progress_bar=self.show_progress)
index = self._index(index_dir, override=False)
ranked = retrieve.ColBERT(index=index).retrieve(queries_embeddings=embeddings, k=k)
return [[(int(h["id"]), float(h["score"])) for h in ranked[q]]
for q in range(len(texts))]
def search(self, index_dir: Path, texts: list[str], k: Positive) -> Ranked:
from pylate import retrieve
embeddings = self._encoder().encode(
texts, batch_size=self.batch_size, is_query=True,
show_progress_bar=self.show_progress, max_length=self.query_length)
index = self._index(index_dir, override=False)
ranked = retrieve.ColBERT(index=index).retrieve(queries_embeddings=embeddings, k=k)
return [[(int(h["id"]), float(h["score"])) for h in ranked[q]]
for q in range(len(texts))]

@chicham chicham requested a review from CharlesMoslonka June 3, 2026 13:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant