Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
PHONE_AGENT_API_KEY: API key for model authentication (default: EMPTY)
PHONE_AGENT_MAX_STEPS: Maximum steps per task (default: 100)
PHONE_AGENT_DEVICE_ID: ADB device ID for multi-device setups
PHONE_AGENT_SYSTEM_PROMPT: Override system prompt text
PHONE_AGENT_SYSTEM_PROMPT_FILE: Path to UTF-8 system prompt file
"""

import argparse
Expand Down Expand Up @@ -368,6 +370,9 @@ def parse_args() -> argparse.Namespace:
# Use API key for authentication
python main.py --apikey sk-xxxxx

# Use a custom system prompt
python main.py --system-prompt-file ./prompt.txt "Open the Android launcher"

# Run with specific device
python main.py --device-id emulator-5554

Expand Down Expand Up @@ -430,6 +435,20 @@ def parse_args() -> argparse.Namespace:
help="Maximum steps per task",
)

parser.add_argument(
"--system-prompt",
type=str,
default=os.getenv("PHONE_AGENT_SYSTEM_PROMPT"),
help="Override system prompt text",
)

parser.add_argument(
"--system-prompt-file",
type=str,
default=os.getenv("PHONE_AGENT_SYSTEM_PROMPT_FILE"),
help="Path to a UTF-8 file containing the system prompt override",
)

# Device options
parser.add_argument(
"--device-id",
Expand Down Expand Up @@ -524,6 +543,22 @@ def parse_args() -> argparse.Namespace:
return parser.parse_args()


def resolve_system_prompt(
system_prompt: str | None, system_prompt_file: str | None
) -> str | None:
"""Resolve an optional system prompt override from text or a UTF-8 file."""
if system_prompt and system_prompt_file:
raise ValueError(
"--system-prompt and --system-prompt-file cannot be used together"
)

if system_prompt_file:
with open(system_prompt_file, encoding="utf-8") as f:
return f.read()

return system_prompt


def handle_ios_device_commands(args) -> bool:
"""
Handle iOS device-related commands.
Expand Down Expand Up @@ -684,6 +719,13 @@ def handle_device_commands(args) -> bool:
def main():
"""Main entry point."""
args = parse_args()
try:
system_prompt = resolve_system_prompt(
args.system_prompt, args.system_prompt_file
)
except (OSError, ValueError) as e:
print(f"Error: {e}", file=sys.stderr)
sys.exit(1)

# Set device type globally based on args
if args.device_type == "adb":
Expand Down Expand Up @@ -760,6 +802,7 @@ def main():
device_id=args.device_id,
verbose=not args.quiet,
lang=args.lang,
system_prompt=system_prompt,
)

agent = IOSPhoneAgent(
Expand All @@ -773,6 +816,7 @@ def main():
device_id=args.device_id,
verbose=not args.quiet,
lang=args.lang,
system_prompt=system_prompt,
)

agent = PhoneAgent(
Expand Down
81 changes: 81 additions & 0 deletions tests/test_system_prompt_cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import sys
from argparse import Namespace

import pytest

import main


def test_parse_args_accepts_inline_system_prompt(monkeypatch):
monkeypatch.setattr(
sys,
"argv",
["main.py", "--system-prompt", "Use short Android launcher labels."],
)

args = main.parse_args()

assert args.system_prompt == "Use short Android launcher labels."
assert args.system_prompt_file is None


def test_resolve_system_prompt_reads_file(tmp_path):
prompt_file = tmp_path / "prompt.txt"
prompt_file.write_text("Report each launcher step to the user.", encoding="utf-8")

system_prompt = main.resolve_system_prompt(None, str(prompt_file))

assert system_prompt == "Report each launcher step to the user."


def test_resolve_system_prompt_rejects_inline_and_file(tmp_path):
prompt_file = tmp_path / "prompt.txt"
prompt_file.write_text("file prompt", encoding="utf-8")

with pytest.raises(ValueError, match="--system-prompt"):
main.resolve_system_prompt("inline prompt", str(prompt_file))


@pytest.mark.parametrize("device_type", ["adb", "ios"])
def test_main_passes_custom_system_prompt_to_agent_config(monkeypatch, device_type):
captured = {}
args = Namespace(
apikey="EMPTY",
base_url="http://localhost:8000/v1",
device_id=None,
device_type=device_type,
lang="en",
list_apps=False,
max_steps=3,
model="autoglm-phone-9b",
quiet=True,
system_prompt="Only use short launcher labels.",
system_prompt_file=None,
task="Open the launcher",
wda_url="http://localhost:8100",
)

class FakeAgent:
def __init__(self, model_config, agent_config):
captured["agent_config"] = agent_config

def run(self, task):
return "ok"

class FakeDeviceFactory:
def list_devices(self):
return []

monkeypatch.setattr(main, "parse_args", lambda: args)
monkeypatch.setattr(main, "handle_device_commands", lambda _args: False)
monkeypatch.setattr(main, "check_system_requirements", lambda *_, **__: True)
monkeypatch.setattr(main, "check_model_api", lambda *_, **__: True)
monkeypatch.setattr(main, "set_device_type", lambda _device_type: None)
monkeypatch.setattr(main, "get_device_factory", lambda: FakeDeviceFactory())
monkeypatch.setattr(main, "list_ios_devices", lambda: [])
monkeypatch.setattr(main, "PhoneAgent", FakeAgent)
monkeypatch.setattr(main, "IOSPhoneAgent", FakeAgent)

main.main()

assert captured["agent_config"].system_prompt == "Only use short launcher labels."