diff --git a/README.md b/README.md index d2b82507c..547065994 100644 --- a/README.md +++ b/README.md @@ -63,6 +63,7 @@ ART is an open-source RL framework that improves agent reliability by allowing L | Agent Task | Example Notebook | Description | Comparative Performance | | ------------------- | -------------------------------------------------------------------------------------------------------------------------------------- | --------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | | **ART•E [Serverless]** | [🏋️ Train agent](https://colab.research.google.com/github/openpipe/art-notebooks/blob/main/examples/art-e.ipynb) | Qwen3 14B learns to search emails using RULER | [benchmarks](/dev/art-e/art_e/evaluate/display_benchmarks.ipynb) | +| **ART•E Local** | [Example](/examples/art_e) | Lightweight local email-search agent task with deterministic inbox fixtures | Local smoke-test example | | **2048 [Serverless]** | [🏋️ Train agent](https://colab.research.google.com/github/openpipe/art-notebooks/blob/main/examples/2048/2048.ipynb) | Qwen3 14B learns to play 2048 | [benchmarks](/examples/2048/display_benchmarks.ipynb) | | **ART•E LangGraph** | [🏋️ Train agent](https://colab.research.google.com/github/openpipe/art-notebooks/blob/main/examples/langgraph/art-e-langgraph.ipynb) | Qwen 2.5 7B learns to search emails using LangGraph | [Link coming soon] | | **MCP•RL** | [🏋️ Train agent](https://colab.research.google.com/github/openpipe/art-notebooks/blob/main/examples/mcp-rl/mcp-rl.ipynb) | Qwen 2.5 3B masters the NWS MCP server | [Link coming soon] | diff --git a/examples/art_e/README.md b/examples/art_e/README.md new file mode 100644 index 000000000..94cf44f79 --- /dev/null +++ b/examples/art_e/README.md @@ -0,0 +1,46 @@ +# ART-E Email Search Example + +This example is a lightweight, local version of the ART-E email search task. It +shows how to train an agent to search a small inbox, read relevant messages, and +return a grounded answer with supporting message IDs. + +The example is intentionally small: + +- No external email service is required. +- The inboxes are deterministic Python fixtures. +- The rollout uses a simple text protocol instead of provider-specific tool + calling so it works across most chat models. +- The reward combines exact answer matching and citation correctness. + +For the full ART-E research context, see the +[ART-E blog post](https://openpipe.ai/blog/art-e-mail-agent). + +## Files + +- `scenarios.py` defines inbox fixtures, search/read helpers, and answer + scoring. +- `rollout.py` runs one multi-turn email-search trajectory. +- `train.py` trains a small model with ART using the local scenarios. + +## Run One Rollout + +Set an inference API key for the provider used by your `art.Model`, then run: + +```bash +python examples/art_e/rollout.py +``` + +The script uses an OpenRouter model by default for a cheap smoke test. You can +change the model configuration at the bottom of `rollout.py`. + +## Train + +Training requires the normal ART local backend setup: + +```bash +python examples/art_e/train.py +``` + +The default training configuration is deliberately modest so the example is easy +to inspect. Increase `SIMULTANEOUS_ROLLOUTS`, `TRAIN_STEPS`, or the base model +when running on a larger GPU. diff --git a/examples/art_e/__init__.py b/examples/art_e/__init__.py new file mode 100644 index 000000000..ef66a8b65 --- /dev/null +++ b/examples/art_e/__init__.py @@ -0,0 +1 @@ +"""Lightweight ART-E email search example.""" diff --git a/examples/art_e/rollout.py b/examples/art_e/rollout.py new file mode 100644 index 000000000..9e49bbcdc --- /dev/null +++ b/examples/art_e/rollout.py @@ -0,0 +1,191 @@ +from __future__ import annotations + +import asyncio +import json +import os +import random + +from dotenv import load_dotenv +import openai +import requests +from scenarios import ( + SCENARIOS, + EmailScenario, + parse_json_command, + read_email, + score_answer, + search_emails, +) + +import art + +load_dotenv() + +MAX_TURNS = 6 + +SYSTEM_PROMPT = """You are ART-E, an email research agent. + +You need to answer the user's question by searching and reading their inbox. + +Use exactly one command per assistant message: + +{"keywords":["keyword"],"sent_before":"YYYY-MM-DD"} +{"message_id":"message-id"} +{"answer":"final answer","reference_message_ids":["message-id"]} + +Rules: +- Search before answering unless the answer is already present in the context. +- Read a message before citing it. +- Cite only message IDs that support your answer. +- Keep final answers concise and factual. +""" + + +def tool_message(payload: object) -> dict[str, str]: + return { + "role": "user", + "content": "Tool result:\n" + json.dumps(payload, indent=2, sort_keys=True), + } + + +@art.retry(exceptions=(openai.LengthFinishReasonError, requests.ReadTimeout)) +async def rollout( + model: art.Model, + scenario: EmailScenario, + step: int = 0, + is_validation: bool = False, + verbose: bool = False, +) -> art.Trajectory: + trajectory = art.Trajectory( + messages_and_choices=[ + {"role": "system", "content": SYSTEM_PROMPT}, + { + "role": "user", + "content": ( + f"Inbox: {scenario.inbox_address}\n" + f"Today: {scenario.query_date}\n" + f"Question: {scenario.question}" + ), + }, + ], + metadata={ + "scenario_id": scenario.id, + "step": step, + "validation": is_validation, + }, + reward=0, + ) + + client = model.openai_client() + searched = False + read_message_ids: set[str] = set() + + for _ in range(MAX_TURNS): + completion = await client.chat.completions.create( + max_completion_tokens=256, + messages=trajectory.messages(), + model=model.get_inference_name(), + ) + choice = completion.choices[0] + content = choice.message.content or "" + trajectory.messages_and_choices.append(choice) + + if verbose: + print(content) + + if search_payload := parse_json_command(content, "search"): + keywords = search_payload.get("keywords", []) + if not isinstance(keywords, list): + trajectory.reward = -0.25 + trajectory.metrics["invalid_command"] = 1 + break + sent_before = search_payload.get("sent_before", scenario.query_date) + results = search_emails( + scenario, + [str(keyword) for keyword in keywords], + sent_before=str(sent_before) if sent_before else None, + ) + searched = True + trajectory.messages_and_choices.append( + tool_message({"search_results": results}) + ) + continue + + if read_payload := parse_json_command(content, "read"): + message_id = read_payload.get("message_id") + if not isinstance(message_id, str): + trajectory.reward = -0.25 + trajectory.metrics["invalid_command"] = 1 + break + email = read_email(scenario, message_id) + if email is None: + trajectory.messages_and_choices.append( + tool_message({"error": f"Message not found: {message_id}"}) + ) + else: + read_message_ids.add(message_id) + trajectory.messages_and_choices.append(tool_message({"email": email})) + continue + + if answer_payload := parse_json_command(content, "answer"): + answer = answer_payload.get("answer", "") + reference_message_ids = answer_payload.get("reference_message_ids", []) + if not isinstance(answer, str) or not isinstance( + reference_message_ids, list + ): + trajectory.reward = -0.25 + trajectory.metrics["invalid_command"] = 1 + break + + references = [str(message_id) for message_id in reference_message_ids] + reward, metrics = score_answer(scenario, answer, references) + unread_citation = any( + message_id not in read_message_ids for message_id in references + ) + if unread_citation: + reward *= 0.5 + trajectory.reward = reward + trajectory.metrics.update(metrics) + trajectory.metrics["searched"] = float(searched) + trajectory.metrics["unread_citation"] = float(unread_citation) + break + + trajectory.messages_and_choices.append( + tool_message( + { + "error": ( + "Invalid command. Use one of , , or " + " with a JSON payload." + ) + } + ) + ) + else: + trajectory.reward = -0.1 + trajectory.metrics["ran_out_of_turns"] = 1 + + return trajectory + + +if __name__ == "__main__": + random.seed(42) + + smoke_test_model = art.Model( + name="gpt-4o-mini", + project="art-e", + inference_model_name="openai/gpt-4o-mini", + inference_base_url="https://openrouter.ai/api/v1", + inference_api_key=os.getenv("OPENROUTER_API_KEY"), + ) + + async def main() -> None: + trajectory = await rollout( + smoke_test_model, + random.choice(SCENARIOS), + is_validation=True, + verbose=True, + ) + print("reward:", trajectory.reward) + print("metrics:", trajectory.metrics) + + asyncio.run(main()) diff --git a/examples/art_e/scenarios.py b/examples/art_e/scenarios.py new file mode 100644 index 000000000..775404040 --- /dev/null +++ b/examples/art_e/scenarios.py @@ -0,0 +1,236 @@ +from __future__ import annotations + +from dataclasses import dataclass +from datetime import date +import json +import re +from typing import Any, TypedDict + + +class EmailMessage(TypedDict): + id: str + sender: str + recipients: list[str] + sent_at: str + subject: str + body: str + + +class SearchResult(TypedDict): + id: str + sender: str + sent_at: str + subject: str + snippet: str + + +@dataclass(frozen=True) +class EmailScenario: + id: str + inbox_address: str + query_date: str + question: str + answer: str + reference_message_ids: tuple[str, ...] + inbox: tuple[EmailMessage, ...] + + +SCENARIOS: tuple[EmailScenario, ...] = ( + EmailScenario( + id="quarterly-budget-owner", + inbox_address="alex@acme.test", + query_date="2026-03-15", + question=( + "Who owns the quarterly budget deck, and when did they say the " + "draft would be ready?" + ), + answer="Maya owns the quarterly budget deck and said the draft would be ready by Friday.", + reference_message_ids=("msg-budget-2",), + inbox=( + { + "id": "msg-budget-1", + "sender": "nora@acme.test", + "recipients": ["alex@acme.test"], + "sent_at": "2026-03-01", + "subject": "Budget planning kickoff", + "body": "Let's collect assumptions for the quarterly budget review.", + }, + { + "id": "msg-budget-2", + "sender": "maya@acme.test", + "recipients": ["alex@acme.test"], + "sent_at": "2026-03-08", + "subject": "Quarterly budget deck owner", + "body": ( + "I will own the quarterly budget deck. The draft will be " + "ready by Friday so finance can review it before Monday." + ), + }, + { + "id": "msg-budget-3", + "sender": "finance@acme.test", + "recipients": ["alex@acme.test"], + "sent_at": "2026-03-12", + "subject": "Reminder: travel budget", + "body": "Please submit travel budget updates before the end of the month.", + }, + ), + ), + EmailScenario( + id="customer-escalation-time", + inbox_address="alex@acme.test", + query_date="2026-04-05", + question=( + "What time is the Northwind escalation call, and which customer " + "issue should be discussed first?" + ), + answer="The Northwind escalation call is at 3 PM UTC, and the login outage should be discussed first.", + reference_message_ids=("msg-northwind-2",), + inbox=( + { + "id": "msg-northwind-1", + "sender": "support@acme.test", + "recipients": ["alex@acme.test"], + "sent_at": "2026-04-01", + "subject": "Northwind weekly notes", + "body": "Northwind asked for the usual weekly usage report.", + }, + { + "id": "msg-northwind-2", + "sender": "sam@acme.test", + "recipients": ["alex@acme.test"], + "sent_at": "2026-04-03", + "subject": "Northwind escalation call", + "body": ( + "The Northwind escalation call is at 3 PM UTC. Please " + "discuss the login outage first, then the reporting delay." + ), + }, + { + "id": "msg-northwind-3", + "sender": "calendar@acme.test", + "recipients": ["alex@acme.test"], + "sent_at": "2026-04-04", + "subject": "Daily standup moved", + "body": "Daily standup is moving to 10 AM local time next week.", + }, + ), + ), + EmailScenario( + id="contract-renewal-discount", + inbox_address="alex@acme.test", + query_date="2026-04-20", + question="What renewal discount did Finch Labs approve?", + answer="Finch Labs approved a 12% renewal discount.", + reference_message_ids=("msg-finch-3",), + inbox=( + { + "id": "msg-finch-1", + "sender": "sales@acme.test", + "recipients": ["alex@acme.test"], + "sent_at": "2026-04-09", + "subject": "Finch Labs renewal", + "body": "Finch Labs is reviewing renewal pricing this week.", + }, + { + "id": "msg-finch-2", + "sender": "legal@acme.test", + "recipients": ["alex@acme.test"], + "sent_at": "2026-04-11", + "subject": "Finch Labs contract language", + "body": "The renewal contract language is approved from legal.", + }, + { + "id": "msg-finch-3", + "sender": "riley@finch.test", + "recipients": ["alex@acme.test"], + "sent_at": "2026-04-18", + "subject": "Re: Finch Labs renewal", + "body": "We approved the 12% renewal discount. Please send the final order form.", + }, + ), + ), +) + + +def search_emails( + scenario: EmailScenario, + keywords: list[str], + sent_before: str | None = None, + limit: int = 5, +) -> list[SearchResult]: + normalized_keywords = [keyword.lower() for keyword in keywords if keyword.strip()] + try: + before = date.fromisoformat(sent_before) if sent_before else None + except ValueError: + before = None + results: list[SearchResult] = [] + + for email in scenario.inbox: + sent_at = date.fromisoformat(email["sent_at"]) + if before and sent_at >= before: + continue + + haystack = " ".join( + [email["sender"], email["subject"], email["body"]] + ).lower() + if normalized_keywords and not all( + keyword in haystack for keyword in normalized_keywords + ): + continue + + results.append( + { + "id": email["id"], + "sender": email["sender"], + "sent_at": email["sent_at"], + "subject": email["subject"], + "snippet": email["body"][:160], + } + ) + + return results[:limit] + + +def read_email(scenario: EmailScenario, message_id: str) -> EmailMessage | None: + for email in scenario.inbox: + if email["id"] == message_id: + return email + return None + + +def parse_json_command(content: str, tag: str) -> dict[str, Any] | None: + match = re.search(rf"<{tag}>\s*(\{{.*?\}})\s*", content, re.DOTALL) + if not match: + return None + try: + payload = json.loads(match.group(1)) + except json.JSONDecodeError: + return None + return payload if isinstance(payload, dict) else None + + +def normalize_text(value: str) -> str: + value = value.lower() + value = re.sub(r"[^a-z0-9% ]+", " ", value) + return re.sub(r"\s+", " ", value).strip() + + +def score_answer( + scenario: EmailScenario, + answer: str, + reference_message_ids: list[str], +) -> tuple[float, dict[str, float]]: + normalized_answer = normalize_text(answer) + expected_answer = normalize_text(scenario.answer) + + answer_score = 1.0 if expected_answer in normalized_answer else 0.0 + expected_refs = set(scenario.reference_message_ids) + provided_refs = set(reference_message_ids) + citation_score = 1.0 if expected_refs.issubset(provided_refs) else 0.0 + reward = (0.75 * answer_score) + (0.25 * citation_score) + + return reward, { + "answer_correct": answer_score, + "citations_correct": citation_score, + } diff --git a/examples/art_e/train.py b/examples/art_e/train.py new file mode 100644 index 000000000..e4a2f7040 --- /dev/null +++ b/examples/art_e/train.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +import asyncio +import random + +from dotenv import load_dotenv +from rollout import rollout +from scenarios import SCENARIOS + +import art +from art.local import LocalBackend + +load_dotenv() + +random.seed(42) + +TRAIN_STEPS = 30 +SIMULTANEOUS_ROLLOUTS = 12 +VALIDATION_ROLLOUTS = 3 + + +async def train() -> None: + backend = LocalBackend() + + model = art.TrainableModel( + name="art-e-email-search-001", + project="art-e", + base_model="Qwen/Qwen2.5-3B-Instruct", + ) + + await model.register(backend) + + scenarios = list(SCENARIOS) + for step in range(await model.get_step(), TRAIN_STEPS): + random.shuffle(scenarios) + + train_groups = await art.gather_trajectory_groups( + ( + art.TrajectoryGroup( + rollout(model, scenario, step=step, is_validation=False) + for _ in range(SIMULTANEOUS_ROLLOUTS) + ) + for scenario in scenarios + ), + pbar_desc="train", + max_exceptions=10, + ) + + val_groups = await art.gather_trajectory_groups( + ( + art.TrajectoryGroup( + rollout(model, scenario, step=step, is_validation=True) + for _ in range(VALIDATION_ROLLOUTS) + ) + for scenario in SCENARIOS + ), + pbar_desc="val", + max_exceptions=10, + ) + + await model.log(val_groups) + await model.delete_checkpoints() + result = await backend.train(model, train_groups, learning_rate=1e-5) + await model.log( + train_groups, + metrics=result.metrics, + step=result.step, + split="train", + ) + + +if __name__ == "__main__": + asyncio.run(train()) diff --git a/tests/test_art_e_example.py b/tests/test_art_e_example.py new file mode 100644 index 000000000..f61c39ded --- /dev/null +++ b/tests/test_art_e_example.py @@ -0,0 +1,41 @@ +from examples.art_e.scenarios import ( + SCENARIOS, + parse_json_command, + read_email, + score_answer, + search_emails, +) + + +def test_search_emails_finds_matching_message() -> None: + scenario = SCENARIOS[0] + + results = search_emails(scenario, ["quarterly", "deck"]) + + assert [result["id"] for result in results] == ["msg-budget-2"] + + +def test_read_email_returns_message_by_id() -> None: + scenario = SCENARIOS[1] + + email = read_email(scenario, "msg-northwind-2") + + assert email is not None + assert "3 PM UTC" in email["body"] + + +def test_score_answer_rewards_answer_and_citation() -> None: + scenario = SCENARIOS[2] + + reward, metrics = score_answer( + scenario, + "Finch Labs approved a 12% renewal discount.", + ["msg-finch-3"], + ) + + assert reward == 1.0 + assert metrics == {"answer_correct": 1.0, "citations_correct": 1.0} + + +def test_parse_json_command_ignores_invalid_json() -> None: + assert parse_json_command("{invalid}", "search") is None