diff --git a/quantmind/flows/_paper_draft.py b/quantmind/flows/_paper_draft.py new file mode 100644 index 0000000..c07fbc8 --- /dev/null +++ b/quantmind/flows/_paper_draft.py @@ -0,0 +1,179 @@ +"""LLM-facing draft schema for ``paper_flow`` + lift into canonical ``Paper``. + +The canonical :class:`quantmind.knowledge.Paper` is the *store* schema: a flat +``dict[UUID, TreeNode]`` keyed by UUID for O(1) lookup and stable dedup keys. +That shape is hostile to LLM structured output for two reasons: + +1. A ``dict`` field serialises to ``additionalProperties``, which the Agents + SDK's strict-JSON-schema mode rejects outright. +2. Even with strict mode off, models do not reliably emit RFC-4122 UUIDs, so + ``UUID`` id fields fail Pydantic validation on free-form ids like + ``"intro_node"``. + +So the agent targets :class:`PaperDraft` instead — a nested tree of +:class:`PaperDraftNode` with plain-string-free structure (children are +embedded, not referenced by id) that is strict-schema clean. ``draft_to_paper`` +then assigns real UUIDs, wires ``parent_id`` / ``children_ids``, and injects +provenance (``source``, ``arxiv_id``, ``authors``) the flow already knows from +the fetch layer rather than trusting the model to hallucinate it. +""" + +from datetime import datetime, timezone +from typing import Any +from uuid import uuid4 + +from pydantic import BaseModel, ConfigDict, Field + +from quantmind.knowledge import ( + Citation, + ExtractionRef, + Paper, + SourceRef, + TreeNode, +) + +# Map the ``source`` discriminator emitted by ``_fetch_and_format`` onto the +# ``SourceRef.kind`` literal. ``inline`` (RawText) has no external origin, so it +# records as a manual entry. +_SOURCE_KIND: dict[str, str] = { + "arxiv": "arxiv", + "web": "http", + "local": "local", + "inline": "manual", +} + +# ``Citation.quote`` caps at 500 chars; models routinely over-quote. +_QUOTE_MAX = 500 + + +class DraftCitation(BaseModel): + """A citation as emitted by the model (no tree/node anchors yet).""" + + model_config = ConfigDict(extra="ignore") + + source_id: str + page: int | None = None + char_offset: int | None = None + quote: str | None = None + + +class PaperDraftNode(BaseModel): + """One section as emitted by the model. + + Children are embedded directly (a true tree), so the model never has to + keep a flat id map consistent — eliminating dangling-reference and + duplicate-id failure modes. ``draft_to_paper`` assigns identity. + """ + + model_config = ConfigDict(extra="ignore") + + title: str + summary: str + content: str | None = None + citations: list[DraftCitation] = Field(default_factory=list) + children: list["PaperDraftNode"] = Field(default_factory=list) + + +class PaperDraft(BaseModel): + """Strict-schema-safe extraction target for ``paper_flow``. + + Carries only what the model can author from the text. Provenance and + identity are supplied by :func:`draft_to_paper`. + """ + + model_config = ConfigDict(extra="ignore") + + root: PaperDraftNode + arxiv_id: str | None = None + authors: list[str] = Field(default_factory=list) + asset_classes: list[str] = Field(default_factory=list) + + +def _to_citations(drafts: list[DraftCitation]) -> list[Citation]: + return [ + Citation( + source_id=d.source_id, + page=d.page, + char_offset=d.char_offset, + quote=None if d.quote is None else d.quote[:_QUOTE_MAX], + ) + for d in drafts + ] + + +def _source_ref(source_meta: dict[str, Any]) -> SourceRef: + origin = source_meta.get("source", "manual") + kind = _SOURCE_KIND.get(origin, "manual") + if origin == "arxiv": + aid = source_meta.get("arxiv_id") + uri = f"arxiv:{aid}" if aid else None + elif origin == "web": + uri = source_meta.get("url") + elif origin == "local": + uri = source_meta.get("path") + else: + uri = None + return SourceRef(kind=kind, uri=uri) # type: ignore[arg-type] + + +def draft_to_paper( + draft: PaperDraft, + *, + source_meta: dict[str, Any], + model: str, +) -> Paper: + """Lift a validated ``PaperDraft`` into a canonical ``Paper``. + + Args: + draft: The model's nested extraction output. + source_meta: The ``(_, meta)`` dict from ``_fetch_and_format`` — the + authoritative provenance (origin, arxiv id, authors, published + date) known to the flow. + model: Model identifier recorded on the ``ExtractionRef``. + + Returns: + A fully-formed ``Paper`` with UUID identity and injected provenance. + """ + nodes: dict[Any, TreeNode] = {} + + def build(dn: PaperDraftNode, parent_id: Any, position: int) -> Any: + node_id = uuid4() + children_ids = [ + build(child, node_id, pos) for pos, child in enumerate(dn.children) + ] + nodes[node_id] = TreeNode( + node_id=node_id, + parent_id=parent_id, + position=position, + title=dn.title, + summary=dn.summary, + content=dn.content, + citations=_to_citations(dn.citations), + children_ids=children_ids, + ) + return node_id + + root_id = build(draft.root, None, 0) + + published = source_meta.get("published_at") + as_of = ( + published + if isinstance(published, datetime) + else datetime.now(timezone.utc) + ) + authors = source_meta.get("authors") or draft.authors + + return Paper( + as_of=as_of, + source=_source_ref(source_meta), + extraction=ExtractionRef( + flow="paper_flow", + model=model, + extracted_at=datetime.now(timezone.utc), + ), + root_node_id=root_id, + nodes=nodes, + arxiv_id=source_meta.get("arxiv_id") or draft.arxiv_id, + authors=list(authors), + asset_classes=draft.asset_classes, + ) diff --git a/quantmind/flows/paper.py b/quantmind/flows/paper.py index 5b8373b..e0bcde1 100644 --- a/quantmind/flows/paper.py +++ b/quantmind/flows/paper.py @@ -24,6 +24,7 @@ PaperInput, RawText, ) +from quantmind.flows._paper_draft import PaperDraft, draft_to_paper from quantmind.flows._runner import run_with_observability from quantmind.knowledge import Paper from quantmind.preprocess.fetch import ( @@ -37,10 +38,11 @@ P = TypeVar("P", bound=Paper) _DEFAULT_INSTRUCTIONS = """\ -You are extracting a research paper into a structured QuantMind ``Paper`` -TreeKnowledge object. Build the section tree top-down: every node has a -title and a short summary; leaf nodes additionally carry the section -markdown content. Cite supporting passages on each node. +You are extracting a research paper into a structured QuantMind paper tree. +Build the section tree top-down by nesting each child section under its +parent's ``children``: every node has a title and a short summary; leaf +nodes additionally carry the section markdown content. Cite supporting +passages on each node. Honour these flags from the run config: - extract_methodology={extract_methodology}: when true, every methodology @@ -48,12 +50,10 @@ - extract_limitations={extract_limitations}: when true, surface limitations as a dedicated top-level child rather than inlining them. - asset_class_hint={asset_class_hint!r}: when set, prefer this asset - class for ``Paper.asset_classes`` if the paper does not state one - explicitly. + class for ``asset_classes`` if the paper does not state one explicitly. -Set ``as_of`` to the publication date when given; otherwise use today's -date. Set the ``source`` provenance ref using the metadata supplied in -the prompt. +Identity and provenance (node ids, source, dates, authors) are added by +the framework afterwards from the fetch metadata — do not invent them. """ @@ -85,7 +85,10 @@ async def paper_flow( unpaywall fallback is its own follow-up issue). """ cfg = cfg or PaperFlowCfg() - out_type: type[Paper] = output_type or Paper # type: ignore[assignment] + # The agent targets the strict-schema-safe ``PaperDraft`` by default; a + # caller-supplied ``output_type`` opts out of the draft mechanism and is + # forwarded verbatim (its result is returned unconverted below). + out_type: Any = output_type or PaperDraft raw_md, source_meta = await _fetch_and_format(input) @@ -105,13 +108,19 @@ async def paper_flow( if cfg.model_settings is not None: agent_kwargs["model_settings"] = cfg.model_settings agent: Agent[Any] = Agent(**agent_kwargs) - return await run_with_observability( + result = await run_with_observability( agent, _format_input(raw_md, source_meta), cfg=cfg, memory=memory, extra_run_hooks=list(extra_run_hooks or []), ) + # A caller-supplied ``output_type`` already yields a ``Paper`` (subclass); + # the default path yields a ``PaperDraft`` we lift into the canonical + # store schema, injecting flow-known provenance. + if isinstance(result, Paper): + return result + return draft_to_paper(result, source_meta=source_meta, model=cfg.model) async def _fetch_and_format( @@ -126,6 +135,7 @@ async def _fetch_and_format( "arxiv_id": raw.arxiv_id, "title": raw.title, "authors": list(raw.authors), + "published_at": raw.published_at, } if isinstance(input, HttpUrl): raw = await fetch_url(input.url) diff --git a/tests/flows/test_paper.py b/tests/flows/test_paper.py index 2a143c7..f461e4b 100644 --- a/tests/flows/test_paper.py +++ b/tests/flows/test_paper.py @@ -17,6 +17,7 @@ LocalFilePath, RawText, ) +from quantmind.flows._paper_draft import PaperDraft, PaperDraftNode from quantmind.flows.paper import ( UnsupportedContentTypeError, _compose_instructions, @@ -357,3 +358,48 @@ def _capture_agent(*_a: Any, **kwargs: Any) -> Any: ): await paper_flow(RawText(text="x")) self.assertNotIn("model_settings", seen) + + async def test_default_agent_output_type_is_paper_draft(self) -> None: + # The agent must target the strict-schema-safe draft, not the + # canonical (UUID/dict) Paper, unless the caller overrides it. + seen: dict[str, Any] = {} + + def _capture_agent(*_a: Any, **kwargs: Any) -> Any: + seen.update(kwargs) + return MagicMock() + + with ( + patch("quantmind.flows.paper.Agent", side_effect=_capture_agent), + _patch_runner(_stub_paper()), + ): + await paper_flow(RawText(text="x")) + self.assertIs(seen["output_type"], PaperDraft) + + async def test_draft_result_converted_to_canonical_paper(self) -> None: + raw_paper = RawPaper( + bytes=b"%PDF", + content_type="application/pdf", + arxiv_id="2604.12345", + authors=("Alice",), + ) + draft = PaperDraft( + root=PaperDraftNode(title="Extracted Title", summary="s") + ) + with ( + patch( + "quantmind.flows.paper.fetch_arxiv", + new=AsyncMock(return_value=raw_paper), + ), + patch( + "quantmind.flows.paper.pdf_to_markdown", + new=AsyncMock(return_value="MD"), + ), + _patch_runner(draft), + ): + out = await paper_flow(ArxivIdentifier(id="2604.12345")) + # Converted into the canonical store schema with injected provenance. + self.assertIsInstance(out, Paper) + self.assertEqual(out.root().title, "Extracted Title") + self.assertEqual(out.arxiv_id, "2604.12345") + self.assertEqual(out.source.kind, "arxiv") + self.assertEqual(out.authors, ["Alice"]) diff --git a/tests/flows/test_paper_draft.py b/tests/flows/test_paper_draft.py new file mode 100644 index 0000000..f3e3b11 --- /dev/null +++ b/tests/flows/test_paper_draft.py @@ -0,0 +1,181 @@ +"""Tests for ``quantmind.flows._paper_draft``. + +The draft layer is the LLM-facing extraction schema (strict-structured-output +safe: nested children, plain-string ids) plus ``draft_to_paper`` which lifts a +validated draft into the canonical ``quantmind.knowledge.Paper`` store schema +(UUID identity, flat ``nodes`` map, injected provenance). +""" + +import unittest +from datetime import datetime, timezone + +from quantmind.flows._paper_draft import ( + DraftCitation, + PaperDraft, + PaperDraftNode, + draft_to_paper, +) +from quantmind.knowledge import Paper + +_ARXIV_META = { + "source": "arxiv", + "arxiv_id": "2604.12345", + "title": "Cross-Sectional Momentum", + "authors": ["Alice", "Bob"], +} + + +def _leaf( + title: str, summary: str = "s", content: str | None = "body" +) -> PaperDraftNode: + return PaperDraftNode(title=title, summary=summary, content=content) + + +class DraftToPaperStructureTests(unittest.TestCase): + def test_single_node_draft_becomes_valid_paper(self) -> None: + draft = PaperDraft(root=PaperDraftNode(title="Root", summary="top")) + paper = draft_to_paper( + draft, source_meta=_ARXIV_META, model="gpt-4o-mini" + ) + self.assertIsInstance(paper, Paper) + # Root resolves and carries the draft's title/summary. + self.assertEqual(paper.root().title, "Root") + self.assertEqual(paper.root().summary, "top") + # root_node_id points into the nodes map. + self.assertIn(paper.root_node_id, paper.nodes) + self.assertEqual(len(paper.nodes), 1) + + def test_nested_children_get_uuids_and_parent_child_wiring(self) -> None: + draft = PaperDraft( + root=PaperDraftNode( + title="Root", + summary="top", + children=[_leaf("Intro"), _leaf("Method")], + ) + ) + paper = draft_to_paper( + draft, source_meta=_ARXIV_META, model="gpt-4o-mini" + ) + root = paper.root() + self.assertEqual(len(paper.nodes), 3) + self.assertEqual(len(root.children_ids), 2) + children = paper.children_of(root.node_id) + self.assertEqual([c.title for c in children], ["Intro", "Method"]) + # Each child points back to the root and preserves declared order. + for pos, child in enumerate(children): + self.assertEqual(child.parent_id, root.node_id) + self.assertEqual(child.position, pos) + # Root has no parent. + self.assertIsNone(root.parent_id) + + def test_leaf_content_preserved(self) -> None: + draft = PaperDraft( + root=PaperDraftNode( + title="Root", + summary="top", + children=[_leaf("Body", content="full markdown")], + ) + ) + paper = draft_to_paper( + draft, source_meta=_ARXIV_META, model="gpt-4o-mini" + ) + leaf = paper.children_of(paper.root_node_id)[0] + self.assertEqual(leaf.content, "full markdown") + + +class DraftToPaperProvenanceTests(unittest.TestCase): + def test_source_injected_from_arxiv_meta_not_llm(self) -> None: + draft = PaperDraft(root=PaperDraftNode(title="R", summary="s")) + paper = draft_to_paper( + draft, source_meta=_ARXIV_META, model="gpt-4o-mini" + ) + self.assertEqual(paper.source.kind, "arxiv") + self.assertIn("2604.12345", paper.source.uri or "") + + def test_web_meta_maps_to_http_source(self) -> None: + draft = PaperDraft(root=PaperDraftNode(title="R", summary="s")) + meta = { + "source": "web", + "url": "https://example.com/x.pdf", + "content_type": "application/pdf", + } + paper = draft_to_paper(draft, source_meta=meta, model="m") + self.assertEqual(paper.source.kind, "http") + self.assertEqual(paper.source.uri, "https://example.com/x.pdf") + + def test_inline_meta_maps_to_manual_source(self) -> None: + draft = PaperDraft(root=PaperDraftNode(title="R", summary="s")) + paper = draft_to_paper( + draft, source_meta={"source": "inline"}, model="m" + ) + self.assertEqual(paper.source.kind, "manual") + + def test_extraction_records_flow_and_model(self) -> None: + draft = PaperDraft(root=PaperDraftNode(title="R", summary="s")) + paper = draft_to_paper(draft, source_meta=_ARXIV_META, model="gpt-x") + self.assertIsNotNone(paper.extraction) + assert paper.extraction is not None # narrow for type-checker + self.assertEqual(paper.extraction.flow, "paper_flow") + self.assertEqual(paper.extraction.model, "gpt-x") + + def test_arxiv_id_and_authors_taken_from_meta(self) -> None: + # Draft leaves them empty; the flow knows them from fetch metadata. + draft = PaperDraft(root=PaperDraftNode(title="R", summary="s")) + paper = draft_to_paper( + draft, source_meta=_ARXIV_META, model="gpt-4o-mini" + ) + self.assertEqual(paper.arxiv_id, "2604.12345") + self.assertEqual(paper.authors, ["Alice", "Bob"]) + + def test_asset_classes_taken_from_draft(self) -> None: + draft = PaperDraft( + root=PaperDraftNode(title="R", summary="s"), + asset_classes=["equities", "rates"], + ) + paper = draft_to_paper(draft, source_meta=_ARXIV_META, model="m") + self.assertEqual(paper.asset_classes, ["equities", "rates"]) + + def test_as_of_uses_published_at_when_present(self) -> None: + published = datetime(2025, 3, 1, tzinfo=timezone.utc) + meta = {**_ARXIV_META, "published_at": published} + draft = PaperDraft(root=PaperDraftNode(title="R", summary="s")) + paper = draft_to_paper(draft, source_meta=meta, model="m") + self.assertEqual(paper.as_of, published) + + def test_as_of_defaults_to_aware_now_without_published_at(self) -> None: + draft = PaperDraft(root=PaperDraftNode(title="R", summary="s")) + paper = draft_to_paper(draft, source_meta=_ARXIV_META, model="m") + self.assertIsNotNone(paper.as_of.tzinfo) + + +class DraftToPaperCitationTests(unittest.TestCase): + def test_citation_mapped_onto_node(self) -> None: + draft = PaperDraft( + root=PaperDraftNode( + title="R", + summary="s", + citations=[ + DraftCitation(source_id="arxiv:1", page=2, quote="hi") + ], + ) + ) + paper = draft_to_paper(draft, source_meta=_ARXIV_META, model="m") + cites = paper.root().citations + self.assertEqual(len(cites), 1) + self.assertEqual(cites[0].source_id, "arxiv:1") + self.assertEqual(cites[0].page, 2) + + def test_overlong_quote_truncated_to_schema_limit(self) -> None: + draft = PaperDraft( + root=PaperDraftNode( + title="R", + summary="s", + citations=[DraftCitation(source_id="x", quote="z" * 600)], + ) + ) + paper = draft_to_paper(draft, source_meta=_ARXIV_META, model="m") + self.assertEqual(len(paper.root().citations[0].quote or ""), 500) + + +if __name__ == "__main__": + unittest.main()