From cca2bab4a4759e81478102711584c6b9be29ed6a Mon Sep 17 00:00:00 2001 From: rustopian <96253492+rustopian@users.noreply.github.com> Date: Sat, 23 May 2026 13:52:15 +0100 Subject: [PATCH 1/6] Add core angr MCP analysis tools --- CHANGELOG.md | 13 ++ README.md | 10 +- angr_decompile.py | 444 ++++++++++++++++++++++++++++++++++++++++--- bridge_mcp_ghidra.py | 281 +++++++++++++++++++++++++++ tests/test_bridge.py | 180 ++++++++++++++++++ 5 files changed, 897 insertions(+), 31 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index bfadf4a9..5d296bcb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,19 @@ This is a fork of [LaurieWired/GhidraMCP](https://github.com/LaurieWired/GhidraM The "Unreleased" section accumulates changes since the upstream `v1-4` release (commit `27f316f`). +## [Unreleased] + +### Added +- **Expanded core angr MCP capabilities**: added `angr_reachability`, + `angr_cfg_summary`, `angr_callgraph_summary`, `angr_lift_block`, + `angr_solve_constraints_at`, and `angr_compare_decompilers`. +- **Richer symbolic solving**: `angr_solve_constraints_at` reaches a target + address, applies JSON-described register/memory/stdin/argv constraints, and + evaluates requested registers, memory, stdin, and symbolic inputs. +- **IR and graph inspection**: MCP callers can now request VEX/AIL block + lifting, static CFG reachability, CFG summaries, and callgraph edge samples + without requiring AngryGhidra. + ## [1.6.0] - 2026-05-23 ### Added diff --git a/README.md b/README.md index 9a4e13f8..12225d9c 100644 --- a/README.md +++ b/README.md @@ -49,9 +49,13 @@ tools. - For Solana/eBPF ELFs, pass `pcode_language="eBPF:LE:64:default"` or let the bridge infer it from Ghidra's program language id. The helper patches CLE at runtime for Solana's e_machine 263 and uses angr's p-code engine. -- `angr_symbolic_find` exposes core angr path search without AngryGhidra. It can - find a path to a target address, avoid addresses, and solve symbolic - stdin/argv, memory, and register values. +- Core angr tools do not require AngryGhidra: + `angr_symbolic_find` searches for a path to a target address; + `angr_solve_constraints_at` adds JSON-described constraints at the found + state and evaluates requested values; `angr_reachability` checks static CFG + reachability; `angr_cfg_summary` and `angr_callgraph_summary` summarize + recovered graph structure; `angr_lift_block` lifts a block to VEX/AIL; and + `angr_compare_decompilers` batches Ghidra-vs-Oxidizer decompiler output. - AngryGhidra support is optional. `angryghidra_*` tools look for `ANGRYGHIDRA_SCRIPT`, `ANGRYGHIDRA_HOME/angryghidra_script/angryghidra.py`, or a sibling `AngryGhidra/angryghidra_script/angryghidra.py`. If none is diff --git a/angr_decompile.py b/angr_decompile.py index 792d6419..f62e42a6 100644 --- a/angr_decompile.py +++ b/angr_decompile.py @@ -85,25 +85,65 @@ def parse_csv_ints(value: str) -> list[int]: return [int(part.strip(), 0) for part in value.split(",") if part.strip()] -def run_check(args: argparse.Namespace) -> int: - import angr +def parse_csv_strings(value: str) -> list[str]: + if not value: + return [] + return [part.strip() for part in value.split(",") if part.strip()] - print(f"python: {sys.executable}") - print(f"angr: {angr.__version__}") - if args.binary: - project = make_project(args.binary, args.pcode_language, rust=False, base_address=args.base_address) - print(f"binary: {args.binary}") - print(f"arch: {project.arch}") - print(f"min_addr: {project.loader.main_object.min_addr:#x}") - print(f"max_addr: {project.loader.main_object.max_addr:#x}") - return 0 +def hex_addr(value: int) -> str: + return f"0x{value:x}" -def run_symbolic_find(args: argparse.Namespace) -> int: - import angr + +def build_cfg(project, args, function_starts: list[int] | None = None): + cfg_kwargs = { + "normalize": True, + "force_complete_scan": args.complete_cfg, + } + if function_starts: + cfg_kwargs["function_starts"] = function_starts + cfg_kwargs["force_complete_scan"] = False + return project.analyses.CFGFast(**cfg_kwargs) + + +def get_cfg_node(cfg, address: int): + return cfg.model.get_any_node(address, anyaddr=True) + + +def function_label(project, address: int | None) -> str: + if address is None: + return "unknown" + func = project.kb.functions.get_by_addr(address) + if func is None: + return hex_addr(address) + return f"{func.name} @ {hex_addr(func.addr)}" + + +def normalize_register_name(project, reg_name: str) -> str: + if reg_name in project.arch.registers: + return reg_name + lowered = reg_name.lower() + if lowered in project.arch.registers: + return lowered + uppered = reg_name.upper() + if uppered in project.arch.registers: + return uppered + return reg_name + + +def make_block(project, address: int, block_size: int = 0, num_inst: int = 0): + kwargs = {} + if block_size > 0: + kwargs["size"] = block_size + if num_inst > 0: + kwargs["num_inst"] = num_inst + return project.factory.block(address, **kwargs) + + +def setup_symbolic_execution(args: argparse.Namespace, target_value: str): import claripy - target_addr = parse_address(args.symbolic_find) + target_addr = parse_address(target_value) avoid = [parse_address(value) for value in args.avoid_address.split(",") if value.strip()] project = make_project(args.binary, args.pcode_language, rust=False, base_address=args.base_address) @@ -111,7 +151,7 @@ def run_symbolic_find(args: argparse.Namespace) -> int: symbolic_argv = [] for index, length in enumerate(parse_csv_ints(args.argv_bytes), start=1): sym_arg = claripy.BVS(f"argv{index}", length * 8) - symbolic_argv.append((index, sym_arg)) + symbolic_argv.append((index, length, sym_arg)) argv.append(sym_arg) symbolic_stdin = None @@ -145,33 +185,140 @@ def run_symbolic_find(args: argparse.Namespace) -> int: mem_addr = parse_address(str(addr)) if isinstance(value, str): concrete = int(value, 0) - byte_len = max(1, (concrete.bit_length() + 7) // 8) else: concrete = int(value) - byte_len = max(1, (concrete.bit_length() + 7) // 8) + byte_len = max(1, (concrete.bit_length() + 7) // 8) state.memory.store(mem_addr, concrete, size=byte_len) symbolic_registers = {} registers = parse_json_map(args.registers_json, "registers_json") for reg_name, value in registers.items(): + reg_name = normalize_register_name(project, reg_name) if isinstance(value, str) and value.startswith("sv"): byte_len = int(value[2:], 0) sym_reg = claripy.BVS(f"reg_{reg_name}", byte_len * 8) - symbolic_registers[reg_name] = sym_reg + symbolic_registers[reg_name] = (byte_len, sym_reg) setattr(state.regs, reg_name, sym_reg) else: setattr(state.regs, reg_name, int(str(value), 0)) + symbols = { + "stdin": (args.stdin_bytes, symbolic_stdin), + "argv": symbolic_argv, + "memory": symbolic_memory, + "registers": symbolic_registers, + } + return project, state, symbols, target_addr, avoid + + +def run_explorer(project, state, target_addr: int, avoid: list[int], max_steps: int): + import angr + simgr = project.factory.simulation_manager(state) explorer_kwargs = {"find": target_addr} if avoid: explorer_kwargs["avoid"] = avoid simgr.use_technique(angr.exploration_techniques.Explorer(**explorer_kwargs)) - if args.max_steps > 0: - simgr.run(n=args.max_steps) + if max_steps > 0: + simgr.run(n=max_steps) else: simgr.run() + return simgr + + +def describe_symbolic_solution(found, symbols) -> list[str]: + lines = [] + stdin_len, symbolic_stdin = symbols["stdin"] + if symbolic_stdin is not None: + lines.append(f"stdin = {found.solver.eval(symbolic_stdin, cast_to=bytes)!r}") + for index, _length, sym_arg in symbols["argv"]: + lines.append(f"argv[{index}] = {found.solver.eval(sym_arg, cast_to=bytes)!r}") + for mem_addr, (mem_len, sym_mem) in symbols["memory"].items(): + lines.append(f"mem[{hex_addr(mem_addr)}:{mem_len}] = {found.solver.eval(sym_mem, cast_to=bytes)!r}") + for reg_name, (_byte_len, sym_reg) in symbols["registers"].items(): + lines.append(f"reg[{reg_name}] = {found.solver.eval(sym_reg):#x}") + return lines + + +def get_symbolic_ast(state, symbols, item: dict): + target_type = item.get("type") + if target_type == "reg": + reg_name = normalize_register_name(state.project, item["name"]) + return getattr(state.regs, reg_name) + if target_type == "mem": + return state.memory.load(parse_address(str(item["address"])), int(str(item["length"]), 0)) + if target_type == "stdin": + _length, symbolic_stdin = symbols["stdin"] + if symbolic_stdin is None: + raise ValueError("constraint references stdin, but stdin is not symbolic") + return symbolic_stdin + if target_type == "argv": + index = int(str(item["index"]), 0) + for arg_index, _length, sym_arg in symbols["argv"]: + if arg_index == index: + return sym_arg + raise ValueError(f"constraint references argv[{index}], but it is not symbolic") + raise ValueError(f"unsupported constraint type: {target_type!r}") + + +def concrete_bvv(state, item: dict, bits: int): + import claripy + + if "value_hex" in item: + raw = bytes.fromhex(str(item["value_hex"]).removeprefix("0x")) + return claripy.BVV(raw) + if "value_bytes" in item: + return claripy.BVV(str(item["value_bytes"]).encode()) + if "value" not in item: + raise ValueError("constraint is missing value, value_hex, or value_bytes") + return claripy.BVV(int(str(item["value"]), 0), bits) + + +def constraint_expr(state, symbols, item: dict): + ast = get_symbolic_ast(state, symbols, item) + value = concrete_bvv(state, item, ast.size()) + op = item.get("op", "==") + if op in {"==", "eq"}: + return ast == value + if op in {"!=", "ne"}: + return ast != value + if op in {"<", "ult"}: + return ast < value + if op in {"<=", "ule"}: + return ast <= value + if op in {">", "ugt"}: + return ast > value + if op in {">=", "uge"}: + return ast >= value + if op == "slt": + return ast.SLT(value) + if op == "sle": + return ast.SLE(value) + if op == "sgt": + return ast.SGT(value) + if op == "sge": + return ast.SGE(value) + raise ValueError(f"unsupported constraint op: {op!r}") + + +def run_check(args: argparse.Namespace) -> int: + import angr + + print(f"python: {sys.executable}") + print(f"angr: {angr.__version__}") + if args.binary: + project = make_project(args.binary, args.pcode_language, rust=False, base_address=args.base_address) + print(f"binary: {args.binary}") + print(f"arch: {project.arch}") + print(f"min_addr: {project.loader.main_object.min_addr:#x}") + print(f"max_addr: {project.loader.main_object.max_addr:#x}") + return 0 + + +def run_symbolic_find(args: argparse.Namespace) -> int: + project, state, symbols, target_addr, avoid = setup_symbolic_execution(args, args.symbolic_find) + simgr = run_explorer(project, state, target_addr, avoid, args.max_steps) print(f"binary: {args.binary}") print(f"arch: {project.arch}") @@ -195,15 +342,214 @@ def run_symbolic_find(args: argparse.Namespace) -> int: for addr in found.history.bbl_addrs.hardcopy: print(f" {addr:#x}") - if symbolic_stdin is not None: - print(f"stdin = {found.solver.eval(symbolic_stdin, cast_to=bytes)!r}") - for index, sym_arg in symbolic_argv: - print(f"argv[{index}] = {found.solver.eval(sym_arg, cast_to=bytes)!r}") - for mem_addr, (mem_len, sym_mem) in symbolic_memory.items(): - print(f"mem[{mem_addr:#x}:{mem_len}] = {found.solver.eval(sym_mem, cast_to=bytes)!r}") - for reg_name, sym_reg in symbolic_registers.items(): - print(f"reg[{reg_name}] = {found.solver.eval(sym_reg):#x}") + for line in describe_symbolic_solution(found, symbols): + print(line) + + return 0 + + +def run_solve_at(args: argparse.Namespace) -> int: + project, state, symbols, target_addr, avoid = setup_symbolic_execution(args, args.solve_at) + simgr = run_explorer(project, state, target_addr, avoid, args.max_steps) + + print(f"binary: {args.binary}") + print(f"arch: {project.arch}") + print(f"target: {target_addr:#x}") + print(f"max_steps: {args.max_steps}") + + if not simgr.found: + print("found: false") + print(f"active_states: {len(simgr.active)}") + print(f"deadended_states: {len(simgr.deadended)}") + print(f"avoid_states: {len(simgr.avoid)}") + return 2 + + found = simgr.found[0] + print("found: true") + + if not args.constraints_json: + parsed_constraints = [] + else: + decoded_constraints = json.loads(args.constraints_json) + if isinstance(decoded_constraints, dict): + parsed_constraints = decoded_constraints.get("constraints", []) + else: + parsed_constraints = decoded_constraints + if not isinstance(parsed_constraints, list): + raise ValueError("constraints_json constraints must be a list") + + for item in parsed_constraints: + if not isinstance(item, dict): + raise ValueError("each constraint must be a JSON object") + found.add_constraints(constraint_expr(found, symbols, item)) + + satisfiable = found.solver.satisfiable() + print(f"satisfiable: {str(satisfiable).lower()}") + if not satisfiable: + return 3 + + for line in describe_symbolic_solution(found, symbols): + print(line) + + eval_registers = parse_csv_strings(args.eval_registers) + for reg_name in eval_registers: + reg_name = normalize_register_name(project, reg_name) + print(f"eval_reg[{reg_name}] = {found.solver.eval(getattr(found.regs, reg_name)):#x}") + + for addr, length in parse_json_map(args.eval_memory_json, "eval_memory_json").items(): + mem_addr = parse_address(str(addr)) + mem_len = int(str(length), 0) + value = found.solver.eval(found.memory.load(mem_addr, mem_len), cast_to=bytes) + print(f"eval_mem[{hex_addr(mem_addr)}:{mem_len}] = {value!r}") + + if args.eval_stdin_bytes > 0: + stdin_len, symbolic_stdin = symbols["stdin"] + if symbolic_stdin is None: + print("eval_stdin = ") + else: + byte_len = min(stdin_len, args.eval_stdin_bytes) + print(f"eval_stdin[{byte_len}] = {found.solver.eval(symbolic_stdin, cast_to=bytes)[:byte_len]!r}") + + return 0 + + +def run_reachability(args: argparse.Namespace) -> int: + source = parse_address(args.reachability_from) + target = parse_address(args.reachability_to) + project = make_project(args.binary, args.pcode_language, rust=False, base_address=args.base_address) + cfg = build_cfg(project, args, function_starts=[source]) + source_node = get_cfg_node(cfg, source) + target_node = get_cfg_node(cfg, target) + + print(f"binary: {args.binary}") + print(f"arch: {project.arch}") + print(f"source: {source:#x}") + print(f"target: {target:#x}") + print(f"cfg_nodes: {cfg.graph.number_of_nodes()}") + print(f"cfg_edges: {cfg.graph.number_of_edges()}") + + if source_node is None: + print("reachable: false") + print("reason: source node not found") + return 2 + if target_node is None: + print("reachable: false") + print("reason: target node not found") + return 2 + + queue = [source_node] + predecessor = {source_node: None} + while queue: + node = queue.pop(0) + if node == target_node: + break + for successor in cfg.graph.successors(node): + if successor not in predecessor: + predecessor[successor] = node + queue.append(successor) + + if target_node not in predecessor: + print("reachable: false") + return 0 + + path = [] + node = target_node + while node is not None: + path.append(node) + node = predecessor[node] + path.reverse() + + print("reachable: true") + print(f"path_length: {len(path)}") + if args.include_path: + print("path:") + for node in path[: args.summary_limit]: + print(f" {node.addr:#x}") + if len(path) > args.summary_limit: + print(f" ... {len(path) - args.summary_limit} more nodes") + return 0 + + +def run_cfg_summary(args: argparse.Namespace) -> int: + project = make_project(args.binary, args.pcode_language, rust=False, base_address=args.base_address) + function_addr = parse_address(args.function_address) if args.function_address else None + cfg = build_cfg(project, args, function_starts=[function_addr] if function_addr is not None else None) + + print(f"binary: {args.binary}") + print(f"arch: {project.arch}") + print(f"cfg_nodes: {cfg.graph.number_of_nodes()}") + print(f"cfg_edges: {cfg.graph.number_of_edges()}") + print(f"functions: {len(project.kb.functions)}") + + if function_addr is not None: + func = project.kb.functions.get_by_addr(function_addr) + if func is None: + print(f"function: not found @ {function_addr:#x}") + return 2 + print(f"function: {func.name} @ {func.addr:#x}") + block_addrs = sorted(func.block_addrs_set) + print(f"blocks: {len(block_addrs)}") + print("block_addresses:") + for addr in block_addrs[: args.summary_limit]: + print(f" {addr:#x}") + if len(block_addrs) > args.summary_limit: + print(f" ... {len(block_addrs) - args.summary_limit} more blocks") + call_sites = list(func.get_call_sites()) + print(f"call_sites: {len(call_sites)}") + for site in call_sites[: args.summary_limit]: + target = func.get_call_target(site) + print(f" {site:#x} -> {function_label(project, target)}") + return 0 + + print("functions_sample:") + for func in list(project.kb.functions.values())[: args.summary_limit]: + print(f" {func.addr:#x} {func.name} blocks={len(func.block_addrs_set)}") + return 0 + + +def run_callgraph_summary(args: argparse.Namespace) -> int: + project = make_project(args.binary, args.pcode_language, rust=False, base_address=args.base_address) + build_cfg(project, args) + callgraph = project.kb.functions.callgraph + print(f"binary: {args.binary}") + print(f"arch: {project.arch}") + print(f"functions: {callgraph.number_of_nodes()}") + print(f"calls: {callgraph.number_of_edges()}") + + edges = list(callgraph.edges()) + print("edges:") + for src, dst in edges[: args.summary_limit]: + print(f" {function_label(project, src)} -> {function_label(project, dst)}") + if len(edges) > args.summary_limit: + print(f" ... {len(edges) - args.summary_limit} more edges") + return 0 + + +def run_lift_block(args: argparse.Namespace) -> int: + from angr import ailment + from angr.ailment.manager import Manager + + address = parse_address(args.lift_block) + project = make_project(args.binary, args.pcode_language, rust=False, base_address=args.base_address) + block = make_block(project, address, args.block_size, args.num_inst) + print(f"binary: {args.binary}") + print(f"arch: {project.arch}") + print(f"block: {address:#x}") + print(f"size: {block.size}") + print(f"instruction_addresses: {', '.join(hex_addr(addr) for addr in block.instruction_addrs)}") + + if args.lift_format in {"vex", "both"}: + print() + print("VEX:") + print(block.vex) + + if args.lift_format in {"ail", "both"}: + print() + print("AIL:") + manager = Manager(arch=project.arch) + ail_block = ailment.IRSBConverter.convert(block.vex, manager) + print(ail_block.dbg_repr().rstrip()) return 0 @@ -276,6 +622,23 @@ def main() -> int: parser.add_argument("--memory-json", default="", help="JSON object mapping address to concrete integer/hex value") parser.add_argument("--registers-json", default="", help='JSON object mapping register names to values or "svN" symbolic byte lengths') parser.add_argument("--max-steps", type=int, default=10000, help="Maximum symbolic execution steps, or 0 for unbounded") + parser.add_argument("--solve-at", help="Find an address and solve/evaluate requested constraints there") + parser.add_argument("--constraints-json", default="", help="JSON list of constraints, or object with a constraints list") + parser.add_argument("--eval-registers", default="", help="Comma-separated register names to evaluate after solve-at") + parser.add_argument("--eval-memory-json", default="", help="JSON object mapping address to byte length to evaluate after solve-at") + parser.add_argument("--eval-stdin-bytes", type=int, default=0, help="Evaluate this many symbolic stdin bytes after solve-at") + parser.add_argument("--reachability-from", help="Source address for static CFG reachability") + parser.add_argument("--reachability-to", help="Target address for static CFG reachability") + parser.add_argument("--cfg-summary", action="store_true", help="Summarize angr CFGFast output") + parser.add_argument("--callgraph-summary", action="store_true", help="Summarize angr's recovered callgraph") + parser.add_argument("--function-address", default="", help="Optional function address for CFG summaries") + parser.add_argument("--complete-cfg", action="store_true", help="Force complete CFG scan instead of targeted/default scan") + parser.add_argument("--summary-limit", type=int, default=50, help="Maximum items to print in summaries") + parser.add_argument("--include-path", action="store_true", help="Include a static reachability path when found") + parser.add_argument("--lift-block", help="Lift a basic block at this address") + parser.add_argument("--lift-format", choices=["vex", "ail", "both"], default="both", help="IR format to print for --lift-block") + parser.add_argument("--block-size", type=int, default=0, help="Optional block size for lifting") + parser.add_argument("--num-inst", type=int, default=0, help="Optional instruction count for lifting") rust_group = parser.add_mutually_exclusive_group() rust_group.add_argument("--rust", dest="rust", action="store_true", default=True) rust_group.add_argument("--no-rust", dest="rust", action="store_false") @@ -295,6 +658,31 @@ def main() -> int: parser.error("--binary is required with --symbolic-find") return run_symbolic_find(args) + if args.solve_at: + if not args.binary: + parser.error("--binary is required with --solve-at") + return run_solve_at(args) + + if args.reachability_from or args.reachability_to: + if not args.binary or not args.reachability_from or not args.reachability_to: + parser.error("--binary, --reachability-from, and --reachability-to are required together") + return run_reachability(args) + + if args.cfg_summary: + if not args.binary: + parser.error("--binary is required with --cfg-summary") + return run_cfg_summary(args) + + if args.callgraph_summary: + if not args.binary: + parser.error("--binary is required with --callgraph-summary") + return run_callgraph_summary(args) + + if args.lift_block: + if not args.binary: + parser.error("--binary is required with --lift-block") + return run_lift_block(args) + if not args.binary or not args.address: parser.error("--binary and --address are required unless --check is used") diff --git a/bridge_mcp_ghidra.py b/bridge_mcp_ghidra.py index 41793a73..9c8d3081 100644 --- a/bridge_mcp_ghidra.py +++ b/bridge_mcp_ghidra.py @@ -202,6 +202,45 @@ def parse_optional_json(value: str, field_name: str): except json.JSONDecodeError as e: raise ValueError(f"{field_name} must be valid JSON: {e}") from e +def resolve_angr_defaults(binary_path: str = "", pcode_language: str = "") -> tuple[str, str]: + program_info = {} + if not binary_path or not pcode_language: + program_info = parse_key_value_lines(safe_get("program_info")) + if not binary_path: + binary_path = program_info.get("executable_path", "") + if not pcode_language: + pcode_language = infer_pcode_language(program_info.get("language_id")) + return binary_path, pcode_language + +def require_binary_path(binary_path: str) -> str: + if not binary_path: + return "No binary_path provided and Ghidra did not return an executable_path" + return "" + +def append_common_angr_args(args: list[str], pcode_language: str = "", base_address: str = "") -> None: + if pcode_language: + args.extend(["--pcode-language", pcode_language]) + if base_address: + args.extend(["--base-address", normalize_ghidra_address(base_address)]) + +def append_json_arg(args: list[str], option: str, value: str, field_name: str) -> str: + try: + parsed = parse_optional_json(value, field_name) + except ValueError as e: + return str(e) + if parsed is not None: + args.extend([option, json.dumps(parsed)]) + return "" + +def split_addresses(addresses: str, max_addresses: int) -> list[str]: + normalized = addresses.replace("\n", ",").replace(" ", ",") + result = [ + normalize_ghidra_address(address) + for address in normalized.split(",") + if address.strip() + ] + return result[:max(1, max_addresses)] + @mcp.tool() def list_methods(offset: int = 0, limit: int = 100) -> list: """ @@ -479,6 +518,248 @@ def angr_symbolic_find( return run_angr_helper(args, max(1, timeout)) +@mcp.tool() +def angr_solve_constraints_at( + address: str, + binary_path: str = "", + start_address: str = "", + avoid_addresses: str = "", + pcode_language: str = "", + base_address: str = "", + stdin_bytes: int = 0, + argv_bytes: str = "", + symbolic_memory_json: str = "", + memory_json: str = "", + registers_json: str = "", + constraints_json: str = "", + eval_registers: str = "", + eval_memory_json: str = "", + eval_stdin_bytes: int = 0, + timeout: int = 120, + max_steps: int = 10000, +) -> str: + """ + Find an execution path to an address, add constraints, and solve values. + + constraints_json accepts a JSON list (or {"constraints": [...]}) of objects + like {"type":"reg","name":"r1","op":"==","value":"0x10"} or + {"type":"mem","address":"0x2000","length":4,"op":"!=","value_hex":"00000000"}. + Supported types are reg, mem, stdin, and argv. + """ + binary_path, pcode_language = resolve_angr_defaults(binary_path, pcode_language) + missing = require_binary_path(binary_path) + if missing: + return missing + + args = [ + "--binary", binary_path, + "--solve-at", normalize_ghidra_address(address), + "--max-steps", str(max(0, max_steps)), + ] + if start_address: + args.extend(["--start-address", normalize_ghidra_address(start_address)]) + if avoid_addresses: + normalized_avoid = ",".join( + normalize_ghidra_address(address_part) + for address_part in avoid_addresses.split(",") + ) + args.extend(["--avoid-address", normalized_avoid]) + append_common_angr_args(args, pcode_language, base_address) + if stdin_bytes > 0: + args.extend(["--stdin-bytes", str(stdin_bytes)]) + if argv_bytes: + args.extend(["--argv-bytes", argv_bytes]) + for option, value, field_name in [ + ("--symbolic-memory-json", symbolic_memory_json, "symbolic_memory_json"), + ("--memory-json", memory_json, "memory_json"), + ("--registers-json", registers_json, "registers_json"), + ("--constraints-json", constraints_json, "constraints_json"), + ("--eval-memory-json", eval_memory_json, "eval_memory_json"), + ]: + error = append_json_arg(args, option, value, field_name) + if error: + return error + if eval_registers: + args.extend(["--eval-registers", eval_registers]) + if eval_stdin_bytes > 0: + args.extend(["--eval-stdin-bytes", str(eval_stdin_bytes)]) + + return run_angr_helper(args, max(1, timeout)) + +@mcp.tool() +def angr_reachability( + source_address: str, + target_address: str, + binary_path: str = "", + pcode_language: str = "", + base_address: str = "", + complete_cfg: bool = False, + include_path: bool = True, + summary_limit: int = 50, + timeout: int = 120, +) -> str: + """ + Use angr CFGFast to check static reachability from one address to another. + """ + binary_path, pcode_language = resolve_angr_defaults(binary_path, pcode_language) + missing = require_binary_path(binary_path) + if missing: + return missing + + args = [ + "--binary", binary_path, + "--reachability-from", normalize_ghidra_address(source_address), + "--reachability-to", normalize_ghidra_address(target_address), + "--summary-limit", str(max(1, summary_limit)), + ] + append_common_angr_args(args, pcode_language, base_address) + if complete_cfg: + args.append("--complete-cfg") + if include_path: + args.append("--include-path") + return run_angr_helper(args, max(1, timeout)) + +@mcp.tool() +def angr_cfg_summary( + binary_path: str = "", + function_address: str = "", + pcode_language: str = "", + base_address: str = "", + complete_cfg: bool = False, + summary_limit: int = 50, + timeout: int = 120, +) -> str: + """ + Summarize angr CFGFast output for the whole binary or a single function. + """ + binary_path, pcode_language = resolve_angr_defaults(binary_path, pcode_language) + missing = require_binary_path(binary_path) + if missing: + return missing + + args = [ + "--binary", binary_path, + "--cfg-summary", + "--summary-limit", str(max(1, summary_limit)), + ] + append_common_angr_args(args, pcode_language, base_address) + if function_address: + args.extend(["--function-address", normalize_ghidra_address(function_address)]) + if complete_cfg: + args.append("--complete-cfg") + return run_angr_helper(args, max(1, timeout)) + +@mcp.tool() +def angr_callgraph_summary( + binary_path: str = "", + pcode_language: str = "", + base_address: str = "", + complete_cfg: bool = False, + summary_limit: int = 100, + timeout: int = 180, +) -> str: + """ + Summarize angr's recovered callgraph edges. + """ + binary_path, pcode_language = resolve_angr_defaults(binary_path, pcode_language) + missing = require_binary_path(binary_path) + if missing: + return missing + + args = [ + "--binary", binary_path, + "--callgraph-summary", + "--summary-limit", str(max(1, summary_limit)), + ] + append_common_angr_args(args, pcode_language, base_address) + if complete_cfg: + args.append("--complete-cfg") + return run_angr_helper(args, max(1, timeout)) + +@mcp.tool() +def angr_lift_block( + address: str, + binary_path: str = "", + pcode_language: str = "", + base_address: str = "", + lift_format: str = "both", + block_size: int = 0, + num_inst: int = 0, + timeout: int = 60, +) -> str: + """ + Lift a basic block to VEX, AIL, or both. + """ + binary_path, pcode_language = resolve_angr_defaults(binary_path, pcode_language) + missing = require_binary_path(binary_path) + if missing: + return missing + if lift_format not in {"vex", "ail", "both"}: + return "lift_format must be one of: vex, ail, both" + + args = [ + "--binary", binary_path, + "--lift-block", normalize_ghidra_address(address), + "--lift-format", lift_format, + ] + append_common_angr_args(args, pcode_language, base_address) + if block_size > 0: + args.extend(["--block-size", str(block_size)]) + if num_inst > 0: + args.extend(["--num-inst", str(num_inst)]) + return run_angr_helper(args, max(1, timeout)) + +@mcp.tool() +def angr_compare_decompilers( + addresses: str, + binary_path: str = "", + pcode_language: str = "", + rust: bool = True, + run_rust_setup: bool = False, + timeout_per_function: int = 120, + max_functions: int = 10, +) -> str: + """ + Batch-compare Ghidra decompiler output with angr/Oxidizer output. + + addresses accepts comma, space, or newline-separated function entry + addresses. Results are returned in side-by-side text sections. + """ + binary_path, pcode_language = resolve_angr_defaults(binary_path, pcode_language) + missing = require_binary_path(binary_path) + if missing: + return missing + + selected_addresses = split_addresses(addresses, max_functions) + if not selected_addresses: + return "No addresses provided" + + sections = [] + per_function_timeout = max(1, min(timeout_per_function, TIMEOUT_DECOMPILE_MAX)) + for address in selected_addresses: + ghidra_output = "\n".join(safe_get( + "decompile_function", + {"address": address, "timeout": per_function_timeout}, + timeout=float(per_function_timeout), + )) + angr_args = ["--binary", binary_path, "--address", address] + if rust: + angr_args.append("--rust") + if not run_rust_setup: + angr_args.append("--skip-rust-setup") + else: + angr_args.append("--no-rust") + if pcode_language: + angr_args.extend(["--pcode-language", pcode_language]) + oxidizer_output = run_angr_helper(angr_args, per_function_timeout) + sections.append( + f"## {address}\n\n" + f"### Ghidra\n{ghidra_output}\n\n" + f"### angr/Oxidizer\n{oxidizer_output}" + ) + + return "\n\n".join(sections) + @mcp.tool() def angryghidra_check_setup() -> str: """ diff --git a/tests/test_bridge.py b/tests/test_bridge.py index 4edc8100..f8855cd5 100644 --- a/tests/test_bridge.py +++ b/tests/test_bridge.py @@ -309,6 +309,186 @@ def fake_run(args, timeout): "--registers-json", '{"r1": "sv8", "r2": "0x10"}', ] + def test_angr_solve_constraints_at_builds_rich_solver_args( + self, bridge_module, httpx_mock, monkeypatch): + httpx_mock.add_response( + url=_url("program_info"), + text="executable_path: /tmp/eternal.so\n" + "language_id: eBPF:LE:64:default") + calls = {} + + def fake_run(args, timeout): + calls["args"] = args + calls["timeout"] = timeout + return "satisfiable: true" + + monkeypatch.setattr(bridge_module, "run_angr_helper", fake_run) + out = bridge_module.angr_solve_constraints_at( + address="ram:00000180", + start_address="ram:00000120", + constraints_json='[{"type":"reg","name":"r1","op":"==","value":"0x10"}]', + eval_registers="r0,r1", + eval_memory_json='{"0x3000": 8}', + timeout=88, + max_steps=44) + + assert out == "satisfiable: true" + assert calls["timeout"] == 88 + assert calls["args"] == [ + "--binary", "/tmp/eternal.so", + "--solve-at", "0x180", + "--max-steps", "44", + "--start-address", "0x120", + "--pcode-language", "eBPF:LE:64:default", + "--constraints-json", '[{"type": "reg", "name": "r1", "op": "==", "value": "0x10"}]', + "--eval-memory-json", '{"0x3000": 8}', + "--eval-registers", "r0,r1", + ] + + def test_angr_reachability_builds_cfg_args( + self, bridge_module, httpx_mock, monkeypatch): + httpx_mock.add_response( + url=_url("program_info"), + text="executable_path: /tmp/eternal.so\n" + "language_id: eBPF:LE:64:default") + calls = {} + monkeypatch.setattr( + bridge_module, + "run_angr_helper", + lambda args, timeout: calls.setdefault("data", (args, timeout)) or "reachable: true") + + bridge_module.angr_reachability( + "ram:00000120", + "ram:00000180", + complete_cfg=True, + include_path=False, + summary_limit=7, + timeout=99) + + args, timeout = calls["data"] + assert timeout == 99 + assert args == [ + "--binary", "/tmp/eternal.so", + "--reachability-from", "0x120", + "--reachability-to", "0x180", + "--summary-limit", "7", + "--pcode-language", "eBPF:LE:64:default", + "--complete-cfg", + ] + + def test_angr_cfg_summary_builds_function_args( + self, bridge_module, httpx_mock, monkeypatch): + httpx_mock.add_response( + url=_url("program_info"), + text="executable_path: /tmp/eternal.so\n" + "language_id: eBPF:LE:64:default") + calls = {} + monkeypatch.setattr( + bridge_module, + "run_angr_helper", + lambda args, timeout: calls.setdefault("data", (args, timeout)) or "cfg") + + bridge_module.angr_cfg_summary( + function_address="ram:00000120", + summary_limit=5, + timeout=66) + + args, timeout = calls["data"] + assert timeout == 66 + assert args == [ + "--binary", "/tmp/eternal.so", + "--cfg-summary", + "--summary-limit", "5", + "--pcode-language", "eBPF:LE:64:default", + "--function-address", "0x120", + ] + + def test_angr_callgraph_summary_builds_args( + self, bridge_module, httpx_mock, monkeypatch): + httpx_mock.add_response( + url=_url("program_info"), + text="executable_path: /tmp/eternal.so\n" + "language_id: eBPF:LE:64:default") + calls = {} + monkeypatch.setattr( + bridge_module, + "run_angr_helper", + lambda args, timeout: calls.setdefault("data", (args, timeout)) or "callgraph") + + bridge_module.angr_callgraph_summary(summary_limit=3, timeout=77) + + args, timeout = calls["data"] + assert timeout == 77 + assert args == [ + "--binary", "/tmp/eternal.so", + "--callgraph-summary", + "--summary-limit", "3", + "--pcode-language", "eBPF:LE:64:default", + ] + + def test_angr_lift_block_builds_args( + self, bridge_module, httpx_mock, monkeypatch): + httpx_mock.add_response( + url=_url("program_info"), + text="executable_path: /tmp/eternal.so\n" + "language_id: eBPF:LE:64:default") + calls = {} + monkeypatch.setattr( + bridge_module, + "run_angr_helper", + lambda args, timeout: calls.setdefault("data", (args, timeout)) or "AIL") + + bridge_module.angr_lift_block( + "ram:00000120", + lift_format="ail", + num_inst=4, + timeout=33) + + args, timeout = calls["data"] + assert timeout == 33 + assert args == [ + "--binary", "/tmp/eternal.so", + "--lift-block", "0x120", + "--lift-format", "ail", + "--pcode-language", "eBPF:LE:64:default", + "--num-inst", "4", + ] + + def test_angr_compare_decompilers_batches_ghidra_and_oxidizer( + self, bridge_module, httpx_mock, monkeypatch): + httpx_mock.add_response( + url=_url("program_info"), + text="executable_path: /tmp/eternal.so\n" + "language_id: eBPF:LE:64:default") + httpx_mock.add_response( + url=_url("decompile_function?address=0x120&timeout=12"), + text="ghidra one") + httpx_mock.add_response( + url=_url("decompile_function?address=0x180&timeout=12"), + text="ghidra two") + calls = [] + + def fake_run(args, timeout): + calls.append((args, timeout)) + return "oxidizer" + + monkeypatch.setattr(bridge_module, "run_angr_helper", fake_run) + out = bridge_module.angr_compare_decompilers( + "ram:00000120, ram:00000180", + timeout_per_function=12, + max_functions=2) + + assert "ghidra one" in out + assert "ghidra two" in out + assert len(calls) == 2 + assert calls[0] == ([ + "--binary", "/tmp/eternal.so", + "--address", "0x120", + "--rust", + "--skip-rust-setup", + "--pcode-language", "eBPF:LE:64:default", + ], 12) + def test_angryghidra_check_setup_missing_is_clear( self, bridge_module, monkeypatch): monkeypatch.setattr(bridge_module, "find_angryghidra_script", lambda: "") From c7258cf479ba9ff86e86e222225b15d2bde30ee6 Mon Sep 17 00:00:00 2001 From: rustopian <96253492+rustopian@users.noreply.github.com> Date: Sat, 23 May 2026 14:13:35 +0100 Subject: [PATCH 2/6] Prefer AngryGhidra for symbolic paths --- CHANGELOG.md | 5 + README.md | 11 +- angr_decompile.py | 74 ++++++++-- bridge_mcp_ghidra.py | 328 +++++++++++++++++++++++++++++++++++++------ tests/test_bridge.py | 136 +++++++++++++++--- 5 files changed, 477 insertions(+), 77 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5d296bcb..ff5721ca 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,11 @@ The "Unreleased" section accumulates changes since the upstream `v1-4` release - **Expanded core angr MCP capabilities**: added `angr_reachability`, `angr_cfg_summary`, `angr_callgraph_summary`, `angr_lift_block`, `angr_solve_constraints_at`, and `angr_compare_decompilers`. +- **AngryGhidra-first symbolic path search**: `angr_symbolic_find` now supports + `engine="auto"|"angryghidra"|"core"` and uses AngryGhidra when installed and + compatible with the request, while preserving the core helper fallback. +- **Writeable angr annotations**: added `angr_annotate_symbolic_path` to run a + symbolic path search and write the recovered trace into Ghidra comments. - **Richer symbolic solving**: `angr_solve_constraints_at` reaches a target address, applies JSON-described register/memory/stdin/argv constraints, and evaluates requested registers, memory, stdin, and symbolic inputs. diff --git a/README.md b/README.md index 12225d9c..8e5f1438 100644 --- a/README.md +++ b/README.md @@ -49,13 +49,20 @@ tools. - For Solana/eBPF ELFs, pass `pcode_language="eBPF:LE:64:default"` or let the bridge infer it from Ghidra's program language id. The helper patches CLE at runtime for Solana's e_machine 263 and uses angr's p-code engine. -- Core angr tools do not require AngryGhidra: - `angr_symbolic_find` searches for a path to a target address; +- `angr_symbolic_find` defaults to `engine="auto"`: it uses AngryGhidra when + the script is installed and the request fits AngryGhidra's native symbolic + executor, then falls back to the core helper when needed. Use + `engine="angryghidra"` to require AngryGhidra or `engine="core"` to force the + direct helper. +- Additional core angr tools do not require AngryGhidra: `angr_solve_constraints_at` adds JSON-described constraints at the found state and evaluates requested values; `angr_reachability` checks static CFG reachability; `angr_cfg_summary` and `angr_callgraph_summary` summarize recovered graph structure; `angr_lift_block` lifts a block to VEX/AIL; and `angr_compare_decompilers` batches Ghidra-vs-Oxidizer decompiler output. +- `angr_annotate_symbolic_path` is an explicit write endpoint: it runs symbolic + path search and writes the recovered trace as Ghidra disassembly and/or + decompiler comments. - AngryGhidra support is optional. `angryghidra_*` tools look for `ANGRYGHIDRA_SCRIPT`, `ANGRYGHIDRA_HOME/angryghidra_script/angryghidra.py`, or a sibling `AngryGhidra/angryghidra_script/angryghidra.py`. If none is diff --git a/angr_decompile.py b/angr_decompile.py index f62e42a6..495c3935 100644 --- a/angr_decompile.py +++ b/angr_decompile.py @@ -5,6 +5,7 @@ import os import sys import traceback +from collections import deque from contextlib import redirect_stderr BRIDGE_DIR = os.path.dirname(os.path.abspath(__file__)) @@ -45,11 +46,17 @@ def extract_arch_with_pcode_fallback(reader): ELF.extract_arch = staticmethod(extract_arch_with_pcode_fallback) -def make_project(binary_path: str, pcode_language: str, rust: bool, base_address: str = ""): +def make_project( + binary_path: str, + pcode_language: str, + rust: bool, + base_address: str = "", + auto_load_libs: bool = False, +): import angr import archinfo - kwargs = {"auto_load_libs": False} + kwargs = {"auto_load_libs": auto_load_libs} main_opts = {} if pcode_language: patch_elf_pcode_loader(pcode_language) @@ -145,7 +152,13 @@ def setup_symbolic_execution(args: argparse.Namespace, target_value: str): target_addr = parse_address(target_value) avoid = [parse_address(value) for value in args.avoid_address.split(",") if value.strip()] - project = make_project(args.binary, args.pcode_language, rust=False, base_address=args.base_address) + project = make_project( + args.binary, + args.pcode_language, + rust=False, + base_address=args.base_address, + auto_load_libs=args.auto_load_libs, + ) argv = [args.binary] symbolic_argv = [] @@ -308,7 +321,13 @@ def run_check(args: argparse.Namespace) -> int: print(f"python: {sys.executable}") print(f"angr: {angr.__version__}") if args.binary: - project = make_project(args.binary, args.pcode_language, rust=False, base_address=args.base_address) + project = make_project( + args.binary, + args.pcode_language, + rust=False, + base_address=args.base_address, + auto_load_libs=args.auto_load_libs, + ) print(f"binary: {args.binary}") print(f"arch: {project.arch}") print(f"min_addr: {project.loader.main_object.min_addr:#x}") @@ -416,8 +435,14 @@ def run_solve_at(args: argparse.Namespace) -> int: def run_reachability(args: argparse.Namespace) -> int: source = parse_address(args.reachability_from) target = parse_address(args.reachability_to) - project = make_project(args.binary, args.pcode_language, rust=False, base_address=args.base_address) - cfg = build_cfg(project, args, function_starts=[source]) + project = make_project( + args.binary, + args.pcode_language, + rust=False, + base_address=args.base_address, + auto_load_libs=args.auto_load_libs, + ) + cfg = build_cfg(project, args, function_starts=None if args.complete_cfg else [source]) source_node = get_cfg_node(cfg, source) target_node = get_cfg_node(cfg, target) @@ -437,10 +462,10 @@ def run_reachability(args: argparse.Namespace) -> int: print("reason: target node not found") return 2 - queue = [source_node] + queue = deque([source_node]) predecessor = {source_node: None} while queue: - node = queue.pop(0) + node = queue.popleft() if node == target_node: break for successor in cfg.graph.successors(node): @@ -471,7 +496,13 @@ def run_reachability(args: argparse.Namespace) -> int: def run_cfg_summary(args: argparse.Namespace) -> int: - project = make_project(args.binary, args.pcode_language, rust=False, base_address=args.base_address) + project = make_project( + args.binary, + args.pcode_language, + rust=False, + base_address=args.base_address, + auto_load_libs=args.auto_load_libs, + ) function_addr = parse_address(args.function_address) if args.function_address else None cfg = build_cfg(project, args, function_starts=[function_addr] if function_addr is not None else None) @@ -508,7 +539,13 @@ def run_cfg_summary(args: argparse.Namespace) -> int: def run_callgraph_summary(args: argparse.Namespace) -> int: - project = make_project(args.binary, args.pcode_language, rust=False, base_address=args.base_address) + project = make_project( + args.binary, + args.pcode_language, + rust=False, + base_address=args.base_address, + auto_load_libs=args.auto_load_libs, + ) build_cfg(project, args) callgraph = project.kb.functions.callgraph print(f"binary: {args.binary}") @@ -530,7 +567,13 @@ def run_lift_block(args: argparse.Namespace) -> int: from angr.ailment.manager import Manager address = parse_address(args.lift_block) - project = make_project(args.binary, args.pcode_language, rust=False, base_address=args.base_address) + project = make_project( + args.binary, + args.pcode_language, + rust=False, + base_address=args.base_address, + auto_load_libs=args.auto_load_libs, + ) block = make_block(project, address, args.block_size, args.num_inst) print(f"binary: {args.binary}") @@ -557,7 +600,13 @@ def run_decompile(args: argparse.Namespace) -> int: import angr # noqa: F401 target_addr = parse_address(args.address) - project = make_project(args.binary, args.pcode_language, args.rust, base_address=args.base_address) + project = make_project( + args.binary, + args.pcode_language, + args.rust, + base_address=args.base_address, + auto_load_libs=args.auto_load_libs, + ) cfg_kwargs = { "normalize": True, @@ -611,6 +660,7 @@ def main() -> int: parser.add_argument("--address", help="Function entry address") parser.add_argument("--pcode-language", default="", help="Optional pypcode language id") parser.add_argument("--base-address", default="", help="Optional image base for raw/blob-style loads") + parser.add_argument("--auto-load-libs", action="store_true", help="Ask angr to load imported libraries") parser.add_argument("--skip-rust-setup", action="store_true", help="Skip Rust recovery setup analyses") parser.add_argument("--check", action="store_true", help="Only verify angr imports and optional binary load") parser.add_argument("--symbolic-find", help="Find an execution path to this address") diff --git a/bridge_mcp_ghidra.py b/bridge_mcp_ghidra.py index 9c8d3081..b5970f7d 100644 --- a/bridge_mcp_ghidra.py +++ b/bridge_mcp_ghidra.py @@ -241,6 +241,162 @@ def split_addresses(addresses: str, max_addresses: int) -> list[str]: ] return result[:max(1, max_addresses)] +def normalize_angryghidra_map_values( + parsed: dict, + symbolic_prefix_ok: bool = False, + integers_as_hex: bool = False, +) -> dict: + normalized = {} + for key, value in parsed.items(): + if isinstance(value, int) and integers_as_hex: + normalized[str(key)] = hex(value) + elif symbolic_prefix_ok and isinstance(value, str) and value.startswith("sv"): + normalized[str(key)] = value + else: + normalized[str(key)] = str(value) + return normalized + +def make_angryghidra_arguments(argv_bytes: str) -> dict: + arguments = {} + for index, length in enumerate((part.strip() for part in argv_bytes.split(",")), start=1): + if length: + arguments[str(index)] = length + return arguments + +def run_angryghidra_options(options: dict, timeout: int) -> str: + script = find_angryghidra_script() + if not script: + return angryghidra_missing_message() + + python = os.environ.get("ANGRYGHIDRA_PYTHON") or os.environ.get("GHIDRA_MCP_ANGR_PYTHON") or default_angr_python() + options_path = "" + try: + with tempfile.NamedTemporaryFile("w", suffix="-angryghidra.json", delete=False) as options_file: + json.dump(options, options_file) + options_path = options_file.name + completed = subprocess.run( + [python, script, options_path], + capture_output=True, + text=True, + timeout=max(1, timeout), + check=False, + ) + except FileNotFoundError as e: + return f"Failed to start AngryGhidra: {e}" + except subprocess.TimeoutExpired: + return f"AngryGhidra timed out after {timeout} seconds" + finally: + if options_path: + try: + os.unlink(options_path) + except OSError: + pass + + output = completed.stdout.strip() + errors = completed.stderr.strip() + if completed.returncode == 0: + return output if output else "(AngryGhidra returned no solution)" + + details = output + if errors: + details = f"{details}\n\nstderr:\n{errors}" if details else f"stderr:\n{errors}" + return f"AngryGhidra failed with exit code {completed.returncode}\n\n{details}".strip() + +def build_angryghidra_symbolic_options( + find_address: str, + binary_path: str, + start_address: str = "", + avoid_addresses: str = "", + base_address: str = "", + raw_binary_arch: str = "", + auto_load_libs: bool = False, + argv_bytes: str = "", + symbolic_memory_json: str = "", + memory_json: str = "", + registers_json: str = "", +) -> tuple[dict | None, str]: + if not binary_path: + return None, "No binary_path provided and Ghidra did not return an executable_path" + if not base_address: + program_info = parse_key_value_lines(safe_get("program_info")) + base_address = program_info.get("min_address", "0x0") + + try: + options = { + "binary_file": binary_path, + "base_address": normalize_ghidra_address(base_address), + "find_address": normalize_ghidra_address(find_address), + "auto_load_libs": auto_load_libs, + } + if start_address: + options["blank_state"] = normalize_ghidra_address(start_address) + if avoid_addresses: + options["avoid_address"] = ",".join( + normalize_ghidra_address(address) + for address in avoid_addresses.split(",") + if address.strip() + ) + if raw_binary_arch: + options["raw_binary_arch"] = raw_binary_arch + arguments = make_angryghidra_arguments(argv_bytes) + if arguments: + options["arguments"] = arguments + for key, value, field_name, symbolic_prefix_ok, integers_as_hex in [ + ("vectors", symbolic_memory_json, "symbolic_memory_json", False, False), + ("mem_store", memory_json, "memory_json", False, True), + ("regs_vals", registers_json, "registers_json", True, True), + ]: + parsed = parse_optional_json(value, field_name) + if parsed is not None: + if not isinstance(parsed, dict): + return None, f"{field_name} must be a JSON object" + options[key] = normalize_angryghidra_map_values( + parsed, + symbolic_prefix_ok, + integers_as_hex, + ) + except ValueError as e: + return None, str(e) + + return options, "" + +def angryghidra_symbolic_unsupported_reason( + stdin_bytes: int, + pcode_language: str = "", + raw_binary_arch: str = "", +) -> str: + unsupported = [] + if stdin_bytes > 0: + unsupported.append("symbolic stdin is not supported by AngryGhidra's native script") + if pcode_language and not raw_binary_arch: + unsupported.append("p-code language loading requires the core angr helper unless raw_binary_arch is provided") + return "; ".join(unsupported) + +def extract_trace_addresses(output: str) -> list[str]: + addresses = [] + in_core_path = False + for line in output.splitlines(): + stripped = line.strip() + if stripped == "path:": + in_core_path = True + continue + if stripped.startswith("t:"): + addresses.append(normalize_ghidra_address(stripped[2:])) + in_core_path = False + continue + if in_core_path and stripped.startswith("0x"): + addresses.append(normalize_ghidra_address(stripped.split()[0])) + continue + if in_core_path and stripped and not stripped.startswith("..."): + in_core_path = False + deduped = [] + seen = set() + for address in addresses: + if address not in seen: + deduped.append(address) + seen.add(address) + return deduped + @mcp.tool() def list_methods(offset: int = 0, limit: int = 100) -> list: """ @@ -443,19 +599,24 @@ def angr_symbolic_find( avoid_addresses: str = "", pcode_language: str = "", base_address: str = "", + raw_binary_arch: str = "", + auto_load_libs: bool = False, stdin_bytes: int = 0, argv_bytes: str = "", symbolic_memory_json: str = "", memory_json: str = "", registers_json: str = "", + engine: str = "auto", timeout: int = 120, max_steps: int = 10000, ) -> str: """ - Use core angr symbolic execution to find a path to an address. + Find a symbolic execution path to an address. - This is independent of AngryGhidra. It can make stdin/argv symbolic, seed - symbolic memory, seed concrete memory/register values, and avoid addresses. + engine="auto" prefers AngryGhidra when it is installed and the request fits + AngryGhidra's native script, then falls back to the core angr helper. Use + engine="angryghidra" to require AngryGhidra, or engine="core" to force the + bridge's direct angr helper. Args: find_address: Address to reach. @@ -466,6 +627,8 @@ def angr_symbolic_find( avoid_addresses: Optional comma-separated addresses to avoid. pcode_language: Optional pypcode language id. base_address: Optional loader base address. + raw_binary_arch: Optional AngryGhidra raw blob architecture. + auto_load_libs: Whether AngryGhidra/core angr should load shared libs. stdin_bytes: Symbolic stdin length in bytes. argv_bytes: Comma-separated symbolic argv byte lengths, e.g. "8,16". symbolic_memory_json: JSON object mapping address to symbolic byte @@ -473,11 +636,17 @@ def angr_symbolic_find( memory_json: JSON object mapping address to concrete integer/hex value. registers_json: JSON object mapping register names to values or "svN" for an N-byte symbolic register. + engine: "auto", "angryghidra", or "core". timeout: Maximum helper runtime in seconds. - max_steps: Maximum symbolic execution steps. Use 0 for unbounded. + max_steps: Maximum core-helper symbolic execution steps. AngryGhidra's + native script runs until it finds a path or the timeout hits. """ + requested_engine = engine.lower().strip() + if requested_engine not in {"auto", "angryghidra", "core"}: + return 'engine must be one of: "auto", "angryghidra", "core"' + program_info = {} - if not binary_path or not pcode_language: + if not binary_path or not pcode_language or not base_address: program_info = parse_key_value_lines(safe_get("program_info")) if not binary_path: binary_path = program_info.get("executable_path", "") @@ -485,6 +654,37 @@ def angr_symbolic_find( return "No binary_path provided and Ghidra did not return an executable_path" if not pcode_language: pcode_language = infer_pcode_language(program_info.get("language_id")) + if not base_address: + base_address = program_info.get("min_address", "") + + if requested_engine in {"auto", "angryghidra"}: + script = find_angryghidra_script() + unsupported = angryghidra_symbolic_unsupported_reason( + stdin_bytes, + pcode_language, + raw_binary_arch, + ) + if script and not unsupported: + options, error = build_angryghidra_symbolic_options( + find_address=find_address, + binary_path=binary_path, + start_address=start_address, + avoid_addresses=avoid_addresses, + base_address=base_address, + raw_binary_arch=raw_binary_arch, + auto_load_libs=auto_load_libs, + argv_bytes=argv_bytes, + symbolic_memory_json=symbolic_memory_json, + memory_json=memory_json, + registers_json=registers_json, + ) + if error: + return error + return "engine: AngryGhidra\n" + run_angryghidra_options(options, timeout) + if requested_engine == "angryghidra": + if not script: + return angryghidra_missing_message() + return f"AngryGhidra cannot run this request: {unsupported}" args = [ "--binary", binary_path, @@ -503,6 +703,8 @@ def angr_symbolic_find( args.extend(["--pcode-language", pcode_language]) if base_address: args.extend(["--base-address", normalize_ghidra_address(base_address)]) + if auto_load_libs: + args.append("--auto-load-libs") if stdin_bytes > 0: args.extend(["--stdin-bytes", str(stdin_bytes)]) if argv_bytes: @@ -516,7 +718,84 @@ def angr_symbolic_find( if parsed is not None: args.extend([option, json.dumps(parsed)]) - return run_angr_helper(args, max(1, timeout)) + return "engine: core angr\n" + run_angr_helper(args, max(1, timeout)) + +@mcp.tool() +def angr_annotate_symbolic_path( + find_address: str, + binary_path: str = "", + start_address: str = "", + avoid_addresses: str = "", + pcode_language: str = "", + base_address: str = "", + raw_binary_arch: str = "", + auto_load_libs: bool = False, + stdin_bytes: int = 0, + argv_bytes: str = "", + symbolic_memory_json: str = "", + memory_json: str = "", + registers_json: str = "", + engine: str = "auto", + comment_kind: str = "disasm", + comment_prefix: str = "angr symbolic path", + max_comments: int = 100, + timeout: int = 120, + max_steps: int = 10000, +) -> str: + """ + Run a symbolic path search and write path comments into the Ghidra program. + + This is an explicit write endpoint. comment_kind may be "disasm", + "decomp", or "both"; comments are applied only when a trace/path is found. + The underlying path search prefers AngryGhidra in engine="auto" when the + request fits AngryGhidra's native script. + """ + if comment_kind not in {"disasm", "decomp", "both"}: + return 'comment_kind must be one of: "disasm", "decomp", "both"' + + result = angr_symbolic_find( + find_address=find_address, + binary_path=binary_path, + start_address=start_address, + avoid_addresses=avoid_addresses, + pcode_language=pcode_language, + base_address=base_address, + raw_binary_arch=raw_binary_arch, + auto_load_libs=auto_load_libs, + stdin_bytes=stdin_bytes, + argv_bytes=argv_bytes, + symbolic_memory_json=symbolic_memory_json, + memory_json=memory_json, + registers_json=registers_json, + engine=engine, + timeout=timeout, + max_steps=max_steps, + ) + + trace_addresses = extract_trace_addresses(result)[: max(1, max_comments)] + if not trace_addresses: + return f"{result}\n\nNo trace addresses found; no comments were written." + + endpoints = [] + if comment_kind in {"disasm", "both"}: + endpoints.append("set_disassembly_comment") + if comment_kind in {"decomp", "both"}: + endpoints.append("set_decompiler_comment") + + writes = [] + total = len(trace_addresses) + normalized_target = normalize_ghidra_address(find_address) + for index, address in enumerate(trace_addresses, start=1): + comment = f"{comment_prefix}: step {index}/{total} toward {normalized_target}" + for endpoint in endpoints: + response = safe_post(endpoint, {"address": address, "comment": comment}) + writes.append(f"{endpoint} {address}: {response}") + + return ( + f"{result}\n\n" + f"Annotated {len(trace_addresses)} trace address(es) with {len(writes)} comment write(s).\n" + + "\n".join(writes) + ) @mcp.tool() def angr_solve_constraints_at( @@ -794,8 +1073,7 @@ def angryghidra_symbolic_execute( not installed, this returns a clear error and leaves all other bridge tools working normally. """ - script = find_angryghidra_script() - if not script: + if not find_angryghidra_script(): return angryghidra_missing_message() program_info = {} @@ -839,39 +1117,7 @@ def angryghidra_symbolic_execute( except ValueError as e: return str(e) - python = os.environ.get("ANGRYGHIDRA_PYTHON") or os.environ.get("GHIDRA_MCP_ANGR_PYTHON") or default_angr_python() - options_path = "" - try: - with tempfile.NamedTemporaryFile("w", suffix="-angryghidra.json", delete=False) as options_file: - json.dump(options, options_file) - options_path = options_file.name - completed = subprocess.run( - [python, script, options_path], - capture_output=True, - text=True, - timeout=max(1, timeout), - check=False, - ) - except FileNotFoundError as e: - return f"Failed to start AngryGhidra: {e}" - except subprocess.TimeoutExpired: - return f"AngryGhidra timed out after {timeout} seconds" - finally: - if options_path: - try: - os.unlink(options_path) - except OSError: - pass - - output = completed.stdout.strip() - errors = completed.stderr.strip() - if completed.returncode == 0: - return output if output else "(AngryGhidra returned no solution)" - - details = output - if errors: - details = f"{details}\n\nstderr:\n{errors}" if details else f"stderr:\n{errors}" - return f"AngryGhidra failed with exit code {completed.returncode}\n\n{details}".strip() + return run_angryghidra_options(options, timeout) @mcp.tool() def decompile_by_addr(address: str, timeout: int = 120) -> str: diff --git a/tests/test_bridge.py b/tests/test_bridge.py index f8855cd5..1c172c7c 100644 --- a/tests/test_bridge.py +++ b/tests/test_bridge.py @@ -271,7 +271,7 @@ def fake_run(args, timeout): "--pcode-language", "eBPF:LE:64:default", ] - def test_angr_symbolic_find_uses_core_angr_without_angryghidra( + def test_angr_symbolic_find_falls_back_to_core_when_angryghidra_cannot_represent_request( self, bridge_module, httpx_mock, monkeypatch): httpx_mock.add_response( url=_url("program_info"), @@ -295,7 +295,7 @@ def fake_run(args, timeout): timeout=55, max_steps=123) - assert out == "found: true" + assert out == "engine: core angr\nfound: true" assert calls["timeout"] == 55 assert calls["args"] == [ "--binary", "/tmp/eternal.so", @@ -309,6 +309,90 @@ def fake_run(args, timeout): "--registers-json", '{"r1": "sv8", "r2": "0x10"}', ] + def test_angr_symbolic_find_prefers_angryghidra_when_available( + self, bridge_module, httpx_mock, monkeypatch): + httpx_mock.add_response( + url=_url("program_info"), + text="executable_path: /tmp/a.out\n" + "language_id: x86:LE:64:default\n" + "min_address: 0x400000") + monkeypatch.setattr(bridge_module, "find_angryghidra_script", lambda: "/opt/angryghidra.py") + calls = {} + + def fake_run(options, timeout): + calls["options"] = options + calls["timeout"] = timeout + return "t:0x401000\nt:0x401020\nargv[1] = b'ok'" + + monkeypatch.setattr(bridge_module, "run_angryghidra_options", fake_run) + out = bridge_module.angr_symbolic_find( + find_address="0x401020", + start_address="0x401000", + argv_bytes="4", + symbolic_memory_json='{"0x404000": 8}', + memory_json='{"0x405000": "0x41"}', + registers_json='{"rax": "sv8"}', + timeout=44) + + assert out.startswith("engine: AngryGhidra\n") + assert calls["timeout"] == 44 + assert calls["options"] == { + "binary_file": "/tmp/a.out", + "base_address": "0x400000", + "find_address": "0x401020", + "auto_load_libs": False, + "blank_state": "0x401000", + "arguments": {"1": "4"}, + "vectors": {"0x404000": "8"}, + "mem_store": {"0x405000": "0x41"}, + "regs_vals": {"rax": "sv8"}, + } + + def test_angr_symbolic_find_forced_angryghidra_missing_is_clear( + self, bridge_module, httpx_mock, monkeypatch): + httpx_mock.add_response( + url=_url("program_info"), + text="executable_path: /tmp/a.out\n" + "language_id: x86:LE:64:default\n" + "min_address: 0x400000") + monkeypatch.setattr(bridge_module, "find_angryghidra_script", lambda: "") + + out = bridge_module.angr_symbolic_find( + find_address="0x401020", + engine="angryghidra") + + assert "AngryGhidra is not installed or configured" in out + + def test_angr_annotate_symbolic_path_writes_trace_comments( + self, bridge_module, httpx_mock, monkeypatch): + monkeypatch.setattr( + bridge_module, + "angr_symbolic_find", + lambda **_kwargs: "engine: AngryGhidra\nt:0x401000\nt:0x401020") + httpx_mock.add_response( + method="POST", + url=_url("set_disassembly_comment"), + match_content=( + b"address=0x401000&comment=angr+symbolic+path%3A+step+1%2F2+" + b"toward+0x401020" + ), + text="Comment set successfully") + httpx_mock.add_response( + method="POST", + url=_url("set_disassembly_comment"), + match_content=( + b"address=0x401020&comment=angr+symbolic+path%3A+step+2%2F2+" + b"toward+0x401020" + ), + text="Comment set successfully") + + out = bridge_module.angr_annotate_symbolic_path( + find_address="0x401020", + comment_kind="disasm") + + assert "Annotated 2 trace address(es)" in out + assert "set_disassembly_comment 0x401000: Comment set successfully" in out + def test_angr_solve_constraints_at_builds_rich_solver_args( self, bridge_module, httpx_mock, monkeypatch): httpx_mock.add_response( @@ -352,12 +436,13 @@ def test_angr_reachability_builds_cfg_args( text="executable_path: /tmp/eternal.so\n" "language_id: eBPF:LE:64:default") calls = {} - monkeypatch.setattr( - bridge_module, - "run_angr_helper", - lambda args, timeout: calls.setdefault("data", (args, timeout)) or "reachable: true") + def fake_run(args, timeout): + calls["data"] = (args, timeout) + return "reachable: true" - bridge_module.angr_reachability( + monkeypatch.setattr(bridge_module, "run_angr_helper", fake_run) + + out = bridge_module.angr_reachability( "ram:00000120", "ram:00000180", complete_cfg=True, @@ -365,6 +450,7 @@ def test_angr_reachability_builds_cfg_args( summary_limit=7, timeout=99) + assert out == "reachable: true" args, timeout = calls["data"] assert timeout == 99 assert args == [ @@ -383,16 +469,18 @@ def test_angr_cfg_summary_builds_function_args( text="executable_path: /tmp/eternal.so\n" "language_id: eBPF:LE:64:default") calls = {} - monkeypatch.setattr( - bridge_module, - "run_angr_helper", - lambda args, timeout: calls.setdefault("data", (args, timeout)) or "cfg") + def fake_run(args, timeout): + calls["data"] = (args, timeout) + return "cfg" - bridge_module.angr_cfg_summary( + monkeypatch.setattr(bridge_module, "run_angr_helper", fake_run) + + out = bridge_module.angr_cfg_summary( function_address="ram:00000120", summary_limit=5, timeout=66) + assert out == "cfg" args, timeout = calls["data"] assert timeout == 66 assert args == [ @@ -410,13 +498,15 @@ def test_angr_callgraph_summary_builds_args( text="executable_path: /tmp/eternal.so\n" "language_id: eBPF:LE:64:default") calls = {} - monkeypatch.setattr( - bridge_module, - "run_angr_helper", - lambda args, timeout: calls.setdefault("data", (args, timeout)) or "callgraph") + def fake_run(args, timeout): + calls["data"] = (args, timeout) + return "callgraph" + + monkeypatch.setattr(bridge_module, "run_angr_helper", fake_run) - bridge_module.angr_callgraph_summary(summary_limit=3, timeout=77) + out = bridge_module.angr_callgraph_summary(summary_limit=3, timeout=77) + assert out == "callgraph" args, timeout = calls["data"] assert timeout == 77 assert args == [ @@ -433,17 +523,19 @@ def test_angr_lift_block_builds_args( text="executable_path: /tmp/eternal.so\n" "language_id: eBPF:LE:64:default") calls = {} - monkeypatch.setattr( - bridge_module, - "run_angr_helper", - lambda args, timeout: calls.setdefault("data", (args, timeout)) or "AIL") + def fake_run(args, timeout): + calls["data"] = (args, timeout) + return "AIL" + + monkeypatch.setattr(bridge_module, "run_angr_helper", fake_run) - bridge_module.angr_lift_block( + out = bridge_module.angr_lift_block( "ram:00000120", lift_format="ail", num_inst=4, timeout=33) + assert out == "AIL" args, timeout = calls["data"] assert timeout == 33 assert args == [ From d90028e1439ff96ac254743b4d385d63bbb66075 Mon Sep 17 00:00:00 2001 From: rustopian <96253492+rustopian@users.noreply.github.com> Date: Sat, 23 May 2026 14:34:51 +0100 Subject: [PATCH 3/6] Harden angr MCP tools --- CHANGELOG.md | 5 +- README.md | 10 +- angr_decompile.py | 151 ++++++++++++--- bridge_mcp_ghidra.py | 434 ++++++++++++++++++++++++++++++++++++++++--- tests/test_bridge.py | 59 +++++- 5 files changed, 604 insertions(+), 55 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ff5721ca..fae01581 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,7 +18,10 @@ The "Unreleased" section accumulates changes since the upstream `v1-4` release `engine="auto"|"angryghidra"|"core"` and uses AngryGhidra when installed and compatible with the request, while preserving the core helper fallback. - **Writeable angr annotations**: added `angr_annotate_symbolic_path` to run a - symbolic path search and write the recovered trace into Ghidra comments. + symbolic path search, preview recovered trace comments by default, and write + them only with explicit overwrite confirmation. +- **angr safety caps**: bounded helper output, symbolic input sizes, execution + steps, summary output, lift size, and decompiler comparison batches. - **Richer symbolic solving**: `angr_solve_constraints_at` reaches a target address, applies JSON-described register/memory/stdin/argv constraints, and evaluates requested registers, memory, stdin, and symbolic inputs. diff --git a/README.md b/README.md index 8e5f1438..194acdac 100644 --- a/README.md +++ b/README.md @@ -60,9 +60,13 @@ tools. reachability; `angr_cfg_summary` and `angr_callgraph_summary` summarize recovered graph structure; `angr_lift_block` lifts a block to VEX/AIL; and `angr_compare_decompilers` batches Ghidra-vs-Oxidizer decompiler output. -- `angr_annotate_symbolic_path` is an explicit write endpoint: it runs symbolic - path search and writes the recovered trace as Ghidra disassembly and/or - decompiler comments. +- `angr_annotate_symbolic_path` previews by default. To write the recovered + trace as Ghidra disassembly and/or decompiler comments, call it with + `apply=true` and `overwrite_existing=true`; the underlying Ghidra comment + endpoints replace existing comments. +- angr/AngryGhidra execution is bounded by conservative limits on helper output, + symbolic input sizes, symbolic steps, summary output, lift size, and batch + comparison size. - AngryGhidra support is optional. `angryghidra_*` tools look for `ANGRYGHIDRA_SCRIPT`, `ANGRYGHIDRA_HOME/angryghidra_script/angryghidra.py`, or a sibling `AngryGhidra/angryghidra_script/angryghidra.py`. If none is diff --git a/angr_decompile.py b/angr_decompile.py index 495c3935..41272774 100644 --- a/angr_decompile.py +++ b/angr_decompile.py @@ -7,12 +7,25 @@ import traceback from collections import deque from contextlib import redirect_stderr +from itertools import islice BRIDGE_DIR = os.path.dirname(os.path.abspath(__file__)) os.environ.setdefault("XDG_CACHE_HOME", os.path.join(BRIDGE_DIR, ".angr-cache")) +MAX_JSON_INPUT_CHARS = 100_000 +MAX_SYMBOLIC_BYTES = 4096 +MAX_TOTAL_SYMBOLIC_BYTES = 16_384 +MAX_SYMBOLIC_ARGS = 16 +MAX_SYMBOLIC_REGIONS = 64 +MAX_SYMBOLIC_REGISTER_BYTES = 64 +MAX_CONSTRAINTS = 128 +MAX_STEPS = 100_000 +MAX_SUMMARY_LIMIT = 500 +MAX_BLOCK_SIZE = 4096 +MAX_NUM_INST = 256 + def parse_address(value: str) -> int: if value.startswith("0x") and ":" in value: @@ -77,6 +90,8 @@ def make_project( def parse_json_map(value: str, field_name: str) -> dict: if not value: return {} + if len(value) > MAX_JSON_INPUT_CHARS: + raise ValueError(f"{field_name} exceeds {MAX_JSON_INPUT_CHARS} characters") try: parsed = json.loads(value) except json.JSONDecodeError as exc: @@ -86,10 +101,31 @@ def parse_json_map(value: str, field_name: str) -> dict: return parsed -def parse_csv_ints(value: str) -> list[int]: +def checked_int(value, field_name: str, minimum: int, maximum: int) -> int: + try: + parsed = int(str(value), 0) + except ValueError as exc: + raise ValueError(f"{field_name} must be an integer") from exc + if parsed < minimum or parsed > maximum: + raise ValueError(f"{field_name} must be between {minimum} and {maximum}") + return parsed + + +def parse_csv_ints( + value: str, + field_name: str, + max_items: int, + max_value: int, +) -> list[int]: if not value: return [] - return [int(part.strip(), 0) for part in value.split(",") if part.strip()] + parts = [part.strip() for part in value.split(",") if part.strip()] + if len(parts) > max_items: + raise ValueError(f"{field_name} may contain at most {max_items} entries") + return [ + checked_int(part, f"{field_name} entry", 1, max_value) + for part in parts + ] def parse_csv_strings(value: str) -> list[str]: @@ -141,8 +177,10 @@ def normalize_register_name(project, reg_name: str) -> str: def make_block(project, address: int, block_size: int = 0, num_inst: int = 0): kwargs = {} if block_size > 0: + block_size = checked_int(block_size, "block_size", 1, MAX_BLOCK_SIZE) kwargs["size"] = block_size if num_inst > 0: + num_inst = checked_int(num_inst, "num_inst", 1, MAX_NUM_INST) kwargs["num_inst"] = num_inst return project.factory.block(address, **kwargs) @@ -162,16 +200,26 @@ def setup_symbolic_execution(args: argparse.Namespace, target_value: str): argv = [args.binary] symbolic_argv = [] - for index, length in enumerate(parse_csv_ints(args.argv_bytes), start=1): + total_symbolic = 0 + argv_lengths = parse_csv_ints( + args.argv_bytes, + "argv_bytes", + MAX_SYMBOLIC_ARGS, + MAX_SYMBOLIC_BYTES, + ) + for index, length in enumerate(argv_lengths, start=1): sym_arg = claripy.BVS(f"argv{index}", length * 8) symbolic_argv.append((index, length, sym_arg)) argv.append(sym_arg) + total_symbolic += length symbolic_stdin = None stdin = None if args.stdin_bytes > 0: + args.stdin_bytes = checked_int(args.stdin_bytes, "stdin_bytes", 1, MAX_SYMBOLIC_BYTES) symbolic_stdin = claripy.BVS("stdin", args.stdin_bytes * 8) stdin = symbolic_stdin + total_symbolic += args.stdin_bytes if args.start_address: state_kwargs = {} @@ -187,34 +235,61 @@ def setup_symbolic_execution(args: argparse.Namespace, target_value: str): state = project.factory.entry_state(**state_kwargs) symbolic_memory = {} - for addr, length in parse_json_map(args.symbolic_memory_json, "symbolic_memory_json").items(): + symbolic_memory_map = parse_json_map(args.symbolic_memory_json, "symbolic_memory_json") + if len(symbolic_memory_map) > MAX_SYMBOLIC_REGIONS: + raise ValueError(f"symbolic_memory_json may contain at most {MAX_SYMBOLIC_REGIONS} entries") + for addr, length in symbolic_memory_map.items(): mem_addr = parse_address(str(addr)) - mem_len = int(str(length), 0) + mem_len = checked_int( + length, + f"symbolic_memory_json[{addr!r}]", + 1, + MAX_SYMBOLIC_BYTES, + ) sym_mem = claripy.BVS(f"mem_{mem_addr:x}", mem_len * 8) symbolic_memory[mem_addr] = (mem_len, sym_mem) state.memory.store(mem_addr, sym_mem) + total_symbolic += mem_len - for addr, value in parse_json_map(args.memory_json, "memory_json").items(): + memory_map = parse_json_map(args.memory_json, "memory_json") + if len(memory_map) > MAX_SYMBOLIC_REGIONS: + raise ValueError(f"memory_json may contain at most {MAX_SYMBOLIC_REGIONS} entries") + for addr, value in memory_map.items(): mem_addr = parse_address(str(addr)) if isinstance(value, str): concrete = int(value, 0) else: concrete = int(value) + if concrete < 0: + raise ValueError(f"memory_json[{addr!r}] must be non-negative") byte_len = max(1, (concrete.bit_length() + 7) // 8) + if byte_len > MAX_SYMBOLIC_BYTES: + raise ValueError(f"memory_json[{addr!r}] may contain at most {MAX_SYMBOLIC_BYTES} bytes") state.memory.store(mem_addr, concrete, size=byte_len) symbolic_registers = {} registers = parse_json_map(args.registers_json, "registers_json") + if len(registers) > MAX_SYMBOLIC_REGIONS: + raise ValueError(f"registers_json may contain at most {MAX_SYMBOLIC_REGIONS} entries") for reg_name, value in registers.items(): reg_name = normalize_register_name(project, reg_name) if isinstance(value, str) and value.startswith("sv"): - byte_len = int(value[2:], 0) + byte_len = checked_int( + value[2:], + f"registers_json[{reg_name!r}]", + 1, + MAX_SYMBOLIC_REGISTER_BYTES, + ) sym_reg = claripy.BVS(f"reg_{reg_name}", byte_len * 8) symbolic_registers[reg_name] = (byte_len, sym_reg) setattr(state.regs, reg_name, sym_reg) + total_symbolic += byte_len else: setattr(state.regs, reg_name, int(str(value), 0)) + if total_symbolic > MAX_TOTAL_SYMBOLIC_BYTES: + raise ValueError(f"total symbolic input may not exceed {MAX_TOTAL_SYMBOLIC_BYTES} bytes") + symbols = { "stdin": (args.stdin_bytes, symbolic_stdin), "argv": symbolic_argv, @@ -227,16 +302,14 @@ def setup_symbolic_execution(args: argparse.Namespace, target_value: str): def run_explorer(project, state, target_addr: int, avoid: list[int], max_steps: int): import angr + max_steps = checked_int(max_steps, "max_steps", 1, MAX_STEPS) simgr = project.factory.simulation_manager(state) explorer_kwargs = {"find": target_addr} if avoid: explorer_kwargs["avoid"] = avoid simgr.use_technique(angr.exploration_techniques.Explorer(**explorer_kwargs)) - if max_steps > 0: - simgr.run(n=max_steps) - else: - simgr.run() + simgr.run(n=max_steps) return simgr @@ -260,7 +333,8 @@ def get_symbolic_ast(state, symbols, item: dict): reg_name = normalize_register_name(state.project, item["name"]) return getattr(state.regs, reg_name) if target_type == "mem": - return state.memory.load(parse_address(str(item["address"])), int(str(item["length"]), 0)) + mem_len = checked_int(item["length"], "constraint memory length", 1, MAX_SYMBOLIC_BYTES) + return state.memory.load(parse_address(str(item["address"])), mem_len) if target_type == "stdin": _length, symbolic_stdin = symbols["stdin"] if symbolic_stdin is None: @@ -280,9 +354,14 @@ def concrete_bvv(state, item: dict, bits: int): if "value_hex" in item: raw = bytes.fromhex(str(item["value_hex"]).removeprefix("0x")) + if len(raw) * 8 != bits: + raise ValueError(f"value_hex is {len(raw) * 8} bits, expected {bits}") return claripy.BVV(raw) if "value_bytes" in item: - return claripy.BVV(str(item["value_bytes"]).encode()) + raw = str(item["value_bytes"]).encode() + if len(raw) * 8 != bits: + raise ValueError(f"value_bytes is {len(raw) * 8} bits, expected {bits}") + return claripy.BVV(raw) if "value" not in item: raise ValueError("constraint is missing value, value_hex, or value_bytes") return claripy.BVV(int(str(item["value"]), 0), bits) @@ -389,6 +468,8 @@ def run_solve_at(args: argparse.Namespace) -> int: if not args.constraints_json: parsed_constraints = [] else: + if len(args.constraints_json) > MAX_JSON_INPUT_CHARS: + raise ValueError(f"constraints_json exceeds {MAX_JSON_INPUT_CHARS} characters") decoded_constraints = json.loads(args.constraints_json) if isinstance(decoded_constraints, dict): parsed_constraints = decoded_constraints.get("constraints", []) @@ -396,6 +477,8 @@ def run_solve_at(args: argparse.Namespace) -> int: parsed_constraints = decoded_constraints if not isinstance(parsed_constraints, list): raise ValueError("constraints_json constraints must be a list") + if len(parsed_constraints) > MAX_CONSTRAINTS: + raise ValueError(f"constraints_json may contain at most {MAX_CONSTRAINTS} entries") for item in parsed_constraints: if not isinstance(item, dict): @@ -415,13 +498,22 @@ def run_solve_at(args: argparse.Namespace) -> int: reg_name = normalize_register_name(project, reg_name) print(f"eval_reg[{reg_name}] = {found.solver.eval(getattr(found.regs, reg_name)):#x}") - for addr, length in parse_json_map(args.eval_memory_json, "eval_memory_json").items(): + eval_memory = parse_json_map(args.eval_memory_json, "eval_memory_json") + if len(eval_memory) > MAX_SYMBOLIC_REGIONS: + raise ValueError(f"eval_memory_json may contain at most {MAX_SYMBOLIC_REGIONS} entries") + for addr, length in eval_memory.items(): mem_addr = parse_address(str(addr)) - mem_len = int(str(length), 0) + mem_len = checked_int(length, f"eval_memory_json[{addr!r}]", 1, MAX_SYMBOLIC_BYTES) value = found.solver.eval(found.memory.load(mem_addr, mem_len), cast_to=bytes) print(f"eval_mem[{hex_addr(mem_addr)}:{mem_len}] = {value!r}") if args.eval_stdin_bytes > 0: + args.eval_stdin_bytes = checked_int( + args.eval_stdin_bytes, + "eval_stdin_bytes", + 1, + MAX_SYMBOLIC_BYTES, + ) stdin_len, symbolic_stdin = symbols["stdin"] if symbolic_stdin is None: print("eval_stdin = ") @@ -433,6 +525,7 @@ def run_solve_at(args: argparse.Namespace) -> int: def run_reachability(args: argparse.Namespace) -> int: + args.summary_limit = checked_int(args.summary_limit, "summary_limit", 1, MAX_SUMMARY_LIMIT) source = parse_address(args.reachability_from) target = parse_address(args.reachability_to) project = make_project( @@ -496,6 +589,7 @@ def run_reachability(args: argparse.Namespace) -> int: def run_cfg_summary(args: argparse.Namespace) -> int: + args.summary_limit = checked_int(args.summary_limit, "summary_limit", 1, MAX_SUMMARY_LIMIT) project = make_project( args.binary, args.pcode_language, @@ -533,12 +627,13 @@ def run_cfg_summary(args: argparse.Namespace) -> int: return 0 print("functions_sample:") - for func in list(project.kb.functions.values())[: args.summary_limit]: + for func in islice(project.kb.functions.values(), args.summary_limit): print(f" {func.addr:#x} {func.name} blocks={len(func.block_addrs_set)}") return 0 def run_callgraph_summary(args: argparse.Namespace) -> int: + args.summary_limit = checked_int(args.summary_limit, "summary_limit", 1, MAX_SUMMARY_LIMIT) project = make_project( args.binary, args.pcode_language, @@ -553,12 +648,12 @@ def run_callgraph_summary(args: argparse.Namespace) -> int: print(f"functions: {callgraph.number_of_nodes()}") print(f"calls: {callgraph.number_of_edges()}") - edges = list(callgraph.edges()) + edge_count = callgraph.number_of_edges() print("edges:") - for src, dst in edges[: args.summary_limit]: + for src, dst in islice(callgraph.edges(), args.summary_limit): print(f" {function_label(project, src)} -> {function_label(project, dst)}") - if len(edges) > args.summary_limit: - print(f" ... {len(edges) - args.summary_limit} more edges") + if edge_count > args.summary_limit: + print(f" ... {edge_count - args.summary_limit} more edges") return 0 @@ -671,7 +766,7 @@ def main() -> int: parser.add_argument("--symbolic-memory-json", default="", help="JSON object mapping address to symbolic byte length") parser.add_argument("--memory-json", default="", help="JSON object mapping address to concrete integer/hex value") parser.add_argument("--registers-json", default="", help='JSON object mapping register names to values or "svN" symbolic byte lengths') - parser.add_argument("--max-steps", type=int, default=10000, help="Maximum symbolic execution steps, or 0 for unbounded") + parser.add_argument("--max-steps", type=int, default=10000, help=f"Maximum symbolic execution steps, 1-{MAX_STEPS}") parser.add_argument("--solve-at", help="Find an address and solve/evaluate requested constraints there") parser.add_argument("--constraints-json", default="", help="JSON list of constraints, or object with a constraints list") parser.add_argument("--eval-registers", default="", help="Comma-separated register names to evaluate after solve-at") @@ -693,6 +788,14 @@ def main() -> int: rust_group.add_argument("--rust", dest="rust", action="store_true", default=True) rust_group.add_argument("--no-rust", dest="rust", action="store_false") args = parser.parse_args() + args.max_steps = checked_int(args.max_steps, "max_steps", 1, MAX_STEPS) + args.stdin_bytes = checked_int(args.stdin_bytes, "stdin_bytes", 0, MAX_SYMBOLIC_BYTES) + args.eval_stdin_bytes = checked_int(args.eval_stdin_bytes, "eval_stdin_bytes", 0, MAX_SYMBOLIC_BYTES) + args.summary_limit = checked_int(args.summary_limit, "summary_limit", 1, MAX_SUMMARY_LIMIT) + if args.block_size < 0: + raise ValueError("block_size must be non-negative") + if args.num_inst < 0: + raise ValueError("num_inst must be non-negative") if args.check: if not args.binary: @@ -744,6 +847,12 @@ def main() -> int: try: with redirect_stderr(stderr): exit_code = main() + except ValueError as exc: + captured = stderr.getvalue().strip() + if captured: + print(captured, file=sys.stderr) + print(f"error: {exc}", file=sys.stderr) + sys.exit(1) except Exception as exc: # pylint: disable=broad-exception-caught captured = stderr.getvalue().strip() if captured: diff --git a/bridge_mcp_ghidra.py b/bridge_mcp_ghidra.py index b5970f7d..1f5d6b29 100644 --- a/bridge_mcp_ghidra.py +++ b/bridge_mcp_ghidra.py @@ -49,6 +49,22 @@ # Configurable timeouts (in seconds) TIMEOUT_DECOMPILE_MAX = 1800 # Maximum decompilation timeout (30 minutes) +ANGR_HELPER_OUTPUT_MAX_CHARS = 200_000 +ANGR_JSON_INPUT_MAX_CHARS = 100_000 +ANGR_OPTIONS_JSON_MAX_CHARS = 100_000 +ANGR_MAX_SYMBOLIC_BYTES = 4096 +ANGR_MAX_TOTAL_SYMBOLIC_BYTES = 16_384 +ANGR_MAX_SYMBOLIC_ARGS = 16 +ANGR_MAX_SYMBOLIC_REGIONS = 64 +ANGR_MAX_SYMBOLIC_REGISTER_BYTES = 64 +ANGR_MAX_HOOKS = 64 +ANGR_MAX_STEPS = 100_000 +ANGR_MAX_SUMMARY_LIMIT = 500 +ANGR_MAX_BLOCK_SIZE = 4096 +ANGR_MAX_NUM_INST = 256 +ANGR_MAX_COMPARE_FUNCTIONS = 25 +ANGR_MAX_COMMENTS = 100 +ANGR_MAX_COMMENT_PREFIX_CHARS = 200 def get_http_client(): global _http_client @@ -149,21 +165,30 @@ def run_angr_helper(args: list[str], timeout: int) -> str: helper = os.environ.get("GHIDRA_MCP_ANGR_HELPER", DEFAULT_ANGR_HELPER) python = os.environ.get("GHIDRA_MCP_ANGR_PYTHON", default_angr_python()) cmd = [python, helper, *args] + effective_timeout = max(1, min(timeout, TIMEOUT_DECOMPILE_MAX)) try: - completed = subprocess.run( - cmd, - capture_output=True, - text=True, - timeout=timeout, - check=False, - ) + with tempfile.TemporaryFile("w+", encoding="utf-8", errors="replace") as stdout_file, \ + tempfile.TemporaryFile("w+", encoding="utf-8", errors="replace") as stderr_file: + completed = subprocess.run( + cmd, + stdout=stdout_file, + stderr=stderr_file, + timeout=effective_timeout, + check=False, + ) + output, output_truncated = read_limited_stream(stdout_file, ANGR_HELPER_OUTPUT_MAX_CHARS) + errors, errors_truncated = read_limited_stream(stderr_file, ANGR_HELPER_OUTPUT_MAX_CHARS) except FileNotFoundError as e: return f"Failed to start angr helper: {e}" except subprocess.TimeoutExpired: - return f"angr helper timed out after {timeout} seconds" - - output = completed.stdout.strip() - errors = completed.stderr.strip() + return f"angr helper timed out after {effective_timeout} seconds" + + errors = errors.strip() + output = output.strip() + if output_truncated: + output += truncation_note("angr stdout", ANGR_HELPER_OUTPUT_MAX_CHARS) + if errors_truncated: + errors += truncation_note("angr stderr", ANGR_HELPER_OUTPUT_MAX_CHARS) if completed.returncode == 0: return output if output else "(angr returned no output)" @@ -197,11 +222,266 @@ def angryghidra_missing_message() -> str: def parse_optional_json(value: str, field_name: str): if not value: return None + if len(value) > ANGR_JSON_INPUT_MAX_CHARS: + raise ValueError(f"{field_name} exceeds {ANGR_JSON_INPUT_MAX_CHARS} characters") try: return json.loads(value) except json.JSONDecodeError as e: raise ValueError(f"{field_name} must be valid JSON: {e}") from e +def bounded_int(value: int, field_name: str, minimum: int, maximum: int) -> tuple[int, str]: + try: + parsed = int(value) + except (TypeError, ValueError): + return minimum, f"{field_name} must be an integer" + if parsed < minimum or parsed > maximum: + return parsed, f"{field_name} must be between {minimum} and {maximum}" + return parsed, "" + +def read_limited_stream(stream, limit: int) -> tuple[str, bool]: + stream.seek(0) + text = stream.read(limit + 1) + if len(text) > limit: + return text[:limit], True + return text, False + +def truncation_note(label: str, limit: int) -> str: + return f"\n\n[{label} truncated after {limit} characters]" + +def parse_capped_csv_ints(value: str, field_name: str, max_items: int, max_value: int) -> tuple[list[int], str]: + if not value: + return [], "" + parts = [part.strip() for part in value.split(",") if part.strip()] + if len(parts) > max_items: + return [], f"{field_name} may contain at most {max_items} entries" + result = [] + for part in parts: + try: + parsed = int(part, 0) + except ValueError: + return [], f"{field_name} entry {part!r} is not an integer" + if parsed < 1 or parsed > max_value: + return [], f"{field_name} entries must be between 1 and {max_value}" + result.append(parsed) + return result, "" + +def parse_capped_json_map(value: str, field_name: str, max_items: int) -> tuple[dict, str]: + try: + parsed = parse_optional_json(value, field_name) + except ValueError as e: + return {}, str(e) + if parsed is None: + return {}, "" + if not isinstance(parsed, dict): + return {}, f"{field_name} must be a JSON object" + if len(parsed) > max_items: + return {}, f"{field_name} may contain at most {max_items} entries" + return parsed, "" + +def symbolic_length(value, field_name: str, max_value: int = ANGR_MAX_SYMBOLIC_BYTES) -> tuple[int, str]: + try: + parsed = int(str(value), 0) + except ValueError: + return 0, f"{field_name} must be an integer byte length" + if parsed < 1 or parsed > max_value: + return parsed, f"{field_name} must be between 1 and {max_value} bytes" + return parsed, "" + +def validate_symbolic_input_caps( + stdin_bytes: int = 0, + argv_bytes: str = "", + symbolic_memory_json: str = "", + memory_json: str = "", + registers_json: str = "", +) -> str: + if stdin_bytes: + _stdin_bytes, error = bounded_int(stdin_bytes, "stdin_bytes", 0, ANGR_MAX_SYMBOLIC_BYTES) + if error: + return error + argv_lengths, error = parse_capped_csv_ints( + argv_bytes, + "argv_bytes", + ANGR_MAX_SYMBOLIC_ARGS, + ANGR_MAX_SYMBOLIC_BYTES, + ) + if error: + return error + total_symbolic = sum(argv_lengths) + max(0, int(stdin_bytes or 0)) + + symbolic_memory, error = parse_capped_json_map( + symbolic_memory_json, + "symbolic_memory_json", + ANGR_MAX_SYMBOLIC_REGIONS, + ) + if error: + return error + for addr, length in symbolic_memory.items(): + byte_len, error = symbolic_length(length, f"symbolic_memory_json[{addr!r}]") + if error: + return error + total_symbolic += byte_len + + memory, error = parse_capped_json_map( + memory_json, + "memory_json", + ANGR_MAX_SYMBOLIC_REGIONS, + ) + if error: + return error + for addr, value in memory.items(): + try: + concrete = int(str(value), 0) + except ValueError: + return f"memory_json[{addr!r}] must be an integer or hex string" + if concrete < 0: + return f"memory_json[{addr!r}] must be non-negative" + byte_len = max(1, (concrete.bit_length() + 7) // 8) + if byte_len > ANGR_MAX_SYMBOLIC_BYTES: + return f"memory_json[{addr!r}] may contain at most {ANGR_MAX_SYMBOLIC_BYTES} bytes" + + registers, error = parse_capped_json_map( + registers_json, + "registers_json", + ANGR_MAX_SYMBOLIC_REGIONS, + ) + if error: + return error + for reg_name, value in registers.items(): + if isinstance(value, str) and value.startswith("sv"): + byte_len, error = symbolic_length( + value[2:], + f"registers_json[{reg_name!r}]", + ANGR_MAX_SYMBOLIC_REGISTER_BYTES, + ) + if error: + return error + total_symbolic += byte_len + + if total_symbolic > ANGR_MAX_TOTAL_SYMBOLIC_BYTES: + return f"total symbolic input may not exceed {ANGR_MAX_TOTAL_SYMBOLIC_BYTES} bytes" + return "" + +def validate_solver_caps( + constraints_json: str = "", + eval_memory_json: str = "", + eval_stdin_bytes: int = 0, +) -> str: + try: + constraints = parse_optional_json(constraints_json, "constraints_json") + except ValueError as e: + return str(e) + if isinstance(constraints, dict): + constraints = constraints.get("constraints", []) + if constraints is None: + constraints = [] + if not isinstance(constraints, list): + return "constraints_json constraints must be a list" + if len(constraints) > ANGR_MAX_SYMBOLIC_REGIONS * 2: + return f"constraints_json may contain at most {ANGR_MAX_SYMBOLIC_REGIONS * 2} entries" + + eval_memory, error = parse_capped_json_map( + eval_memory_json, + "eval_memory_json", + ANGR_MAX_SYMBOLIC_REGIONS, + ) + if error: + return error + for addr, length in eval_memory.items(): + _byte_len, error = symbolic_length(length, f"eval_memory_json[{addr!r}]") + if error: + return error + + if eval_stdin_bytes: + _eval_stdin_bytes, error = bounded_int( + eval_stdin_bytes, + "eval_stdin_bytes", + 0, + ANGR_MAX_SYMBOLIC_BYTES, + ) + if error: + return error + return "" + +def validate_angryghidra_options(options: dict) -> str: + encoded = json.dumps(options) + if len(encoded) > ANGR_OPTIONS_JSON_MAX_CHARS: + return f"AngryGhidra options exceed {ANGR_OPTIONS_JSON_MAX_CHARS} characters" + + arguments = options.get("arguments", {}) + if arguments and not isinstance(arguments, dict): + return "AngryGhidra arguments must be a JSON object" + if len(arguments) > ANGR_MAX_SYMBOLIC_ARGS: + return f"AngryGhidra arguments may contain at most {ANGR_MAX_SYMBOLIC_ARGS} entries" + total_symbolic = 0 + for key, value in arguments.items(): + byte_len, error = symbolic_length(value, f"arguments[{key!r}]") + if error: + return error + total_symbolic += byte_len + + vectors = options.get("vectors", {}) + if vectors and not isinstance(vectors, dict): + return "AngryGhidra vectors must be a JSON object" + if len(vectors) > ANGR_MAX_SYMBOLIC_REGIONS: + return f"AngryGhidra vectors may contain at most {ANGR_MAX_SYMBOLIC_REGIONS} entries" + for addr, length in vectors.items(): + byte_len, error = symbolic_length(length, f"vectors[{addr!r}]") + if error: + return error + total_symbolic += byte_len + + mem_store = options.get("mem_store", {}) + if mem_store and not isinstance(mem_store, dict): + return "AngryGhidra mem_store must be a JSON object" + if len(mem_store) > ANGR_MAX_SYMBOLIC_REGIONS: + return f"AngryGhidra mem_store may contain at most {ANGR_MAX_SYMBOLIC_REGIONS} entries" + for addr, value in mem_store.items(): + text_value = str(value) + if len(text_value.removeprefix("0x")) > ANGR_MAX_SYMBOLIC_BYTES * 2: + return f"mem_store[{addr!r}] may contain at most {ANGR_MAX_SYMBOLIC_BYTES} bytes" + + regs_vals = options.get("regs_vals", {}) + if regs_vals and not isinstance(regs_vals, dict): + return "AngryGhidra regs_vals must be a JSON object" + if len(regs_vals) > ANGR_MAX_SYMBOLIC_REGIONS: + return f"AngryGhidra regs_vals may contain at most {ANGR_MAX_SYMBOLIC_REGIONS} entries" + for reg_name, value in regs_vals.items(): + if isinstance(value, str) and value.startswith("sv"): + byte_len, error = symbolic_length( + value[2:], + f"regs_vals[{reg_name!r}]", + ANGR_MAX_SYMBOLIC_REGISTER_BYTES, + ) + if error: + return error + total_symbolic += byte_len + + hooks = options.get("hooks", []) + if hooks and not isinstance(hooks, list): + return "AngryGhidra hooks must be a JSON array" + if len(hooks) > ANGR_MAX_HOOKS: + return f"AngryGhidra hooks may contain at most {ANGR_MAX_HOOKS} entries" + for index, hook in enumerate(hooks): + if not isinstance(hook, dict): + return f"AngryGhidra hooks[{index}] must be a JSON object" + for _address, register_updates in hook.items(): + if not isinstance(register_updates, dict): + return f"AngryGhidra hooks[{index}] values must be JSON objects" + for reg_name, value in register_updates.items(): + if isinstance(value, str) and value.startswith("sv"): + byte_len, error = symbolic_length( + value[2:], + f"hooks[{index}][{reg_name!r}]", + ANGR_MAX_SYMBOLIC_REGISTER_BYTES, + ) + if error: + return error + total_symbolic += byte_len + + if total_symbolic > ANGR_MAX_TOTAL_SYMBOLIC_BYTES: + return f"total AngryGhidra symbolic input may not exceed {ANGR_MAX_TOTAL_SYMBOLIC_BYTES} bytes" + return "" + def resolve_angr_defaults(binary_path: str = "", pcode_language: str = "") -> tuple[str, str]: program_info = {} if not binary_path or not pcode_language: @@ -267,24 +547,32 @@ def run_angryghidra_options(options: dict, timeout: int) -> str: script = find_angryghidra_script() if not script: return angryghidra_missing_message() + validation_error = validate_angryghidra_options(options) + if validation_error: + return validation_error python = os.environ.get("ANGRYGHIDRA_PYTHON") or os.environ.get("GHIDRA_MCP_ANGR_PYTHON") or default_angr_python() options_path = "" + effective_timeout = max(1, min(timeout, TIMEOUT_DECOMPILE_MAX)) try: with tempfile.NamedTemporaryFile("w", suffix="-angryghidra.json", delete=False) as options_file: json.dump(options, options_file) options_path = options_file.name - completed = subprocess.run( - [python, script, options_path], - capture_output=True, - text=True, - timeout=max(1, timeout), - check=False, - ) + with tempfile.TemporaryFile("w+", encoding="utf-8", errors="replace") as stdout_file, \ + tempfile.TemporaryFile("w+", encoding="utf-8", errors="replace") as stderr_file: + completed = subprocess.run( + [python, script, options_path], + stdout=stdout_file, + stderr=stderr_file, + timeout=effective_timeout, + check=False, + ) + output, output_truncated = read_limited_stream(stdout_file, ANGR_HELPER_OUTPUT_MAX_CHARS) + errors, errors_truncated = read_limited_stream(stderr_file, ANGR_HELPER_OUTPUT_MAX_CHARS) except FileNotFoundError as e: return f"Failed to start AngryGhidra: {e}" except subprocess.TimeoutExpired: - return f"AngryGhidra timed out after {timeout} seconds" + return f"AngryGhidra timed out after {effective_timeout} seconds" finally: if options_path: try: @@ -292,8 +580,12 @@ def run_angryghidra_options(options: dict, timeout: int) -> str: except OSError: pass - output = completed.stdout.strip() - errors = completed.stderr.strip() + output = output.strip() + errors = errors.strip() + if output_truncated: + output += truncation_note("AngryGhidra stdout", ANGR_HELPER_OUTPUT_MAX_CHARS) + if errors_truncated: + errors += truncation_note("AngryGhidra stderr", ANGR_HELPER_OUTPUT_MAX_CHARS) if completed.returncode == 0: return output if output else "(AngryGhidra returned no solution)" @@ -644,6 +936,19 @@ def angr_symbolic_find( requested_engine = engine.lower().strip() if requested_engine not in {"auto", "angryghidra", "core"}: return 'engine must be one of: "auto", "angryghidra", "core"' + _max_steps, error = bounded_int(max_steps, "max_steps", 1, ANGR_MAX_STEPS) + if error: + return error + max_steps = _max_steps + error = validate_symbolic_input_caps( + stdin_bytes=stdin_bytes, + argv_bytes=argv_bytes, + symbolic_memory_json=symbolic_memory_json, + memory_json=memory_json, + registers_json=registers_json, + ) + if error: + return error program_info = {} if not binary_path or not pcode_language or not base_address: @@ -689,7 +994,7 @@ def angr_symbolic_find( args = [ "--binary", binary_path, "--symbolic-find", normalize_ghidra_address(find_address), - "--max-steps", str(max(0, max_steps)), + "--max-steps", str(max_steps), ] if start_address: args.extend(["--start-address", normalize_ghidra_address(start_address)]) @@ -739,19 +1044,27 @@ def angr_annotate_symbolic_path( comment_kind: str = "disasm", comment_prefix: str = "angr symbolic path", max_comments: int = 100, + apply: bool = False, + overwrite_existing: bool = False, timeout: int = 120, max_steps: int = 10000, ) -> str: """ Run a symbolic path search and write path comments into the Ghidra program. - This is an explicit write endpoint. comment_kind may be "disasm", + This endpoint previews by default. Set apply=True and + overwrite_existing=True to write comments, because the current Ghidra + comment endpoints replace existing comments. comment_kind may be "disasm", "decomp", or "both"; comments are applied only when a trace/path is found. - The underlying path search prefers AngryGhidra in engine="auto" when the - request fits AngryGhidra's native script. """ if comment_kind not in {"disasm", "decomp", "both"}: return 'comment_kind must be one of: "disasm", "decomp", "both"' + _max_comments, error = bounded_int(max_comments, "max_comments", 1, ANGR_MAX_COMMENTS) + if error: + return error + max_comments = _max_comments + if len(comment_prefix) > ANGR_MAX_COMMENT_PREFIX_CHARS: + return f"comment_prefix may contain at most {ANGR_MAX_COMMENT_PREFIX_CHARS} characters" result = angr_symbolic_find( find_address=find_address, @@ -776,6 +1089,27 @@ def angr_annotate_symbolic_path( if not trace_addresses: return f"{result}\n\nNo trace addresses found; no comments were written." + total = len(trace_addresses) + normalized_target = normalize_ghidra_address(find_address) + preview = [ + f"{address}: {comment_prefix}: step {index}/{total} toward {normalized_target}" + for index, address in enumerate(trace_addresses, start=1) + ] + + if not apply: + return ( + f"{result}\n\n" + f"Preview only: {len(trace_addresses)} trace comment(s) would be written. " + "Call with apply=True and overwrite_existing=True to write them.\n" + + "\n".join(preview) + ) + if not overwrite_existing: + return ( + f"{result}\n\n" + "Refusing to write comments because Ghidra's comment endpoints replace existing comments. " + "Call with overwrite_existing=True to confirm." + ) + endpoints = [] if comment_kind in {"disasm", "both"}: endpoints.append("set_disassembly_comment") @@ -783,8 +1117,6 @@ def angr_annotate_symbolic_path( endpoints.append("set_decompiler_comment") writes = [] - total = len(trace_addresses) - normalized_target = normalize_ghidra_address(find_address) for index, address in enumerate(trace_addresses, start=1): comment = f"{comment_prefix}: step {index}/{total} toward {normalized_target}" for endpoint in endpoints: @@ -829,11 +1161,31 @@ def angr_solve_constraints_at( missing = require_binary_path(binary_path) if missing: return missing + _max_steps, error = bounded_int(max_steps, "max_steps", 1, ANGR_MAX_STEPS) + if error: + return error + max_steps = _max_steps + error = validate_symbolic_input_caps( + stdin_bytes=stdin_bytes, + argv_bytes=argv_bytes, + symbolic_memory_json=symbolic_memory_json, + memory_json=memory_json, + registers_json=registers_json, + ) + if error: + return error + error = validate_solver_caps( + constraints_json=constraints_json, + eval_memory_json=eval_memory_json, + eval_stdin_bytes=eval_stdin_bytes, + ) + if error: + return error args = [ "--binary", binary_path, "--solve-at", normalize_ghidra_address(address), - "--max-steps", str(max(0, max_steps)), + "--max-steps", str(max_steps), ] if start_address: args.extend(["--start-address", normalize_ghidra_address(start_address)]) @@ -884,6 +1236,10 @@ def angr_reachability( missing = require_binary_path(binary_path) if missing: return missing + _summary_limit, error = bounded_int(summary_limit, "summary_limit", 1, ANGR_MAX_SUMMARY_LIMIT) + if error: + return error + summary_limit = _summary_limit args = [ "--binary", binary_path, @@ -915,6 +1271,10 @@ def angr_cfg_summary( missing = require_binary_path(binary_path) if missing: return missing + _summary_limit, error = bounded_int(summary_limit, "summary_limit", 1, ANGR_MAX_SUMMARY_LIMIT) + if error: + return error + summary_limit = _summary_limit args = [ "--binary", binary_path, @@ -944,6 +1304,10 @@ def angr_callgraph_summary( missing = require_binary_path(binary_path) if missing: return missing + _summary_limit, error = bounded_int(summary_limit, "summary_limit", 1, ANGR_MAX_SUMMARY_LIMIT) + if error: + return error + summary_limit = _summary_limit args = [ "--binary", binary_path, @@ -975,6 +1339,16 @@ def angr_lift_block( return missing if lift_format not in {"vex", "ail", "both"}: return "lift_format must be one of: vex, ail, both" + if block_size: + _block_size, error = bounded_int(block_size, "block_size", 1, ANGR_MAX_BLOCK_SIZE) + if error: + return error + block_size = _block_size + if num_inst: + _num_inst, error = bounded_int(num_inst, "num_inst", 1, ANGR_MAX_NUM_INST) + if error: + return error + num_inst = _num_inst args = [ "--binary", binary_path, @@ -1008,6 +1382,10 @@ def angr_compare_decompilers( missing = require_binary_path(binary_path) if missing: return missing + _max_functions, error = bounded_int(max_functions, "max_functions", 1, ANGR_MAX_COMPARE_FUNCTIONS) + if error: + return error + max_functions = _max_functions selected_addresses = split_addresses(addresses, max_functions) if not selected_addresses: diff --git a/tests/test_bridge.py b/tests/test_bridge.py index 1c172c7c..3e8efbe6 100644 --- a/tests/test_bridge.py +++ b/tests/test_bridge.py @@ -363,7 +363,33 @@ def test_angr_symbolic_find_forced_angryghidra_missing_is_clear( assert "AngryGhidra is not installed or configured" in out - def test_angr_annotate_symbolic_path_writes_trace_comments( + def test_angr_annotate_symbolic_path_previews_by_default( + self, bridge_module, monkeypatch): + monkeypatch.setattr( + bridge_module, + "angr_symbolic_find", + lambda **_kwargs: "engine: AngryGhidra\nt:0x401000\nt:0x401020") + + out = bridge_module.angr_annotate_symbolic_path(find_address="0x401020") + + assert "Preview only: 2 trace comment(s) would be written" in out + assert "0x401000: angr symbolic path: step 1/2 toward 0x401020" in out + + def test_angr_annotate_symbolic_path_requires_overwrite_confirmation( + self, bridge_module, monkeypatch): + monkeypatch.setattr( + bridge_module, + "angr_symbolic_find", + lambda **_kwargs: "engine: AngryGhidra\nt:0x401000") + + out = bridge_module.angr_annotate_symbolic_path( + find_address="0x401020", + apply=True) + + assert "Refusing to write comments" in out + assert "overwrite_existing=True" in out + + def test_angr_annotate_symbolic_path_writes_trace_comments_with_confirmation( self, bridge_module, httpx_mock, monkeypatch): monkeypatch.setattr( bridge_module, @@ -388,11 +414,31 @@ def test_angr_annotate_symbolic_path_writes_trace_comments( out = bridge_module.angr_annotate_symbolic_path( find_address="0x401020", - comment_kind="disasm") + comment_kind="disasm", + apply=True, + overwrite_existing=True) assert "Annotated 2 trace address(es)" in out assert "set_disassembly_comment 0x401000: Comment set successfully" in out + def test_angr_symbolic_find_rejects_unbounded_max_steps( + self, bridge_module): + out = bridge_module.angr_symbolic_find( + find_address="0x401020", + binary_path="/tmp/a.out", + max_steps=0) + + assert "max_steps must be between 1" in out + + def test_angr_symbolic_find_rejects_huge_symbolic_stdin( + self, bridge_module): + out = bridge_module.angr_symbolic_find( + find_address="0x401020", + binary_path="/tmp/a.out", + stdin_bytes=999999) + + assert "stdin_bytes must be between 0 and 4096" in out + def test_angr_solve_constraints_at_builds_rich_solver_args( self, bridge_module, httpx_mock, monkeypatch): httpx_mock.add_response( @@ -581,6 +627,15 @@ def fake_run(args, timeout): "--pcode-language", "eBPF:LE:64:default", ], 12) + def test_angr_compare_decompilers_rejects_excessive_batch( + self, bridge_module): + out = bridge_module.angr_compare_decompilers( + "0x120", + binary_path="/tmp/eternal.so", + max_functions=999) + + assert "max_functions must be between 1 and 25" in out + def test_angryghidra_check_setup_missing_is_clear( self, bridge_module, monkeypatch): monkeypatch.setattr(bridge_module, "find_angryghidra_script", lambda: "") From 992a7063b8eed6021bad2acf132c12f2e790b4ee Mon Sep 17 00:00:00 2001 From: rustopian <96253492+rustopian@users.noreply.github.com> Date: Sat, 23 May 2026 14:37:18 +0100 Subject: [PATCH 4/6] Deduplicate angr subprocess handling --- bridge_mcp_ghidra.py | 91 ++++++++++++++++++++------------------------ 1 file changed, 41 insertions(+), 50 deletions(-) diff --git a/bridge_mcp_ghidra.py b/bridge_mcp_ghidra.py index 1f5d6b29..222fe8a6 100644 --- a/bridge_mcp_ghidra.py +++ b/bridge_mcp_ghidra.py @@ -165,37 +165,16 @@ def run_angr_helper(args: list[str], timeout: int) -> str: helper = os.environ.get("GHIDRA_MCP_ANGR_HELPER", DEFAULT_ANGR_HELPER) python = os.environ.get("GHIDRA_MCP_ANGR_PYTHON", default_angr_python()) cmd = [python, helper, *args] - effective_timeout = max(1, min(timeout, TIMEOUT_DECOMPILE_MAX)) - try: - with tempfile.TemporaryFile("w+", encoding="utf-8", errors="replace") as stdout_file, \ - tempfile.TemporaryFile("w+", encoding="utf-8", errors="replace") as stderr_file: - completed = subprocess.run( - cmd, - stdout=stdout_file, - stderr=stderr_file, - timeout=effective_timeout, - check=False, - ) - output, output_truncated = read_limited_stream(stdout_file, ANGR_HELPER_OUTPUT_MAX_CHARS) - errors, errors_truncated = read_limited_stream(stderr_file, ANGR_HELPER_OUTPUT_MAX_CHARS) - except FileNotFoundError as e: - return f"Failed to start angr helper: {e}" - except subprocess.TimeoutExpired: - return f"angr helper timed out after {effective_timeout} seconds" - - errors = errors.strip() - output = output.strip() - if output_truncated: - output += truncation_note("angr stdout", ANGR_HELPER_OUTPUT_MAX_CHARS) - if errors_truncated: - errors += truncation_note("angr stderr", ANGR_HELPER_OUTPUT_MAX_CHARS) - if completed.returncode == 0: + returncode, output, errors, error = run_limited_subprocess(cmd, "angr helper", timeout) + if error: + return error + if returncode == 0: return output if output else "(angr returned no output)" details = output if errors: details = f"{details}\n\nstderr:\n{errors}" if details else f"stderr:\n{errors}" - return f"angr helper failed with exit code {completed.returncode}\n\n{details}".strip() + return f"angr helper failed with exit code {returncode}\n\n{details}".strip() def find_angryghidra_script() -> str: candidates = [ @@ -248,6 +227,33 @@ def read_limited_stream(stream, limit: int) -> tuple[str, bool]: def truncation_note(label: str, limit: int) -> str: return f"\n\n[{label} truncated after {limit} characters]" +def run_limited_subprocess(cmd: list[str], label: str, timeout: int) -> tuple[int | None, str, str, str]: + effective_timeout = max(1, min(timeout, TIMEOUT_DECOMPILE_MAX)) + try: + with tempfile.TemporaryFile("w+", encoding="utf-8", errors="replace") as stdout_file, \ + tempfile.TemporaryFile("w+", encoding="utf-8", errors="replace") as stderr_file: + completed = subprocess.run( + cmd, + stdout=stdout_file, + stderr=stderr_file, + timeout=effective_timeout, + check=False, + ) + output, output_truncated = read_limited_stream(stdout_file, ANGR_HELPER_OUTPUT_MAX_CHARS) + errors, errors_truncated = read_limited_stream(stderr_file, ANGR_HELPER_OUTPUT_MAX_CHARS) + except FileNotFoundError as e: + return None, "", "", f"Failed to start {label}: {e}" + except subprocess.TimeoutExpired: + return None, "", "", f"{label} timed out after {effective_timeout} seconds" + + output = output.strip() + errors = errors.strip() + if output_truncated: + output += truncation_note(f"{label} stdout", ANGR_HELPER_OUTPUT_MAX_CHARS) + if errors_truncated: + errors += truncation_note(f"{label} stderr", ANGR_HELPER_OUTPUT_MAX_CHARS) + return completed.returncode, output, errors, "" + def parse_capped_csv_ints(value: str, field_name: str, max_items: int, max_value: int) -> tuple[list[int], str]: if not value: return [], "" @@ -553,26 +559,15 @@ def run_angryghidra_options(options: dict, timeout: int) -> str: python = os.environ.get("ANGRYGHIDRA_PYTHON") or os.environ.get("GHIDRA_MCP_ANGR_PYTHON") or default_angr_python() options_path = "" - effective_timeout = max(1, min(timeout, TIMEOUT_DECOMPILE_MAX)) try: with tempfile.NamedTemporaryFile("w", suffix="-angryghidra.json", delete=False) as options_file: json.dump(options, options_file) options_path = options_file.name - with tempfile.TemporaryFile("w+", encoding="utf-8", errors="replace") as stdout_file, \ - tempfile.TemporaryFile("w+", encoding="utf-8", errors="replace") as stderr_file: - completed = subprocess.run( - [python, script, options_path], - stdout=stdout_file, - stderr=stderr_file, - timeout=effective_timeout, - check=False, - ) - output, output_truncated = read_limited_stream(stdout_file, ANGR_HELPER_OUTPUT_MAX_CHARS) - errors, errors_truncated = read_limited_stream(stderr_file, ANGR_HELPER_OUTPUT_MAX_CHARS) - except FileNotFoundError as e: - return f"Failed to start AngryGhidra: {e}" - except subprocess.TimeoutExpired: - return f"AngryGhidra timed out after {effective_timeout} seconds" + returncode, output, errors, error = run_limited_subprocess( + [python, script, options_path], + "AngryGhidra", + timeout, + ) finally: if options_path: try: @@ -580,19 +575,15 @@ def run_angryghidra_options(options: dict, timeout: int) -> str: except OSError: pass - output = output.strip() - errors = errors.strip() - if output_truncated: - output += truncation_note("AngryGhidra stdout", ANGR_HELPER_OUTPUT_MAX_CHARS) - if errors_truncated: - errors += truncation_note("AngryGhidra stderr", ANGR_HELPER_OUTPUT_MAX_CHARS) - if completed.returncode == 0: + if error: + return error + if returncode == 0: return output if output else "(AngryGhidra returned no solution)" details = output if errors: details = f"{details}\n\nstderr:\n{errors}" if details else f"stderr:\n{errors}" - return f"AngryGhidra failed with exit code {completed.returncode}\n\n{details}".strip() + return f"AngryGhidra failed with exit code {returncode}\n\n{details}".strip() def build_angryghidra_symbolic_options( find_address: str, From 0fde5d9ada22111a88be3e9e50bd2d247aff1d26 Mon Sep 17 00:00:00 2001 From: rustopian <96253492+rustopian@users.noreply.github.com> Date: Sat, 23 May 2026 14:52:33 +0100 Subject: [PATCH 5/6] Require preview token before angr annotations --- bridge_mcp_ghidra.py | 147 ++++++++++++++++++++++++++++++++++++++----- tests/conftest.py | 2 + tests/test_bridge.py | 50 ++++++++++++++- 3 files changed, 182 insertions(+), 17 deletions(-) diff --git a/bridge_mcp_ghidra.py b/bridge_mcp_ghidra.py index 222fe8a6..fa2670a5 100644 --- a/bridge_mcp_ghidra.py +++ b/bridge_mcp_ghidra.py @@ -16,6 +16,8 @@ import logging import subprocess import tempfile +import time +import hashlib from urllib.parse import urljoin from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type @@ -65,6 +67,10 @@ ANGR_MAX_COMPARE_FUNCTIONS = 25 ANGR_MAX_COMMENTS = 100 ANGR_MAX_COMMENT_PREFIX_CHARS = 200 +ANGR_PREVIEW_CONFIRMATION_TTL_SECONDS = 600 +ANGR_PREVIEW_CONFIRMATION_MAX_ENTRIES = 128 + +_angr_annotation_previews: dict[str, float] = {} def get_http_client(): global _http_client @@ -680,6 +686,59 @@ def extract_trace_addresses(output: str) -> list[str]: seen.add(address) return deduped +def annotation_preview_token(payload: dict) -> str: + encoded = json.dumps(payload, sort_keys=True, separators=(",", ":")) + return hashlib.sha256(encoded.encode("utf-8")).hexdigest()[:24] + +def prune_annotation_previews(now: float | None = None) -> None: + if now is None: + now = time.time() + expired = [ + token for token, timestamp in _angr_annotation_previews.items() + if now - timestamp > ANGR_PREVIEW_CONFIRMATION_TTL_SECONDS + ] + for token in expired: + _angr_annotation_previews.pop(token, None) + while len(_angr_annotation_previews) > ANGR_PREVIEW_CONFIRMATION_MAX_ENTRIES: + oldest = min(_angr_annotation_previews, key=_angr_annotation_previews.get) + _angr_annotation_previews.pop(oldest, None) + +def remember_annotation_preview(payload: dict) -> str: + now = time.time() + prune_annotation_previews(now) + token = annotation_preview_token(payload) + _angr_annotation_previews[token] = now + return token + +def validate_annotation_preview(payload: dict, preview_token: str) -> str: + prune_annotation_previews() + expected_token = annotation_preview_token(payload) + if not preview_token: + return missing_annotation_preview_message() + if preview_token != expected_token: + return ( + "Refusing to write comments because preview_token does not match this exact " + "annotation request. Re-run the preview with the same arguments you want to " + "apply, then use the preview_token from that response." + ) + if expected_token not in _angr_annotation_previews: + return ( + "Refusing to write comments because no recent preview was found for this exact " + "annotation request. Preview confirmations expire after " + f"{ANGR_PREVIEW_CONFIRMATION_TTL_SECONDS // 60} minutes; re-run the preview " + "and retry with the new preview_token." + ) + _angr_annotation_previews.pop(expected_token, None) + return "" + +def missing_annotation_preview_message() -> str: + return ( + "Refusing to write comments because no matching preview token was provided. " + "First call angr_annotate_symbolic_path with the same arguments and " + "apply=False, overwrite_existing=False. Review the preview, then retry " + "with apply=True, overwrite_existing=True, and the preview_token from that response." + ) + @mcp.tool() def list_methods(offset: int = 0, limit: int = 100) -> list: """ @@ -1037,16 +1096,19 @@ def angr_annotate_symbolic_path( max_comments: int = 100, apply: bool = False, overwrite_existing: bool = False, + preview_token: str = "", timeout: int = 120, max_steps: int = 10000, ) -> str: """ Run a symbolic path search and write path comments into the Ghidra program. - This endpoint previews by default. Set apply=True and - overwrite_existing=True to write comments, because the current Ghidra - comment endpoints replace existing comments. comment_kind may be "disasm", - "decomp", or "both"; comments are applied only when a trace/path is found. + This endpoint previews by default. To write comments, first review a preview + response, then call again with the same arguments, apply=True, + overwrite_existing=True, and the preview_token from that preview. The token + check is required because the current Ghidra comment endpoints replace + existing comments. comment_kind may be "disasm", "decomp", or "both"; + comments are applied only when a trace/path is found. """ if comment_kind not in {"disasm", "decomp", "both"}: return 'comment_kind must be one of: "disasm", "decomp", "both"' @@ -1056,6 +1118,13 @@ def angr_annotate_symbolic_path( max_comments = _max_comments if len(comment_prefix) > ANGR_MAX_COMMENT_PREFIX_CHARS: return f"comment_prefix may contain at most {ANGR_MAX_COMMENT_PREFIX_CHARS} characters" + if apply and not overwrite_existing: + return ( + "Refusing to write comments because Ghidra's comment endpoints replace existing comments. " + "Call with overwrite_existing=True to confirm after reviewing a preview." + ) + if apply and overwrite_existing and not preview_token: + return missing_annotation_preview_message() result = angr_symbolic_find( find_address=find_address, @@ -1086,12 +1155,60 @@ def angr_annotate_symbolic_path( f"{address}: {comment_prefix}: step {index}/{total} toward {normalized_target}" for index, address in enumerate(trace_addresses, start=1) ] + endpoints = [] + if comment_kind in {"disasm", "both"}: + endpoints.append("set_disassembly_comment") + if comment_kind in {"decomp", "both"}: + endpoints.append("set_decompiler_comment") + planned_writes = [ + { + "endpoint": endpoint, + "address": address, + "comment": f"{comment_prefix}: step {index}/{total} toward {normalized_target}", + } + for index, address in enumerate(trace_addresses, start=1) + for endpoint in endpoints + ] + preview_payload = { + "tool": "angr_annotate_symbolic_path", + "request": { + "find_address": normalize_ghidra_address(find_address), + "binary_path": binary_path, + "start_address": normalize_ghidra_address(start_address), + "avoid_addresses": ",".join( + normalize_ghidra_address(address) + for address in avoid_addresses.split(",") + if address.strip() + ), + "pcode_language": pcode_language, + "base_address": normalize_ghidra_address(base_address), + "raw_binary_arch": raw_binary_arch, + "auto_load_libs": auto_load_libs, + "stdin_bytes": stdin_bytes, + "argv_bytes": argv_bytes, + "symbolic_memory_json": symbolic_memory_json, + "memory_json": memory_json, + "registers_json": registers_json, + "engine": engine, + "comment_kind": comment_kind, + "comment_prefix": comment_prefix, + "max_comments": max_comments, + "timeout": max(1, timeout), + "max_steps": max_steps, + }, + "trace_addresses": trace_addresses, + "planned_writes": planned_writes, + } if not apply: + token = remember_annotation_preview(preview_payload) return ( f"{result}\n\n" f"Preview only: {len(trace_addresses)} trace comment(s) would be written. " - "Call with apply=True and overwrite_existing=True to write them.\n" + "Review the preview, then call with the same arguments, apply=True, " + f"overwrite_existing=True, and preview_token=\"{token}\" to write them. " + f"Preview tokens expire after {ANGR_PREVIEW_CONFIRMATION_TTL_SECONDS // 60} minutes.\n" + f"preview_token: {token}\n" + "\n".join(preview) ) if not overwrite_existing: @@ -1100,19 +1217,17 @@ def angr_annotate_symbolic_path( "Refusing to write comments because Ghidra's comment endpoints replace existing comments. " "Call with overwrite_existing=True to confirm." ) - - endpoints = [] - if comment_kind in {"disasm", "both"}: - endpoints.append("set_disassembly_comment") - if comment_kind in {"decomp", "both"}: - endpoints.append("set_decompiler_comment") + preview_error = validate_annotation_preview(preview_payload, preview_token) + if preview_error: + return f"{result}\n\n{preview_error}" writes = [] - for index, address in enumerate(trace_addresses, start=1): - comment = f"{comment_prefix}: step {index}/{total} toward {normalized_target}" - for endpoint in endpoints: - response = safe_post(endpoint, {"address": address, "comment": comment}) - writes.append(f"{endpoint} {address}: {response}") + for planned_write in planned_writes: + response = safe_post( + planned_write["endpoint"], + {"address": planned_write["address"], "comment": planned_write["comment"]}, + ) + writes.append(f"{planned_write['endpoint']} {planned_write['address']}: {response}") return ( f"{result}\n\n" diff --git a/tests/conftest.py b/tests/conftest.py index d233420a..55a2ddfa 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -40,6 +40,7 @@ def _reset_http_client(bridge_module): Per-test: clear the cached singleton so pytest-httpx intercepts cleanly. """ bridge_module._http_client = None + bridge_module._angr_annotation_previews.clear() yield if bridge_module._http_client is not None: try: @@ -47,3 +48,4 @@ def _reset_http_client(bridge_module): except Exception: pass bridge_module._http_client = None + bridge_module._angr_annotation_previews.clear() diff --git a/tests/test_bridge.py b/tests/test_bridge.py index 3e8efbe6..b90d436a 100644 --- a/tests/test_bridge.py +++ b/tests/test_bridge.py @@ -22,6 +22,13 @@ def _url(endpoint: str) -> str: return BASE_URL + endpoint +def _preview_token(output: str) -> str: + for line in output.splitlines(): + if line.startswith("preview_token: "): + return line.split(": ", 1)[1] + raise AssertionError("preview_token not found") + + # --------------------------------------------------------------------------- # Listing endpoints (GET ?offset=&limit=) # --------------------------------------------------------------------------- @@ -373,6 +380,7 @@ def test_angr_annotate_symbolic_path_previews_by_default( out = bridge_module.angr_annotate_symbolic_path(find_address="0x401020") assert "Preview only: 2 trace comment(s) would be written" in out + assert "preview_token: " in out assert "0x401000: angr symbolic path: step 1/2 toward 0x401020" in out def test_angr_annotate_symbolic_path_requires_overwrite_confirmation( @@ -389,6 +397,41 @@ def test_angr_annotate_symbolic_path_requires_overwrite_confirmation( assert "Refusing to write comments" in out assert "overwrite_existing=True" in out + def test_angr_annotate_symbolic_path_requires_matching_preview_token( + self, bridge_module, monkeypatch): + monkeypatch.setattr( + bridge_module, + "angr_symbolic_find", + lambda **_kwargs: "engine: AngryGhidra\nt:0x401000") + + out = bridge_module.angr_annotate_symbolic_path( + find_address="0x401020", + apply=True, + overwrite_existing=True) + + assert "no matching preview token" in out + assert "First call angr_annotate_symbolic_path" in out + + def test_angr_annotate_symbolic_path_rejects_changed_call_after_preview( + self, bridge_module, monkeypatch): + monkeypatch.setattr( + bridge_module, + "angr_symbolic_find", + lambda **_kwargs: "engine: AngryGhidra\nt:0x401000") + preview = bridge_module.angr_annotate_symbolic_path( + find_address="0x401020", + comment_prefix="previewed path") + token = _preview_token(preview) + + out = bridge_module.angr_annotate_symbolic_path( + find_address="0x401020", + comment_prefix="changed path", + apply=True, + overwrite_existing=True, + preview_token=token) + + assert "preview_token does not match this exact annotation request" in out + def test_angr_annotate_symbolic_path_writes_trace_comments_with_confirmation( self, bridge_module, httpx_mock, monkeypatch): monkeypatch.setattr( @@ -412,11 +455,16 @@ def test_angr_annotate_symbolic_path_writes_trace_comments_with_confirmation( ), text="Comment set successfully") + preview = bridge_module.angr_annotate_symbolic_path( + find_address="0x401020", + comment_kind="disasm") + token = _preview_token(preview) out = bridge_module.angr_annotate_symbolic_path( find_address="0x401020", comment_kind="disasm", apply=True, - overwrite_existing=True) + overwrite_existing=True, + preview_token=token) assert "Annotated 2 trace address(es)" in out assert "set_disassembly_comment 0x401000: Comment set successfully" in out From 3220abd2fd46bebee238434a24a8124dff907fc4 Mon Sep 17 00:00:00 2001 From: rustopian <96253492+rustopian@users.noreply.github.com> Date: Sat, 23 May 2026 14:58:16 +0100 Subject: [PATCH 6/6] Show overwritten comments in angr previews --- README.md | 10 +- bridge_mcp_ghidra.py | 98 ++++++++++++++++++- .../java/com/lauriewired/GhidraMCPPlugin.java | 91 +++++++++++++++++ tests/test_bridge.py | 81 ++++++++++++++- 4 files changed, 268 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index 194acdac..02e0ef62 100644 --- a/README.md +++ b/README.md @@ -60,10 +60,12 @@ tools. reachability; `angr_cfg_summary` and `angr_callgraph_summary` summarize recovered graph structure; `angr_lift_block` lifts a block to VEX/AIL; and `angr_compare_decompilers` batches Ghidra-vs-Oxidizer decompiler output. -- `angr_annotate_symbolic_path` previews by default. To write the recovered - trace as Ghidra disassembly and/or decompiler comments, call it with - `apply=true` and `overwrite_existing=true`; the underlying Ghidra comment - endpoints replace existing comments. +- `angr_annotate_symbolic_path` previews by default and shows the current + comment that each planned annotation would overwrite alongside the pending + comment. To write the recovered trace as Ghidra disassembly and/or decompiler + comments, call it again with the same arguments, `apply=true`, + `overwrite_existing=true`, and the preview token from the reviewed dry run; + the underlying Ghidra comment endpoints replace existing comments. - angr/AngryGhidra execution is bounded by conservative limits on helper output, symbolic input sizes, symbolic steps, summary output, lift size, and batch comparison size. diff --git a/bridge_mcp_ghidra.py b/bridge_mcp_ghidra.py index fa2670a5..cb34b724 100644 --- a/bridge_mcp_ghidra.py +++ b/bridge_mcp_ghidra.py @@ -106,6 +106,36 @@ def safe_get(endpoint: str, params: dict = None, timeout: float = 30.0) -> list: except Exception as e: return [f"Request failed: {str(e)}"] +@retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=1, max=10), + retry=retry_if_exception_type((httpx.ConnectError, httpx.ConnectTimeout)), + reraise=True, +) +def safe_get_json(endpoint: str, params: dict = None, timeout: float = 30.0) -> dict: + """ + Perform a GET request and parse a JSON object response. + """ + if params is None: + params = {} + + url = urljoin(ghidra_server_url, endpoint) + + try: + response = get_http_client().get(url, params=params, timeout=timeout) + response.encoding = 'utf-8' + if response.status_code != 200: + return {"error": f"Error {response.status_code}: {response.text.strip()}"} + try: + parsed = response.json() + except ValueError as e: + return {"error": f"Invalid JSON response from {endpoint}: {e}"} + if not isinstance(parsed, dict): + return {"error": f"Invalid JSON response from {endpoint}: expected object"} + return parsed + except Exception as e: + return {"error": f"Request failed: {str(e)}"} + @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=1, max=10), @@ -739,6 +769,48 @@ def missing_annotation_preview_message() -> str: "with apply=True, overwrite_existing=True, and the preview_token from that response." ) +def comment_kind_for_write_endpoint(endpoint: str) -> str: + if endpoint == "set_disassembly_comment": + return "disasm" + if endpoint == "set_decompiler_comment": + return "decomp" + return "" + +def read_current_comment(address: str, endpoint: str) -> tuple[str, str]: + kind = comment_kind_for_write_endpoint(endpoint) + if not kind: + return "", f"Cannot read current comment for unsupported endpoint {endpoint!r}" + response = safe_get_json("get_comment", {"address": address, "kind": kind}, timeout=10.0) + error = response.get("error") + if error: + return "", str(error) + return str(response.get("comment", "")), "" + +def attach_current_comments(planned_writes: list[dict]) -> str: + for planned_write in planned_writes: + current_comment, error = read_current_comment( + planned_write["address"], + planned_write["endpoint"], + ) + if error: + return error + planned_write["current_comment"] = current_comment + return "" + +def format_preview_comment(value: str) -> str: + if not value: + return "" + return value.replace("\n", "\n ") + +def format_planned_comment_preview(planned_writes: list[dict]) -> str: + lines = ["Planned comment writes:"] + for planned_write in planned_writes: + kind = comment_kind_for_write_endpoint(planned_write["endpoint"]) or planned_write["endpoint"] + lines.append(f"- {kind} {planned_write['address']}") + lines.append(f" current: {format_preview_comment(planned_write.get('current_comment', ''))}") + lines.append(f" pending: {format_preview_comment(planned_write['comment'])}") + return "\n".join(lines) + @mcp.tool() def list_methods(offset: int = 0, limit: int = 100) -> list: """ @@ -1151,10 +1223,6 @@ def angr_annotate_symbolic_path( total = len(trace_addresses) normalized_target = normalize_ghidra_address(find_address) - preview = [ - f"{address}: {comment_prefix}: step {index}/{total} toward {normalized_target}" - for index, address in enumerate(trace_addresses, start=1) - ] endpoints = [] if comment_kind in {"disasm", "both"}: endpoints.append("set_disassembly_comment") @@ -1169,6 +1237,14 @@ def angr_annotate_symbolic_path( for index, address in enumerate(trace_addresses, start=1) for endpoint in endpoints ] + current_comment_error = attach_current_comments(planned_writes) + if current_comment_error: + return ( + f"{result}\n\n" + "Refusing to create an annotation preview token because the current " + "comments that would be overwritten could not be read. Update and " + f"restart the GhidraMCP extension, then retry. Details: {current_comment_error}" + ) preview_payload = { "tool": "angr_annotate_symbolic_path", "request": { @@ -1209,7 +1285,7 @@ def angr_annotate_symbolic_path( f"overwrite_existing=True, and preview_token=\"{token}\" to write them. " f"Preview tokens expire after {ANGR_PREVIEW_CONFIRMATION_TTL_SECONDS // 60} minutes.\n" f"preview_token: {token}\n" - + "\n".join(preview) + + format_planned_comment_preview(planned_writes) ) if not overwrite_existing: return ( @@ -1682,6 +1758,18 @@ def disassemble_function(address: str) -> list: """ return safe_get("disassemble_function", {"address": address}) +@mcp.tool() +def get_comment(address: str, kind: str = "disasm") -> dict: + """ + Read the current comment that set_disasm_comment or set_decomp_comment would replace. + + kind must be "disasm" for the disassembly EOL comment or "decomp" for the + decompiler/pre comment. + """ + if kind not in {"disasm", "decomp"}: + return {"error": 'kind must be "disasm" or "decomp"'} + return safe_get_json("get_comment", {"address": address, "kind": kind}) + @mcp.tool() def set_decomp_comment(address: str, comment: str) -> str: """ diff --git a/src/main/java/com/lauriewired/GhidraMCPPlugin.java b/src/main/java/com/lauriewired/GhidraMCPPlugin.java index ff2b1a2a..5928932c 100644 --- a/src/main/java/com/lauriewired/GhidraMCPPlugin.java +++ b/src/main/java/com/lauriewired/GhidraMCPPlugin.java @@ -64,6 +64,7 @@ import java.util.*; import java.util.concurrent.*; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; @PluginInfo( status = PluginStatus.RELEASED, @@ -400,6 +401,33 @@ private void startServer() throws IOException { sendResponse(exchange, disassembleFunction(address)); }); + server.createContext("/get_comment", exchange -> { + Map qparams = parseQueryParams(exchange); + String address = qparams.get("address"); + String kind = qparams.getOrDefault("kind", "disasm"); + Integer commentType = commentTypeForKind(kind); + if (address == null || address.isEmpty()) { + sendJsonResponse(exchange, 400, "{\"error\":\"address is required\"}"); + return; + } + if (commentType == null) { + sendJsonResponse(exchange, 400, "{\"error\":\"kind must be disasm or decomp\"}"); + return; + } + + CommentLookupResult result = getCommentAtAddress(address, commentType); + if (result.errorMessage != null) { + sendJsonResponse(exchange, 400, "{\"error\":\"" + jsonEscape(result.errorMessage) + "\"}"); + return; + } + + String json = "{\"address\":\"" + jsonEscape(address) + "\"," + + "\"kind\":\"" + jsonEscape(kind) + "\"," + + "\"exists\":" + (result.comment != null) + "," + + "\"comment\":\"" + jsonEscape(result.comment == null ? "" : result.comment) + "\"}"; + sendJsonResponse(exchange, json); + }); + server.createContext("/set_decompiler_comment", exchange -> { Map params = parsePostParams(exchange); String address = params.get("address"); @@ -1394,6 +1422,61 @@ private boolean setDisassemblyComment(String addressStr, String comment) { return setCommentAtAddress(addressStr, comment, CodeUnit.EOL_COMMENT, "Set disassembly comment"); } + private Integer commentTypeForKind(String kind) { + if ("disasm".equals(kind)) { + return CodeUnit.EOL_COMMENT; + } + if ("decomp".equals(kind)) { + return CodeUnit.PRE_COMMENT; + } + return null; + } + + private static class CommentLookupResult { + private final String comment; + private final String errorMessage; + + CommentLookupResult(String comment, String errorMessage) { + this.comment = comment; + this.errorMessage = errorMessage; + } + } + + /** + * Read a comment that a set_*_comment endpoint would replace. + */ + private CommentLookupResult getCommentAtAddress(String addressStr, int commentType) { + Program program = getCurrentProgram(); + if (program == null) { + return new CommentLookupResult(null, "No current program"); + } + if (addressStr == null || addressStr.isEmpty()) { + return new CommentLookupResult(null, "address is required"); + } + + AtomicReference comment = new AtomicReference<>(); + AtomicReference error = new AtomicReference<>(); + try { + SwingUtilities.invokeAndWait(() -> { + try { + Address addr = program.getAddressFactory().getAddress(addressStr); + if (addr == null) { + error.set("Invalid address: " + addressStr); + return; + } + comment.set(program.getListing().getComment(commentType, addr)); + } catch (Exception e) { + error.set(e.getMessage()); + Msg.error(this, "Error getting comment", e); + } + }); + } catch (InterruptedException | InvocationTargetException e) { + error.set(e.getMessage()); + Msg.error(this, "Failed to execute get comment on Swing thread", e); + } + return new CommentLookupResult(comment.get(), error.get()); + } + /** * Class to hold the result of a prototype setting operation */ @@ -2809,6 +2892,14 @@ private void sendResponse(HttpExchange exchange, String response) throws IOExcep os.write(bytes); } } + + private static String jsonEscape(String s) { + if (s == null) return ""; + return s.replace("\\", "\\\\") + .replace("\"", "\\\"") + .replace("\n", "\\n") + .replace("\r", "\\r"); + } private void sendJsonResponse(HttpExchange exchange, int code, String json) throws IOException { byte[] bytes = json.getBytes(StandardCharsets.UTF_8); diff --git a/tests/test_bridge.py b/tests/test_bridge.py index b90d436a..5acc6bae 100644 --- a/tests/test_bridge.py +++ b/tests/test_bridge.py @@ -29,6 +29,19 @@ def _preview_token(output: str) -> str: raise AssertionError("preview_token not found") +def _mock_comment(httpx_mock, address: str, kind: str = "disasm", comment: str = "", count: int = 1): + for _ in range(count): + httpx_mock.add_response( + url=_url(f"get_comment?address={address}&kind={kind}"), + json={ + "address": address, + "kind": kind, + "exists": bool(comment), + "comment": comment, + }, + ) + + # --------------------------------------------------------------------------- # Listing endpoints (GET ?offset=&limit=) # --------------------------------------------------------------------------- @@ -371,17 +384,23 @@ def test_angr_symbolic_find_forced_angryghidra_missing_is_clear( assert "AngryGhidra is not installed or configured" in out def test_angr_annotate_symbolic_path_previews_by_default( - self, bridge_module, monkeypatch): + self, bridge_module, httpx_mock, monkeypatch): monkeypatch.setattr( bridge_module, "angr_symbolic_find", lambda **_kwargs: "engine: AngryGhidra\nt:0x401000\nt:0x401020") + _mock_comment(httpx_mock, "0x401000", comment="existing branch note") + _mock_comment(httpx_mock, "0x401020") out = bridge_module.angr_annotate_symbolic_path(find_address="0x401020") assert "Preview only: 2 trace comment(s) would be written" in out assert "preview_token: " in out - assert "0x401000: angr symbolic path: step 1/2 toward 0x401020" in out + assert "- disasm 0x401000" in out + assert "current: existing branch note" in out + assert "pending: angr symbolic path: step 1/2 toward 0x401020" in out + assert "- disasm 0x401020" in out + assert "current: " in out def test_angr_annotate_symbolic_path_requires_overwrite_confirmation( self, bridge_module, monkeypatch): @@ -413,11 +432,12 @@ def test_angr_annotate_symbolic_path_requires_matching_preview_token( assert "First call angr_annotate_symbolic_path" in out def test_angr_annotate_symbolic_path_rejects_changed_call_after_preview( - self, bridge_module, monkeypatch): + self, bridge_module, httpx_mock, monkeypatch): monkeypatch.setattr( bridge_module, "angr_symbolic_find", lambda **_kwargs: "engine: AngryGhidra\nt:0x401000") + _mock_comment(httpx_mock, "0x401000", count=2) preview = bridge_module.angr_annotate_symbolic_path( find_address="0x401020", comment_prefix="previewed path") @@ -432,12 +452,49 @@ def test_angr_annotate_symbolic_path_rejects_changed_call_after_preview( assert "preview_token does not match this exact annotation request" in out + def test_angr_annotate_symbolic_path_rejects_changed_current_comment( + self, bridge_module, httpx_mock, monkeypatch): + monkeypatch.setattr( + bridge_module, + "angr_symbolic_find", + lambda **_kwargs: "engine: AngryGhidra\nt:0x401000") + _mock_comment(httpx_mock, "0x401000", comment="old") + _mock_comment(httpx_mock, "0x401000", comment="changed") + preview = bridge_module.angr_annotate_symbolic_path(find_address="0x401020") + token = _preview_token(preview) + + out = bridge_module.angr_annotate_symbolic_path( + find_address="0x401020", + apply=True, + overwrite_existing=True, + preview_token=token) + + assert "preview_token does not match this exact annotation request" in out + + def test_angr_annotate_symbolic_path_requires_comment_read_before_token( + self, bridge_module, httpx_mock, monkeypatch): + monkeypatch.setattr( + bridge_module, + "angr_symbolic_find", + lambda **_kwargs: "engine: AngryGhidra\nt:0x401000") + httpx_mock.add_response( + url=_url("get_comment?address=0x401000&kind=disasm"), + status_code=404, + text="missing") + + out = bridge_module.angr_annotate_symbolic_path(find_address="0x401020") + + assert "could not be read" in out + assert "preview_token:" not in out + def test_angr_annotate_symbolic_path_writes_trace_comments_with_confirmation( self, bridge_module, httpx_mock, monkeypatch): monkeypatch.setattr( bridge_module, "angr_symbolic_find", lambda **_kwargs: "engine: AngryGhidra\nt:0x401000\nt:0x401020") + _mock_comment(httpx_mock, "0x401000", comment="existing branch note", count=2) + _mock_comment(httpx_mock, "0x401020", count=2) httpx_mock.add_response( method="POST", url=_url("set_disassembly_comment"), @@ -785,6 +842,24 @@ def test_set_decomp_comment(self, bridge_module, httpx_mock): text="Comment set successfully") bridge_module.set_decomp_comment("0x120", "hello") + def test_get_comment(self, bridge_module, httpx_mock): + httpx_mock.add_response( + url=_url("get_comment?address=0x120&kind=decomp"), + json={ + "address": "0x120", + "kind": "decomp", + "exists": True, + "comment": "existing note", + }, + ) + + assert bridge_module.get_comment("0x120", "decomp") == { + "address": "0x120", + "kind": "decomp", + "exists": True, + "comment": "existing note", + } + def test_set_disasm_comment(self, bridge_module, httpx_mock): httpx_mock.add_response( method="POST", url=_url("set_disassembly_comment"),