diff --git a/requirements-dev.txt b/requirements-dev.txt index b54a2d1..4557a45 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -6,4 +6,6 @@ pytest-cov>=5.0 ruff>=0.9.0 pip-audit>=2.7.0 hypothesis>=6.100.0 +PyYAML>=6.0 +types-PyYAML>=6.0 pytest-benchmark==5.2.3 diff --git a/tests/test_md_exporter_yaml.py b/tests/test_md_exporter_yaml.py new file mode 100644 index 0000000..a055f51 --- /dev/null +++ b/tests/test_md_exporter_yaml.py @@ -0,0 +1,121 @@ +"""YAML frontmatter escaping and round-trip tests for md_exporter.""" + +from __future__ import annotations + +import os +import sys + +import yaml +from hypothesis import given, settings, strategies as st + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +from models.session import SessionDict +from utils.md_exporter import ( + _append_yaml_value, + _escape_yaml, + _session_frontmatter_dict, + session_to_markdown, +) + +FUZZ_SETTINGS = settings(max_examples=100) + + +def _extract_frontmatter_dict(markdown: str) -> dict: + lines = markdown.splitlines() + if not lines or lines[0].strip() != "---": + raise ValueError("missing opening frontmatter delimiter") + yaml_lines: list[str] = [] + for line in lines[1:]: + if line.strip() == "---": + break + yaml_lines.append(line) + else: + raise ValueError("missing closing frontmatter delimiter") + loaded = yaml.safe_load("\n".join(yaml_lines)) + return loaded if isinstance(loaded, dict) else {} + + +def _base_session(**overrides: object) -> SessionDict: + session: SessionDict = { + "session_id": "sess-001", + "title": "Hello", + "messages": [{"role": "user", "text": "hi"}], + "metadata": { + "session_id": "sess-001", + "models_used": ["claude-sonnet-4-20250514"], + "first_timestamp": "2026-01-02T12:00:00Z", + "last_timestamp": "2026-01-02T12:30:00Z", + "total_input_tokens": 120, + "total_output_tokens": 45, + "total_cache_read_tokens": 10, + "total_tool_calls": 2, + "tool_call_counts": {"Read": 2}, + "cwd": "/workspace", + "git_branch": "main", + "version": "1.0.0", + "permission_mode": "default", + }, + } + if overrides: + for key, value in overrides.items(): + if key == "metadata" and isinstance(value, dict): + session["metadata"].update(value) # type: ignore[typeddict-item] + else: + session[key] = value # type: ignore[literal-required] + return session + + +class TestYamlFrontmatterRoundtrip: + def test_yaml_frontmatter_roundtrip(self): + session = _base_session( + title="Fix: handle edge case #42", + metadata={ + "cwd": r"C:\Users\dev\project", + "git_branch": "feat#yaml", + "permission_mode": "true", + "stop_reasons": {"max_tokens": 1, "end_turn": 2}, + "tool_call_counts": {"Read": 1, "Fix: tool": 1}, + }, + ) + md = session_to_markdown(session) + assert _extract_frontmatter_dict(md) == _session_frontmatter_dict(session) + + def test_multiline_title_uses_quoted_scalar(self): + session = _base_session(title="line one\nline two") + md = session_to_markdown(session) + assert 'title: "line one\\nline two"' in md.split("---")[1] + assert _extract_frontmatter_dict(md)["title"] == "line one\nline two" + + def test_tab_and_hash_in_title(self): + session = _base_session(title="tab\there # not a comment") + md = session_to_markdown(session) + assert _extract_frontmatter_dict(md)["title"] == "tab\there # not a comment" + + def test_models_used_serializes_as_yaml_sequence(self): + session = _base_session( + metadata={"models_used": ["claude-sonnet-4", "claude-opus-4"]}, + ) + md = session_to_markdown(session) + frontmatter = _extract_frontmatter_dict(md) + assert frontmatter["models_used"] == ["claude-sonnet-4", "claude-opus-4"] + assert "models_used:\n" in md.split("---")[1] or "models_used:" in md.split("---")[1] + assert ' - "claude-sonnet-4"' in md.split("---")[1] + + +@FUZZ_SETTINGS +@given(st.text()) +def test_escape_yaml_roundtrip(s: str) -> None: + """Double-quoted scalars round-trip for arbitrary text.""" + loaded = yaml.safe_load(f"key: {_escape_yaml(s)}") + assert loaded["key"] == s + + +@FUZZ_SETTINGS +@given(st.text()) +def test_yaml_string_field_roundtrip(s: str) -> None: + """Frontmatter string serializer round-trips arbitrary text.""" + lines: list[str] = [] + _append_yaml_value(lines, "key", s) + loaded = yaml.safe_load("\n".join(lines)) + assert loaded["key"] == s diff --git a/utils/md_exporter.py b/utils/md_exporter.py index bc617b1..ab76d71 100644 --- a/utils/md_exporter.py +++ b/utils/md_exporter.py @@ -1,6 +1,7 @@ """Markdown export. Produces a .md with YAML frontmatter, a summary section (cost, files touched, commands run), and the full conversation.""" +import re from datetime import datetime from typing import Any @@ -25,63 +26,70 @@ def session_to_markdown(session: SessionDict, stats: SessionStatsDict | None = N def _build_frontmatter(session: SessionDict) -> str: - meta = session["metadata"] lines = ["---"] - lines.append(f'title: "{_escape_yaml(session["title"])}"') - if meta["first_timestamp"]: - lines.append(f"created: {meta['first_timestamp']}") - if meta["last_timestamp"]: - lines.append(f"updated: {meta['last_timestamp']}") - lines.append(f"session_id: {session['session_id']}") - if meta["models_used"]: - lines.append(f"models_used: {', '.join(meta['models_used'])}") - lines.append(f"total_input_tokens: {meta['total_input_tokens']}") - lines.append(f"total_output_tokens: {meta['total_output_tokens']}") - lines.append(f"total_cache_read_tokens: {meta['total_cache_read_tokens']}") + for key, value in _session_frontmatter_dict(session).items(): + _append_yaml_value(lines, key, value) + lines.append("---") + return "\n".join(lines) + + +def _session_frontmatter_dict(session: SessionDict) -> dict[str, Any]: + """Canonical frontmatter payload; used for export and round-trip tests.""" + meta = session["metadata"] + data: dict[str, Any] = { + "title": session["title"], + "session_id": session["session_id"], + "total_input_tokens": meta.get("total_input_tokens", 0), + "total_output_tokens": meta.get("total_output_tokens", 0), + "total_cache_read_tokens": meta.get("total_cache_read_tokens", 0), + "total_tool_calls": meta.get("total_tool_calls", 0), + "message_count": len(session["messages"]), + } + if meta.get("first_timestamp"): + data["created"] = meta["first_timestamp"] + if meta.get("last_timestamp"): + data["updated"] = meta["last_timestamp"] + if meta.get("models_used"): + data["models_used"] = list(meta["models_used"]) if meta.get("total_cache_creation_tokens", 0) > 0: - lines.append(f"total_cache_creation_tokens: {meta['total_cache_creation_tokens']}") - lines.append(f"total_tool_calls: {meta['total_tool_calls']}") - if meta["tool_call_counts"]: - lines.append("tool_call_breakdown:") - for tool, count in sorted(meta["tool_call_counts"].items(), key=lambda x: -x[1]): - lines.append(f" {tool}: {count}") + data["total_cache_creation_tokens"] = meta["total_cache_creation_tokens"] + if meta.get("tool_call_counts"): + data["tool_call_breakdown"] = dict( + sorted(meta["tool_call_counts"].items(), key=lambda item: -item[1]) + ) if meta.get("stop_reasons"): - lines.append("stop_reasons:") - for reason, count in sorted(meta["stop_reasons"].items(), key=lambda x: -x[1]): - lines.append(f" {reason}: {count}") - if meta["cwd"]: - lines.append(f'working_directory: "{_escape_yaml(meta["cwd"])}"') - if meta["git_branch"]: - lines.append(f"git_branch: {meta['git_branch']}") - if meta["version"]: - lines.append(f"claude_code_version: {meta['version']}") - if meta["permission_mode"]: - lines.append(f"permission_mode: {meta['permission_mode']}") + data["stop_reasons"] = dict(sorted(meta["stop_reasons"].items(), key=lambda item: -item[1])) + if meta.get("cwd"): + data["working_directory"] = meta["cwd"] + if meta.get("git_branch"): + data["git_branch"] = meta["git_branch"] + if meta.get("version"): + data["claude_code_version"] = meta["version"] + if meta.get("permission_mode"): + data["permission_mode"] = meta["permission_mode"] if meta.get("service_tiers"): - lines.append(f"service_tiers: {', '.join(meta['service_tiers'])}") - lines.append(f"message_count: {len(session['messages'])}") - if meta["compactions"] > 0: - lines.append(f"compactions: {meta['compactions']}") + data["service_tiers"] = list(meta["service_tiers"]) + if meta.get("compactions", 0) > 0: + data["compactions"] = meta["compactions"] if meta.get("api_errors", 0) > 0: - lines.append(f"api_errors: {meta['api_errors']}") + data["api_errors"] = meta["api_errors"] if meta.get("sidechain_messages", 0) > 0: - lines.append(f"sidechain_messages: {meta['sidechain_messages']}") + data["sidechain_messages"] = meta["sidechain_messages"] wall = meta.get("session_wall_time_seconds") if wall is not None: - lines.append(f"wall_clock_seconds: {int(wall)}") + data["wall_clock_seconds"] = int(wall) files_r = meta.get("files_read", []) files_w = meta.get("files_written", []) files_c = meta.get("files_created", []) if files_r or files_w or files_c: - lines.append(f"files_read: {len(files_r)}") - lines.append(f"files_written: {len(files_w)}") - lines.append(f"files_created: {len(files_c)}") + data["files_read"] = len(files_r) + data["files_written"] = len(files_w) + data["files_created"] = len(files_c) if meta.get("bash_commands"): - lines.append(f"commands_run: {len(meta['bash_commands'])}") + data["commands_run"] = len(meta["bash_commands"]) if meta.get("web_fetches"): - lines.append(f"web_fetches: {len(meta['web_fetches'])}") - lines.append("---") - return "\n".join(lines) + data["web_fetches"] = len(meta["web_fetches"]) + return data def _build_header(session: SessionDict) -> str: @@ -459,8 +467,74 @@ def _format_ts(ts: str | None) -> str: return ts +_PLAIN_YAML_KEY = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$") + + +def _yaml_mapping_key(key: str) -> str: + """Format a nested mapping key, quoting when it is not a plain identifier.""" + if _PLAIN_YAML_KEY.match(key): + return key + return _escape_yaml(key) + + def _escape_yaml(s: str) -> str: - return s.replace('"', '\\"').replace("\n", " ") + """Return a YAML double-quoted scalar for any string (including embedded newlines).""" + parts: list[str] = [] + for ch in s: + if ch == "\\": + parts.append("\\\\") + elif ch == '"': + parts.append('\\"') + elif ch == "\t": + parts.append("\\t") + elif ch == "\n": + parts.append("\\n") + elif ch == "\r": + parts.append("\\r") + elif not ch.isprintable(): + code = ord(ch) + if code <= 0xFF: + parts.append(f"\\x{code:02x}") + elif code <= 0xFFFF: + parts.append(f"\\u{code:04x}") + else: + parts.append(f"\\U{code:08x}") + else: + parts.append(ch) + return f'"{"".join(parts)}"' + + +def _append_yaml_value(lines: list[str], key: str, value: Any, *, indent: int = 0) -> None: + """Append one frontmatter field, quoting strings and nesting mappings safely.""" + prefix = " " * indent + if isinstance(value, dict): + lines.append(f"{prefix}{key}:") + for nested_key, nested_value in value.items(): + _append_yaml_value( + lines, + _yaml_mapping_key(nested_key), + nested_value, + indent=indent + 1, + ) + return + if isinstance(value, list): + lines.append(f"{prefix}{key}:") + for item in value: + if isinstance(item, str): + lines.append(f"{prefix} - {_escape_yaml(item)}") + elif isinstance(item, int): + lines.append(f"{prefix} - {item}") + else: + raise TypeError( + f"unsupported frontmatter sequence item type: {type(item).__name__}" + ) + return + if isinstance(value, int): + lines.append(f"{prefix}{key}: {value}") + return + if not isinstance(value, str): + raise TypeError(f"unsupported frontmatter value type: {type(value).__name__}") + lines.append(f"{prefix}{key}: {_escape_yaml(value)}") def _truncate(s: str, max_len: int) -> str: