From ad7130f8be5e380dea5f73f7b4bee9f4bdc49bb0 Mon Sep 17 00:00:00 2001 From: Alexander Serdyukov Date: Wed, 1 Apr 2026 02:41:58 +0400 Subject: [PATCH 1/3] Initial version of HiCT_JVM API-based library for Python --- README.md | 24 + doc/jvm_api_v1.md | 85 ++++ hict_jvm_api/__init__.py | 39 ++ hict_jvm_api/client.py | 534 +++++++++++++++++++++ hict_jvm_api/dataloader.py | 119 +++++ hict_jvm_api/exceptions.py | 20 + hict_jvm_api/models.py | 175 +++++++ hict_jvm_api/py.typed | 0 hict_jvm_api/units.py | 226 +++++++++ notebooks/jvm_api_pytorch_dataloader.ipynb | 75 +++ notebooks/jvm_api_quickstart.ipynb | 83 ++++ pyproject.toml | 36 ++ pytest.ini | 3 + requirements.txt | 2 + run_jvm_api_optional_data_tests.sh | 14 + run_jvm_api_tests.sh | 7 + run_tests.sh | 3 +- tests_jvm_api/test_client.py | 165 +++++++ tests_jvm_api/test_dataloader.py | 53 ++ tests_jvm_api/test_optional_integration.py | 93 ++++ tests_jvm_api/test_units.py | 82 ++++ 21 files changed, 1837 insertions(+), 1 deletion(-) create mode 100644 doc/jvm_api_v1.md create mode 100644 hict_jvm_api/__init__.py create mode 100644 hict_jvm_api/client.py create mode 100644 hict_jvm_api/dataloader.py create mode 100644 hict_jvm_api/exceptions.py create mode 100644 hict_jvm_api/models.py create mode 100644 hict_jvm_api/py.typed create mode 100644 hict_jvm_api/units.py create mode 100644 notebooks/jvm_api_pytorch_dataloader.ipynb create mode 100644 notebooks/jvm_api_quickstart.ipynb create mode 100644 pyproject.toml create mode 100644 pytest.ini create mode 100755 run_jvm_api_optional_data_tests.sh create mode 100755 run_jvm_api_tests.sh create mode 100644 tests_jvm_api/test_client.py create mode 100644 tests_jvm_api/test_dataloader.py create mode 100644 tests_jvm_api/test_optional_integration.py create mode 100644 tests_jvm_api/test_units.py diff --git a/README.md b/README.md index 9a55643..b3139a0 100644 --- a/README.md +++ b/README.md @@ -28,3 +28,27 @@ This library has ContactMatrixFacet as the main interaction point. It hides all ## Building from source You can run `rebuild.sh` script in source directory which will perform static type-checking of module using mypy (it may produce error messages), build library from source and reinstall it, deleting current version. + +## JVM API client (v1) + +This branch also contains `hict_jvm_api`, a Python client for controlling a running `HiCT_JVM` backend. + +### Key capabilities +* Open/attach/close sessions in HiCT_JVM; +* Fetch Hi-C map regions as numpy RGBA arrays (`PNG_BY_PIXELS`) for ML pipelines; +* Run scaffolding operations via API (reverse/move/split/group/ungroup/debris); +* Run converter jobs (single and batch) and monitor status; +* Link FASTA, export FASTA selections/assembly, import/export AGP; +* Convert coordinates between BP/BINS/PIXELS with hidden-contig awareness. + +### Quick links +* API docs: [`doc/jvm_api_v1.md`](./doc/jvm_api_v1.md) +* Notebooks: + * [`notebooks/jvm_api_quickstart.ipynb`](./notebooks/jvm_api_quickstart.ipynb) + * [`notebooks/jvm_api_pytorch_dataloader.ipynb`](./notebooks/jvm_api_pytorch_dataloader.ipynb) + +### Tests +* Unit tests (mocked HTTP transport): + * `./run_jvm_api_tests.sh` +* Optional integration tests against a real running HiCT_JVM: + * `./run_jvm_api_optional_data_tests.sh` diff --git a/doc/jvm_api_v1.md b/doc/jvm_api_v1.md new file mode 100644 index 0000000..e397473 --- /dev/null +++ b/doc/jvm_api_v1.md @@ -0,0 +1,85 @@ +# HiCT JVM API v1 Python Library + +`hict_jvm_api` is a Python package that uses a running `HiCT_JVM` server as the execution backend. + +## Design goals + +- Keep heavy matrix/assembly logic in JVM. +- Expose a Python API suitable for bioinformatics scripting and ML pipelines. +- Keep I/O efficient through pooled HTTP connections and direct region extraction (`PNG_BY_PIXELS`). + +## Main classes + +- `hict_jvm_api.client.HiCTJVMClient` + - session management (`open_file`, `attach_session`, `close_session`) + - map region fetch (`fetch_region_pixels`, `fetch_tile_png`, `fetch_tile_with_ranges`) + - scaffolding operations (`reverse_selection_range`, `move_selection_range`, `split_contig_at_bin`, etc.) + - conversion jobs (`start_conversion_job`, `start_batch_conversion_jobs`, polling helpers) + - FASTA/AGP operations (`link_fasta`, `export_fasta_for_selection`, `load_agp`) +- `hict_jvm_api.units.UnitConverter` + - fast local conversion between `BP`, `BINS`, `PIXELS` at a selected resolution, + respecting hidden contigs via `contigPresenceAtResolution`. +- `hict_jvm_api.dataloader.HiCTRegionDataset` + - PyTorch-friendly random-access dataset fetching regions from a live session. + +## Installation + +From repository root: + +```bash +pip install -e . +``` + +With PyTorch extras: + +```bash +pip install -e '.[torch]' +``` + +## Quick start + +```python +from hict_jvm_api import HiCTJVMClient, Unit + +client = HiCTJVMClient("http://localhost:5001") +open_resp = client.open_file("build/quad/combined_ind2_4DN.hict.hdf5") + +resolution = open_resp.resolutions[0] # coarse level +img = client.fetch_region_pixels( + start_row_px=0, + start_col_px=0, + rows=256, + cols=256, + bp_resolution=resolution, +) + +# Convert BP -> visible pixel coordinate +px = client.convert_units(1_000_000, from_unit=Unit.BP, to_unit=Unit.PIXELS, bp_resolution=resolution) +``` + +## Testing + +- Unit tests (mocked transport): + +```bash +./run_jvm_api_tests.sh +``` + +- Optional integration tests against a real server and optional files: + +```bash +export HICT_JVM_API_BASE_URL=http://localhost:5001 +export HICT_DATASET_FILE=build/quad/combined_ind2_4DN.hict.hdf5 +# Optional: +# export HICT_FASTA_FILE=build/quad/quad_combined_ind2.fasta +# export HICT_AGP_FILE=build/quad/some.agp +# export HICT_JVM_API_ALLOW_MUTATION=true +./run_jvm_api_optional_data_tests.sh +``` + +## Notebook examples + +See: + +- `notebooks/jvm_api_quickstart.ipynb` +- `notebooks/jvm_api_pytorch_dataloader.ipynb` diff --git a/hict_jvm_api/__init__.py b/hict_jvm_api/__init__.py new file mode 100644 index 0000000..bede2f7 --- /dev/null +++ b/hict_jvm_api/__init__.py @@ -0,0 +1,39 @@ +"""HiCT JVM API client package. + +This package provides a Python interface for communicating with a running +HiCT_JVM API server. It includes: + +- :class:`hict_jvm_api.client.HiCTJVMClient` for server operations +- :class:`hict_jvm_api.units.UnitConverter` for fast BP/BINS/PIXELS conversion +- :class:`hict_jvm_api.dataloader.HiCTRegionDataset` for PyTorch data loading +""" + +from .client import HiCTJVMClient +from .dataloader import HiCTRegionDataset +from .exceptions import HiCTAPIError, HiCTClientStateError, HiddenCoordinateError +from .models import ( + AssemblyInfo, + ContigDescriptor, + OpenFileResponse, + ScaffoldDescriptor, + TileRanges, + TileWithRanges, + Unit, +) +from .units import UnitConverter + +__all__ = [ + "HiCTJVMClient", + "HiCTRegionDataset", + "HiCTAPIError", + "HiCTClientStateError", + "HiddenCoordinateError", + "AssemblyInfo", + "ContigDescriptor", + "OpenFileResponse", + "ScaffoldDescriptor", + "TileRanges", + "TileWithRanges", + "Unit", + "UnitConverter", +] diff --git a/hict_jvm_api/client.py b/hict_jvm_api/client.py new file mode 100644 index 0000000..d8fe6a3 --- /dev/null +++ b/hict_jvm_api/client.py @@ -0,0 +1,534 @@ +"""High-level Python client for HiCT_JVM HTTP API.""" + +from __future__ import annotations + +import base64 +import time +from io import BytesIO +from typing import Any, Dict, Mapping, Optional, Sequence, Union + +import numpy as np +import requests +from PIL import Image + +from .exceptions import HiCTAPIError, HiCTClientStateError +from .models import OpenFileResponse, SecondarySourceStatus, TileRanges, TileWithRanges, Unit +from .units import UnitConverter + + +class HiCTJVMClient: + """Client for communicating with a running HiCT_JVM API server. + + The client uses a persistent HTTP session (`requests.Session`) for connection + reuse and better throughput in region-heavy workloads. + """ + + def __init__( + self, + base_url: str, + timeout: float = 30.0, + session: Optional[Any] = None, + ): + if not base_url: + raise ValueError("base_url must be provided") + self.base_url = base_url.rstrip("/") + self.timeout = timeout + self._own_session = session is None + self._http = session if session is not None else requests.Session() + self._open_file_response: Optional[OpenFileResponse] = None + + def close(self) -> None: + """Close the underlying HTTP session if owned by this client.""" + if self._own_session and hasattr(self._http, "close"): + self._http.close() + + def __enter__(self) -> "HiCTJVMClient": + return self + + def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> None: + self.close() + + def _url(self, path: str) -> str: + if not path.startswith("/"): + path = "/" + path + return self.base_url + path + + @staticmethod + def _raise_if_error_payload(payload: Any, status_code: Optional[int] = None) -> None: + if isinstance(payload, Mapping): + error = payload.get("error") + if isinstance(error, str) and error.strip(): + raise HiCTAPIError(error, status_code=status_code, payload=payload) + + def _request( + self, + method: str, + path: str, + *, + params: Optional[Mapping[str, Any]] = None, + payload: Optional[Mapping[str, Any]] = None, + ) -> requests.Response: + response = self._http.request( + method, + self._url(path), + params=dict(params or {}), + json=dict(payload or {}), + timeout=self.timeout, + ) + if response.status_code >= 400: + message = response.text + parsed_payload: Any = None + try: + parsed_payload = response.json() + if isinstance(parsed_payload, Mapping) and isinstance(parsed_payload.get("error"), str): + message = str(parsed_payload["error"]) + except ValueError: + parsed_payload = response.text + raise HiCTAPIError(message, status_code=response.status_code, payload=parsed_payload) + return response + + def _request_json( + self, + method: str, + path: str, + *, + params: Optional[Mapping[str, Any]] = None, + payload: Optional[Mapping[str, Any]] = None, + ) -> Any: + response = self._request(method, path, params=params, payload=payload) + try: + data = response.json() + except ValueError as exc: + raise HiCTAPIError( + "Expected JSON response from %s %s" % (method, path), + status_code=response.status_code, + payload=response.text, + ) from exc + self._raise_if_error_payload(data, status_code=response.status_code) + return data + + def _request_text(self, method: str, path: str, payload: Optional[Mapping[str, Any]] = None) -> str: + response = self._request(method, path, payload=payload) + return response.text + + @staticmethod + def _decode_png_bytes(png_bytes: bytes) -> np.ndarray: + with Image.open(BytesIO(png_bytes)) as image: + rgba = image.convert("RGBA") + return np.asarray(rgba, dtype=np.uint8) + + @staticmethod + def _decode_data_url_png(image_data_url: str) -> np.ndarray: + prefix = "data:image/png;base64," + if not image_data_url.startswith(prefix): + raise HiCTAPIError("Unexpected tile image payload format", payload=image_data_url) + encoded = image_data_url[len(prefix) :] + return HiCTJVMClient._decode_png_bytes(base64.b64decode(encoded)) + + def require_open_file_response(self) -> OpenFileResponse: + """Return cached open-file metadata or raise when session is not attached.""" + if self._open_file_response is None: + raise HiCTClientStateError( + "No opened session metadata is cached. Call open_file(...) or attach_session() first." + ) + return self._open_file_response + + # ------------------------- + # Session and file control + # ------------------------- + def version(self) -> Mapping[str, Any]: + """Get HiCT_JVM version payload from ``GET /version``.""" + payload = self._request_json("GET", "/version") + return payload + + def list_files(self) -> Sequence[str]: + payload = self._request_json("POST", "/list_files", payload={}) + return [str(item) for item in payload] + + def list_files_detailed(self) -> Sequence[Mapping[str, Any]]: + payload = self._request_json("POST", "/list_files_detailed", payload={}) + return list(payload) + + def list_coolers(self) -> Sequence[str]: + payload = self._request_json("POST", "/list_coolers", payload={}) + return [str(item) for item in payload] + + def list_fasta_files(self) -> Sequence[str]: + payload = self._request_json("POST", "/list_fasta_files", payload={}) + return [str(item) for item in payload] + + def list_agp_files(self) -> Sequence[str]: + payload = self._request_json("POST", "/list_agp_files", payload={}) + return [str(item) for item in payload] + + def open_file(self, filename: str, fasta_filename: Optional[str] = None) -> OpenFileResponse: + body: Dict[str, Any] = {"filename": filename} + if fasta_filename: + body["fastaFilename"] = fasta_filename + payload = self._request_json("POST", "/open", payload=body) + parsed = OpenFileResponse.from_json(payload) + self._open_file_response = parsed + return parsed + + def open_progress(self) -> Mapping[str, Any]: + return self._request_json("POST", "/open_progress", payload={}) + + def attach_session(self) -> Mapping[str, Any]: + payload = self._request_json("POST", "/attach", payload={}) + open_response_raw = payload.get("openFileResponse") if isinstance(payload, Mapping) else None + if isinstance(open_response_raw, Mapping): + self._open_file_response = OpenFileResponse.from_json(open_response_raw) + return payload + + def close_session(self) -> Mapping[str, Any]: + payload = self._request_json("POST", "/close", payload={}) + self._open_file_response = None + return payload + + # ------------------------- + # Secondary source control + # ------------------------- + def get_secondary_status(self) -> SecondarySourceStatus: + payload = self._request_json("POST", "/secondary/status", payload={}) + return SecondarySourceStatus.from_json(payload) + + def open_secondary_source(self, filename: str, allow_mismatch: bool = False) -> Mapping[str, Any]: + return self._request_json( + "POST", + "/secondary/open", + payload={"filename": filename, "allowMismatch": bool(allow_mismatch)}, + ) + + def close_secondary_source(self) -> SecondarySourceStatus: + payload = self._request_json("POST", "/secondary/close", payload={}) + return SecondarySourceStatus.from_json(payload) + + def set_assembly_source(self, assembly_source: str) -> Mapping[str, Any]: + return self._request_json( + "POST", "/secondary/set_assembly_source", payload={"assemblySource": assembly_source} + ) + + # ------------------------- + # Map tile and region fetch + # ------------------------- + def fetch_tile_with_ranges( + self, + row: int, + col: int, + *, + bp_resolution: Optional[int] = None, + level: Optional[int] = None, + version: int = 0, + tile_size: int = 256, + ) -> TileWithRanges: + if bp_resolution is None and level is None: + raise ValueError("Either bp_resolution or level must be specified") + params: Dict[str, Any] = { + "row": int(row), + "col": int(col), + "version": int(version), + "tile_size": int(tile_size), + "format": "JSON_PNG_WITH_RANGES", + } + if bp_resolution is not None: + params["bpResolution"] = int(bp_resolution) + else: + params["level"] = int(level) + payload = self._request_json("GET", "/get_tile", params=params) + image_data_url = str(payload["image"]) + return TileWithRanges( + image=self._decode_data_url_png(image_data_url), + ranges=TileRanges.from_json(dict(payload.get("ranges", {}))), + image_data_url=image_data_url, + ) + + def fetch_tile_png( + self, + row: int, + col: int, + *, + bp_resolution: Optional[int] = None, + level: Optional[int] = None, + version: int = 0, + tile_size: int = 256, + ) -> np.ndarray: + if bp_resolution is None and level is None: + raise ValueError("Either bp_resolution or level must be specified") + params: Dict[str, Any] = { + "row": int(row), + "col": int(col), + "version": int(version), + "tile_size": int(tile_size), + "format": "PNG", + } + if bp_resolution is not None: + params["bpResolution"] = int(bp_resolution) + else: + params["level"] = int(level) + response = self._request("GET", "/get_tile", params=params) + return self._decode_png_bytes(response.content) + + def fetch_region_pixels( + self, + *, + start_row_px: int, + start_col_px: int, + rows: int, + cols: int, + bp_resolution: int, + version: int = 0, + ) -> np.ndarray: + params = { + "row": int(start_row_px), + "col": int(start_col_px), + "rows": int(rows), + "cols": int(cols), + "bpResolution": int(bp_resolution), + "version": int(version), + "format": "PNG_BY_PIXELS", + } + response = self._request("GET", "/get_tile", params=params) + return self._decode_png_bytes(response.content) + + def fetch_region_bins( + self, + *, + start_row_bin: int, + start_col_bin: int, + rows: int, + cols: int, + bp_resolution: int, + version: int = 0, + ) -> np.ndarray: + converter = self.unit_converter(bp_resolution) + row_px = converter.convert(start_row_bin, Unit.BINS, Unit.PIXELS) + col_px = converter.convert(start_col_bin, Unit.BINS, Unit.PIXELS) + return self.fetch_region_pixels( + start_row_px=row_px, + start_col_px=col_px, + rows=rows, + cols=cols, + bp_resolution=bp_resolution, + version=version, + ) + + def fetch_region_bp( + self, + *, + start_row_bp: int, + start_col_bp: int, + rows: int, + cols: int, + bp_resolution: int, + version: int = 0, + ) -> np.ndarray: + converter = self.unit_converter(bp_resolution) + row_px = converter.convert(start_row_bp, Unit.BP, Unit.PIXELS) + col_px = converter.convert(start_col_bp, Unit.BP, Unit.PIXELS) + return self.fetch_region_pixels( + start_row_px=row_px, + start_col_px=col_px, + rows=rows, + cols=cols, + bp_resolution=bp_resolution, + version=version, + ) + + # ------------------------- + # Unit conversion utilities + # ------------------------- + def unit_converter(self, bp_resolution: int) -> UnitConverter: + return UnitConverter.from_open_file_response(self.require_open_file_response(), bp_resolution) + + def convert_units( + self, + value: int, + *, + from_unit: Union[Unit, str], + to_unit: Union[Unit, str], + bp_resolution: int, + clamp: bool = False, + ) -> int: + return self.unit_converter(bp_resolution).convert(value, from_unit, to_unit, clamp=clamp) + + # ------------------------- + # Visualization options/pipeline + # ------------------------- + def get_visualization_options(self) -> Mapping[str, Any]: + return self._request_json("POST", "/get_visualization_options", payload={}) + + def set_visualization_options(self, options: Mapping[str, Any]) -> Mapping[str, Any]: + return self._request_json("POST", "/set_visualization_options", payload=dict(options)) + + def get_render_pipeline(self) -> Mapping[str, Any]: + return self._request_json("POST", "/render_pipeline/get", payload={}) + + def set_render_pipeline(self, config: Mapping[str, Any]) -> Mapping[str, Any]: + return self._request_json("POST", "/render_pipeline/set", payload=dict(config)) + + def reset_render_pipeline(self) -> Mapping[str, Any]: + return self._request_json("POST", "/render_pipeline/reset", payload={}) + + # ------------------------- + # Scaffolding operations + # ------------------------- + def reverse_selection_range(self, start_bp: int, end_bp: int) -> Mapping[str, Any]: + return self._request_json( + "POST", "/reverse_selection_range", payload={"startBP": int(start_bp), "endBP": int(end_bp)} + ) + + def move_selection_range(self, start_bp: int, end_bp: int, target_start_bp: int) -> Mapping[str, Any]: + return self._request_json( + "POST", + "/move_selection_range", + payload={"startBP": int(start_bp), "endBP": int(end_bp), "targetStartBP": int(target_start_bp)}, + ) + + def split_contig_at_bin(self, split_px: int, bp_resolution: int) -> Mapping[str, Any]: + return self._request_json( + "POST", + "/split_contig_at_bin", + payload={"splitPx": int(split_px), "bpResolution": int(bp_resolution)}, + ) + + def group_contigs_into_scaffold(self, start_bp: int, end_bp: int) -> Mapping[str, Any]: + return self._request_json( + "POST", "/group_contigs_into_scaffold", payload={"startBP": int(start_bp), "endBP": int(end_bp)} + ) + + def ungroup_contigs_from_scaffold(self, start_bp: int, end_bp: int) -> Mapping[str, Any]: + return self._request_json( + "POST", "/ungroup_contigs_from_scaffold", payload={"startBP": int(start_bp), "endBP": int(end_bp)} + ) + + def move_selection_to_debris(self, start_bp: int, end_bp: int) -> Mapping[str, Any]: + return self._request_json( + "POST", "/move_selection_to_debris", payload={"startBP": int(start_bp), "endBP": int(end_bp)} + ) + + # ------------------------- + # FASTA / AGP operations + # ------------------------- + def link_fasta(self, fasta_filename: str, allow_mismatch: bool = False) -> Mapping[str, Any]: + return self._request_json( + "POST", + "/link_fasta", + payload={"fastaFilename": fasta_filename, "allowMismatch": bool(allow_mismatch)}, + ) + + def export_fasta_for_assembly(self) -> str: + return self._request_text("POST", "/get_fasta_for_assembly", payload={}) + + def export_fasta_for_selection( + self, from_bp_x: int, from_bp_y: int, to_bp_x: int, to_bp_y: int + ) -> str: + return self._request_text( + "POST", + "/get_fasta_for_selection", + payload={ + "fromBpX": int(from_bp_x), + "fromBpY": int(from_bp_y), + "toBpX": int(to_bp_x), + "toBpY": int(to_bp_y), + }, + ) + + def export_agp_for_assembly(self, default_spacer_length: int = 1000) -> str: + return self._request_text( + "POST", "/get_agp_for_assembly", payload={"defaultSpacerLength": int(default_spacer_length)} + ) + + def load_agp(self, agp_filename: str) -> Mapping[str, Any]: + return self._request_json("POST", "/load_agp", payload={"agpFilename": agp_filename}) + + # ------------------------- + # Conversion operations + # ------------------------- + def start_conversion_job( + self, + *, + filename: str, + direction: str = "mcool-to-hict", + overwrite: bool = False, + resolutions: Optional[str] = None, + compression: Optional[int] = None, + compression_algorithm: Optional[str] = None, + chunk_size: Optional[int] = None, + parallelism: Optional[int] = None, + ) -> Mapping[str, Any]: + payload: Dict[str, Any] = { + "filename": filename, + "direction": direction, + "overwrite": bool(overwrite), + } + if resolutions is not None: + payload["resolutions"] = resolutions + if compression is not None: + payload["compression"] = int(compression) + if compression_algorithm is not None: + payload["compressionAlgorithm"] = compression_algorithm + if chunk_size is not None: + payload["chunkSize"] = int(chunk_size) + if parallelism is not None: + payload["parallelism"] = int(parallelism) + return self._request_json("POST", "/convert/jobs", payload=payload) + + def start_batch_conversion_jobs( + self, + *, + files: Sequence[str], + parallel_jobs: int, + parallelism: int, + overwrite: bool = False, + resolutions: Optional[str] = None, + compression: Optional[int] = None, + compression_algorithm: Optional[str] = None, + chunk_size: Optional[int] = None, + ) -> Mapping[str, Any]: + payload: Dict[str, Any] = { + "files": list(files), + "parallelJobs": int(parallel_jobs), + "parallelism": int(parallelism), + "overwrite": bool(overwrite), + } + if resolutions is not None: + payload["resolutions"] = resolutions + if compression is not None: + payload["compression"] = int(compression) + if compression_algorithm is not None: + payload["compressionAlgorithm"] = compression_algorithm + if chunk_size is not None: + payload["chunkSize"] = int(chunk_size) + return self._request_json("POST", "/convert/jobs/batch", payload=payload) + + def list_conversion_jobs(self) -> Sequence[Mapping[str, Any]]: + return list(self._request_json("POST", "/convert/jobs/list", payload={})) + + def get_conversion_job(self, job_id: str) -> Mapping[str, Any]: + return self._request_json("GET", "/convert/jobs/%s" % job_id) + + def stop_conversion_job(self, job_id: str) -> Mapping[str, Any]: + return self._request_json("POST", "/convert/jobs/%s/stop" % job_id, payload={}) + + def wait_for_conversion_job( + self, + job_id: str, + *, + poll_interval_sec: float = 1.0, + timeout_sec: float = 600.0, + ) -> Mapping[str, Any]: + deadline = time.monotonic() + timeout_sec + while True: + payload = self.get_conversion_job(job_id) + status = str(payload.get("status", "")) + if status in {"finished", "failed", "cancelled"}: + return payload + if time.monotonic() >= deadline: + raise TimeoutError("Timed out waiting for conversion job %s" % job_id) + time.sleep(poll_interval_sec) + + # ------------------------- + # Misc diagnostics + # ------------------------- + def worker_diagnostics(self) -> Mapping[str, Any]: + return self._request_json("POST", "/diagnostics/workers", payload={}) diff --git a/hict_jvm_api/dataloader.py b/hict_jvm_api/dataloader.py new file mode 100644 index 0000000..7050006 --- /dev/null +++ b/hict_jvm_api/dataloader.py @@ -0,0 +1,119 @@ +"""PyTorch-friendly datasets for HiCT_JVM region sampling.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Callable, List, Optional, Sequence, Tuple + +import numpy as np + +from .client import HiCTJVMClient + + +@dataclass(frozen=True) +class RegionRequest: + """A single region request in visible pixel coordinates.""" + + row_px: int + col_px: int + + +class HiCTRegionDataset: + """Random-access dataset that fetches regions from a running HiCT_JVM session. + + Parameters + ---------- + client: + Active :class:`hict_jvm_api.client.HiCTJVMClient` instance. + bp_resolution: + Resolution used for tile/region extraction. + window_px: + Square crop size in pixels. + num_samples: + Number of samples generated when ``regions`` is not provided. + regions: + Explicit pixel coordinates to fetch. If set, ``num_samples`` is ignored. + seed: + RNG seed used for synthetic sampling. + return_torch: + Return ``torch.Tensor`` instead of ``numpy.ndarray``. + channel_first: + If ``return_torch=True``, output layout is ``(C,H,W)`` when enabled, + otherwise ``(H,W,C)``. + transform: + Optional callable applied to each sample. + """ + + def __init__( + self, + *, + client: HiCTJVMClient, + bp_resolution: int, + window_px: int, + num_samples: int = 1024, + regions: Optional[Sequence[Tuple[int, int]]] = None, + seed: int = 0, + return_torch: bool = False, + channel_first: bool = True, + transform: Optional[Callable[[np.ndarray], object]] = None, + ): + if window_px <= 0: + raise ValueError("window_px must be positive") + if num_samples <= 0 and not regions: + raise ValueError("num_samples must be positive when regions are not provided") + + self.client = client + self.bp_resolution = int(bp_resolution) + self.window_px = int(window_px) + self.return_torch = bool(return_torch) + self.channel_first = bool(channel_first) + self.transform = transform + + converter = self.client.unit_converter(self.bp_resolution) + self._visible_size = int(converter.total_visible_pixels) + max_coord = max(0, self._visible_size - self.window_px) + + if regions is not None: + self._regions = [RegionRequest(int(r), int(c)) for r, c in regions] + else: + rng = np.random.default_rng(seed) + self._regions = [ + RegionRequest( + row_px=int(rng.integers(0, max_coord + 1)), + col_px=int(rng.integers(0, max_coord + 1)), + ) + for _ in range(int(num_samples)) + ] + + def __len__(self) -> int: + return len(self._regions) + + def __getitem__(self, index: int): + region = self._regions[index] + image = self.client.fetch_region_pixels( + start_row_px=region.row_px, + start_col_px=region.col_px, + rows=self.window_px, + cols=self.window_px, + bp_resolution=self.bp_resolution, + ) + + if self.transform is not None: + image = self.transform(image) + return image + + if not self.return_torch: + return image + + try: + import torch + except ImportError as exc: + raise RuntimeError( + "return_torch=True requires torch to be installed. " + "Install optional dependency: pip install 'hict-jvm-api[torch]'" + ) from exc + + tensor = torch.from_numpy(np.ascontiguousarray(image)) + if self.channel_first: + tensor = tensor.permute(2, 0, 1) + return tensor.float() / 255.0 diff --git a/hict_jvm_api/exceptions.py b/hict_jvm_api/exceptions.py new file mode 100644 index 0000000..11de443 --- /dev/null +++ b/hict_jvm_api/exceptions.py @@ -0,0 +1,20 @@ +"""Exceptions for the HiCT JVM API client.""" + +from typing import Optional + + +class HiCTAPIError(RuntimeError): + """Raised when the HiCT_JVM API returns an HTTP or API-level error.""" + + def __init__(self, message: str, status_code: Optional[int] = None, payload: Optional[object] = None): + super().__init__(message) + self.status_code = status_code + self.payload = payload + + +class HiCTClientStateError(RuntimeError): + """Raised when a client operation requires an opened session but none is available.""" + + +class HiddenCoordinateError(ValueError): + """Raised when converting to PIXELS for coordinates located in hidden contigs.""" diff --git a/hict_jvm_api/models.py b/hict_jvm_api/models.py new file mode 100644 index 0000000..44fcb93 --- /dev/null +++ b/hict_jvm_api/models.py @@ -0,0 +1,175 @@ +"""Typed models used by :mod:`hict_jvm_api`. + +These dataclasses mirror the payloads returned by HiCT_JVM endpoints. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum +from typing import Dict, List, Mapping, Optional, Sequence + +import numpy as np + + +class Unit(str, Enum): + """Length units supported by HiCT queries and conversions.""" + + BP = "BP" + BINS = "BINS" + PIXELS = "PIXELS" + + +@dataclass(frozen=True) +class ContigDescriptor: + """Contig metadata for the current assembly state.""" + + contig_id: int + contig_name: str + contig_original_name: str + contig_source_name: str + contig_offset_in_source: int + contig_direction: int + contig_length_bp: int + contig_length_bins: Dict[int, int] + contig_presence_at_resolution: Dict[int, int] + + @staticmethod + def from_json(payload: Mapping[str, object]) -> "ContigDescriptor": + length_bins_raw = payload.get("contigLengthBins", {}) + presence_raw = payload.get("contigPresenceAtResolution", {}) + return ContigDescriptor( + contig_id=int(payload.get("contigId", 0)), + contig_name=str(payload.get("contigName", "")), + contig_original_name=str(payload.get("contigOriginalName", "")), + contig_source_name=str(payload.get("contigSourceName", "")), + contig_offset_in_source=int(payload.get("contigOffsetInSource", 0)), + contig_direction=int(payload.get("contigDirection", 0)), + contig_length_bp=int(payload.get("contigLengthBp", 0)), + contig_length_bins={int(k): int(v) for k, v in dict(length_bins_raw).items()}, + contig_presence_at_resolution={int(k): int(v) for k, v in dict(presence_raw).items()}, + ) + + +@dataclass(frozen=True) +class ScaffoldDescriptor: + """Scaffold metadata for the current assembly state.""" + + scaffold_id: int + scaffold_name: str + scaffold_original_name: str + spacer_length: int + scaffold_borders_bp: Optional[Dict[str, int]] + + @staticmethod + def from_json(payload: Mapping[str, object]) -> "ScaffoldDescriptor": + borders = payload.get("scaffoldBordersBP") + normalized_borders: Optional[Dict[str, int]] + if isinstance(borders, Mapping): + normalized_borders = {str(k): int(v) for k, v in borders.items()} + else: + normalized_borders = None + return ScaffoldDescriptor( + scaffold_id=int(payload.get("scaffoldId", 0)), + scaffold_name=str(payload.get("scaffoldName", "")), + scaffold_original_name=str(payload.get("scaffoldOriginalName", "")), + spacer_length=int(payload.get("spacerLength", 0)), + scaffold_borders_bp=normalized_borders, + ) + + +@dataclass(frozen=True) +class AssemblyInfo: + """Assembly descriptors returned by HiCT_JVM.""" + + contigs: List[ContigDescriptor] + scaffolds: List[ScaffoldDescriptor] + + @staticmethod + def from_json(payload: Mapping[str, object]) -> "AssemblyInfo": + contigs_raw = payload.get("contigDescriptors", []) + scaffolds_raw = payload.get("scaffoldDescriptors", []) + return AssemblyInfo( + contigs=[ContigDescriptor.from_json(item) for item in list(contigs_raw)], + scaffolds=[ScaffoldDescriptor.from_json(item) for item in list(scaffolds_raw)], + ) + + +@dataclass(frozen=True) +class OpenFileResponse: + """Response returned by ``/open`` and embedded in ``/attach``.""" + + status: str + dtype: str + resolutions: List[int] + pixel_resolutions: List[float] + tile_size: int + assembly_info: AssemblyInfo + matrix_sizes_bins: List[int] + + @staticmethod + def from_json(payload: Mapping[str, object]) -> "OpenFileResponse": + return OpenFileResponse( + status=str(payload.get("status", "")), + dtype=str(payload.get("dtype", "")), + resolutions=[int(v) for v in list(payload.get("resolutions", []))], + pixel_resolutions=[float(v) for v in list(payload.get("pixelResolutions", []))], + tile_size=int(payload.get("tileSize", 256)), + assembly_info=AssemblyInfo.from_json(dict(payload.get("assemblyInfo", {}))), + matrix_sizes_bins=[int(v) for v in list(payload.get("matrixSizesBins", []))], + ) + + def matrix_size_for_resolution(self, bp_resolution: int) -> int: + """Return matrix size in bins for a given bp resolution.""" + if bp_resolution not in self.resolutions: + raise KeyError(f"Resolution {bp_resolution} is not present in opened dataset") + idx = self.resolutions.index(bp_resolution) + return self.matrix_sizes_bins[idx] + + +@dataclass(frozen=True) +class TileRanges: + """Per-resolution signal ranges returned together with tile image.""" + + lower_bounds: Dict[int, float] + upper_bounds: Dict[int, float] + + @staticmethod + def from_json(payload: Mapping[str, object]) -> "TileRanges": + lower = payload.get("lowerBounds", {}) + upper = payload.get("upperBounds", {}) + return TileRanges( + lower_bounds={int(k): float(v) for k, v in dict(lower).items()}, + upper_bounds={int(k): float(v) for k, v in dict(upper).items()}, + ) + + +@dataclass(frozen=True) +class TileWithRanges: + """Tile image together with signal ranges metadata.""" + + image: np.ndarray + ranges: TileRanges + image_data_url: str + + +@dataclass(frozen=True) +class SecondarySourceStatus: + """Secondary source status returned by ``/secondary/status`` and related endpoints.""" + + attached: bool + filename: str + assembly_source: str + compatibility: Optional[Mapping[str, object]] + warnings: Sequence[str] + + @staticmethod + def from_json(payload: Mapping[str, object]) -> "SecondarySourceStatus": + warnings = payload.get("warnings", []) + return SecondarySourceStatus( + attached=bool(payload.get("attached", False)), + filename=str(payload.get("filename", "")), + assembly_source=str(payload.get("assemblySource", "PRIMARY")), + compatibility=payload.get("compatibility") if isinstance(payload.get("compatibility"), Mapping) else None, + warnings=[str(item) for item in list(warnings)], + ) diff --git a/hict_jvm_api/py.typed b/hict_jvm_api/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/hict_jvm_api/units.py b/hict_jvm_api/units.py new file mode 100644 index 0000000..0d04ebb --- /dev/null +++ b/hict_jvm_api/units.py @@ -0,0 +1,226 @@ +"""Coordinate conversion utilities for HiCT units. + +The converter uses assembly descriptors returned by HiCT_JVM (`/open` or `/attach`) +to translate between: + +- base pairs (BP) +- bins at a specific bp resolution (BINS) +- visible pixels (PIXELS), i.e. bins belonging only to contigs marked as SHOWN + at this resolution. +""" + +from __future__ import annotations + +from bisect import bisect_right +from dataclasses import dataclass +from typing import Dict, Iterable, List, Sequence, Tuple, Union + +from .exceptions import HiddenCoordinateError +from .models import AssemblyInfo, OpenFileResponse, Unit + + +@dataclass(frozen=True) +class _Span: + bp_start: int + bp_end: int + bin_start: int + bin_end: int + visible_start: int + visible_end: int + shown: bool + + +class UnitConverter: + """Fast converter between BP/BINS/PIXELS for a given assembly and resolution. + + Parameters + ---------- + assembly_info: + Assembly information from :class:`hict_jvm_api.models.OpenFileResponse`. + bp_resolution: + Resolution in bp used for BINS and PIXELS conversion. + """ + + def __init__(self, assembly_info: AssemblyInfo, bp_resolution: int): + if bp_resolution <= 0: + raise ValueError("bp_resolution must be positive") + self.bp_resolution = int(bp_resolution) + self._spans: List[_Span] = [] + self._bp_ends: List[int] = [] + self._bin_ends: List[int] = [] + self._shown_spans: List[_Span] = [] + self._visible_ends: List[int] = [] + + bp_cursor = 0 + bin_cursor = 0 + visible_cursor = 0 + + for contig in assembly_info.contigs: + bins = contig.contig_length_bins.get(self.bp_resolution) + if bins is None: + # Fall back to ceil(length_bp / bp_resolution) for robustness. + bins = max(1, (int(contig.contig_length_bp) + self.bp_resolution - 1) // self.bp_resolution) + bins = int(bins) + if bins <= 0: + continue + bp_len = int(contig.contig_length_bp) + shown = int(contig.contig_presence_at_resolution.get(self.bp_resolution, 1)) == 1 + + span = _Span( + bp_start=bp_cursor, + bp_end=bp_cursor + bp_len, + bin_start=bin_cursor, + bin_end=bin_cursor + bins, + visible_start=visible_cursor, + visible_end=visible_cursor + (bins if shown else 0), + shown=shown, + ) + self._spans.append(span) + self._bp_ends.append(span.bp_end) + self._bin_ends.append(span.bin_end) + if shown: + self._shown_spans.append(span) + self._visible_ends.append(span.visible_end) + + bp_cursor = span.bp_end + bin_cursor = span.bin_end + visible_cursor = span.visible_end + + self.total_bp = bp_cursor + self.total_bins = bin_cursor + self.total_visible_pixels = visible_cursor + + @classmethod + def from_open_file_response(cls, response: OpenFileResponse, bp_resolution: int) -> "UnitConverter": + """Build a converter from an opened-session response.""" + return cls(response.assembly_info, bp_resolution) + + @staticmethod + def _normalize_unit(unit: Union[Unit, str]) -> Unit: + if isinstance(unit, Unit): + return unit + token = str(unit).strip().upper() + if token in {"BP", "BASE_PAIRS", "BASEPAIR", "BASEPAIRS"}: + return Unit.BP + if token in {"BIN", "BINS"}: + return Unit.BINS + if token in {"PX", "PIXEL", "PIXELS"}: + return Unit.PIXELS + raise ValueError("Unsupported unit: %s" % unit) + + @staticmethod + def _guard_range(value: int, size: int, unit: str, clamp: bool) -> int: + if size <= 0: + raise ValueError("Cannot convert with empty assembly") + if clamp: + return min(max(int(value), 0), size - 1) + ivalue = int(value) + if ivalue < 0 or ivalue >= size: + raise ValueError("%s coordinate %d is out of range [0, %d)" % (unit, ivalue, size)) + return ivalue + + def _span_by_bp(self, bp: int) -> _Span: + idx = bisect_right(self._bp_ends, bp) + if idx >= len(self._spans): + raise ValueError("BP coordinate %d is outside assembly" % bp) + return self._spans[idx] + + def _span_by_bin(self, bin_coord: int) -> _Span: + idx = bisect_right(self._bin_ends, bin_coord) + if idx >= len(self._spans): + raise ValueError("BIN coordinate %d is outside assembly" % bin_coord) + return self._spans[idx] + + def _shown_span_by_pixel(self, pixel: int) -> _Span: + idx = bisect_right(self._visible_ends, pixel) + if idx >= len(self._shown_spans): + raise ValueError("PIXEL coordinate %d is outside visible range" % pixel) + return self._shown_spans[idx] + + def bp_to_bins(self, bp: int, clamp: bool = False) -> int: + """Convert BP coordinate to BINS coordinate.""" + bp = self._guard_range(bp, self.total_bp, "BP", clamp) + span = self._span_by_bp(bp) + offset_bp = bp - span.bp_start + candidate = span.bin_start + (offset_bp // self.bp_resolution) + return min(candidate, span.bin_end - 1) + + def bins_to_bp(self, bins: int, clamp: bool = False) -> int: + """Convert BINS coordinate to BP coordinate.""" + bins = self._guard_range(bins, self.total_bins, "BINS", clamp) + span = self._span_by_bin(bins) + offset_bin = bins - span.bin_start + candidate = span.bp_start + offset_bin * self.bp_resolution + return min(candidate, span.bp_end - 1) + + def bins_to_pixels(self, bins: int, clamp: bool = False) -> int: + """Convert BINS coordinate to PIXELS coordinate. + + Raises + ------ + HiddenCoordinateError + If `bins` belongs to a hidden contig at this resolution. + """ + bins = self._guard_range(bins, self.total_bins, "BINS", clamp) + span = self._span_by_bin(bins) + if not span.shown: + raise HiddenCoordinateError( + "BIN coordinate %d belongs to hidden contig at %dbp resolution" % (bins, self.bp_resolution) + ) + return span.visible_start + (bins - span.bin_start) + + def pixels_to_bins(self, pixels: int, clamp: bool = False) -> int: + """Convert PIXELS coordinate to BINS coordinate.""" + pixels = self._guard_range(pixels, self.total_visible_pixels, "PIXELS", clamp) + span = self._shown_span_by_pixel(pixels) + return span.bin_start + (pixels - span.visible_start) + + def bp_to_pixels(self, bp: int, clamp: bool = False) -> int: + """Convert BP coordinate to PIXELS coordinate.""" + return self.bins_to_pixels(self.bp_to_bins(bp, clamp=clamp), clamp=clamp) + + def pixels_to_bp(self, pixels: int, clamp: bool = False) -> int: + """Convert PIXELS coordinate to BP coordinate.""" + return self.bins_to_bp(self.pixels_to_bins(pixels, clamp=clamp), clamp=clamp) + + def convert(self, value: int, from_unit: Union[Unit, str], to_unit: Union[Unit, str], clamp: bool = False) -> int: + """Convert a single coordinate between units.""" + src = self._normalize_unit(from_unit) + dst = self._normalize_unit(to_unit) + if src == dst: + return int(value) + + if src == Unit.BP: + if dst == Unit.BINS: + return self.bp_to_bins(value, clamp=clamp) + return self.bp_to_pixels(value, clamp=clamp) + + if src == Unit.BINS: + if dst == Unit.BP: + return self.bins_to_bp(value, clamp=clamp) + return self.bins_to_pixels(value, clamp=clamp) + + # src == PIXELS + if dst == Unit.BINS: + return self.pixels_to_bins(value, clamp=clamp) + return self.pixels_to_bp(value, clamp=clamp) + + def convert_interval( + self, + start: int, + end_exclusive: int, + from_unit: Union[Unit, str], + to_unit: Union[Unit, str], + clamp: bool = False, + ) -> Tuple[int, int]: + """Convert a half-open interval ``[start, end_exclusive)`` between units.""" + if end_exclusive <= start: + raise ValueError("end_exclusive must be larger than start") + dst_start = self.convert(start, from_unit, to_unit, clamp=clamp) + dst_end_inclusive = self.convert(end_exclusive - 1, from_unit, to_unit, clamp=clamp) + return dst_start, dst_end_inclusive + 1 + + @property + def spans(self) -> Sequence[_Span]: + """Read-only sequence of conversion spans (mainly for debugging/tests).""" + return tuple(self._spans) diff --git a/notebooks/jvm_api_pytorch_dataloader.ipynb b/notebooks/jvm_api_pytorch_dataloader.ipynb new file mode 100644 index 0000000..c856487 --- /dev/null +++ b/notebooks/jvm_api_pytorch_dataloader.ipynb @@ -0,0 +1,75 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# HiCT JVM API + PyTorch DataLoader\\n", + "\\n", + "This notebook shows how to sample matrix regions for model training." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\\n", + "from torch.utils.data import DataLoader\\n", + "from hict_jvm_api import HiCTJVMClient, HiCTRegionDataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "client = HiCTJVMClient(\"http://localhost:5001\")\\n", + "opened = client.open_file(\"build/quad/combined_ind2_4DN.hict.hdf5\")\\n", + "resolution = opened.resolutions[0]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "dataset = HiCTRegionDataset(\\n", + " client=client,\\n", + " bp_resolution=resolution,\\n", + " window_px=128,\\n", + " num_samples=32,\\n", + " return_torch=True,\\n", + " channel_first=True,\\n", + ")\\n", + "loader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=0)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "batch = next(iter(loader))\\n", + "batch.shape" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebooks/jvm_api_quickstart.ipynb b/notebooks/jvm_api_quickstart.ipynb new file mode 100644 index 0000000..d43e77b --- /dev/null +++ b/notebooks/jvm_api_quickstart.ipynb @@ -0,0 +1,83 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# HiCT JVM API Quickstart\\n", + "\\n", + "This notebook demonstrates basic usage of `hict_jvm_api` with a running HiCT_JVM server." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from hict_jvm_api import HiCTJVMClient, Unit" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "BASE_URL = \"http://localhost:5001\"\\n", + "DATASET = \"build/quad/combined_ind2_4DN.hict.hdf5\"\\n", + "\\n", + "client = HiCTJVMClient(BASE_URL)\\n", + "opened = client.open_file(DATASET)\\n", + "opened" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "resolution = opened.resolutions[0] # coarse\\n", + "tile = client.fetch_region_pixels(\\n", + " start_row_px=0,\\n", + " start_col_px=0,\\n", + " rows=256,\\n", + " cols=256,\\n", + " bp_resolution=resolution,\\n", + ")\\n", + "tile.shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Unit conversion\\n", + "bp_value = 1_000_000\\n", + "px_value = client.convert_units(\\n", + " bp_value,\\n", + " from_unit=Unit.BP,\\n", + " to_unit=Unit.PIXELS,\\n", + " bp_resolution=resolution,\\n", + ")\\n", + "bp_value, px_value" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..aa059c9 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,36 @@ +[build-system] +requires = ["setuptools>=68", "wheel>=0.41"] +build-backend = "setuptools.build_meta" + +[project] +name = "hict-jvm-api" +version = "0.1.0" +description = "Python client for controlling HiCT_JVM and fetching Hi-C map regions" +readme = "README.md" +requires-python = ">=3.9" +license = { text = "MIT" } +authors = [ + { name = "CT Lab ITMO University team" } +] +dependencies = [ + "requests>=2.28", + "numpy>=1.23", + "Pillow>=9.5", +] + +[project.optional-dependencies] +torch = [ + "torch>=2.0" +] +dev = [ + "pytest>=7", + "mypy>=1.8", + "hypothesis>=6.61", +] + +[tool.setuptools.packages.find] +where = ["."] +include = ["hict*", "hict_jvm_api*"] + +[tool.setuptools.package-data] +"*" = ["py.typed"] diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..0f71cca --- /dev/null +++ b/pytest.ini @@ -0,0 +1,3 @@ +[pytest] +markers = + integration: optional tests that require a running HiCT_JVM server and data files diff --git a/requirements.txt b/requirements.txt index 1100e0b..c06e45c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,6 +4,8 @@ recordclass>=0.17.2 frozendict>=2.3.4 scipy>=1.8.1 numpy>=1.23.2 +requests>=2.28 +Pillow>=9.5 cachetools>=5.2.0 bio>=1.3.9 readerwriterlock>=1.0.9 diff --git a/run_jvm_api_optional_data_tests.sh b/run_jvm_api_optional_data_tests.sh new file mode 100755 index 0000000..252c5e3 --- /dev/null +++ b/run_jvm_api_optional_data_tests.sh @@ -0,0 +1,14 @@ +#!/bin/bash +set -euo pipefail + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +export PYTHONPATH="$SCRIPT_DIR:${PYTHONPATH:-}" + +# Required env vars for integration tests: +# HICT_JVM_API_BASE_URL +# Optional env vars: +# HICT_DATASET_FILE +# HICT_FASTA_FILE +# HICT_AGP_FILE +# HICT_JVM_API_ALLOW_MUTATION=true +pytest -vv "$SCRIPT_DIR/tests_jvm_api/test_optional_integration.py" -m integration diff --git a/run_jvm_api_tests.sh b/run_jvm_api_tests.sh new file mode 100755 index 0000000..3ce2e98 --- /dev/null +++ b/run_jvm_api_tests.sh @@ -0,0 +1,7 @@ +#!/bin/bash +set -euo pipefail + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +export PYTHONPATH="$SCRIPT_DIR:${PYTHONPATH:-}" + +pytest -vv "$SCRIPT_DIR/tests_jvm_api" -m "not integration" diff --git a/run_tests.sh b/run_tests.sh index 80e2b35..aa2c08a 100755 --- a/run_tests.sh +++ b/run_tests.sh @@ -27,4 +27,5 @@ SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) HICT_DIR="$SCRIPT_DIR/../HiCT_Library/" export PYTHONPATH="$PYTHONPATH:$HICT_DIR" TESTS_DIR="$HICT_DIR/tests/" -pytest -vvv -x -n16 $TESTS_DIR +JVM_API_TESTS_DIR="$HICT_DIR/tests_jvm_api/" +pytest -vvv -x -n16 "$TESTS_DIR" "$JVM_API_TESTS_DIR" -m "not integration" diff --git a/tests_jvm_api/test_client.py b/tests_jvm_api/test_client.py new file mode 100644 index 0000000..7f876db --- /dev/null +++ b/tests_jvm_api/test_client.py @@ -0,0 +1,165 @@ +import base64 +from io import BytesIO +from typing import Any, Dict, List, Mapping +from urllib.parse import parse_qs, urlparse + +import numpy as np +from PIL import Image + +from hict_jvm_api.client import HiCTJVMClient +from hict_jvm_api.exceptions import HiCTAPIError + + +class _FakeResponse: + def __init__(self, status_code: int, *, json_payload: Any = None, text: str = "", content: bytes = b""): + self.status_code = status_code + self._json_payload = json_payload + self.text = text if text else ("" if json_payload is not None else content.decode("utf-8", errors="ignore")) + self.content = content + + def json(self) -> Any: + if self._json_payload is None: + raise ValueError("No JSON payload") + return self._json_payload + + +class _FakeSession: + def __init__(self, handler): + self._handler = handler + + def request(self, method: str, url: str, *, params: Mapping[str, Any], json: Mapping[str, Any], timeout: float): + parsed = urlparse(url) + query = {k: v[0] for k, v in parse_qs(parsed.query).items()} + merged_params = dict(query) + merged_params.update({k: str(v) for k, v in dict(params).items()}) + return self._handler(method.upper(), parsed.path, merged_params, dict(json)) + + +def _png_bytes(rgba: tuple[int, int, int, int], width: int = 2, height: int = 2) -> bytes: + arr = np.zeros((height, width, 4), dtype=np.uint8) + arr[:, :] = np.array(rgba, dtype=np.uint8) + image = Image.fromarray(arr, mode="RGBA") + bio = BytesIO() + image.save(bio, format="PNG") + return bio.getvalue() + + +def _data_url_from_png_bytes(content: bytes) -> str: + return "data:image/png;base64," + base64.b64encode(content).decode("ascii") + + +def test_open_and_fetch_tile_with_ranges_decodes_image() -> None: + png = _png_bytes((10, 20, 30, 255)) + data_url = _data_url_from_png_bytes(png) + + def handler(method: str, path: str, params: Mapping[str, str], payload: Mapping[str, Any]) -> _FakeResponse: + if method == "POST" and path == "/open": + assert payload["filename"] == "build/quad/combined_ind2_4DN.hict.hdf5" + return _FakeResponse( + 200, + json_payload={ + "status": "Opened", + "dtype": "uint8", + "resolutions": [50000, 1000], + "pixelResolutions": [50.0, 1.0], + "tileSize": 256, + "assemblyInfo": { + "contigDescriptors": [ + { + "contigId": 1, + "contigName": "ctg1", + "contigOriginalName": "ctg1", + "contigSourceName": "ctg1", + "contigOffsetInSource": 0, + "contigDirection": 0, + "contigLengthBp": 200, + "contigLengthBins": {"50000": 1, "1000": 1}, + "contigPresenceAtResolution": {"50000": 1, "1000": 1}, + } + ], + "scaffoldDescriptors": [], + }, + "matrixSizesBins": [1, 1], + }, + ) + + if method == "GET" and path == "/get_tile": + assert params["format"] == "JSON_PNG_WITH_RANGES" + return _FakeResponse( + 200, + json_payload={ + "image": data_url, + "ranges": { + "lowerBounds": {"1": 0.0}, + "upperBounds": {"1": 1.0}, + }, + }, + ) + + raise AssertionError(f"Unexpected request: {method} {path}") + + client = HiCTJVMClient("http://test", session=_FakeSession(handler)) + + opened = client.open_file("build/quad/combined_ind2_4DN.hict.hdf5") + assert opened.resolutions == [50000, 1000] + + tile = client.fetch_tile_with_ranges(row=0, col=0, bp_resolution=50000) + assert tile.image.shape == (2, 2, 4) + assert int(tile.image[0, 0, 0]) == 10 + assert tile.ranges.upper_bounds[1] == 1.0 + + + +def test_fetch_region_pixels_uses_png_by_pixels() -> None: + png = _png_bytes((3, 4, 5, 255), width=4, height=3) + + def handler(method: str, path: str, params: Mapping[str, str], payload: Mapping[str, Any]) -> _FakeResponse: + if method == "GET" and path == "/get_tile": + assert params["format"] == "PNG_BY_PIXELS" + assert params["rows"] == "3" + assert params["cols"] == "4" + return _FakeResponse(200, content=png) + raise AssertionError(f"Unexpected request: {method} {path}") + + client = HiCTJVMClient("http://test", session=_FakeSession(handler)) + + image = client.fetch_region_pixels( + start_row_px=10, + start_col_px=20, + rows=3, + cols=4, + bp_resolution=50000, + ) + assert image.shape == (3, 4, 4) + assert int(image[0, 0, 1]) == 4 + + + +def test_start_conversion_job_sends_overwrite() -> None: + seen_payloads: List[Dict[str, Any]] = [] + + def handler(method: str, path: str, params: Mapping[str, str], payload: Mapping[str, Any]) -> _FakeResponse: + if method == "POST" and path == "/convert/jobs": + seen_payloads.append(dict(payload)) + return _FakeResponse(200, json_payload={"status": "submitted", "jobId": "job-1"}) + raise AssertionError(f"Unexpected request: {method} {path}") + + client = HiCTJVMClient("http://test", session=_FakeSession(handler)) + + response = client.start_conversion_job(filename="x.mcool", overwrite=True) + assert response["jobId"] == "job-1" + assert seen_payloads[-1]["overwrite"] is True + + + +def test_api_error_payload_raises_exception() -> None: + def handler(method: str, path: str, params: Mapping[str, str], payload: Mapping[str, Any]) -> _FakeResponse: + return _FakeResponse(200, json_payload={"error": "Something failed"}) + + client = HiCTJVMClient("http://test", session=_FakeSession(handler)) + + try: + client.list_files() + assert False, "Expected HiCTAPIError" + except HiCTAPIError as err: + assert "Something failed" in str(err) diff --git a/tests_jvm_api/test_dataloader.py b/tests_jvm_api/test_dataloader.py new file mode 100644 index 0000000..0afcb0a --- /dev/null +++ b/tests_jvm_api/test_dataloader.py @@ -0,0 +1,53 @@ +import numpy as np + +from hict_jvm_api.dataloader import HiCTRegionDataset + + +class _DummyConverter: + def __init__(self, total_visible_pixels: int = 128): + self.total_visible_pixels = total_visible_pixels + + +class _DummyClient: + def __init__(self) -> None: + self.requests = [] + + def unit_converter(self, bp_resolution: int): + assert bp_resolution == 50000 + return _DummyConverter(total_visible_pixels=128) + + def fetch_region_pixels(self, *, start_row_px: int, start_col_px: int, rows: int, cols: int, bp_resolution: int): + self.requests.append((start_row_px, start_col_px, rows, cols, bp_resolution)) + arr = np.zeros((rows, cols, 4), dtype=np.uint8) + arr[..., 0] = start_row_px % 255 + arr[..., 1] = start_col_px % 255 + arr[..., 2] = 100 + arr[..., 3] = 255 + return arr + + +def test_dataset_random_sampling_numpy_output() -> None: + client = _DummyClient() + ds = HiCTRegionDataset(client=client, bp_resolution=50000, window_px=16, num_samples=5, seed=123) + + assert len(ds) == 5 + sample = ds[0] + assert sample.shape == (16, 16, 4) + assert sample.dtype == np.uint8 + assert len(client.requests) == 1 + + +def test_dataset_explicit_regions_and_transform() -> None: + client = _DummyClient() + ds = HiCTRegionDataset( + client=client, + bp_resolution=50000, + window_px=8, + regions=[(10, 20), (30, 40)], + transform=lambda image: int(image[0, 0, 0]) + int(image[0, 0, 1]), + ) + + assert len(ds) == 2 + value = ds[1] + assert value == 70 + assert client.requests[-1][:2] == (30, 40) diff --git a/tests_jvm_api/test_optional_integration.py b/tests_jvm_api/test_optional_integration.py new file mode 100644 index 0000000..ad09095 --- /dev/null +++ b/tests_jvm_api/test_optional_integration.py @@ -0,0 +1,93 @@ +import os + +import pytest + +from hict_jvm_api.client import HiCTJVMClient +from hict_jvm_api.models import Unit + +BASE_URL = os.getenv("HICT_JVM_API_BASE_URL", "").strip() +DATASET_FILE = os.getenv("HICT_DATASET_FILE", "").strip() +FASTA_FILE = os.getenv("HICT_FASTA_FILE", "").strip() +AGP_FILE = os.getenv("HICT_AGP_FILE", "").strip() +ALLOW_MUTATION = os.getenv("HICT_JVM_API_ALLOW_MUTATION", "false").strip().lower() in {"1", "true", "yes"} + +pytestmark = pytest.mark.integration + + +@pytest.fixture(scope="module") +def client(): + if not BASE_URL: + pytest.skip("Set HICT_JVM_API_BASE_URL to run optional integration tests") + c = HiCTJVMClient(BASE_URL) + if DATASET_FILE: + c.open_file(DATASET_FILE) + yield c + try: + c.close_session() + except Exception: + pass + c.close() + + +def test_open_and_fetch_region(client: HiCTJVMClient) -> None: + if not DATASET_FILE: + pytest.skip("Set HICT_DATASET_FILE to run dataset integration tests") + opened = client.require_open_file_response() + coarse_resolution = opened.resolutions[0] + + image = client.fetch_region_pixels( + start_row_px=0, + start_col_px=0, + rows=32, + cols=32, + bp_resolution=coarse_resolution, + ) + assert image.shape == (32, 32, 4) + + converter = client.unit_converter(coarse_resolution) + assert converter.convert(0, Unit.BP, Unit.BINS) == 0 + + +def test_scaffolding_reverse_twice(client: HiCTJVMClient) -> None: + if not DATASET_FILE: + pytest.skip("Set HICT_DATASET_FILE to run dataset integration tests") + if not ALLOW_MUTATION: + pytest.skip("Set HICT_JVM_API_ALLOW_MUTATION=true to run mutating scaffolding tests") + + opened = client.require_open_file_response() + coarse_resolution = opened.resolutions[0] + total_bp = client.unit_converter(coarse_resolution).total_bp + if total_bp < 4: + pytest.skip("Dataset is too small for reverse operation test") + + start_bp = 0 + end_bp = min(total_bp - 1, max(2, total_bp // 100)) + + first = client.reverse_selection_range(start_bp, end_bp) + second = client.reverse_selection_range(start_bp, end_bp) + assert first["version"] < second["version"] + + +def test_fasta_link_and_export(client: HiCTJVMClient) -> None: + if not DATASET_FILE: + pytest.skip("Set HICT_DATASET_FILE to run dataset integration tests") + if not FASTA_FILE: + pytest.skip("Set HICT_FASTA_FILE to test FASTA integration") + + link_result = client.link_fasta(FASTA_FILE, allow_mismatch=True) + assert "linked" in link_result or "warnings" in link_result + + fasta_text = client.export_fasta_for_selection(0, 0, 1_000, 1_000) + assert isinstance(fasta_text, str) + + +def test_agp_load_and_export(client: HiCTJVMClient) -> None: + if not DATASET_FILE: + pytest.skip("Set HICT_DATASET_FILE to run dataset integration tests") + if not AGP_FILE: + pytest.skip("Set HICT_AGP_FILE to test AGP integration") + + _ = client.load_agp(AGP_FILE) + agp_text = client.export_agp_for_assembly(default_spacer_length=1000) + assert isinstance(agp_text, str) + assert len(agp_text) > 0 diff --git a/tests_jvm_api/test_units.py b/tests_jvm_api/test_units.py new file mode 100644 index 0000000..2c9c39d --- /dev/null +++ b/tests_jvm_api/test_units.py @@ -0,0 +1,82 @@ +from hict_jvm_api.exceptions import HiddenCoordinateError +from hict_jvm_api.models import AssemblyInfo, ContigDescriptor, OpenFileResponse +from hict_jvm_api.units import UnitConverter + + +def _open_response_with_hidden_middle_contig() -> OpenFileResponse: + contigs = [ + ContigDescriptor( + contig_id=1, + contig_name="c1", + contig_original_name="c1", + contig_source_name="c1", + contig_offset_in_source=0, + contig_direction=0, + contig_length_bp=100, + contig_length_bins={10: 10}, + contig_presence_at_resolution={10: 1}, + ), + ContigDescriptor( + contig_id=2, + contig_name="c2", + contig_original_name="c2", + contig_source_name="c2", + contig_offset_in_source=0, + contig_direction=0, + contig_length_bp=50, + contig_length_bins={10: 5}, + contig_presence_at_resolution={10: 0}, + ), + ContigDescriptor( + contig_id=3, + contig_name="c3", + contig_original_name="c3", + contig_source_name="c3", + contig_offset_in_source=0, + contig_direction=0, + contig_length_bp=100, + contig_length_bins={10: 10}, + contig_presence_at_resolution={10: 1}, + ), + ] + return OpenFileResponse( + status="Opened", + dtype="uint8", + resolutions=[10], + pixel_resolutions=[1.0], + tile_size=256, + assembly_info=AssemblyInfo(contigs=contigs, scaffolds=[]), + matrix_sizes_bins=[25], + ) + + +def test_converter_total_lengths_and_visibility() -> None: + converter = UnitConverter.from_open_file_response(_open_response_with_hidden_middle_contig(), bp_resolution=10) + assert converter.total_bp == 250 + assert converter.total_bins == 25 + assert converter.total_visible_pixels == 20 + + +def test_converter_bp_bin_pixel_roundtrip_on_shown_contig() -> None: + converter = UnitConverter.from_open_file_response(_open_response_with_hidden_middle_contig(), bp_resolution=10) + + # c3 starts at bp=150 and bin=15, but visible pixels continue from c1 (10 pixels). + assert converter.bp_to_bins(200) == 20 + assert converter.bins_to_bp(20) == 200 + assert converter.bins_to_pixels(20) == 15 + assert converter.pixels_to_bins(15) == 20 + assert converter.pixels_to_bp(15) == 200 + + +def test_converter_hidden_contig_raises_for_pixel_conversion() -> None: + converter = UnitConverter.from_open_file_response(_open_response_with_hidden_middle_contig(), bp_resolution=10) + + # bp=120 belongs to hidden contig c2. + hidden_bin = converter.bp_to_bins(120) + assert hidden_bin == 12 + + try: + converter.bins_to_pixels(hidden_bin) + assert False, "Expected HiddenCoordinateError" + except HiddenCoordinateError: + pass From 864ea0b663051fd17aa2d685ea138f810b419d4d Mon Sep 17 00:00:00 2001 From: Alexander Serdyukov Date: Fri, 3 Apr 2026 00:14:37 +0400 Subject: [PATCH 2/3] Intermediate API/stability fixes --- .github/workflows/autotests.yml | 72 ++++++++++--------------- README.md | 35 ++++++++++-- doc/jvm_api_v1.md | 15 ++++-- hict/__init__.py | 90 +++++++++++++++++++------------ hict/api/__init__.py | 53 ++++++++---------- hict_jvm_api/client.py | 73 +++++++++++++++++++++---- pyproject.toml | 6 +-- run_tests.sh | 11 +++- setup.py | 56 ++----------------- tests_jvm_api/test_client.py | 11 ++++ tests_jvm_api/test_hict_facade.py | 5 ++ 11 files changed, 246 insertions(+), 181 deletions(-) create mode 100644 tests_jvm_api/test_hict_facade.py diff --git a/.github/workflows/autotests.yml b/.github/workflows/autotests.yml index 0d7a628..11ef1e0 100644 --- a/.github/workflows/autotests.yml +++ b/.github/workflows/autotests.yml @@ -1,50 +1,34 @@ -name: Generate latest builds +name: Python CI Tests + on: push: - branches: ["master"] + branches: ["master", "dev*", "jvm-api-v1"] pull_request: - branches: ["master", "dev*"] + branches: ["master", "dev*", "jvm-api-v1"] + workflow_dispatch: jobs: - run_pytest: - name: HiCT Library autotests - runs-on: [ "ubuntu-latest" ] - + tests: + runs-on: ubuntu-latest steps: - - name: Checkout sources - uses: actions/checkout@v3 - with: - submodules: recursive - - name: Setup Python - uses: actions/setup-python@v4.3.1 - with: - # Version range or exact version of Python or PyPy to use, using SemVer's version range syntax. Reads from .python-version if unset. - python-version: '>=3.9 <3.11' - # Used to specify a package manager for caching in the default directory. Supported values: pip, pipenv, poetry. - cache: pip - # The target architecture (x86, x64) of the Python or PyPy interpreter. - architecture: x64 - # Set this option if you want the action to update environment variables. - update-environment: true - - name: Install HDF5 library - uses: awalsh128/cache-apt-pkgs-action@latest - with: - packages: libhdf5-dev - - name: Install dependencies - run: | - pip install -r requirements.txt - pip install -r requirements-dev.txt - continue-on-error: true - - name: Install dependencies - run: | - pip install pylint - - name: Analysing the code with pylint - run: | - pylint $(git ls-files '*.py') - continue-on-error: true - - name: Analysing the code with mypy - run: | - mypy -p hict - continue-on-error: true - - name: Launch PyTest - run: pytest -v . + - name: Checkout + uses: actions/checkout@v4 + + - name: Setup Python + uses: actions/setup-python@v5 + with: + python-version: "3.11" + cache: pip + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + pip install -r requirements-dev.txt + pip install -e . + + - name: Run mypy (JVM API package) + run: mypy hict_jvm_api + + - name: Run tests (JVM API-first suite) + run: ./run_tests.sh diff --git a/README.md b/README.md index b3139a0..6abd3e9 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,7 @@ -# HiCT library for interactive manual scaffolding using Hi-C contact maps +# HiCT Python library (JVM API-first) -**Note**: this version is preliminary but provides an overview of essential implementation details for HiCT model. +This repository now provides a JVM-backed Python API as the primary and maintained interface. +Heavy operations are executed in `HiCT_JVM`; Python acts as a fast typed client layer. ## Overview @@ -24,14 +25,16 @@ It is recommended to use virtual environments provided by `venv` module to simpl This library uses HiCT format for the HiC data and you can convert Cooler's `.cool` or `.mcool` files to it using [HiCT utils](https://github.com/ctlab/HiCT_Utils) ## Documentation -This library has ContactMatrixFacet as the main interaction point. It hides all the internal methods, exposing only simple ones. Documentation for this module could be found at [doc directory](https://github.com/ctlab/HiCT/blob/master/doc/hict.api.ContactMatrixFacet.html) (download this file and open it using your web browser). +- JVM API client docs: [`doc/jvm_api_v1.md`](./doc/jvm_api_v1.md) +- Legacy `ContactMatrixFacet` docs (compatibility only): + [`doc/hict.api.ContactMatrixFacet.html`](./doc/hict.api.ContactMatrixFacet.html) ## Building from source You can run `rebuild.sh` script in source directory which will perform static type-checking of module using mypy (it may produce error messages), build library from source and reinstall it, deleting current version. ## JVM API client (v1) -This branch also contains `hict_jvm_api`, a Python client for controlling a running `HiCT_JVM` backend. +Use `hict.HiCTClient` (alias of `hict_jvm_api.HiCTJVMClient`) as the default entry point. ### Key capabilities * Open/attach/close sessions in HiCT_JVM; @@ -41,6 +44,30 @@ This branch also contains `hict_jvm_api`, a Python client for controlling a runn * Link FASTA, export FASTA selections/assembly, import/export AGP; * Convert coordinates between BP/BINS/PIXELS with hidden-contig awareness. +### Install + +```bash +pip install -e . +``` + +### Quick start + +```python +from hict import HiCTClient, Unit + +client = HiCTClient("http://localhost:5000") +session = client.open_file("build/quad/combined_ind2_4DN.hict.hdf5") +resolution = session.resolutions[0] +tile = client.fetch_region_pixels( + start_row_px=0, + start_col_px=0, + rows=256, + cols=256, + bp_resolution=resolution, +) +px = client.convert_units(1_000_000, from_unit=Unit.BP, to_unit=Unit.PIXELS, bp_resolution=resolution) +``` + ### Quick links * API docs: [`doc/jvm_api_v1.md`](./doc/jvm_api_v1.md) * Notebooks: diff --git a/doc/jvm_api_v1.md b/doc/jvm_api_v1.md index e397473..06ae535 100644 --- a/doc/jvm_api_v1.md +++ b/doc/jvm_api_v1.md @@ -1,6 +1,7 @@ # HiCT JVM API v1 Python Library -`hict_jvm_api` is a Python package that uses a running `HiCT_JVM` server as the execution backend. +`hict` now defaults to a JVM-backed API client (`hict.HiCTClient`) that uses a running +`HiCT_JVM` server as the execution backend. ## Design goals @@ -10,7 +11,7 @@ ## Main classes -- `hict_jvm_api.client.HiCTJVMClient` +- `hict.HiCTClient` / `hict_jvm_api.client.HiCTJVMClient` - session management (`open_file`, `attach_session`, `close_session`) - map region fetch (`fetch_region_pixels`, `fetch_tile_png`, `fetch_tile_with_ranges`) - scaffolding operations (`reverse_selection_range`, `move_selection_range`, `split_contig_at_bin`, etc.) @@ -39,9 +40,9 @@ pip install -e '.[torch]' ## Quick start ```python -from hict_jvm_api import HiCTJVMClient, Unit +from hict import HiCTClient, Unit -client = HiCTJVMClient("http://localhost:5001") +client = HiCTClient("http://localhost:5001") open_resp = client.open_file("build/quad/combined_ind2_4DN.hict.hdf5") resolution = open_resp.resolutions[0] # coarse level @@ -83,3 +84,9 @@ See: - `notebooks/jvm_api_quickstart.ipynb` - `notebooks/jvm_api_pytorch_dataloader.ipynb` + +## OpenAPI docs endpoint + +When `HiCT_JVM` is running, interactive API documentation is available at: + +- `http://localhost:5000/api/v1/` diff --git a/hict/__init__.py b/hict/__init__.py index dc7dde6..e5f2144 100644 --- a/hict/__init__.py +++ b/hict/__init__.py @@ -1,33 +1,57 @@ -# MIT License -# -# Copyright (c) 2021-2026. Aleksandr Serdiukov, Anton Zamyatin, Aleksandr Sinitsyn, Vitalii Dravgelis and Computer Technologies Laboratory ITMO University team. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -# MIT License -# -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# +"""HiCT Python API facade. + +This package now defaults to the JVM-backed API client for production usage. +Legacy pure-Python modules remain importable from their original subpackages, +but new code should use :class:`HiCTClient`. +""" + +from __future__ import annotations + +import warnings + +from hict_jvm_api.client import HiCTJVMClient +from hict_jvm_api.dataloader import HiCTRegionDataset +from hict_jvm_api.exceptions import HiCTAPIError, HiCTClientStateError, HiddenCoordinateError +from hict_jvm_api.models import ( + AssemblyInfo, + ContigDescriptor, + OpenFileResponse, + ScaffoldDescriptor, + TileRanges, + TileWithRanges, + Unit, +) +from hict_jvm_api.units import UnitConverter + +# Backward-compatible alias for the primary entrypoint. +HiCTClient = HiCTJVMClient + +__all__ = [ + "HiCTClient", + "HiCTJVMClient", + "HiCTRegionDataset", + "HiCTAPIError", + "HiCTClientStateError", + "HiddenCoordinateError", + "AssemblyInfo", + "ContigDescriptor", + "OpenFileResponse", + "ScaffoldDescriptor", + "TileRanges", + "TileWithRanges", + "Unit", + "UnitConverter", +] + + +def _warn_legacy_import(path: str) -> None: + warnings.warn( + ( + f"The legacy pure-Python API module '{path}' is kept for compatibility " + "but is no longer the recommended API. Use 'hict.HiCTClient' " + "(JVM-backed) for maintained functionality." + ), + category=DeprecationWarning, + stacklevel=2, + ) + diff --git a/hict/api/__init__.py b/hict/api/__init__.py index dc7dde6..9226cc8 100644 --- a/hict/api/__init__.py +++ b/hict/api/__init__.py @@ -1,33 +1,22 @@ -# MIT License -# -# Copyright (c) 2021-2026. Aleksandr Serdiukov, Anton Zamyatin, Aleksandr Sinitsyn, Vitalii Dravgelis and Computer Technologies Laboratory ITMO University team. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. +"""Legacy pure-Python API namespace. -# MIT License -# -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# +This module is retained for backward compatibility. +Prefer `hict.HiCTClient` for the maintained JVM-backed API. +""" + +from __future__ import annotations + +import warnings + +warnings.warn( + ( + "hict.api is deprecated and kept for compatibility only. " + "Use hict.HiCTClient (JVM-backed API) for new development." + ), + category=DeprecationWarning, + stacklevel=2, +) + +from .ContactMatrixFacet import ContactMatrixFacet + +__all__ = ["ContactMatrixFacet"] diff --git a/hict_jvm_api/client.py b/hict_jvm_api/client.py index d8fe6a3..2151a2c 100644 --- a/hict_jvm_api/client.py +++ b/hict_jvm_api/client.py @@ -10,6 +10,8 @@ import numpy as np import requests from PIL import Image +from requests.adapters import HTTPAdapter +from urllib3.util import Retry from .exceptions import HiCTAPIError, HiCTClientStateError from .models import OpenFileResponse, SecondarySourceStatus, TileRanges, TileWithRanges, Unit @@ -27,16 +29,61 @@ def __init__( self, base_url: str, timeout: float = 30.0, - session: Optional[Any] = None, + session: Optional[requests.Session] = None, + *, + pool_maxsize: int = 64, + max_retries: int = 2, + retry_backoff_sec: float = 0.2, ): if not base_url: raise ValueError("base_url must be provided") self.base_url = base_url.rstrip("/") self.timeout = timeout self._own_session = session is None - self._http = session if session is not None else requests.Session() + if session is not None: + self._http = session + else: + self._http = self._create_default_session( + pool_maxsize=int(max(1, pool_maxsize)), + max_retries=int(max(0, max_retries)), + retry_backoff_sec=float(max(0.0, retry_backoff_sec)), + ) self._open_file_response: Optional[OpenFileResponse] = None + @staticmethod + def _create_default_session( + *, + pool_maxsize: int, + max_retries: int, + retry_backoff_sec: float, + ) -> requests.Session: + session = requests.Session() + retries = Retry( + total=max_retries, + connect=max_retries, + read=max_retries, + status=max_retries, + backoff_factor=retry_backoff_sec, + status_forcelist=(429, 500, 502, 503, 504), + allowed_methods=frozenset({"GET", "POST"}), + raise_on_status=False, + ) + adapter = HTTPAdapter( + max_retries=retries, + pool_connections=pool_maxsize, + pool_maxsize=pool_maxsize, + ) + session.mount("http://", adapter) + session.mount("https://", adapter) + session.headers.update( + { + "Accept": "application/json", + "Accept-Encoding": "gzip, deflate", + "User-Agent": "hict-python/1 (jvm-api)", + } + ) + return session + def close(self) -> None: """Close the underlying HTTP session if owned by this client.""" if self._own_session and hasattr(self._http, "close"): @@ -68,13 +115,17 @@ def _request( params: Optional[Mapping[str, Any]] = None, payload: Optional[Mapping[str, Any]] = None, ) -> requests.Response: - response = self._http.request( - method, - self._url(path), - params=dict(params or {}), - json=dict(payload or {}), - timeout=self.timeout, - ) + method_upper = method.upper() + resolved_params = dict(params or {}) + resolved_payload = dict(payload or {}) + request_kwargs: Dict[str, Any] = { + "method": method_upper, + "url": self._url(path), + "params": resolved_params, + "json": resolved_payload, + "timeout": self.timeout, + } + response = self._http.request(**request_kwargs) if response.status_code >= 400: message = response.text parsed_payload: Any = None @@ -141,6 +192,10 @@ def version(self) -> Mapping[str, Any]: payload = self._request_json("GET", "/version") return payload + def openapi_spec_yaml(self) -> str: + """Get backend OpenAPI v1 YAML from ``GET /api/v1/openapi.yaml``.""" + return self._request_text("GET", "/api/v1/openapi.yaml") + def list_files(self) -> Sequence[str]: payload = self._request_json("POST", "/list_files", payload={}) return [str(item) for item in payload] diff --git a/pyproject.toml b/pyproject.toml index aa059c9..7cbc80b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,9 +3,9 @@ requires = ["setuptools>=68", "wheel>=0.41"] build-backend = "setuptools.build_meta" [project] -name = "hict-jvm-api" -version = "0.1.0" -description = "Python client for controlling HiCT_JVM and fetching Hi-C map regions" +name = "hict" +version = "1.0.0b1" +description = "HiCT Python client powered by HiCT_JVM API" readme = "README.md" requires-python = ">=3.9" license = { text = "MIT" } diff --git a/run_tests.sh b/run_tests.sh index aa2c08a..cacb031 100755 --- a/run_tests.sh +++ b/run_tests.sh @@ -28,4 +28,13 @@ HICT_DIR="$SCRIPT_DIR/../HiCT_Library/" export PYTHONPATH="$PYTHONPATH:$HICT_DIR" TESTS_DIR="$HICT_DIR/tests/" JVM_API_TESTS_DIR="$HICT_DIR/tests_jvm_api/" -pytest -vvv -x -n16 "$TESTS_DIR" "$JVM_API_TESTS_DIR" -m "not integration" +if pytest --help 2>/dev/null | grep -q -- "-n NUM"; then + PYTEST_PARALLEL_ARGS=(-n auto) +else + PYTEST_PARALLEL_ARGS=() +fi +if [[ "${HICT_RUN_LEGACY_TESTS:-0}" == "1" ]]; then + pytest -vvv -x "${PYTEST_PARALLEL_ARGS[@]}" "$TESTS_DIR" "$JVM_API_TESTS_DIR" -m "not integration" +else + pytest -vvv -x "${PYTEST_PARALLEL_ARGS[@]}" "$JVM_API_TESTS_DIR" -m "not integration" +fi diff --git a/setup.py b/setup.py index 403b830..d80b34a 100644 --- a/setup.py +++ b/setup.py @@ -1,54 +1,8 @@ -# MIT License -# -# Copyright (c) 2021-2026. Aleksandr Serdiukov, Anton Zamyatin, Aleksandr Sinitsyn, Vitalii Dravgelis and Computer Technologies Laboratory ITMO University team. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. +"""Compatibility setup.py shim. -from typing import List -from setuptools import find_packages, setup +The canonical package metadata is defined in ``pyproject.toml``. +""" +from setuptools import setup -requirements: List[str] = [] -with open("requirements.txt", mode="rt", encoding="utf-8") as f: - requirements = f.readlines() - -setup( - name='hict', - version='0.1.3rc1', - packages=list(set(['hict', 'hict.api', 'hict.core', 'hict.util']).union(find_packages())), - url='https://genome.ifmo.ru', - license='', - author='Alexander Serdiukov, Anton Zamyatin and CT Lab ITMO University team', - author_email='', - description='HiCT is a model for efficient interaction with Hi-C contact matrices that actively uses Split-Merge tree structures.', - setup_requires=[ - 'setuptools>=63.2.0', - 'wheel>=0.37.1', - ], - install_requires=list(set([]).union(requirements)), - tests_require=[ - 'cooler >=0.8.11, <0.9', - 'pytest >=7.2, <8', - 'pytest-quickcheck >=0.8.6, <1', - 'hypothesis >=6.61, <7', - 'mypy >=0.971, <1', - 'types-cachetools >=5.2.0, <6 ', - ], - test_suite='tests' -) +setup() diff --git a/tests_jvm_api/test_client.py b/tests_jvm_api/test_client.py index 7f876db..46f7dea 100644 --- a/tests_jvm_api/test_client.py +++ b/tests_jvm_api/test_client.py @@ -163,3 +163,14 @@ def handler(method: str, path: str, params: Mapping[str, str], payload: Mapping[ assert False, "Expected HiCTAPIError" except HiCTAPIError as err: assert "Something failed" in str(err) + + +def test_openapi_spec_yaml_fetch() -> None: + def handler(method: str, path: str, params: Mapping[str, str], payload: Mapping[str, Any]) -> _FakeResponse: + if method == "GET" and path == "/api/v1/openapi.yaml": + return _FakeResponse(200, text="openapi: 3.0.3\npaths: {}\n") + raise AssertionError(f"Unexpected request: {method} {path}") + + client = HiCTJVMClient("http://test", session=_FakeSession(handler)) + spec = client.openapi_spec_yaml() + assert "openapi: 3.0.3" in spec diff --git a/tests_jvm_api/test_hict_facade.py b/tests_jvm_api/test_hict_facade.py new file mode 100644 index 0000000..e693b65 --- /dev/null +++ b/tests_jvm_api/test_hict_facade.py @@ -0,0 +1,5 @@ +from hict import HiCTClient, HiCTJVMClient + + +def test_hict_facade_aliases_primary_client(): + assert HiCTClient is HiCTJVMClient From 213191169dcb8da574703068f307ca257953b0ad Mon Sep 17 00:00:00 2001 From: Alexander Serdyukov Date: Fri, 3 Apr 2026 02:47:11 +0400 Subject: [PATCH 3/3] Upgrade OpenAPI --- README.md | 11 ++ doc/jvm_api_v1.md | 15 +++ hict/__init__.py | 4 +- hict_jvm_api/__init__.py | 6 +- hict_jvm_api/client.py | 122 +++++++++++++++++++++ hict_jvm_api/dataloader.py | 102 ++++++++++++++++- tests_jvm_api/test_client.py | 68 +++++++++++- tests_jvm_api/test_dataloader.py | 39 ++++++- tests_jvm_api/test_optional_integration.py | 19 ++++ 9 files changed, 378 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 6abd3e9..0aeef0a 100644 --- a/README.md +++ b/README.md @@ -39,6 +39,7 @@ Use `hict.HiCTClient` (alias of `hict_jvm_api.HiCTJVMClient`) as the default ent ### Key capabilities * Open/attach/close sessions in HiCT_JVM; * Fetch Hi-C map regions as numpy RGBA arrays (`PNG_BY_PIXELS`) for ML pipelines; +* Fetch numeric submatrices directly as dense arrays/tensors (`/matrix/query`); * Run scaffolding operations via API (reverse/move/split/group/ungroup/debris); * Run converter jobs (single and batch) and monitor status; * Link FASTA, export FASTA selections/assembly, import/export AGP; @@ -66,6 +67,16 @@ tile = client.fetch_region_pixels( bp_resolution=resolution, ) px = client.convert_units(1_000_000, from_unit=Unit.BP, to_unit=Unit.PIXELS, bp_resolution=resolution) +signal = client.fetch_region_signal( + start_row=0, + start_col=0, + rows=256, + cols=256, + bp_resolution=resolution, + unit=Unit.PIXELS, + signal_mode="TRADITIONAL_NORMALIZED", + dtype="float32", +) ``` ### Quick links diff --git a/doc/jvm_api_v1.md b/doc/jvm_api_v1.md index 06ae535..9938a2c 100644 --- a/doc/jvm_api_v1.md +++ b/doc/jvm_api_v1.md @@ -14,6 +14,8 @@ - `hict.HiCTClient` / `hict_jvm_api.client.HiCTJVMClient` - session management (`open_file`, `attach_session`, `close_session`) - map region fetch (`fetch_region_pixels`, `fetch_tile_png`, `fetch_tile_with_ranges`) + - numeric matrix fetch (`fetch_region_signal`, `fetch_region_signal_torch`) with + `RAW_COUNTS`, `COOLER_WEIGHTED`, `TRADITIONAL_NORMALIZED`, `PIPELINE_SIGNAL` - scaffolding operations (`reverse_selection_range`, `move_selection_range`, `split_contig_at_bin`, etc.) - conversion jobs (`start_conversion_job`, `start_batch_conversion_jobs`, polling helpers) - FASTA/AGP operations (`link_fasta`, `export_fasta_for_selection`, `load_agp`) @@ -22,6 +24,8 @@ respecting hidden contigs via `contigPresenceAtResolution`. - `hict_jvm_api.dataloader.HiCTRegionDataset` - PyTorch-friendly random-access dataset fetching regions from a live session. +- `hict_jvm_api.dataloader.HiCTSignalDataset` + - PyTorch/NumPy-friendly dataset fetching scalar matrix windows from `/matrix/query`. ## Installation @@ -56,6 +60,17 @@ img = client.fetch_region_pixels( # Convert BP -> visible pixel coordinate px = client.convert_units(1_000_000, from_unit=Unit.BP, to_unit=Unit.PIXELS, bp_resolution=resolution) + +signal = client.fetch_region_signal( + start_row=0, + start_col=0, + rows=256, + cols=256, + bp_resolution=resolution, + unit=Unit.PIXELS, + signal_mode="TRADITIONAL_NORMALIZED", + dtype="float32", +) ``` ## Testing diff --git a/hict/__init__.py b/hict/__init__.py index e5f2144..9e28e25 100644 --- a/hict/__init__.py +++ b/hict/__init__.py @@ -10,7 +10,7 @@ import warnings from hict_jvm_api.client import HiCTJVMClient -from hict_jvm_api.dataloader import HiCTRegionDataset +from hict_jvm_api.dataloader import HiCTRegionDataset, HiCTSignalDataset from hict_jvm_api.exceptions import HiCTAPIError, HiCTClientStateError, HiddenCoordinateError from hict_jvm_api.models import ( AssemblyInfo, @@ -30,6 +30,7 @@ "HiCTClient", "HiCTJVMClient", "HiCTRegionDataset", + "HiCTSignalDataset", "HiCTAPIError", "HiCTClientStateError", "HiddenCoordinateError", @@ -54,4 +55,3 @@ def _warn_legacy_import(path: str) -> None: category=DeprecationWarning, stacklevel=2, ) - diff --git a/hict_jvm_api/__init__.py b/hict_jvm_api/__init__.py index bede2f7..b86cb7e 100644 --- a/hict_jvm_api/__init__.py +++ b/hict_jvm_api/__init__.py @@ -5,11 +5,12 @@ - :class:`hict_jvm_api.client.HiCTJVMClient` for server operations - :class:`hict_jvm_api.units.UnitConverter` for fast BP/BINS/PIXELS conversion -- :class:`hict_jvm_api.dataloader.HiCTRegionDataset` for PyTorch data loading +- :class:`hict_jvm_api.dataloader.HiCTRegionDataset` for rendered RGBA region loading +- :class:`hict_jvm_api.dataloader.HiCTSignalDataset` for numeric signal matrix loading """ from .client import HiCTJVMClient -from .dataloader import HiCTRegionDataset +from .dataloader import HiCTRegionDataset, HiCTSignalDataset from .exceptions import HiCTAPIError, HiCTClientStateError, HiddenCoordinateError from .models import ( AssemblyInfo, @@ -25,6 +26,7 @@ __all__ = [ "HiCTJVMClient", "HiCTRegionDataset", + "HiCTSignalDataset", "HiCTAPIError", "HiCTClientStateError", "HiddenCoordinateError", diff --git a/hict_jvm_api/client.py b/hict_jvm_api/client.py index 2151a2c..70cd598 100644 --- a/hict_jvm_api/client.py +++ b/hict_jvm_api/client.py @@ -162,6 +162,51 @@ def _request_text(self, method: str, path: str, payload: Optional[Mapping[str, A response = self._request(method, path, payload=payload) return response.text + @staticmethod + def _header_int(response: requests.Response, header_name: str, fallback: int = 0) -> int: + raw = response.headers.get(header_name) + if raw is None: + return int(fallback) + try: + return int(raw) + except (TypeError, ValueError): + return int(fallback) + + @staticmethod + def _normalize_unit(unit: Union[Unit, str]) -> str: + if isinstance(unit, Unit): + return unit.value + value = str(unit).strip().upper() + if value in {"PX", "PIXEL", "PIXELS"}: + return Unit.PIXELS.value + if value in {"BIN", "BINS"}: + return Unit.BINS.value + if value in {"BP", "BASE_PAIRS", "BASEPAIR", "BASEPAIRS"}: + return Unit.BP.value + raise ValueError(f"Unsupported unit: {unit}") + + @staticmethod + def _matrix_format_for_dtype(dtype: str) -> str: + normalized = str(dtype).strip().lower() + if normalized in {"float32", "f32"}: + return "BINARY_FLOAT32" + if normalized in {"float64", "double", "f64"}: + return "BINARY_FLOAT64" + if normalized in {"int64", "i64", "long"}: + return "BINARY_INT64" + raise ValueError(f"Unsupported matrix dtype '{dtype}'. Use float32, float64 or int64.") + + @staticmethod + def _numpy_dtype_for_binary_format(binary_format: str) -> np.dtype: + normalized = binary_format.strip().upper() + if normalized == "BINARY_FLOAT32": + return np.dtype(" np.ndarray: with Image.open(BytesIO(png_bytes)) as image: @@ -389,6 +434,83 @@ def fetch_region_bp( version=version, ) + def fetch_region_signal( + self, + *, + start_row: int, + start_col: int, + rows: int, + cols: int, + bp_resolution: int, + unit: Union[Unit, str] = Unit.PIXELS, + signal_mode: str = "TRADITIONAL_NORMALIZED", + dtype: str = "float32", + ) -> np.ndarray: + """Fetch numeric matrix region as a dense NumPy array. + + This method uses ``POST /matrix/query`` with binary payload transfer + for performance-sensitive ML/data-loader workflows. + """ + if rows <= 0 or cols <= 0: + raise ValueError("rows and cols must be positive") + + normalized_unit = self._normalize_unit(unit) + binary_format = self._matrix_format_for_dtype(dtype) + payload: Dict[str, Any] = { + "bpResolution": int(bp_resolution), + "unit": normalized_unit, + "startRow": int(start_row), + "startCol": int(start_col), + "endRow": int(start_row) + int(rows), + "endCol": int(start_col) + int(cols), + "signalMode": str(signal_mode).strip().upper(), + "format": binary_format, + } + response = self._request("POST", "/matrix/query", payload=payload) + response_dtype = self._numpy_dtype_for_binary_format(binary_format) + rows_out = self._header_int(response, "x-hict-rows", rows) + cols_out = self._header_int(response, "x-hict-cols", cols) + data = np.frombuffer(response.content, dtype=response_dtype) + expected = int(rows_out) * int(cols_out) + if data.size != expected: + raise HiCTAPIError( + f"Unexpected matrix payload size: got {data.size}, expected {expected} " + f"for shape ({rows_out}, {cols_out})" + ) + return data.reshape((rows_out, cols_out)) + + def fetch_region_signal_torch( + self, + *, + start_row: int, + start_col: int, + rows: int, + cols: int, + bp_resolution: int, + unit: Union[Unit, str] = Unit.PIXELS, + signal_mode: str = "TRADITIONAL_NORMALIZED", + dtype: str = "float32", + ): + """Fetch numeric matrix region as a ``torch.Tensor``.""" + matrix = self.fetch_region_signal( + start_row=start_row, + start_col=start_col, + rows=rows, + cols=cols, + bp_resolution=bp_resolution, + unit=unit, + signal_mode=signal_mode, + dtype=dtype, + ) + try: + import torch + except ImportError as exc: + raise RuntimeError( + "fetch_region_signal_torch requires torch to be installed. " + "Install optional dependency: pip install 'hict[torch]'" + ) from exc + return torch.from_numpy(np.ascontiguousarray(matrix)) + # ------------------------- # Unit conversion utilities # ------------------------- diff --git a/hict_jvm_api/dataloader.py b/hict_jvm_api/dataloader.py index 7050006..b4b4c17 100644 --- a/hict_jvm_api/dataloader.py +++ b/hict_jvm_api/dataloader.py @@ -3,11 +3,12 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Callable, List, Optional, Sequence, Tuple +from typing import Callable, List, Optional, Sequence, Tuple, Union import numpy as np from .client import HiCTJVMClient +from .models import Unit @dataclass(frozen=True) @@ -110,10 +111,107 @@ def __getitem__(self, index: int): except ImportError as exc: raise RuntimeError( "return_torch=True requires torch to be installed. " - "Install optional dependency: pip install 'hict-jvm-api[torch]'" + "Install optional dependency: pip install 'hict[torch]'" ) from exc tensor = torch.from_numpy(np.ascontiguousarray(image)) if self.channel_first: tensor = tensor.permute(2, 0, 1) return tensor.float() / 255.0 + + +class HiCTSignalDataset: + """Random-access dataset fetching numeric Hi-C signal matrices. + + Regions are sampled in a selected coordinate unit and fetched through + ``POST /matrix/query`` using the JVM-backed client. + """ + + def __init__( + self, + *, + client: HiCTJVMClient, + bp_resolution: int, + window_size: int, + unit: Union[Unit, str] = Unit.PIXELS, + signal_mode: str = "TRADITIONAL_NORMALIZED", + dtype: str = "float32", + num_samples: int = 1024, + regions: Optional[Sequence[Tuple[int, int]]] = None, + seed: int = 0, + return_torch: bool = False, + transform: Optional[Callable[[np.ndarray], object]] = None, + ): + if window_size <= 0: + raise ValueError("window_size must be positive") + if num_samples <= 0 and not regions: + raise ValueError("num_samples must be positive when regions are not provided") + + self.client = client + self.bp_resolution = int(bp_resolution) + self.window_size = int(window_size) + self.unit = unit + self.signal_mode = str(signal_mode) + self.dtype = str(dtype) + self.return_torch = bool(return_torch) + self.transform = transform + + converter = self.client.unit_converter(self.bp_resolution) + if isinstance(unit, Unit): + normalized = unit + else: + normalized = Unit(str(unit).upper()) + + if normalized == Unit.PIXELS: + total_size = int(converter.total_visible_pixels) + elif normalized == Unit.BINS: + total_size = int(converter.total_bins) + else: + total_size = int(converter.total_bp) + + max_coord = max(0, total_size - self.window_size) + + if regions is not None: + self._regions = [RegionRequest(int(r), int(c)) for r, c in regions] + else: + rng = np.random.default_rng(seed) + self._regions = [ + RegionRequest( + row_px=int(rng.integers(0, max_coord + 1)), + col_px=int(rng.integers(0, max_coord + 1)), + ) + for _ in range(int(num_samples)) + ] + + def __len__(self) -> int: + return len(self._regions) + + def __getitem__(self, index: int): + region = self._regions[index] + matrix = self.client.fetch_region_signal( + start_row=region.row_px, + start_col=region.col_px, + rows=self.window_size, + cols=self.window_size, + bp_resolution=self.bp_resolution, + unit=self.unit, + signal_mode=self.signal_mode, + dtype=self.dtype, + ) + + if self.transform is not None: + matrix = self.transform(matrix) + return matrix + + if not self.return_torch: + return matrix + + try: + import torch + except ImportError as exc: + raise RuntimeError( + "return_torch=True requires torch to be installed. " + "Install optional dependency: pip install 'hict[torch]'" + ) from exc + + return torch.from_numpy(np.ascontiguousarray(matrix)) diff --git a/tests_jvm_api/test_client.py b/tests_jvm_api/test_client.py index 46f7dea..b451ed9 100644 --- a/tests_jvm_api/test_client.py +++ b/tests_jvm_api/test_client.py @@ -11,11 +11,20 @@ class _FakeResponse: - def __init__(self, status_code: int, *, json_payload: Any = None, text: str = "", content: bytes = b""): + def __init__( + self, + status_code: int, + *, + json_payload: Any = None, + text: str = "", + content: bytes = b"", + headers: Mapping[str, str] | None = None, + ): self.status_code = status_code self._json_payload = json_payload self.text = text if text else ("" if json_payload is not None else content.decode("utf-8", errors="ignore")) self.content = content + self.headers = dict(headers or {}) def json(self) -> Any: if self._json_payload is None: @@ -174,3 +183,60 @@ def handler(method: str, path: str, params: Mapping[str, str], payload: Mapping[ client = HiCTJVMClient("http://test", session=_FakeSession(handler)) spec = client.openapi_spec_yaml() assert "openapi: 3.0.3" in spec + + +def test_fetch_region_signal_binary_float32() -> None: + expected = np.array([[1.5, 2.5], [3.5, 4.5]], dtype=" _FakeResponse: + if method == "POST" and path == "/matrix/query": + assert payload["format"] == "BINARY_FLOAT32" + assert payload["signalMode"] == "TRADITIONAL_NORMALIZED" + return _FakeResponse( + 200, + content=payload_bytes, + headers={"x-hict-rows": "2", "x-hict-cols": "2"}, + ) + raise AssertionError(f"Unexpected request: {method} {path}") + + client = HiCTJVMClient("http://test", session=_FakeSession(handler)) + matrix = client.fetch_region_signal( + start_row=0, + start_col=0, + rows=2, + cols=2, + bp_resolution=50000, + ) + assert matrix.shape == (2, 2) + np.testing.assert_allclose(matrix, expected.astype(np.float32)) + + +def test_fetch_region_signal_binary_int64() -> None: + expected = np.array([[1, 2], [3, 4]], dtype=" _FakeResponse: + if method == "POST" and path == "/matrix/query": + assert payload["format"] == "BINARY_INT64" + assert payload["signalMode"] == "RAW_COUNTS" + return _FakeResponse( + 200, + content=payload_bytes, + headers={"x-hict-rows": "2", "x-hict-cols": "2"}, + ) + raise AssertionError(f"Unexpected request: {method} {path}") + + client = HiCTJVMClient("http://test", session=_FakeSession(handler)) + matrix = client.fetch_region_signal( + start_row=0, + start_col=0, + rows=2, + cols=2, + bp_resolution=50000, + signal_mode="RAW_COUNTS", + dtype="int64", + ) + assert matrix.shape == (2, 2) + assert matrix.dtype == np.int64 + np.testing.assert_array_equal(matrix, expected.astype(np.int64)) diff --git a/tests_jvm_api/test_dataloader.py b/tests_jvm_api/test_dataloader.py index 0afcb0a..1ea2f87 100644 --- a/tests_jvm_api/test_dataloader.py +++ b/tests_jvm_api/test_dataloader.py @@ -1,11 +1,13 @@ import numpy as np -from hict_jvm_api.dataloader import HiCTRegionDataset +from hict_jvm_api.dataloader import HiCTRegionDataset, HiCTSignalDataset class _DummyConverter: def __init__(self, total_visible_pixels: int = 128): self.total_visible_pixels = total_visible_pixels + self.total_bins = total_visible_pixels + self.total_bp = total_visible_pixels * 1000 class _DummyClient: @@ -25,6 +27,23 @@ def fetch_region_pixels(self, *, start_row_px: int, start_col_px: int, rows: int arr[..., 3] = 255 return arr + def fetch_region_signal( + self, + *, + start_row: int, + start_col: int, + rows: int, + cols: int, + bp_resolution: int, + unit, + signal_mode: str, + dtype: str, + ): + self.requests.append((start_row, start_col, rows, cols, bp_resolution, str(unit), signal_mode, dtype)) + arr = np.zeros((rows, cols), dtype=np.float32) + arr[:, :] = float(start_row + start_col) + return arr + def test_dataset_random_sampling_numpy_output() -> None: client = _DummyClient() @@ -51,3 +70,21 @@ def test_dataset_explicit_regions_and_transform() -> None: value = ds[1] assert value == 70 assert client.requests[-1][:2] == (30, 40) + + +def test_signal_dataset_random_sampling_numpy_output() -> None: + client = _DummyClient() + ds = HiCTSignalDataset( + client=client, + bp_resolution=50000, + window_size=12, + signal_mode="COOLER_WEIGHTED", + dtype="float32", + num_samples=3, + seed=42, + ) + + assert len(ds) == 3 + sample = ds[0] + assert sample.shape == (12, 12) + assert sample.dtype == np.float32 diff --git a/tests_jvm_api/test_optional_integration.py b/tests_jvm_api/test_optional_integration.py index ad09095..5446f4b 100644 --- a/tests_jvm_api/test_optional_integration.py +++ b/tests_jvm_api/test_optional_integration.py @@ -48,6 +48,25 @@ def test_open_and_fetch_region(client: HiCTJVMClient) -> None: assert converter.convert(0, Unit.BP, Unit.BINS) == 0 +def test_open_and_fetch_numeric_signal_region(client: HiCTJVMClient) -> None: + if not DATASET_FILE: + pytest.skip("Set HICT_DATASET_FILE to run dataset integration tests") + opened = client.require_open_file_response() + coarse_resolution = opened.resolutions[0] + + matrix = client.fetch_region_signal( + start_row=0, + start_col=0, + rows=32, + cols=32, + bp_resolution=coarse_resolution, + unit=Unit.PIXELS, + signal_mode="TRADITIONAL_NORMALIZED", + dtype="float32", + ) + assert matrix.shape == (32, 32) + + def test_scaffolding_reverse_twice(client: HiCTJVMClient) -> None: if not DATASET_FILE: pytest.skip("Set HICT_DATASET_FILE to run dataset integration tests")