Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ dependencies = [
"pydantic>=2.9.2",
"tqdm>=4.67.1",
"pandera[pandas]>=0.27.0",
"fsspec>=2026.3.0",
]
classifiers = [
"Programming Language :: Python :: 3",
Expand Down
106 changes: 77 additions & 29 deletions src/classifai/indexers/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@
import json
import logging
import os
import shutil
import time
import uuid

import fsspec
import numpy as np
import polars as pl
from tqdm.autonotebook import tqdm
Expand Down Expand Up @@ -122,29 +122,48 @@ def __init__( # noqa: C901, PLR0912, PLR0913, PLR0915
if not isinstance(file_name, str) or not file_name.strip():
raise DataValidationError("file_name must be a non-empty string.", context={"file_name": file_name})

if not os.path.exists(file_name):
# use fsspec to get the filesystem and path for the input
try:
in_fs, in_path = fsspec.core.url_to_fs(file_name)
except Exception as e:
raise ConfigurationError(
"Failed to read input directory with file loader.",
context={
"file_name": file_name,
"cause": str(e),
"cause_type": type(e).__name__,
},
) from e

# check if the file exists in the filesystem
if not in_fs.exists(in_path):
raise DataValidationError("Input file does not exist.", context={"file_name": file_name})

# check that the user has specified the correct datatype
if data_type not in ["csv"]:
raise DataValidationError(
"Unsupported data_type. Choose from ['csv'].",
context={"data_type": data_type},
)

# check that the vectoriser object is an instance of the VectoriserBase class and has a transform method
if not isinstance(vectoriser, VectoriserBase):
raise ConfigurationError(
"Vectoriser must be an instance of Vectoriser(Base) with a .transform() method.",
context={"vectoriser_type": type(vectoriser).__name__},
)

# check that batch_size is a positive integer
if not isinstance(batch_size, int) or batch_size < 1:
raise DataValidationError("batch_size must be an integer >= 1.", context={"batch_size": batch_size})

# check that meta_data is a dict if provided
if meta_data is not None and not isinstance(meta_data, dict):
raise DataValidationError(
"meta_data must be a dict or None.", context={"meta_data_type": type(meta_data).__name__}
)

# check that hooks is a dict if provided
if hooks is not None and not isinstance(hooks, dict):
raise DataValidationError("hooks must be a dict or None.", context={"hooks_type": type(hooks).__name__})

Expand All @@ -168,19 +187,27 @@ def __init__( # noqa: C901, PLR0912, PLR0913, PLR0915
normalized_file_name = os.path.basename(os.path.splitext(self.file_name)[0])
self.output_dir = os.path.join(normalized_file_name)

if os.path.isdir(self.output_dir):
# use fsspec to get the filesystem and path for the output
out_fs, out_path = fsspec.core.url_to_fs(self.output_dir)

# check if the output directory already exists, and handle according to overwrite flag
if out_fs.exists(out_path):
if overwrite:
shutil.rmtree(self.output_dir)
out_fs.rm(out_path, recursive=True)
else:
raise ConfigurationError(
"Output directory already exists. Pass overwrite=True to overwrite the folder.",
context={"output_dir": self.output_dir},
)
os.makedirs(self.output_dir, exist_ok=True)
out_fs.makedirs(out_path, exist_ok=True)
except Exception as e:
raise ConfigurationError(
"Failed to prepare output directory.",
context={"output_dir": self.output_dir},
context={
"output_dir": self.output_dir,
"cause": str(e),
"cause_type": type(e).__name__,
},
) from e

# ---- Build index (wrap every unexpected failure) -> IndexBuildError
Expand Down Expand Up @@ -208,8 +235,13 @@ def __init__( # noqa: C901, PLR0912, PLR0913, PLR0915
self.vector_shape = self.vectors["embeddings"].to_numpy().shape[1]
self.num_vectors = len(self.vectors)

self.vectors.write_parquet(os.path.join(self.output_dir, "vectors.parquet"))
self._save_metadata(os.path.join(self.output_dir, "metadata.json"))
vectors_out_path = os.path.join(self.output_dir, "vectors.parquet")
self.vectors.write_parquet(
vectors_out_path
) # polars handles fsspec filesystems natively, so this will work with local and remote filesystems supported by fsspec

metadata_out_path = os.path.join(self.output_dir, "metadata.json")
self._save_metadata(metadata_out_path)

logging.info("Vector Store created - files saved to %s", self.output_dir)
except ClassifaiError:
Expand Down Expand Up @@ -248,7 +280,9 @@ def _save_metadata(self, path: str):
"meta_data": serializable_column_meta_data,
}

with open(path, "w", encoding="utf-8") as f:
# inside separate function use fsspec again to write the metadata file to support different filesystems
out_fs, out_path = fsspec.core.url_to_fs(path)
with out_fs.open(out_path, "w", encoding="utf-8") as f:
json.dump(metadata, f, indent=4)

except ClassifaiError:
Expand All @@ -275,7 +309,7 @@ def _create_vector_store_index(self): # noqa: C901
# ---- Reading source data (validation/format issues) -> DataValidationError / IndexBuildError
try:
if self.data_type == "csv":
self.vectors = pl.read_csv(
self.vectors = pl.read_csv( # polars handles fsspec filesystems natively
self.file_name,
columns=["label", "text", *self.meta_data.keys()],
dtypes=self.meta_data | {"label": str, "text": str},
Expand Down Expand Up @@ -768,7 +802,21 @@ def from_filespace(cls, folder_path, vectoriser, hooks: dict | None = None): #
if not isinstance(folder_path, str) or not folder_path.strip():
raise DataValidationError("folder_path must be a non-empty string.", context={"folder_path": folder_path})

if not os.path.isdir(folder_path):
# use fsspec to get the filesystem and path for the input
try:
in_fs, in_path = fsspec.core.url_to_fs(folder_path)
except Exception as e:
raise ConfigurationError(
"Failed to read input directory with file loader.",
context={
"folder_path": folder_path,
"cause": str(e),
"cause_type": type(e).__name__,
},
) from e

# check if the folder exists in the filesystem
if not in_fs.isdir(in_path):
raise DataValidationError(
"folder_path must be an existing directory.", context={"folder_path": folder_path}
)
Expand All @@ -783,41 +831,41 @@ def from_filespace(cls, folder_path, vectoriser, hooks: dict | None = None): #
raise DataValidationError("hooks must be a dict or None.", context={"hooks_type": type(hooks).__name__})

# ---- Load metadata -> IndexBuildError
metadata_path = os.path.join(folder_path, "metadata.json")
if not os.path.exists(metadata_path):
metadata_in_path = os.path.join(in_path, "metadata.json")
if not in_fs.exists(metadata_in_path):
raise DataValidationError(
"Metadata file not found in folder_path.",
context={"folder_path": folder_path, "metadata_path": metadata_path},
context={"folder_path": folder_path, "metadata_path": metadata_in_path},
)

try:
with open(metadata_path, encoding="utf-8") as f:
with in_fs.open(metadata_in_path, encoding="utf-8") as f:
metadata = json.load(f)
except Exception as e:
raise IndexBuildError(
"Failed to read metadata.json.",
context={"metadata_path": metadata_path, "cause_type": type(e).__name__, "cause_message": str(e)},
context={"metadata_path": metadata_in_path, "cause_type": type(e).__name__, "cause_message": str(e)},
) from e

# ---- Validate metadata content -> DataValidationError
if not isinstance(metadata, dict):
raise DataValidationError(
"metadata.json did not contain a JSON object.",
context={"metadata_path": metadata_path, "metadata_type": type(metadata).__name__},
context={"metadata_path": metadata_in_path, "metadata_type": type(metadata).__name__},
)

required_keys = ["vectoriser_class", "vector_shape", "num_vectors", "created_at", "meta_data"]
missing = [k for k in required_keys if k not in metadata]
if missing:
raise DataValidationError(
"Metadata file is missing required keys.",
context={"metadata_path": metadata_path, "missing_keys": missing},
context={"metadata_path": metadata_in_path, "missing_keys": missing},
)

if not isinstance(metadata["meta_data"], dict):
raise DataValidationError(
"metadata.meta_data must be an object/dict.",
context={"metadata_path": metadata_path, "meta_data_type": type(metadata["meta_data"]).__name__},
context={"metadata_path": metadata_in_path, "meta_data_type": type(metadata["meta_data"]).__name__},
)

# ---- Deserialize meta_data types safely -> DataValidationError
Expand All @@ -831,30 +879,30 @@ def from_filespace(cls, folder_path, vectoriser, hooks: dict | None = None): #
raise DataValidationError(
"Unable to deserialize metadata column types from metadata in metadata file.",
context={
"metadata_path": metadata_path,
"metadata_path": metadata_in_path,
"meta_data": metadata["meta_data"],
"cause_type": type(e).__name__,
"cause_message": str(e),
},
) from e

# ---- Load parquet -> IndexBuildError / DataValidationError
vectors_path = os.path.join(folder_path, "vectors.parquet")
if not os.path.exists(vectors_path):
vectors_in_path = os.path.join(folder_path, "vectors.parquet")
if not in_fs.exists(vectors_in_path):
raise DataValidationError(
"Vectors Parquet file not found in folder_path.",
context={"folder_path": folder_path, "vectors_path": vectors_path},
context={"folder_path": folder_path, "vectors_path": vectors_in_path},
)

required_columns = ["label", "text", "embeddings", "uuid", *deserialized_column_meta_data.keys()]

try:
df = pl.read_parquet(vectors_path, columns=required_columns)
df = pl.read_parquet(vectors_in_path, columns=required_columns) # polars handles fsspec path natively
except Exception as e:
raise IndexBuildError(
"Failed to read vectors.parquet.",
context={
"vectors_path": vectors_path,
"vectors_path": vectors_in_path,
"cause_type": type(e).__name__,
"cause_message": str(e),
},
Expand All @@ -863,14 +911,14 @@ def from_filespace(cls, folder_path, vectoriser, hooks: dict | None = None): #
if df.is_empty():
raise DataValidationError(
"Vectors Parquet file is empty.",
context={"vectors_path": vectors_path},
context={"vectors_path": vectors_in_path},
)

missing_cols = [c for c in required_columns if c not in df.columns]
if missing_cols:
raise DataValidationError(
"Vectors Parquet file is missing required columns.",
context={"vectors_path": vectors_path, "missing_columns": missing_cols},
context={"vectors_path": vectors_in_path, "missing_columns": missing_cols},
)

# ---- Validate vectoriser class match -> ConfigurationError
Expand Down Expand Up @@ -902,8 +950,8 @@ def from_filespace(cls, folder_path, vectoriser, hooks: dict | None = None): #
"Failed to initialise VectorStore instance from filespace.",
context={
"folder_path": folder_path,
"metadata_path": metadata_path,
"vectors_path": vectors_path,
"metadata_path": metadata_in_path,
"vectors_path": vectors_in_path,
"cause_type": type(e).__name__,
"cause_message": str(e),
},
Expand Down
2 changes: 2 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading