Skip to content
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/classifai/indexers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
VectorStoreSearchOutput,
)
from .main import VectorStore
from .types import metric_settings

__all__ = [
"VectorStore",
Expand All @@ -18,4 +19,5 @@
"VectorStoreReverseSearchOutput",
"VectorStoreSearchInput",
"VectorStoreSearchOutput",
"metric_settings",
]
132 changes: 98 additions & 34 deletions src/classifai/indexers/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,13 @@
import shutil
import time
import uuid
from typing import get_args

import numpy as np
import polars as pl
from tqdm.autonotebook import tqdm

from ..vectorisers import VectoriserBase
Comment thread
rileyok-ons marked this conversation as resolved.
Outdated
from .dataclasses import (
VectorStoreEmbedInput,
VectorStoreEmbedOutput,
Expand All @@ -45,6 +47,7 @@
VectorStoreSearchInput,
VectorStoreSearchOutput,
)
from .types import metric_settings

# Configure logging for your application
logging.basicConfig(level=logging.INFO, format="%(levelname)s - %(message)s")
Expand All @@ -53,13 +56,29 @@
logging.getLogger("urllib3.connectionpool").setLevel(logging.WARNING)


def metricvalid(metric: metric_settings):
"""Test that the given metric is a valid option.

Args:
metric (str): The selected metric for the VectorStore

Raises:
ValueError: If value is not in ["cosine", "dotprod", "cosinel2", "dotprodl2", "cosinel2squared", "dotprodl2squared"]
Comment thread
rileyok-ons marked this conversation as resolved.
Outdated

"""
valid_metrics = get_args(metric_settings)
if metric not in valid_metrics:
raise ValueError(f"The scoring metric input '{metric}' is not in the valid metrics {valid_metrics}")


class VectorStore:
"""A class to model and create 'VectorStore' objects for building and searching vector databases from CSV text files.

Attributes:
file_name (str): the original file with the knowledgebase to build the vector store
data_type (str): the data type of the original file (curently only csv supported)
vectoriser (object): A Vectoriser object from the corresponding ClassifAI Pacakge module
vectoriser (VectoriserBase): A Vectoriser object from the corresponding ClassifAI Pacakge module
scoring_metric(metric_settings): The metric to use for scoring
batch_size (int): the batch size to pass to the vectoriser when embedding
meta_data (dict[str:type]): key-value pairs of metadata to extract from the input file and their correpsonding types
output_dir (str): the path to the output directory where the VectorStore will be saved
Expand All @@ -74,7 +93,8 @@ def __init__( # noqa: PLR0913
self,
file_name,
data_type,
vectoriser,
vectoriser: VectoriserBase,
scoring_metric: metric_settings = "cosine",
batch_size=8,
meta_data=None,
output_dir=None,
Expand All @@ -87,8 +107,9 @@ def __init__( # noqa: PLR0913
Args:
file_name (str): The name of the input CSV file.
data_type (str): The type of input data (currently supports only "csv").
vectoriser (object): The vectoriser object used to transform text into
vectoriser (VectoriserBase): The vectoriser object used to transform text into
vector embeddings.
scoring_metric(metric_settings): The metric to use for scoring
batch_size (int, optional): The batch size for processing the input file and batching to
vectoriser. Defaults to 8.
meta_data (dict, optional): key,value pair metadata column names to extract from the input file and their types.
Expand All @@ -107,6 +128,7 @@ def __init__( # noqa: PLR0913
self.file_name = file_name
self.data_type = data_type
self.vectoriser = vectoriser
self.scoring_metric = scoring_metric
self.batch_size = batch_size
self.meta_data = meta_data if meta_data is not None else {}
self.output_dir = output_dir
Expand All @@ -119,6 +141,9 @@ def __init__( # noqa: PLR0913
if self.data_type not in ["csv"]:
raise ValueError(f"Data type '{self.data_type}' not supported. Choose from ['csv'].")

## validate scoring metric
metricvalid(self.scoring_metric)

if self.output_dir is None:
logging.info("No output directory specified, attempting to use input file name as output folder name.")

Expand Down Expand Up @@ -146,7 +171,7 @@ def __init__( # noqa: PLR0913
os.makedirs(self.output_dir, exist_ok=True)

self._create_vector_store_index()

self._check_norm_vdb()
logging.info("Gathering metadata and saving vector store / metadata...")

self.vector_shape = self.vectors["embeddings"].to_numpy().shape[1]
Expand Down Expand Up @@ -347,7 +372,65 @@ def reverse_search(self, query: VectorStoreReverseSearchInput, n_results=100) ->

return result_df

def search(self, query: VectorStoreSearchInput, n_results=10, batch_size=8) -> VectorStoreSearchOutput:
def _check_norm_vdb(self):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this functionality a lot, but I think it should be the vectoriser's job to output embeddings in the desired form, not the vector store changing them after the fact.
My preference would be to update the Vectorisers' .transform() methods to take an optional (default False) normalise argument, which applies this normalisation if set to True.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I disagree slightly here.

An informed user, who knows their embedding model already outputs normalized embeddings, should then be able to just use the dotproduct metric, which would give them the effects of cosine similarity without having to do the extra norm checks and steps they would need if they set to a cosine metric.

also i think its a good idea to keep the vectorisers pure and not overcomplicate the logic argument logic - whereas the vectorstore responsible for housing, reloading and metric calculations of the vectors probably should be keeping a note on whether the vectors are normalised or not

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An informed user, who knows their embedding model already outputs normalized embeddings, should then be able to just use the dotproduct metric, which would give them the effects of cosine similarity without having to do the extra norm checks and steps they would need if they set to a cosine metric.

I'm not sure I follow what you mean; if a user knows their embedding model already outputs normalised embeddings, they could just not set the normalise flag when creating the Vectoriser.

also i think its a good idea to keep the vectorisers pure and not overcomplicate the logic argument logic

This is an operation that happens directly on the vectors, a step before any use in a vector store or scoring. I think it fits in well with the task of the Vectoriser, and avoids the other issues you discussed - such as any need to duplicate vectors in the vector store and set/read metadata flags about whether the vector store is normalised.

Lets talk about it in our call later 👍

"""Normalise Vdb if using cosine similarity."""
if "cosine" in self.scoring_metric:
embeddings = self.vectors["embeddings"].to_numpy()
embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)

self.vectors.with_columns(pl.Series("embeddings", embeddings))

def score(
self, query: np.ndarray, n_results: int, query_ids_batch: list[str], query_text_batch: list[str]
) -> tuple[pl.DataFrame, np.ndarray]:
"""Perform Scoring and return Top Values.

Args:
query(np.ndarray): query for search
n_results(int): number of results to return
query_ids_batch(list[str]): ids of query batch
query_text_batch(list[str]): source text of query batch

Returns:
pl.DataFrame: The Polars DataFrame containing the top n most similar results to the query
"""
if self.scoring_metric.startswith("cosine"):
query = query / np.linalg.norm(query, axis=1, keepdims=True)

result = query @ self.vectors["embeddings"].to_numpy().T

# Get the top n_results indices for each query in the batch
idx = np.argpartition(result, -n_results, axis=1)[:, -n_results:]

# Sort top n_results indices by their scores in descending order
idx_sorted = np.zeros_like(idx)
scores = np.zeros_like(idx, dtype=float)

for j in range(idx.shape[0]):
row_scores = result[j, idx[j]]
sorted_indices = np.argsort(row_scores)[::-1]
idx_sorted[j] = idx[j, sorted_indices]
scores[j] = row_scores[sorted_indices]

if "l2" in self.scoring_metric:
scores = 2 * (1 - scores)
if not self.scoring_metric.endswith("squared"):
scores = np.sqrt(scores)

# Build a DataFrame for the current batch results
result_df = pl.DataFrame(
{
"query_id": np.repeat(query_ids_batch, n_results),
"query_text": np.repeat(query_text_batch, n_results),
"rank": np.tile(np.arange(n_results), len(query_text_batch)),
"score": scores.flatten(),
}
)
return result_df, idx_sorted

def search(
self, query: VectorStoreSearchInput, n_results: int = 10, batch_size: int = 8
) -> VectorStoreSearchOutput:
"""Searches the vector store using queries from a VectorStoreSearchInput object and returns
ranked results in VectorStoreSearchOutput object. In batches, converts users text queries into vector embeddings,
computes cosine similarity with stored document vectors, and retrieves the top results.
Expand Down Expand Up @@ -386,35 +469,11 @@ def search(self, query: VectorStoreSearchInput, n_results=10, batch_size=8) -> V
# Get the current batch of queries
query_text_batch = query.query.to_list()[i : i + batch_size]
query_ids_batch = query.id.to_list()[i : i + batch_size]

# Convert the current batch of queries to vectors
query_vectors = self.vectoriser.transform(query_text_batch)

# Compute cosine similarity between the query batch and document vectors
cosine = query_vectors @ self.vectors["embeddings"].to_numpy().T

# Get the top n_results indices for each query in the batch
idx = np.argpartition(cosine, -n_results, axis=1)[:, -n_results:]

# Sort top n_results indices by their scores in descending order
idx_sorted = np.zeros_like(idx)
scores = np.zeros_like(idx, dtype=float)

for j in range(idx.shape[0]):
row_scores = cosine[j, idx[j]]
sorted_indices = np.argsort(row_scores)[::-1]
idx_sorted[j] = idx[j, sorted_indices]
scores[j] = row_scores[sorted_indices]

# Build a DataFrame for the current batch results
result_df = pl.DataFrame(
{
"query_id": np.repeat(query_ids_batch, n_results),
"query_text": np.repeat(query_text_batch, n_results),
"rank": np.tile(np.arange(n_results), len(query_text_batch)),
"score": scores.flatten(),
}
)
# perform scoring and return frame and ids
result_df, idx_sorted = self.score(query_vectors, n_results, query_ids_batch, query_text_batch)

# Get the vector store results for the current batch
ranked_docs = self.vectors[idx_sorted.flatten().tolist()].select(["id", "text", *self.meta_data.keys()])
Expand Down Expand Up @@ -461,7 +520,7 @@ def search(self, query: VectorStoreSearchInput, n_results=10, batch_size=8) -> V
return result_df

@classmethod
def from_filespace(cls, folder_path, vectoriser):
def from_filespace(cls, folder_path, vectoriser: VectoriserBase, scoring_metric: metric_settings = "cosine"):
"""Creates a `VectorStore` instance from stored metadata and Parquet files.
This method reads the metadata and vectors from the specified folder,
validates the contents, and initializes a `VectorStore` object with the
Expand All @@ -474,7 +533,8 @@ def from_filespace(cls, folder_path, vectoriser):

Args:
folder_path (str): The folder path containing the metadata and Parquet files.
vectoriser (object): The vectoriser object used to transform text into vector embeddings.
vectoriser (VectoriserBase): The vectoriser object used to transform text into vector embeddings.
scoring_metric(metric_settings): The metric to use for scoring

Returns:
VectorStore: An instance of the `VectorStore` class.
Expand All @@ -491,6 +551,9 @@ def from_filespace(cls, folder_path, vectoriser):
with open(metadata_path, encoding="utf-8") as f:
metadata = json.load(f)

## validate scoring metric
metricvalid(scoring_metric)

# check that the correct keys exist in metadata
required_keys = [
"vectoriser_class",
Expand Down Expand Up @@ -544,12 +607,13 @@ def from_filespace(cls, folder_path, vectoriser):
vector_store.file_name = None
vector_store.data_type = None
vector_store.vectoriser = vectoriser
vector_store.scoring_metric = scoring_metric
vector_store.batch_size = None
vector_store.meta_data = deserialized_column_meta_data
vector_store.vectors = df
vector_store.vector_shape = metadata["vector_shape"]
vector_store.num_vectors = metadata["num_vectors"]
vector_store.vectoriser_class = metadata["vectoriser_class"]
vector_store.hooks = {}

vector_store._check_norm_vdb()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd favour a 'normalise once' approach -

  1. when the VDB is being constructed by _create_vector_store_index(), it checks if the user specified a metric that requires normalised vectors and normalises the created collection and then saves them to the polars df/parquet file.
  2. Then we'd record the 'metric' used in the metadata file
  3. when the parquet is loaded back in with from_filespace() we know to use the appropriate metric already as its stored in the metadata file and theres no need to redo the normalisation

so i'd also take the 'metric_setting' parameter out of the class method from_filespace() and rely just on the metadata file.

this would mean less operations every time we load the vectorstore in, after initial creation - potentially at the cost of losing the magnitude information and not being able to get it back without running the build step again with a different metric

Copy link
Copy Markdown
Contributor Author

@rileyok-ons rileyok-ons Jan 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved by adding normalize meta field, if choosing cosine with un-normed will norm but will warn user

return vector_store
5 changes: 5 additions & 0 deletions src/classifai/indexers/types.py
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if we scrapped all 6 of these and just had ['IP', 'L2'].

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think L2 squared and IP squared should be a downstream postprocessing hook as its just a common scoring operation

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seen this suggested previously, if we want this can sort

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm happy with that plan 👍

I'd like if we added an example to one of the notebooks showing a way of wrapping one of the Vectorisers to add normalisation though, to tide users over until we properly offer normalisation as an option.

I can add that to this PR tomorrow, if nobody objects.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if we scrapped all 6 of these and just had ['IP', 'L2'].

Would you be okay with renaming 'IP'->'dot' for this? I think 'dot' would be more easily understood by users via docstrings without needing to explore documentation etc. to find out / confirm IP = Inner Product

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest we just leave it - if users really really want it they can make their own custom vectoriser that wraps the hugging face vectoriser - but if you really wanted to you could update the custom_vectoriser demo notebook to have a section on this and show how to do it to the hugging face class?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do already have one user group requesting this functionality (and currently using a custom wrapped HF Vectoriser to achieve it), so I think it is worth adding to the docs.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So... they did use a custom vectoriser? 😀

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what the correct answer is, we definitely don't want to be adding a variant of every Vectoriser called VectoriserX_normalised, or a wrapper for each class. Maybe 1 utility wrapper that wraps round all our Vectoriser class imps.... but what is the benefit/tradeoffs of that new class, which we'd have to add more docs and ensure it's compatible forever, versus guiding users in how to do it with our existing custom vectoriser / base class architecture.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that I made for them as a one-off solution as the package doesn't yet offer that - I'm saying it would be useful to have that knowledge made accessible in the documentation for other users

Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from typing import Literal, TypeAlias

metric_settings: TypeAlias = Literal[
Comment thread
rileyok-ons marked this conversation as resolved.
Outdated
"cosine", "dotprod", "cosinel2", "dotprodl2", "cosinel2squared", "dotprodl2squared"
Comment thread
rileyok-ons marked this conversation as resolved.
Outdated
]
Loading