Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions imednet/endpoints/subjects.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,21 @@ class SubjectsEndpoint(ListGetEndpoint[Subject]):
MODEL = Subject
_id_param = "subjectKey"

def _filter_by_site(self, subjects: List[Subject], site_id: str | int) -> List[Subject]:
# TUI Logic: Strict string comparison to handle int/str mismatch
target_site = str(site_id)
return [s for s in subjects if str(s.site_id) == target_site]

def list_by_site(self, study_key: str, site_id: str | int) -> List[Subject]:
"""
List subjects filtered by a specific site ID.

Migrated from TUI logic to core SDK to support filtering.
"""
all_subjects = self.list(study_key)
# TUI Logic: Strict string comparison to handle int/str mismatch
target_site = str(site_id)
return [s for s in all_subjects if str(s.site_id) == target_site]
return self._filter_by_site(all_subjects, site_id)

async def async_list_by_site(self, study_key: str, site_id: str | int) -> List[Subject]:
"""Asynchronously list subjects filtered by a specific site ID."""
all_subjects = await self.async_list(study_key)
target_site = str(site_id)
return [s for s in all_subjects if str(s.site_id) == target_site]
return self._filter_by_site(all_subjects, site_id)
64 changes: 64 additions & 0 deletions tests/unit/endpoints/test_subjects_filtering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from unittest.mock import AsyncMock, Mock

import pytest

from imednet.endpoints.subjects import SubjectsEndpoint
from imednet.models.subjects import Subject


@pytest.fixture
def subject_list():
return [
Subject(studyKey="sk", subjectId=1, siteId=101, subjectKey="s1"),
Subject(studyKey="sk", subjectId=2, siteId=102, subjectKey="s2"),
Subject(studyKey="sk", subjectId=3, siteId=101, subjectKey="s3"),
Subject(studyKey="sk", subjectId=4, siteId="101", subjectKey="s4"), # String siteId
Copy link

Copilot AI Feb 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment "String siteId" is slightly misleading. While the Subject is constructed with a string siteId value, Pydantic's normalization will convert it to an integer (via parse_int_or_default). The actual test coverage is for filtering when the site_id parameter is passed as either int or str, which is tested on lines 29-30 and 53-54. Consider updating the comment to clarify this, or remove it if it's not necessary.

Suggested change
Subject(studyKey="sk", subjectId=4, siteId="101", subjectKey="s4"), # String siteId
Subject(studyKey="sk", subjectId=4, siteId="101", subjectKey="s4"), # siteId passed as string (normalized to int by Pydantic)

Copilot uses AI. Check for mistakes.
Subject(studyKey="sk", subjectId=5, siteId="103", subjectKey="s5"),
]


def test_list_by_site_filtering(subject_list):
# Mock client and context as they are required by __init__ but not used if we mock list
mock_client = Mock()
mock_ctx = Mock()

endpoint = SubjectsEndpoint(mock_client, mock_ctx)
endpoint.list = Mock(return_value=subject_list)

# Act
filtered_int = endpoint.list_by_site("sk", 101)
filtered_str = endpoint.list_by_site("sk", "101")
filtered_mismatch = endpoint.list_by_site("sk", 999)

# Assert
assert len(filtered_int) == 3
assert {s.subject_id for s in filtered_int} == {1, 3, 4}

assert len(filtered_str) == 3
assert {s.subject_id for s in filtered_str} == {1, 3, 4}

assert len(filtered_mismatch) == 0


@pytest.mark.asyncio
async def test_async_list_by_site_filtering(subject_list):
# Mock client and context
mock_client = Mock()
mock_ctx = Mock()

endpoint = SubjectsEndpoint(mock_client, mock_ctx)
endpoint.async_list = AsyncMock(return_value=subject_list)

# Act
filtered_int = await endpoint.async_list_by_site("sk", 101)
filtered_str = await endpoint.async_list_by_site("sk", "101")
filtered_mismatch = await endpoint.async_list_by_site("sk", 999)

# Assert
assert len(filtered_int) == 3
assert {s.subject_id for s in filtered_int} == {1, 3, 4}

assert len(filtered_str) == 3
assert {s.subject_id for s in filtered_str} == {1, 3, 4}

assert len(filtered_mismatch) == 0
Loading