Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 142 additions & 20 deletions src/snakemake_storage_plugin_cached_http/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,19 @@
#
# SPDX-License-Identifier: MIT

import asyncio
import base64
import hashlib
import json
import shutil
import time
from contextlib import asynccontextmanager
from dataclasses import dataclass, field
from datetime import datetime
from logging import Logger
from pathlib import Path
from posixpath import basename, dirname, join, normpath, relpath
from urllib.parse import urlparse
from urllib.parse import quote, urlparse

import httpx
import platformdirs
Expand Down Expand Up @@ -92,10 +95,11 @@ class StorageProviderSettings(SettingsBase):

@dataclass
class FileMetadata:
"""Metadata for a file in a Zenodo or data.pypsa.org record."""
"""Metadata for a file in a Zenodo, data.pypsa.org, or GCS record."""

checksum: str | None
size: int
mtime: float = 0 # modification time (Unix timestamp), used for GCS
redirect: str | None = None # used to indicate data.pypsa.org redirection


Expand Down Expand Up @@ -144,6 +148,7 @@ def __post_init__(self):
# Cache for record metadata to avoid repeated API calls
self._zenodo_record_cache: dict[str, dict[str, FileMetadata]] = {}
self._pypsa_manifest_cache: dict[str, dict[str, FileMetadata]] = {}
self._gcs_metadata_cache: dict[str, FileMetadata] = {}

@override
def use_rate_limiter(self) -> bool:
Expand Down Expand Up @@ -173,6 +178,11 @@ def example_queries(cls) -> list[ExampleQuery]:
description="A data pypsa file URL",
type=QueryType.INPUT,
),
ExampleQuery(
query="https://storage.googleapis.com/open-tyndp-data-store/CBA_projects.zip",
description="A Google Cloud Storage file URL",
type=QueryType.INPUT,
),
]

@override
Expand All @@ -185,7 +195,7 @@ def is_valid_query(cls, query: str) -> StorageQueryValidationResult:
return StorageQueryValidationResult(
query=query,
valid=False,
reason="Only zenodo.org and data.pypsa.org URLs are handled by this plugin",
reason="Only zenodo.org, data.pypsa.org, and storage.googleapis.com URLs are handled by this plugin",
)

@override
Expand Down Expand Up @@ -288,9 +298,24 @@ async def get_metadata(self, path: str, netloc: str) -> FileMetadata | None:
return await self.get_zenodo_metadata(path, netloc)
elif netloc == "data.pypsa.org":
return await self.get_pypsa_metadata(path, netloc)
elif netloc == "storage.googleapis.com":
return await self.get_gcs_metadata(path, netloc)

raise WorkflowError(
"Cached-http storage plugin is only implemented for zenodo.org and data.pypsa.org urls"
"Cached-http storage plugin is only implemented for zenodo.org, data.pypsa.org, and storage.googleapis.com urls"
)

@staticmethod
def is_immutable(netloc: str):
if netloc in ("zenodo.org", "sandbox.zenodo.org"):
return True
elif netloc == "data.pypsa.org":
return True
elif netloc == "storage.googleapis.com":
return False

raise WorkflowError(
"Cached-http storage plugin is only implemented for zenodo.org, data.pypsa.org, and storage.googleapis.com urls"
)

async def get_zenodo_metadata(self, path: str, netloc: str) -> FileMetadata | None:
Expand Down Expand Up @@ -407,6 +432,73 @@ async def get_pypsa_metadata(self, path: str, netloc: str) -> FileMetadata | Non
filename = relpath(path, base_path)
return metadata.get(filename)

async def get_gcs_metadata(self, path: str, netloc: str) -> FileMetadata | None:
"""
Retrieve and cache file metadata from Google Cloud Storage.

Uses the GCS JSON API to fetch object metadata including MD5 hash.
URL format: https://storage.googleapis.com/{bucket}/{object-path}
API endpoint: https://storage.googleapis.com/storage/v1/b/{bucket}/o/{encoded-object}

Args:
path: Server path (bucket/object-path)
netloc: Network location (storage.googleapis.com)

Returns:
FileMetadata for the requested file, or None if not found
"""
# Check cache first
if path in self._gcs_metadata_cache:
return self._gcs_metadata_cache[path]

# Parse bucket and object path from the URL path
# Path format: /{bucket}/{object-path}
parts = path.split("/", maxsplit=1)
if len(parts) < 2:
raise WorkflowError(
f"Invalid GCS URL format: http(s)://{netloc}/{path}. "
f"Expected format: https://storage.googleapis.com/{{bucket}}/{{object-path}}"
)

bucket, object_path = parts

# URL-encode the object path for the API request (slashes must be encoded)
encoded_object = quote(object_path, safe="")

# GCS JSON API endpoint for object metadata
api_url = f"https://{netloc}/storage/v1/b/{bucket}/o/{encoded_object}"

async with self.httpr("get", api_url) as response:
if response.status_code == 404:
return None
if response.status_code != 200:
raise WorkflowError(
f"Failed to fetch GCS object metadata: HTTP {response.status_code} ({api_url})"
)

content = await response.aread()
data = json.loads(content)

# GCS returns MD5 as base64-encoded bytes
md5_base64: str | None = data.get("md5Hash")
checksum: str | None = None
if md5_base64:
# Convert base64 to hex digest
md5_bytes = base64.b64decode(md5_base64)
checksum = f"md5:{md5_bytes.hex()}"

size: int = int(data.get("size", 0))

updated: str | None = data.get("updated")
mtime: float = datetime.fromisoformat(updated).timestamp() if updated else 0

metadata = FileMetadata(checksum=checksum, size=size, mtime=mtime)

# Store in cache
self._gcs_metadata_cache[path] = metadata

return metadata


# Implementation of storage object
class StorageObject(StorageObjectRead):
Expand Down Expand Up @@ -441,15 +533,19 @@ async def managed_exists(self) -> bool:

if self.provider.cache:
cached = self.provider.cache.get(str(self.query))
if cached is not None:
if cached is not None and self.provider.is_immutable(self.netloc):
return True

metadata = await self.provider.get_metadata(self.path, self.netloc)
return metadata is not None

@override
async def managed_mtime(self) -> float:
return 0
if self.provider.settings.skip_remote_checks:
return 0

metadata = await self.provider.get_metadata(self.path, self.netloc)
return metadata.mtime if metadata is not None else 0

@override
async def managed_size(self) -> int:
Expand All @@ -458,11 +554,20 @@ async def managed_size(self) -> int:

if self.provider.cache:
cached = self.provider.cache.get(str(self.query))
if cached is not None:
if cached is not None and self.provider.is_immutable(self.netloc):
return cached.stat().st_size
else:
cached = None

metadata = await self.provider.get_metadata(self.path, self.netloc)
return metadata.size if metadata is not None else 0
if metadata is None:
return 0

if cached is not None:
if cached.stat().st_mtime >= metadata.mtime:
return cached.stat().st_size

return metadata.size

@override
async def inventory(self, cache: IOCacheStorageInterface) -> None:
Expand All @@ -483,17 +588,31 @@ async def inventory(self, cache: IOCacheStorageInterface) -> None:

if self.provider.cache:
cached = self.provider.cache.get(str(self.query))
if cached is not None:
if cached is not None and self.provider.is_immutable(self.netloc):
cache.exists_in_storage[key] = True
cache.mtime[key] = Mtime(storage=0)
cache.mtime[key] = Mtime(storage=cached.stat().st_mtime)
cache.size[key] = cached.stat().st_size
return
else:
cached = None

metadata = await self.provider.get_metadata(self.path, self.netloc)
exists = metadata is not None
cache.exists_in_storage[key] = exists
cache.mtime[key] = Mtime(storage=0)
cache.size[key] = metadata.size if exists else 0
if metadata is None:
cache.exists_in_storage[key] = False
cache.mtime[key] = Mtime(storage=0)
cache.size[key] = 0
return

if cached is not None:
if cached.stat().st_mtime >= metadata.mtime:
cache.exists_in_storage[key] = True
cache.mtime[key] = Mtime(storage=cached.stat().st_mtime)
cache.size[key] = cached.stat().st_size
return

cache.exists_in_storage[key] = True
cache.mtime[key] = Mtime(storage=metadata.mtime)
cache.size[key] = metadata.size

@override
def cleanup(self):
Expand Down Expand Up @@ -558,17 +677,20 @@ async def managed_retrieve(self):
if metadata is not None and metadata.redirect is not None:
query = f"https://{self.netloc}/{metadata.redirect}"

# If already in cache, just copy
# If already in cache, check if still valid
if self.provider.cache:
cached = self.provider.cache.get(query)
if cached is not None:
logger.info(f"Retrieved {filename} from cache ({query})")
shutil.copy2(cached, local_path)
return
if self.provider.is_immutable(self.netloc) or (
metadata is not None and cached.stat().st_mtime >= metadata.mtime
):
logger.info(f"Retrieved {filename} from cache ({query})")
shutil.copy2(cached, local_path)
return

try:
# Download from Zenodo or data.pypsa.org using a get request, rate limit errors are detected and
# raise WorkflowError to trigger a retry
# Download using a get request, rate limit errors are detected and raise
# WorkflowError to trigger a retry
async with self.provider.httpr("get", query) as response:
if response.status_code != 200:
raise WorkflowError(
Expand Down
1 change: 1 addition & 0 deletions src/snakemake_storage_plugin_cached_http/monkeypatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def is_pypsa_or_zenodo_url(url: str) -> bool:
"zenodo.org",
"sandbox.zenodo.org",
"data.pypsa.org",
"storage.googleapis.com",
) and parsed.scheme in (
"http",
"https",
Expand Down
Loading