diff --git a/imednet/endpoints/_mixins.py b/imednet/endpoints/_mixins.py index b52bc25e..f20db0a8 100644 --- a/imednet/endpoints/_mixins.py +++ b/imednet/endpoints/_mixins.py @@ -56,6 +56,16 @@ class ListGetEndpointMixin(Generic[T]): _pop_study_filter: bool = False _missing_study_exception: type[Exception] = ValueError + def _extract_special_params(self, filters: Dict[str, Any]) -> Dict[str, Any]: + """ + Hook to extract special parameters from filters. + + Subclasses should override this method to handle parameters that need to be + passed separately (e.g. in extra_params) rather than in the filter string. + These parameters should be removed from the filters dictionary. + """ + return {} + def _parse_item(self, item: Any) -> T: """ Parse a single item into the model type. @@ -96,6 +106,14 @@ def _prepare_list_params( ) -> tuple[Optional[str], Any, Dict[str, Any], Dict[str, Any]]: # This method handles filter normalization and cache retrieval preparation filters = self._auto_filter(filters) # type: ignore[attr-defined] + + # Extract special parameters using the hook + special_params = self._extract_special_params(filters) + if special_params: + if extra_params is None: + extra_params = {} + extra_params.update(special_params) + if study_key: filters["studyKey"] = study_key diff --git a/imednet/endpoints/jobs.py b/imednet/endpoints/jobs.py index 8233b1d5..4b158f42 100644 --- a/imednet/endpoints/jobs.py +++ b/imednet/endpoints/jobs.py @@ -1,6 +1,6 @@ """Endpoint for checking job status in a study.""" -from typing import List +from typing import Any, List from imednet.core.parsing import get_model_parser from imednet.endpoints.base import BaseEndpoint @@ -17,6 +17,18 @@ class JobsEndpoint(BaseEndpoint): PATH = "/api/v1/edc/studies" + def _get_job_path(self, study_key: str, batch_id: str) -> str: + return self._build_path(study_key, "jobs", batch_id) + + def _get_jobs_list_path(self, study_key: str) -> str: + return self._build_path(study_key, "jobs") + + def _parse_job_status(self, response_data: Any, batch_id: str, study_key: str) -> JobStatus: + if not response_data: + raise ValueError(f"Job {batch_id} not found in study {study_key}") + parser = get_model_parser(JobStatus) + return parser(response_data) + def get(self, study_key: str, batch_id: str) -> JobStatus: """ Get a specific job by batch ID. @@ -34,13 +46,9 @@ def get(self, study_key: str, batch_id: str) -> JobStatus: Raises: ValueError: If the job is not found """ - endpoint = self._build_path(study_key, "jobs", batch_id) + endpoint = self._get_job_path(study_key, batch_id) response = self._client.get(endpoint) - data = response.json() - if not data: - raise ValueError(f"Job {batch_id} not found in study {study_key}") - parser = get_model_parser(JobStatus) - return parser(data) + return self._parse_job_status(response.json(), batch_id, study_key) async def async_get(self, study_key: str, batch_id: str) -> JobStatus: """ @@ -60,13 +68,9 @@ async def async_get(self, study_key: str, batch_id: str) -> JobStatus: ValueError: If the job is not found """ client = self._require_async_client() - endpoint = self._build_path(study_key, "jobs", batch_id) + endpoint = self._get_job_path(study_key, batch_id) response = await client.get(endpoint) - data = response.json() - if not data: - raise ValueError(f"Job {batch_id} not found in study {study_key}") - parser = get_model_parser(JobStatus) - return parser(data) + return self._parse_job_status(response.json(), batch_id, study_key) def list(self, study_key: str) -> List[Job]: """ @@ -78,7 +82,7 @@ def list(self, study_key: str) -> List[Job]: Returns: List of Job objects """ - endpoint = self._build_path(study_key, "jobs") + endpoint = self._get_jobs_list_path(study_key) response = self._client.get(endpoint) parser = get_model_parser(Job) return [parser(item) for item in response.json()] @@ -94,7 +98,7 @@ async def async_list(self, study_key: str) -> List[Job]: List of Job objects """ client = self._require_async_client() - endpoint = self._build_path(study_key, "jobs") + endpoint = self._get_jobs_list_path(study_key) response = await client.get(endpoint) parser = get_model_parser(Job) return [parser(item) for item in response.json()] diff --git a/imednet/endpoints/records.py b/imednet/endpoints/records.py index 8c445023..ca20bf3d 100644 --- a/imednet/endpoints/records.py +++ b/imednet/endpoints/records.py @@ -21,6 +21,24 @@ class RecordsEndpoint(ListGetEndpoint[Record]): _id_param = "recordId" _pop_study_filter = False + def _extract_special_params(self, filters: Dict[str, Any]) -> Dict[str, Any]: + record_data_filter = filters.pop("record_data_filter", None) + if record_data_filter: + return {"recordDataFilter": record_data_filter} + return {} + + def _prepare_create_request( + self, + study_key: str, + records_data: List[Dict[str, Any]], + email_notify: Union[bool, str, None], + schema: Optional[SchemaCache], + ) -> tuple[str, Dict[str, str]]: + self._validate_records_if_schema_present(schema, records_data) + headers = self._build_headers(email_notify) + path = self._build_path(study_key, self.PATH) + return path, headers + def _validate_records_if_schema_present( self, schema: Optional[SchemaCache], records_data: List[Dict[str, Any]] ) -> None: @@ -89,10 +107,7 @@ def create( Raises: ValueError: If email_notify contains invalid characters """ - self._validate_records_if_schema_present(schema, records_data) - headers = self._build_headers(email_notify) - - path = self._build_path(study_key, self.PATH) + path, headers = self._prepare_create_request(study_key, records_data, email_notify, schema) response = self._client.post(path, json=records_data, headers=headers) return Job.from_json(response.json()) @@ -124,27 +139,6 @@ async def async_create( ValueError: If email_notify contains invalid characters """ client = self._require_async_client() - self._validate_records_if_schema_present(schema, records_data) - headers = self._build_headers(email_notify) - - path = self._build_path(study_key, self.PATH) + path, headers = self._prepare_create_request(study_key, records_data, email_notify, schema) response = await client.post(path, json=records_data, headers=headers) return Job.from_json(response.json()) - - def _list_impl( - self, - client: Any, - paginator_cls: type[Any], - *, - study_key: Optional[str] = None, - record_data_filter: Optional[str] = None, - **filters: Any, - ) -> Any: - extra = {"recordDataFilter": record_data_filter} if record_data_filter else None - return super()._list_impl( - client, - paginator_cls, - study_key=study_key, - extra_params=extra, - **filters, - ) diff --git a/imednet/endpoints/users.py b/imednet/endpoints/users.py index 711148c5..e8710556 100644 --- a/imednet/endpoints/users.py +++ b/imednet/endpoints/users.py @@ -1,9 +1,7 @@ """Endpoint for managing users in a study.""" -from typing import Any, Awaitable, Dict, List, Optional, Union +from typing import Any, Dict -from imednet.core.paginator import AsyncPaginator, Paginator -from imednet.core.protocols import AsyncRequestorProtocol, RequestorProtocol from imednet.endpoints._mixins import ListGetEndpoint from imednet.models.users import User @@ -20,25 +18,6 @@ class UsersEndpoint(ListGetEndpoint[User]): _id_param = "userId" _pop_study_filter = True - def _list_impl( - self, - client: RequestorProtocol | AsyncRequestorProtocol, - paginator_cls: Union[type[Paginator], type[AsyncPaginator]], - *, - study_key: Optional[str] = None, - refresh: bool = False, - extra_params: Optional[Dict[str, Any]] = None, - include_inactive: bool = False, - **filters: Any, - ) -> List[User] | Awaitable[List[User]]: - params = extra_params or {} - params["includeInactive"] = str(include_inactive).lower() - - return super()._list_impl( - client, - paginator_cls, - study_key=study_key, - refresh=refresh, - extra_params=params, - **filters, - ) + def _extract_special_params(self, filters: Dict[str, Any]) -> Dict[str, Any]: + include_inactive = filters.pop("include_inactive", False) + return {"includeInactive": str(include_inactive).lower()} diff --git a/poetry.lock b/poetry.lock index d347d36f..5b8bb219 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1600,6 +1600,18 @@ optional = ["typing-extensions (>=4)"] re2 = ["google-re2 (>=1.1)"] tests = ["pytest (>=9)", "typing-extensions (>=4.15)"] +[[package]] +name = "pip" +version = "26.0" +description = "The PyPA recommended tool for installing Python packages." +optional = false +python-versions = ">=3.9" +groups = ["dev"] +files = [ + {file = "pip-26.0-py3-none-any.whl", hash = "sha256:98436feffb9e31bc9339cf369fd55d3331b1580b6a6f1173bacacddcf9c34754"}, + {file = "pip-26.0.tar.gz", hash = "sha256:3ce220a0a17915972fbf1ab451baae1521c4539e778b28127efa79b974aff0fa"}, +] + [[package]] name = "platformdirs" version = "4.5.1" @@ -2793,4 +2805,4 @@ sqlalchemy = ["SQLAlchemy"] [metadata] lock-version = "2.1" python-versions = "^3.10" -content-hash = "3f7910b623d64132666b1611c589229d5c3eea7e05b57e2323a3d45961b0b687" +content-hash = "f566c8f13cf3eb0c80d496ed8ecf21f9cdd87e0ad0277fe67db4440f8ab6bb09" diff --git a/pyproject.toml b/pyproject.toml index 7074011d..00cba1e1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -91,6 +91,7 @@ openpyxl = "^3.1" sphinxcontrib-mermaid = "^0.9.2" doc8 = "^2.0.0" codespell = "^2.4.1" +pip = "^26.0" [build-system] requires = ["poetry-core>=2.0.0,<3.0.0"]