Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 179 additions & 0 deletions quantmind/flows/_paper_draft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
"""LLM-facing draft schema for ``paper_flow`` + lift into canonical ``Paper``.

The canonical :class:`quantmind.knowledge.Paper` is the *store* schema: a flat
``dict[UUID, TreeNode]`` keyed by UUID for O(1) lookup and stable dedup keys.
That shape is hostile to LLM structured output for two reasons:

1. A ``dict`` field serialises to ``additionalProperties``, which the Agents
SDK's strict-JSON-schema mode rejects outright.
2. Even with strict mode off, models do not reliably emit RFC-4122 UUIDs, so
``UUID`` id fields fail Pydantic validation on free-form ids like
``"intro_node"``.

So the agent targets :class:`PaperDraft` instead — a nested tree of
:class:`PaperDraftNode` with plain-string-free structure (children are
embedded, not referenced by id) that is strict-schema clean. ``draft_to_paper``
then assigns real UUIDs, wires ``parent_id`` / ``children_ids``, and injects
provenance (``source``, ``arxiv_id``, ``authors``) the flow already knows from
the fetch layer rather than trusting the model to hallucinate it.
"""

from datetime import datetime, timezone
from typing import Any
from uuid import uuid4

from pydantic import BaseModel, ConfigDict, Field

from quantmind.knowledge import (
Citation,
ExtractionRef,
Paper,
SourceRef,
TreeNode,
)

# Map the ``source`` discriminator emitted by ``_fetch_and_format`` onto the
# ``SourceRef.kind`` literal. ``inline`` (RawText) has no external origin, so it
# records as a manual entry.
_SOURCE_KIND: dict[str, str] = {
"arxiv": "arxiv",
"web": "http",
"local": "local",
"inline": "manual",
}

# ``Citation.quote`` caps at 500 chars; models routinely over-quote.
_QUOTE_MAX = 500


class DraftCitation(BaseModel):
"""A citation as emitted by the model (no tree/node anchors yet)."""

model_config = ConfigDict(extra="ignore")

source_id: str
page: int | None = None
char_offset: int | None = None
quote: str | None = None


class PaperDraftNode(BaseModel):
"""One section as emitted by the model.

Children are embedded directly (a true tree), so the model never has to
keep a flat id map consistent — eliminating dangling-reference and
duplicate-id failure modes. ``draft_to_paper`` assigns identity.
"""

model_config = ConfigDict(extra="ignore")

title: str
summary: str
content: str | None = None
citations: list[DraftCitation] = Field(default_factory=list)
children: list["PaperDraftNode"] = Field(default_factory=list)


class PaperDraft(BaseModel):
"""Strict-schema-safe extraction target for ``paper_flow``.

Carries only what the model can author from the text. Provenance and
identity are supplied by :func:`draft_to_paper`.
"""

model_config = ConfigDict(extra="ignore")

root: PaperDraftNode
arxiv_id: str | None = None
authors: list[str] = Field(default_factory=list)
asset_classes: list[str] = Field(default_factory=list)


def _to_citations(drafts: list[DraftCitation]) -> list[Citation]:
return [
Citation(
source_id=d.source_id,
page=d.page,
char_offset=d.char_offset,
quote=None if d.quote is None else d.quote[:_QUOTE_MAX],
)
for d in drafts
]


def _source_ref(source_meta: dict[str, Any]) -> SourceRef:
origin = source_meta.get("source", "manual")
kind = _SOURCE_KIND.get(origin, "manual")
if origin == "arxiv":
aid = source_meta.get("arxiv_id")
uri = f"arxiv:{aid}" if aid else None
elif origin == "web":
uri = source_meta.get("url")
elif origin == "local":
uri = source_meta.get("path")
else:
uri = None
return SourceRef(kind=kind, uri=uri) # type: ignore[arg-type]


def draft_to_paper(
draft: PaperDraft,
*,
source_meta: dict[str, Any],
model: str,
) -> Paper:
"""Lift a validated ``PaperDraft`` into a canonical ``Paper``.

Args:
draft: The model's nested extraction output.
source_meta: The ``(_, meta)`` dict from ``_fetch_and_format`` — the
authoritative provenance (origin, arxiv id, authors, published
date) known to the flow.
model: Model identifier recorded on the ``ExtractionRef``.

Returns:
A fully-formed ``Paper`` with UUID identity and injected provenance.
"""
nodes: dict[Any, TreeNode] = {}

def build(dn: PaperDraftNode, parent_id: Any, position: int) -> Any:
node_id = uuid4()
children_ids = [
build(child, node_id, pos) for pos, child in enumerate(dn.children)
]
nodes[node_id] = TreeNode(
node_id=node_id,
parent_id=parent_id,
position=position,
title=dn.title,
summary=dn.summary,
content=dn.content,
citations=_to_citations(dn.citations),
children_ids=children_ids,
)
return node_id

root_id = build(draft.root, None, 0)

published = source_meta.get("published_at")
as_of = (
published
if isinstance(published, datetime)
else datetime.now(timezone.utc)
)
authors = source_meta.get("authors") or draft.authors

return Paper(
as_of=as_of,
source=_source_ref(source_meta),
extraction=ExtractionRef(
flow="paper_flow",
model=model,
extracted_at=datetime.now(timezone.utc),
),
root_node_id=root_id,
nodes=nodes,
arxiv_id=source_meta.get("arxiv_id") or draft.arxiv_id,
authors=list(authors),
asset_classes=draft.asset_classes,
)
32 changes: 21 additions & 11 deletions quantmind/flows/paper.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
PaperInput,
RawText,
)
from quantmind.flows._paper_draft import PaperDraft, draft_to_paper
from quantmind.flows._runner import run_with_observability
from quantmind.knowledge import Paper
from quantmind.preprocess.fetch import (
Expand All @@ -37,23 +38,22 @@
P = TypeVar("P", bound=Paper)

_DEFAULT_INSTRUCTIONS = """\
You are extracting a research paper into a structured QuantMind ``Paper``
TreeKnowledge object. Build the section tree top-down: every node has a
title and a short summary; leaf nodes additionally carry the section
markdown content. Cite supporting passages on each node.
You are extracting a research paper into a structured QuantMind paper tree.
Build the section tree top-down by nesting each child section under its
parent's ``children``: every node has a title and a short summary; leaf
nodes additionally carry the section markdown content. Cite supporting
passages on each node.

Honour these flags from the run config:
- extract_methodology={extract_methodology}: when true, every methodology
section becomes its own subtree with a per-step summary.
- extract_limitations={extract_limitations}: when true, surface
limitations as a dedicated top-level child rather than inlining them.
- asset_class_hint={asset_class_hint!r}: when set, prefer this asset
class for ``Paper.asset_classes`` if the paper does not state one
explicitly.
class for ``asset_classes`` if the paper does not state one explicitly.

Set ``as_of`` to the publication date when given; otherwise use today's
date. Set the ``source`` provenance ref using the metadata supplied in
the prompt.
Identity and provenance (node ids, source, dates, authors) are added by
the framework afterwards from the fetch metadata — do not invent them.
"""


Expand Down Expand Up @@ -85,7 +85,10 @@ async def paper_flow(
unpaywall fallback is its own follow-up issue).
"""
cfg = cfg or PaperFlowCfg()
out_type: type[Paper] = output_type or Paper # type: ignore[assignment]
# The agent targets the strict-schema-safe ``PaperDraft`` by default; a
# caller-supplied ``output_type`` opts out of the draft mechanism and is
# forwarded verbatim (its result is returned unconverted below).
out_type: Any = output_type or PaperDraft

raw_md, source_meta = await _fetch_and_format(input)

Expand All @@ -105,13 +108,19 @@ async def paper_flow(
if cfg.model_settings is not None:
agent_kwargs["model_settings"] = cfg.model_settings
agent: Agent[Any] = Agent(**agent_kwargs)
return await run_with_observability(
result = await run_with_observability(
agent,
_format_input(raw_md, source_meta),
cfg=cfg,
memory=memory,
extra_run_hooks=list(extra_run_hooks or []),
)
# A caller-supplied ``output_type`` already yields a ``Paper`` (subclass);
# the default path yields a ``PaperDraft`` we lift into the canonical
# store schema, injecting flow-known provenance.
if isinstance(result, Paper):
return result
return draft_to_paper(result, source_meta=source_meta, model=cfg.model)


async def _fetch_and_format(
Expand All @@ -126,6 +135,7 @@ async def _fetch_and_format(
"arxiv_id": raw.arxiv_id,
"title": raw.title,
"authors": list(raw.authors),
"published_at": raw.published_at,
}
if isinstance(input, HttpUrl):
raw = await fetch_url(input.url)
Expand Down
46 changes: 46 additions & 0 deletions tests/flows/test_paper.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
LocalFilePath,
RawText,
)
from quantmind.flows._paper_draft import PaperDraft, PaperDraftNode
from quantmind.flows.paper import (
UnsupportedContentTypeError,
_compose_instructions,
Expand Down Expand Up @@ -357,3 +358,48 @@ def _capture_agent(*_a: Any, **kwargs: Any) -> Any:
):
await paper_flow(RawText(text="x"))
self.assertNotIn("model_settings", seen)

async def test_default_agent_output_type_is_paper_draft(self) -> None:
# The agent must target the strict-schema-safe draft, not the
# canonical (UUID/dict) Paper, unless the caller overrides it.
seen: dict[str, Any] = {}

def _capture_agent(*_a: Any, **kwargs: Any) -> Any:
seen.update(kwargs)
return MagicMock()

with (
patch("quantmind.flows.paper.Agent", side_effect=_capture_agent),
_patch_runner(_stub_paper()),
):
await paper_flow(RawText(text="x"))
self.assertIs(seen["output_type"], PaperDraft)

async def test_draft_result_converted_to_canonical_paper(self) -> None:
raw_paper = RawPaper(
bytes=b"%PDF",
content_type="application/pdf",
arxiv_id="2604.12345",
authors=("Alice",),
)
draft = PaperDraft(
root=PaperDraftNode(title="Extracted Title", summary="s")
)
with (
patch(
"quantmind.flows.paper.fetch_arxiv",
new=AsyncMock(return_value=raw_paper),
),
patch(
"quantmind.flows.paper.pdf_to_markdown",
new=AsyncMock(return_value="MD"),
),
_patch_runner(draft),
):
out = await paper_flow(ArxivIdentifier(id="2604.12345"))
# Converted into the canonical store schema with injected provenance.
self.assertIsInstance(out, Paper)
self.assertEqual(out.root().title, "Extracted Title")
self.assertEqual(out.arxiv_id, "2604.12345")
self.assertEqual(out.source.kind, "arxiv")
self.assertEqual(out.authors, ["Alice"])
Loading