Skip to content

Commit b351521

Browse files
committed
Expand SDOH LLM docstrings
1 parent 18a2722 commit b351521

1 file changed

Lines changed: 54 additions & 6 deletions

File tree

pyhealth/models/sdoh_icd9_llm.py

Lines changed: 54 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -176,8 +176,12 @@ def _load_prompt_template() -> str:
176176
class SDOHICD9LLM:
177177
"""Admission-level SDOH ICD-9 V-code detector using an LLM.
178178
179-
This model runs an LLM on each note for an admission, parses the predicted
180-
ICD-9 V-codes, and aggregates predictions across notes (union).
179+
This model sends each note for an admission to an LLM, parses predicted
180+
ICD-9 V-codes, and aggregates the codes across notes (set union).
181+
182+
Notes:
183+
- Use ``dry_run=True`` to skip LLM calls while exercising the pipeline.
184+
- Predictions are derived entirely from the LLM response parsing logic.
181185
182186
Examples:
183187
>>> from pyhealth.models.sdoh_icd9_llm import SDOHICD9LLM
@@ -207,6 +211,21 @@ def __init__(
207211
max_notes: Optional[int] = None,
208212
dry_run: bool = False,
209213
) -> None:
214+
"""Initialize the LLM wrapper.
215+
216+
Args:
217+
target_codes: Target ICD-9 codes to retain after parsing.
218+
model_name: OpenAI model name.
219+
prompt_template: Optional prompt template override. Uses built-in
220+
SDOH template if not provided.
221+
api_key: OpenAI API key. Defaults to ``OPENAI_API_KEY`` env var.
222+
max_tokens: Max tokens for LLM response.
223+
max_chars: Max chars from each note to send.
224+
temperature: LLM temperature.
225+
sleep_s: Delay between per-note requests (seconds).
226+
max_notes: Optional limit on notes per admission.
227+
dry_run: If True, skips API calls and returns "None" responses.
228+
"""
210229
self.target_codes = list(target_codes) if target_codes else list(TARGET_CODES)
211230
self.model_name = model_name
212231
self.prompt_template = prompt_template or _load_prompt_template()
@@ -233,7 +252,14 @@ def _get_client(self):
233252
return self._client
234253

235254
def _call_openai_api(self, text: str) -> str:
236-
"""Send a single note to the LLM and return the raw response."""
255+
"""Send a single note to the LLM and return the raw response.
256+
257+
Args:
258+
text: Note text to send.
259+
260+
Returns:
261+
Raw string response from the LLM.
262+
"""
237263
self._write_prompt_preview(text)
238264

239265
if self.dry_run:
@@ -262,7 +288,11 @@ def _write_prompt_preview(self, text: str) -> None:
262288
f.write(prompt)
263289

264290
def _parse_llm_response(self, response: str) -> Set[str]:
265-
"""Parse the LLM response into a set of valid target codes."""
291+
"""Parse the LLM response into a set of valid target codes.
292+
293+
Returns:
294+
A set of ICD-9 codes intersected with ``target_codes``.
295+
"""
266296
if not response:
267297
return set()
268298

@@ -293,7 +323,16 @@ def _predict_admission(
293323
note_categories: Optional[Iterable[str]] = None,
294324
chartdates: Optional[Iterable[str]] = None,
295325
) -> Tuple[Set[str], List[dict]]:
296-
"""Run per-note predictions and aggregate codes for one admission."""
326+
"""Run per-note predictions and aggregate codes for one admission.
327+
328+
Args:
329+
notes: Iterable of note texts.
330+
note_categories: Optional note categories aligned to ``notes``.
331+
chartdates: Optional chart dates aligned to ``notes``.
332+
333+
Returns:
334+
A tuple of (aggregated_codes, per_note_results).
335+
"""
297336
aggregated: Set[str] = set()
298337
note_results: List[dict] = []
299338
categories = list(note_categories) if note_categories is not None else []
@@ -330,5 +369,14 @@ def predict_admission_with_notes(
330369
note_categories: Optional[Iterable[str]] = None,
331370
chartdates: Optional[Iterable[str]] = None,
332371
) -> Tuple[Set[str], List[dict]]:
333-
"""Public helper to predict and return codes for one admission."""
372+
"""Predict codes for one admission using per-note LLM calls.
373+
374+
Args:
375+
notes: Iterable of note texts.
376+
note_categories: Optional note categories aligned to ``notes``.
377+
chartdates: Optional chart dates aligned to ``notes``.
378+
379+
Returns:
380+
A tuple of (aggregated_codes, per_note_results).
381+
"""
334382
return self._predict_admission(notes, note_categories, chartdates)

0 commit comments

Comments
 (0)