Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions imednet/endpoints/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,23 @@ class ListGetEndpointMixin(Generic[T]):
_pop_study_filter: bool = False
_missing_study_exception: type[Exception] = ValueError

def _extract_special_params(self, filters: Dict[str, Any]) -> Dict[str, Any]:
"""
Extract special parameters from filters.

Subclasses can override this to extract specific parameters from the
filters dictionary (e.g., 'includeInactive') and return them as a
dictionary of query parameters to be added to the request.
The extracted keys should be removed from the filters dictionary.

Args:
filters: The dictionary of filters passed to the list method.

Returns:
A dictionary of parameters to add to the request query string.
"""
return {}

def _parse_item(self, item: Any) -> T:
"""
Parse a single item into the model type.
Expand Down Expand Up @@ -96,6 +113,9 @@ def _prepare_list_params(
) -> tuple[Optional[str], Any, Dict[str, Any], Dict[str, Any]]:
# This method handles filter normalization and cache retrieval preparation
filters = self._auto_filter(filters) # type: ignore[attr-defined]

extracted_params = self._extract_special_params(filters)

if study_key:
filters["studyKey"] = study_key

Expand Down Expand Up @@ -123,6 +143,8 @@ def _prepare_list_params(
params["filter"] = build_filter_string(filters)
if extra_params:
params.update(extra_params)
if extracted_params:
params.update(extracted_params)

return study, cache, params, other_filters

Expand Down
24 changes: 7 additions & 17 deletions imednet/endpoints/records.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,20 +131,10 @@ async def async_create(
response = await client.post(path, json=records_data, headers=headers)
return Job.from_json(response.json())

def _list_impl(
self,
client: Any,
paginator_cls: type[Any],
*,
study_key: Optional[str] = None,
record_data_filter: Optional[str] = None,
**filters: Any,
) -> Any:
extra = {"recordDataFilter": record_data_filter} if record_data_filter else None
return super()._list_impl(
client,
paginator_cls,
study_key=study_key,
extra_params=extra,
**filters,
)
def _extract_special_params(self, filters: Dict[str, Any]) -> Dict[str, Any]:
params = {}
if "record_data_filter" in filters:
val = filters.pop("record_data_filter")
if val:
params["recordDataFilter"] = val
return params
12 changes: 7 additions & 5 deletions imednet/endpoints/subjects.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,21 @@ class SubjectsEndpoint(ListGetEndpoint[Subject]):
MODEL = Subject
_id_param = "subjectKey"

def _filter_by_site(self, subjects: List[Subject], site_id: str | int) -> List[Subject]:
# TUI Logic: Strict string comparison to handle int/str mismatch
target_site = str(site_id)
return [s for s in subjects if str(s.site_id) == target_site]

def list_by_site(self, study_key: str, site_id: str | int) -> List[Subject]:
"""
List subjects filtered by a specific site ID.

Migrated from TUI logic to core SDK to support filtering.
"""
all_subjects = self.list(study_key)
# TUI Logic: Strict string comparison to handle int/str mismatch
target_site = str(site_id)
return [s for s in all_subjects if str(s.site_id) == target_site]
return self._filter_by_site(all_subjects, site_id)

async def async_list_by_site(self, study_key: str, site_id: str | int) -> List[Subject]:
"""Asynchronously list subjects filtered by a specific site ID."""
all_subjects = await self.async_list(study_key)
target_site = str(site_id)
return [s for s in all_subjects if str(s.site_id) == target_site]
return self._filter_by_site(all_subjects, site_id)
32 changes: 7 additions & 25 deletions imednet/endpoints/users.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
"""Endpoint for managing users in a study."""

from typing import Any, Awaitable, Dict, List, Optional, Union
from typing import Any, Dict

from imednet.core.paginator import AsyncPaginator, Paginator
from imednet.core.protocols import AsyncRequestorProtocol, RequestorProtocol
from imednet.endpoints._mixins import ListGetEndpoint
from imednet.models.users import User

Expand All @@ -20,25 +18,9 @@ class UsersEndpoint(ListGetEndpoint[User]):
_id_param = "userId"
_pop_study_filter = True

def _list_impl(
self,
client: RequestorProtocol | AsyncRequestorProtocol,
paginator_cls: Union[type[Paginator], type[AsyncPaginator]],
*,
study_key: Optional[str] = None,
refresh: bool = False,
extra_params: Optional[Dict[str, Any]] = None,
include_inactive: bool = False,
**filters: Any,
) -> List[User] | Awaitable[List[User]]:
params = extra_params or {}
params["includeInactive"] = str(include_inactive).lower()

return super()._list_impl(
client,
paginator_cls,
study_key=study_key,
refresh=refresh,
extra_params=params,
**filters,
)
def _extract_special_params(self, filters: Dict[str, Any]) -> Dict[str, Any]:
params = {}
if "include_inactive" in filters:
val = filters.pop("include_inactive")
params["includeInactive"] = str(val).lower()
return params
7 changes: 4 additions & 3 deletions tests/unit/test_core_paginator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any, Dict, List
from typing import Any, Dict, List, cast

from imednet.core.paginator import Paginator
from imednet.core.protocols import RequestorProtocol


class DummyClient:
Expand All @@ -16,7 +17,7 @@ def get(self, path: str, params: Dict[str, Any] | None = None):

def test_single_page_iteration() -> None:
client = DummyClient([{"data": [1, 2]}])
paginator = Paginator(client, "/p")
paginator = Paginator(cast(RequestorProtocol, client), "/p")
assert list(paginator) == [1, 2]
assert client.calls[0]["params"]["page"] == 0

Expand All @@ -28,7 +29,7 @@ def test_multiple_page_iteration() -> None:
{"data": [2], "pagination": {"totalPages": 2}},
]
)
paginator = Paginator(client, "/p", params={"a": 1}, page_size=10)
paginator = Paginator(cast(RequestorProtocol, client), "/p", params={"a": 1}, page_size=10)
items = list(paginator)
assert items == [1, 2]
assert client.calls[0]["params"] == {"a": 1, "page": 0, "size": 10}
Expand Down
Loading