From 6e5c110fac44e9a7343cc19a65e4a1ae068fa58b Mon Sep 17 00:00:00 2001 From: frayle-ons <194791647+frayle-ons@users.noreply.github.com> Date: Thu, 14 May 2026 14:36:22 +0100 Subject: [PATCH] reworked filespace interaction to use fsspec for generalised file loading and saving --- pyproject.toml | 1 + src/classifai/indexers/main.py | 106 ++++++++++++++++++++++++--------- uv.lock | 2 + 3 files changed, 80 insertions(+), 29 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index bb5b23c..08d73a4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,7 @@ dependencies = [ "pydantic>=2.9.2", "tqdm>=4.67.1", "pandera[pandas]>=0.27.0", + "fsspec>=2026.3.0", ] classifiers = [ "Programming Language :: Python :: 3", diff --git a/src/classifai/indexers/main.py b/src/classifai/indexers/main.py index 9a92717..559f457 100644 --- a/src/classifai/indexers/main.py +++ b/src/classifai/indexers/main.py @@ -31,10 +31,10 @@ import json import logging import os -import shutil import time import uuid +import fsspec import numpy as np import polars as pl from tqdm.autonotebook import tqdm @@ -122,29 +122,48 @@ def __init__( # noqa: C901, PLR0912, PLR0913, PLR0915 if not isinstance(file_name, str) or not file_name.strip(): raise DataValidationError("file_name must be a non-empty string.", context={"file_name": file_name}) - if not os.path.exists(file_name): + # use fsspec to get the filesystem and path for the input + try: + in_fs, in_path = fsspec.core.url_to_fs(file_name) + except Exception as e: + raise ConfigurationError( + "Failed to read input directory with file loader.", + context={ + "file_name": file_name, + "cause": str(e), + "cause_type": type(e).__name__, + }, + ) from e + + # check if the file exists in the filesystem + if not in_fs.exists(in_path): raise DataValidationError("Input file does not exist.", context={"file_name": file_name}) + # check that the user has specified the correct datatype if data_type not in ["csv"]: raise DataValidationError( "Unsupported data_type. Choose from ['csv'].", context={"data_type": data_type}, ) + # check that the vectoriser object is an instance of the VectoriserBase class and has a transform method if not isinstance(vectoriser, VectoriserBase): raise ConfigurationError( "Vectoriser must be an instance of Vectoriser(Base) with a .transform() method.", context={"vectoriser_type": type(vectoriser).__name__}, ) + # check that batch_size is a positive integer if not isinstance(batch_size, int) or batch_size < 1: raise DataValidationError("batch_size must be an integer >= 1.", context={"batch_size": batch_size}) + # check that meta_data is a dict if provided if meta_data is not None and not isinstance(meta_data, dict): raise DataValidationError( "meta_data must be a dict or None.", context={"meta_data_type": type(meta_data).__name__} ) + # check that hooks is a dict if provided if hooks is not None and not isinstance(hooks, dict): raise DataValidationError("hooks must be a dict or None.", context={"hooks_type": type(hooks).__name__}) @@ -168,19 +187,27 @@ def __init__( # noqa: C901, PLR0912, PLR0913, PLR0915 normalized_file_name = os.path.basename(os.path.splitext(self.file_name)[0]) self.output_dir = os.path.join(normalized_file_name) - if os.path.isdir(self.output_dir): + # use fsspec to get the filesystem and path for the output + out_fs, out_path = fsspec.core.url_to_fs(self.output_dir) + + # check if the output directory already exists, and handle according to overwrite flag + if out_fs.exists(out_path): if overwrite: - shutil.rmtree(self.output_dir) + out_fs.rm(out_path, recursive=True) else: raise ConfigurationError( "Output directory already exists. Pass overwrite=True to overwrite the folder.", context={"output_dir": self.output_dir}, ) - os.makedirs(self.output_dir, exist_ok=True) + out_fs.makedirs(out_path, exist_ok=True) except Exception as e: raise ConfigurationError( "Failed to prepare output directory.", - context={"output_dir": self.output_dir}, + context={ + "output_dir": self.output_dir, + "cause": str(e), + "cause_type": type(e).__name__, + }, ) from e # ---- Build index (wrap every unexpected failure) -> IndexBuildError @@ -208,8 +235,13 @@ def __init__( # noqa: C901, PLR0912, PLR0913, PLR0915 self.vector_shape = self.vectors["embeddings"].to_numpy().shape[1] self.num_vectors = len(self.vectors) - self.vectors.write_parquet(os.path.join(self.output_dir, "vectors.parquet")) - self._save_metadata(os.path.join(self.output_dir, "metadata.json")) + vectors_out_path = os.path.join(self.output_dir, "vectors.parquet") + self.vectors.write_parquet( + vectors_out_path + ) # polars handles fsspec filesystems natively, so this will work with local and remote filesystems supported by fsspec + + metadata_out_path = os.path.join(self.output_dir, "metadata.json") + self._save_metadata(metadata_out_path) logging.info("Vector Store created - files saved to %s", self.output_dir) except ClassifaiError: @@ -248,7 +280,9 @@ def _save_metadata(self, path: str): "meta_data": serializable_column_meta_data, } - with open(path, "w", encoding="utf-8") as f: + # inside separate function use fsspec again to write the metadata file to support different filesystems + out_fs, out_path = fsspec.core.url_to_fs(path) + with out_fs.open(out_path, "w", encoding="utf-8") as f: json.dump(metadata, f, indent=4) except ClassifaiError: @@ -275,7 +309,7 @@ def _create_vector_store_index(self): # noqa: C901 # ---- Reading source data (validation/format issues) -> DataValidationError / IndexBuildError try: if self.data_type == "csv": - self.vectors = pl.read_csv( + self.vectors = pl.read_csv( # polars handles fsspec filesystems natively self.file_name, columns=["label", "text", *self.meta_data.keys()], dtypes=self.meta_data | {"label": str, "text": str}, @@ -768,7 +802,21 @@ def from_filespace(cls, folder_path, vectoriser, hooks: dict | None = None): # if not isinstance(folder_path, str) or not folder_path.strip(): raise DataValidationError("folder_path must be a non-empty string.", context={"folder_path": folder_path}) - if not os.path.isdir(folder_path): + # use fsspec to get the filesystem and path for the input + try: + in_fs, in_path = fsspec.core.url_to_fs(folder_path) + except Exception as e: + raise ConfigurationError( + "Failed to read input directory with file loader.", + context={ + "folder_path": folder_path, + "cause": str(e), + "cause_type": type(e).__name__, + }, + ) from e + + # check if the folder exists in the filesystem + if not in_fs.isdir(in_path): raise DataValidationError( "folder_path must be an existing directory.", context={"folder_path": folder_path} ) @@ -783,27 +831,27 @@ def from_filespace(cls, folder_path, vectoriser, hooks: dict | None = None): # raise DataValidationError("hooks must be a dict or None.", context={"hooks_type": type(hooks).__name__}) # ---- Load metadata -> IndexBuildError - metadata_path = os.path.join(folder_path, "metadata.json") - if not os.path.exists(metadata_path): + metadata_in_path = os.path.join(in_path, "metadata.json") + if not in_fs.exists(metadata_in_path): raise DataValidationError( "Metadata file not found in folder_path.", - context={"folder_path": folder_path, "metadata_path": metadata_path}, + context={"folder_path": folder_path, "metadata_path": metadata_in_path}, ) try: - with open(metadata_path, encoding="utf-8") as f: + with in_fs.open(metadata_in_path, encoding="utf-8") as f: metadata = json.load(f) except Exception as e: raise IndexBuildError( "Failed to read metadata.json.", - context={"metadata_path": metadata_path, "cause_type": type(e).__name__, "cause_message": str(e)}, + context={"metadata_path": metadata_in_path, "cause_type": type(e).__name__, "cause_message": str(e)}, ) from e # ---- Validate metadata content -> DataValidationError if not isinstance(metadata, dict): raise DataValidationError( "metadata.json did not contain a JSON object.", - context={"metadata_path": metadata_path, "metadata_type": type(metadata).__name__}, + context={"metadata_path": metadata_in_path, "metadata_type": type(metadata).__name__}, ) required_keys = ["vectoriser_class", "vector_shape", "num_vectors", "created_at", "meta_data"] @@ -811,13 +859,13 @@ def from_filespace(cls, folder_path, vectoriser, hooks: dict | None = None): # if missing: raise DataValidationError( "Metadata file is missing required keys.", - context={"metadata_path": metadata_path, "missing_keys": missing}, + context={"metadata_path": metadata_in_path, "missing_keys": missing}, ) if not isinstance(metadata["meta_data"], dict): raise DataValidationError( "metadata.meta_data must be an object/dict.", - context={"metadata_path": metadata_path, "meta_data_type": type(metadata["meta_data"]).__name__}, + context={"metadata_path": metadata_in_path, "meta_data_type": type(metadata["meta_data"]).__name__}, ) # ---- Deserialize meta_data types safely -> DataValidationError @@ -831,7 +879,7 @@ def from_filespace(cls, folder_path, vectoriser, hooks: dict | None = None): # raise DataValidationError( "Unable to deserialize metadata column types from metadata in metadata file.", context={ - "metadata_path": metadata_path, + "metadata_path": metadata_in_path, "meta_data": metadata["meta_data"], "cause_type": type(e).__name__, "cause_message": str(e), @@ -839,22 +887,22 @@ def from_filespace(cls, folder_path, vectoriser, hooks: dict | None = None): # ) from e # ---- Load parquet -> IndexBuildError / DataValidationError - vectors_path = os.path.join(folder_path, "vectors.parquet") - if not os.path.exists(vectors_path): + vectors_in_path = os.path.join(folder_path, "vectors.parquet") + if not in_fs.exists(vectors_in_path): raise DataValidationError( "Vectors Parquet file not found in folder_path.", - context={"folder_path": folder_path, "vectors_path": vectors_path}, + context={"folder_path": folder_path, "vectors_path": vectors_in_path}, ) required_columns = ["label", "text", "embeddings", "uuid", *deserialized_column_meta_data.keys()] try: - df = pl.read_parquet(vectors_path, columns=required_columns) + df = pl.read_parquet(vectors_in_path, columns=required_columns) # polars handles fsspec path natively except Exception as e: raise IndexBuildError( "Failed to read vectors.parquet.", context={ - "vectors_path": vectors_path, + "vectors_path": vectors_in_path, "cause_type": type(e).__name__, "cause_message": str(e), }, @@ -863,14 +911,14 @@ def from_filespace(cls, folder_path, vectoriser, hooks: dict | None = None): # if df.is_empty(): raise DataValidationError( "Vectors Parquet file is empty.", - context={"vectors_path": vectors_path}, + context={"vectors_path": vectors_in_path}, ) missing_cols = [c for c in required_columns if c not in df.columns] if missing_cols: raise DataValidationError( "Vectors Parquet file is missing required columns.", - context={"vectors_path": vectors_path, "missing_columns": missing_cols}, + context={"vectors_path": vectors_in_path, "missing_columns": missing_cols}, ) # ---- Validate vectoriser class match -> ConfigurationError @@ -902,8 +950,8 @@ def from_filespace(cls, folder_path, vectoriser, hooks: dict | None = None): # "Failed to initialise VectorStore instance from filespace.", context={ "folder_path": folder_path, - "metadata_path": metadata_path, - "vectors_path": vectors_path, + "metadata_path": metadata_in_path, + "vectors_path": vectors_in_path, "cause_type": type(e).__name__, "cause_message": str(e), }, diff --git a/uv.lock b/uv.lock index 868b834..98bfe02 100644 --- a/uv.lock +++ b/uv.lock @@ -387,6 +387,7 @@ version = "1.0.0" source = { editable = "." } dependencies = [ { name = "fastapi", extra = ["standard"] }, + { name = "fsspec" }, { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "numpy", version = "2.4.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "pandas", version = "2.3.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, @@ -442,6 +443,7 @@ test = [ requires-dist = [ { name = "classifai", extras = ["huggingface", "gcp", "ollama"], marker = "extra == 'all'" }, { name = "fastapi", extras = ["standard"], specifier = ">=0.115.11" }, + { name = "fsspec", specifier = ">=2026.3.0" }, { name = "google-genai", marker = "extra == 'gcp'", specifier = ">=1.68.0" }, { name = "numpy", specifier = ">=2.2.4" }, { name = "ollama", marker = "extra == 'ollama'", specifier = ">=0.5.1" },