From 6439a2a8b1af5d74e65794507015a3bb083aa609 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Tue, 5 May 2026 17:31:36 -0700 Subject: [PATCH 1/3] Generic numeric debugging Differential Revision: D103956056 --- devtools/inspector/_inspector.py | 126 +++++++++ devtools/intermediate_output_tap/TARGETS | 85 ++++++ devtools/intermediate_output_tap/__init__.py | 99 +++++++ .../intermediate_output_tap/_convenience.py | 221 ++++++++++++++++ devtools/intermediate_output_tap/_reducers.py | 179 +++++++++++++ .../intermediate_output_tap/_selectors.py | 132 ++++++++++ devtools/intermediate_output_tap/_spec.py | 60 +++++ .../intermediate_output_tap/_strip_pass.py | 164 ++++++++++++ devtools/intermediate_output_tap/_tap_pass.py | 249 ++++++++++++++++++ .../intermediate_output_tap/custom_ops_lib.py | 43 +++ .../intermediate_output_tap/tests/TARGETS | 74 ++++++ .../tests/test_inspector_integration.py | 146 ++++++++++ .../tests/test_reducers.py | 51 ++++ .../tests/test_selectors.py | 121 +++++++++ .../tests/test_strip_pass.py | 118 +++++++++ .../tests/test_tap_pass.py | 159 +++++++++++ .../tests/test_xnnpack_e2e.py | 127 +++++++++ 17 files changed, 2154 insertions(+) create mode 100644 devtools/intermediate_output_tap/TARGETS create mode 100644 devtools/intermediate_output_tap/__init__.py create mode 100644 devtools/intermediate_output_tap/_convenience.py create mode 100644 devtools/intermediate_output_tap/_reducers.py create mode 100644 devtools/intermediate_output_tap/_selectors.py create mode 100644 devtools/intermediate_output_tap/_spec.py create mode 100644 devtools/intermediate_output_tap/_strip_pass.py create mode 100644 devtools/intermediate_output_tap/_tap_pass.py create mode 100644 devtools/intermediate_output_tap/custom_ops_lib.py create mode 100644 devtools/intermediate_output_tap/tests/TARGETS create mode 100644 devtools/intermediate_output_tap/tests/test_inspector_integration.py create mode 100644 devtools/intermediate_output_tap/tests/test_reducers.py create mode 100644 devtools/intermediate_output_tap/tests/test_selectors.py create mode 100644 devtools/intermediate_output_tap/tests/test_strip_pass.py create mode 100644 devtools/intermediate_output_tap/tests/test_tap_pass.py create mode 100644 devtools/intermediate_output_tap/tests/test_xnnpack_e2e.py diff --git a/devtools/inspector/_inspector.py b/devtools/inspector/_inspector.py index e9fbc4778f5..969906a7c10 100644 --- a/devtools/inspector/_inspector.py +++ b/devtools/inspector/_inspector.py @@ -1649,3 +1649,129 @@ def get_stacktraces_for_row(aot_ops: List[str]) -> Dict[str, Optional[str]]: df["stacktraces"] = df["aot_ops"].apply(get_stacktraces_for_row) return df + + def calculate_numeric_gap_from_taps( + self, + flat_runtime_outputs: Sequence, + tap_specs: Sequence, + distance: Union[str, NumericalComparatorBase] = "MSE", + reference_graph: Optional[str] = None, + disable_debug_handle_valdiation: bool = False, + ) -> pd.DataFrame: + """ + Compares AOT intermediate outputs (from the ETRecord, captured by + IntermediateOutputCapturer) with runtime tap values exposed as + USER_OUTPUTs by the `intermediate_output_tap` package. + + Unlike `calculate_numeric_gap`, this method works through delegates + with no backend-side support: the runtime values come from extra + outputs the AOT pass added to the ExportedProgram before lowering. + + IMPORTANT: ETRecord serialization regenerates `debug_handle`s during + roundtrip, so the handles in `tap_specs` (set at AOT-pass time) are + stale. Each spec's `reducer_node_name` (set by + `strip_taps_(edge, tap_specs=specs)`) is used to look up the + post-roundtrip `debug_handle` in the AOT reference graph. + + Args: + flat_runtime_outputs: The flat output tuple/list returned by + running the lowered program (e.g. `Method.execute(inputs)`). + tap_specs: The list of TapSpec returned by + `strip_taps_(edge, tap_specs=specs)` — these carry + `reducer_node_name` for alignment. + distance: "MSE", "L1", "SNR", or a `NumericalComparatorBase`. + reference_graph: AOT graph to use as the golden. See + `calculate_numeric_gap` for valid values. + disable_debug_handle_valdiation: Bypass debug handle validation. + + Returns: + DataFrame with one row per (aot_handle, runtime_handle) pair. + Same shape produced by `calculate_numeric_gap`. + """ + reference_graph_module, _resolved_graph_name = self._resolve_reference_graph( + reference_graph, + disable_debug_handle_valdiation, + ) + aot_intermediate_outputs, aot_debug_handle_to_op_names = ( + self._get_aot_intermediate_outputs_and_op_names(reference_graph_module) + ) + if len(aot_intermediate_outputs) == 0: + raise ValueError( + "No AOT intermediate outputs were captured. ETRecord must be " + "provided with representative_inputs for tap-based comparison." + ) + + spec_handles = _lookup_handles_by_name(reference_graph_module, tap_specs) + + runtime_intermediate_outputs: Dict[DebugHandle, Tuple[Any, int]] = {} + runtime_debug_handle_to_op_names: Dict[DebugHandle, List[str]] = {} + for spec, dh in zip(tap_specs, spec_handles): + if dh is None: + continue + key: DebugHandle = (int(dh),) + runtime_intermediate_outputs[key] = ( + flat_runtime_outputs[spec.output_index], + 1, + ) + runtime_debug_handle_to_op_names[key] = [spec.op_target] + + if not runtime_intermediate_outputs: + raise ValueError( + "Could not recover any post-roundtrip handles for tap_specs. " + "Verify that strip_taps_(edge, tap_specs=specs) was called and " + "the returned specs were passed to this method, and that " + "generate_etrecord ran AFTER strip_taps_." + ) + + mapping = map_runtime_aot_intermediate_outputs( + aot_intermediate_outputs, runtime_intermediate_outputs + ) + + if isinstance(distance, NumericalComparatorBase): + comparator = distance + if comparator.inspector is None: + comparator.inspector = self + else: + metric = distance.strip().upper() + if metric == "MSE": + comparator = MSEComparator(inspector=self) + elif metric == "L1": + comparator = L1Comparator(inspector=self) + elif metric == "SNR": + comparator = SNRComparator(inspector=self) + else: + raise ValueError(f"Unsupported distance metric {distance!r}") + + return comparator.compare( + mapping, + aot_debug_handle_to_op_names, + runtime_debug_handle_to_op_names, + ) + + +def _lookup_handles_by_name( + reference_graph_module, + tap_specs: Sequence, +) -> List[Optional[int]]: + """ + For each TapSpec, return the post-roundtrip `debug_handle` of the FX node + whose name equals `spec.reducer_node_name`. Returns None for any spec + without a `reducer_node_name` set or whose name is not found. + + `reducer_node_name` is set by `strip_taps_(edge, tap_specs=specs)`. FX + node names survive ETRecord serialization roundtrip, so this lookup is + stable. + """ + name_to_handle: Dict[str, Optional[int]] = {} + for n in reference_graph_module.graph.nodes: + h = n.meta.get("debug_handle") + name_to_handle[n.name] = int(h) if isinstance(h, int) else None + + out: List[Optional[int]] = [] + for spec in tap_specs: + rn = getattr(spec, "reducer_node_name", None) + if rn is None: + out.append(None) + continue + out.append(name_to_handle.get(rn)) + return out diff --git a/devtools/intermediate_output_tap/TARGETS b/devtools/intermediate_output_tap/TARGETS new file mode 100644 index 00000000000..5ae2ddc8380 --- /dev/null +++ b/devtools/intermediate_output_tap/TARGETS @@ -0,0 +1,85 @@ +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +oncall("executorch") + +runtime.python_library( + name = "spec", + srcs = ["_spec.py"], +) + +runtime.python_library( + name = "custom_ops_lib", + srcs = ["custom_ops_lib.py"], + deps = [ + "//caffe2:torch", + ], +) + +runtime.python_library( + name = "selectors", + srcs = ["_selectors.py"], + deps = [ + "//caffe2:torch", + ], +) + +runtime.python_library( + name = "reducers", + srcs = ["_reducers.py"], + deps = [ + "//caffe2:torch", + "//executorch/exir/dialects:lib", + ], +) + +runtime.python_library( + name = "tap_pass", + srcs = ["_tap_pass.py"], + deps = [ + "//caffe2:torch", + ":custom_ops_lib", + ":reducers", + ":selectors", + ":spec", + ], +) + +runtime.python_library( + name = "strip_pass", + srcs = ["_strip_pass.py"], + deps = [ + "//caffe2:torch", + ":reducers", + ":tap_pass", + ], +) + +runtime.python_library( + name = "convenience", + srcs = ["_convenience.py"], + deps = [ + "fbsource//third-party/pypi/pandas:pandas", + "//caffe2:torch", + "//executorch/exir:lib", + "//executorch/runtime:runtime", + ":reducers", + ":selectors", + ":spec", + ":strip_pass", + ":tap_pass", + ], +) + +runtime.python_library( + name = "lib", + srcs = ["__init__.py"], + deps = [ + ":convenience", + ":custom_ops_lib", + ":reducers", + ":selectors", + ":spec", + ":strip_pass", + ":tap_pass", + ], +) diff --git a/devtools/intermediate_output_tap/__init__.py b/devtools/intermediate_output_tap/__init__.py new file mode 100644 index 00000000000..c1f85509409 --- /dev/null +++ b/devtools/intermediate_output_tap/__init__.py @@ -0,0 +1,99 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +""" +Public API for the ExecuTorch numerical debugger. + +Backend-agnostic intermediate-value tap that complements the existing +Inspector framework: + +- AOT side : `IntermediateOutputCapturer` (existing) +- Runtime side : ETDump intermediate output events (existing, opaque inside delegates) +- Runtime side : USER_OUTPUT taps (this module — works through delegates without + any backend-side changes) + +Typical usage: + + from executorch.devtools.intermediate_output_tap import ( + tap_intermediate_outputs, strip_taps_, DEFAULT_STATS, + ) + + ep = export(model, example_inputs) + ep_tapped, specs = tap_intermediate_outputs(ep, reducer=DEFAULT_STATS) + edge = to_edge_transform_and_lower(ep_tapped, partitioner=[XnnpackPartitioner()]) + strip_taps_(edge) + et_program = edge.to_executorch() + + flat_outputs = runtime.forward(*example_inputs) + df = inspector.calculate_numeric_gap_from_taps(flat_outputs, specs) +""" + +from executorch.devtools.intermediate_output_tap import ( + custom_ops_lib, # noqa: F401 ensures torch.ops.executorch_devtools.tap is registered +) +from executorch.devtools.intermediate_output_tap._convenience import ( + format_tap_dataframe, + specs_to_dataframe, + tap_all_and_run, +) +from executorch.devtools.intermediate_output_tap._reducers import ( + ABS_MAX_ONLY, + DEFAULT_STATS, + FULL_TENSOR, + get_reducer, + MIN_MAX_MEAN, + StatReducer, +) +from executorch.devtools.intermediate_output_tap._selectors import ( + NodeSelector, + select_all, + select_all_call_function, + select_any, + select_by_meta_tag, + select_by_module_path, + select_by_op_type, + select_not, +) +from executorch.devtools.intermediate_output_tap._spec import TapSpec +from executorch.devtools.intermediate_output_tap._strip_pass import strip_taps_ +from executorch.devtools.intermediate_output_tap._tap_pass import ( + find_tap_nodes, + is_tap_node, + tap_intermediate_outputs, +) + + +__all__ = [ + # Core API + "tap_intermediate_outputs", + "strip_taps_", + "TapSpec", + # Convenience + "tap_all_and_run", + "specs_to_dataframe", + "format_tap_dataframe", + # Reducers + "StatReducer", + "FULL_TENSOR", + "ABS_MAX_ONLY", + "MIN_MAX_MEAN", + "DEFAULT_STATS", + "get_reducer", + # Selectors + "NodeSelector", + "select_all_call_function", + "select_by_op_type", + "select_by_module_path", + "select_by_meta_tag", + "select_any", + "select_all", + "select_not", + # Helpers + "find_tap_nodes", + "is_tap_node", +] diff --git a/devtools/intermediate_output_tap/_convenience.py b/devtools/intermediate_output_tap/_convenience.py new file mode 100644 index 00000000000..2d4481b9bb2 --- /dev/null +++ b/devtools/intermediate_output_tap/_convenience.py @@ -0,0 +1,221 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +""" +One-line convenience wrapper for the most common smoke-test workflow: + + df = tap_all_and_run(model, example_inputs, partitioner=[XnnpackPartitioner()]) + +Exports `model`, taps every call_function, lowers with the user's partitioner, +runs through the ExecuTorch runtime, and returns a pandas DataFrame of one row +per tap (one column per stat field). No Inspector setup, no ETRecord. For +AOT-vs-runtime numerical comparison, use Inspector.calculate_numeric_gap_from_taps, +then `format_tap_dataframe(df, tap_specs)` to get a friendly view. +""" + +from __future__ import annotations + +import os +import tempfile +from collections.abc import Sequence +from typing import Any + +import pandas as pd +import torch +from executorch.devtools.intermediate_output_tap._reducers import StatReducer +from executorch.devtools.intermediate_output_tap._selectors import ( + NodeSelector, + select_all_call_function, +) +from executorch.devtools.intermediate_output_tap._spec import TapSpec +from executorch.devtools.intermediate_output_tap._strip_pass import strip_taps_ +from executorch.devtools.intermediate_output_tap._tap_pass import ( + tap_intermediate_outputs, +) + + +def tap_all_and_run( + model: torch.nn.Module, + example_inputs: tuple[Any, ...], + partitioner: list | None = None, + reducer: str | StatReducer = "DEFAULT_STATS", + selector: NodeSelector | None = None, + skip_if_no_debug_handle: bool = True, +) -> pd.DataFrame: + """ + Export -> tap -> lower -> strip -> to_executorch -> run -> DataFrame. + + Returns a DataFrame indexed by tap with columns: + node_name, op_target, debug_handle, output_index, reducer_name, plus + one column per reducer field (or `value` for FULL_TENSOR). + """ + from executorch.exir import to_edge_transform_and_lower + + selector = selector or select_all_call_function() + ep = torch.export.export(model, example_inputs, strict=True) + ep_tapped, specs = tap_intermediate_outputs( + ep, + selector=selector, + reducer=reducer, + skip_if_no_debug_handle=skip_if_no_debug_handle, + ) + edge = to_edge_transform_and_lower( + ep_tapped, partitioner=partitioner or [] + ) + strip_taps_(edge) + et_program = edge.to_executorch() + + flat_outputs = _run_pte(et_program, example_inputs) + return specs_to_dataframe(specs, flat_outputs) + + +def _run_pte(et_program, example_inputs: tuple[Any, ...]) -> Sequence[Any]: + from executorch.runtime import Runtime, Verification + + with tempfile.TemporaryDirectory() as temp_dir: + pte_path = os.path.join(temp_dir, "model.pte") + et_program.save(pte_path) + rt = Runtime.get() + program = rt.load_program(pte_path, verification=Verification.Minimal) + method = program.load_method("forward") + return method.execute(example_inputs) + + +def specs_to_dataframe( + specs: Sequence[TapSpec], + flat_outputs: Sequence[Any], +) -> pd.DataFrame: + """Build a per-tap DataFrame from the tap_specs + flat output tuple.""" + rows = [] + for spec in specs: + runtime_value = flat_outputs[spec.output_index] + row: dict[str, Any] = { + "node_name": spec.node_name, + "op_target": spec.op_target, + "debug_handle": spec.debug_handle, + "output_index": spec.output_index, + "reducer_name": spec.reducer_name, + } + if spec.fields: + tensor_vals = ( + runtime_value.detach().cpu().tolist() + if isinstance(runtime_value, torch.Tensor) + else list(runtime_value) + ) + for i, field in enumerate(spec.fields): + row[field] = tensor_vals[i] if i < len(tensor_vals) else None + else: + row["value"] = runtime_value + rows.append(row) + return pd.DataFrame(rows) + + +def format_tap_dataframe( + df: pd.DataFrame, + tap_specs: Sequence[TapSpec], +) -> pd.DataFrame: + """ + Reshape the raw DataFrame returned by + `Inspector.calculate_numeric_gap_from_taps` into a friendlier per-tap, + per-field view. + + The raw DataFrame uses the existing Inspector comparator format, which + packs the reducer's stat tensor into a list of 0-d tensors and labels + rows by the post-strip reducer node name (e.g. `aten_stack_default`). + This helper: + - matches each raw row to a TapSpec (by reducer_node_name) + - renames `aot_ops`/`runtime_ops` columns to a single `node_name` (the + original source node name, e.g. `linear`, `linear_1`) + - expands the reducer stat tensor into one column per field + (e.g. `aot_min`, `rt_min`, `aot_max`, `rt_max`, ...) + - flattens the gap to a single float + - drops the verbose `aot_intermediate_output` / `runtime_intermediate_output` + list columns + + Returns a DataFrame with columns: + node_name, op_target, reducer_name, + gap, + aot_, rt_, aot_, rt_, ... + """ + # Map reducer_node_name -> spec for quick lookup. + name_to_spec: dict[str, TapSpec] = { + s.reducer_node_name: s + for s in tap_specs + if s.reducer_node_name is not None + } + + rows = [] + for _, row in df.iterrows(): + aot_ops = row.get("aot_ops", []) + spec = None + for op in aot_ops or []: + if op in name_to_spec: + spec = name_to_spec[op] + break + if spec is None: + # Couldn't match — keep a thin row with whatever we have. + rows.append( + { + "node_name": ",".join(aot_ops or []), + "op_target": "?", + "reducer_name": "?", + "gap": _flatten_gap(row.get("gap")), + } + ) + continue + + new_row: dict[str, Any] = { + "node_name": spec.node_name, + "op_target": spec.op_target, + "reducer_name": spec.reducer_name, + "gap": _flatten_gap(row.get("gap")), + } + aot_vals = _to_float_list(row.get("aot_intermediate_output")) + rt_vals = _to_float_list(row.get("runtime_intermediate_output")) + for i, field in enumerate(spec.fields): + new_row[f"aot_{field}"] = aot_vals[i] if i < len(aot_vals) else None + new_row[f"rt_{field}"] = rt_vals[i] if i < len(rt_vals) else None + if not spec.fields: # FULL_TENSOR + new_row["aot_value"] = row.get("aot_intermediate_output") + new_row["rt_value"] = row.get("runtime_intermediate_output") + rows.append(new_row) + return pd.DataFrame(rows) + + +def _flatten_gap(g: Any) -> float | None: + if g is None: + return None + if isinstance(g, list): + if not g: + return None + g = g[0] + if isinstance(g, torch.Tensor): + return float(g) + try: + return float(g) + except (TypeError, ValueError): + return None + + +def _to_float_list(v: Any) -> list[float]: + if v is None: + return [] + if isinstance(v, torch.Tensor): + return v.detach().cpu().tolist() + if isinstance(v, list): + out: list[float] = [] + for x in v: + if isinstance(x, torch.Tensor): + out.append(float(x)) + else: + try: + out.append(float(x)) + except (TypeError, ValueError): + out.append(float("nan")) + return out + return [] diff --git a/devtools/intermediate_output_tap/_reducers.py b/devtools/intermediate_output_tap/_reducers.py new file mode 100644 index 00000000000..45a4f62578e --- /dev/null +++ b/devtools/intermediate_output_tap/_reducers.py @@ -0,0 +1,179 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +""" +Stat reducers used by `tap_intermediate_outputs`. + +A `StatReducer` is a small specification consumed by `strip_taps_` (after +`to_backend`) to materialise a portable reducer subgraph in place of the +`executorch_devtools::tap.Tensor` placeholder. + +`emit(graph, src_node) -> fx.Node` builds the reducer subgraph in `graph` +just before the placeholder, using the source tensor `src_node` as input, +and returns the final node whose output replaces the placeholder's output. + +The emit functions cast to fp32 first for cross-backend numerical stability +and use full-tensor reductions (no `dim=`) so the result is a stable shape +regardless of the source tensor's rank. + +For v1 we ship: FULL_TENSOR, ABS_MAX_ONLY, MIN_MAX_MEAN, DEFAULT_STATS. +HISTOGRAM_64 is deferred (`aten.histc` has restricted edge support). +""" + +from __future__ import annotations + +from collections.abc import Callable +from dataclasses import dataclass +from typing import TYPE_CHECKING + +import torch +from executorch.exir.dialects._ops import ops as exir_ops + + +if TYPE_CHECKING: + import torch.fx as fx + + +# --- Reducer dataclass --------------------------------------------------- + +EmitFn = Callable[["fx.Graph", "fx.Node"], "fx.Node"] + + +@dataclass(frozen=True) +class StatReducer: + """ + A reducer specification. `emit` is invoked by `strip_taps_` to materialise + the reducer subgraph in the post-lowering graph. + + `name` is what the user types and what's stored on each TapSpec. + `fields` enumerates the columns of the 1-D output tensor (empty for + FULL_TENSOR, which preserves the original tensor shape). + """ + + name: str + fields: tuple[str, ...] + emit: EmitFn + + +# --- Helpers ------------------------------------------------------------- + + +def _cast_fp32(graph: "fx.Graph", x: "fx.Node") -> "fx.Node": + """Insert a fp32 cast (no-op semantically if already fp32).""" + # exir_ops.edge.dim_order_ops._to_dim_order_copy.default exists for edge dialect, + # but the simpler aten._to_copy variant is broadly supported. + return graph.call_function( + exir_ops.edge.aten._to_copy.default, + args=(x,), + kwargs={"dtype": torch.float32}, + ) + + +def _scalar_node(graph: "fx.Graph", op, x: "fx.Node") -> "fx.Node": + """Call a full-reduction op (amin/amax/mean/sum) producing a 0-d tensor.""" + return graph.call_function(op, args=(x,)) + + +def _stack(graph: "fx.Graph", scalars: list["fx.Node"]) -> "fx.Node": + """Stack a list of 0-d tensors into a 1-D tensor.""" + return graph.call_function( + exir_ops.edge.aten.stack.default, + args=(scalars,), + kwargs={"dim": 0}, + ) + + +# --- Built-in reducers --------------------------------------------------- + + +def _emit_full_tensor(_graph: "fx.Graph", src: "fx.Node") -> "fx.Node": + """Identity — return the source node directly. strip_taps_ will splice.""" + return src + + +FULL_TENSOR: StatReducer = StatReducer( + name="FULL_TENSOR", + fields=(), + emit=_emit_full_tensor, +) + + +def _emit_abs_max(graph: "fx.Graph", src: "fx.Node") -> "fx.Node": + f = _cast_fp32(graph, src) + abs_x = graph.call_function(exir_ops.edge.aten.abs.default, args=(f,)) + return _scalar_node(graph, exir_ops.edge.aten.amax.default, abs_x) + + +ABS_MAX_ONLY: StatReducer = StatReducer( + name="ABS_MAX_ONLY", + fields=("abs_max",), + emit=_emit_abs_max, +) + + +def _emit_min_max_mean(graph: "fx.Graph", src: "fx.Node") -> "fx.Node": + f = _cast_fp32(graph, src) + mn = _scalar_node(graph, exir_ops.edge.aten.amin.default, f) + mx = _scalar_node(graph, exir_ops.edge.aten.amax.default, f) + me = _scalar_node(graph, exir_ops.edge.aten.mean.default, f) + return _stack(graph, [mn, mx, me]) + + +MIN_MAX_MEAN: StatReducer = StatReducer( + name="MIN_MAX_MEAN", + fields=("min", "max", "mean"), + emit=_emit_min_max_mean, +) + + +def _emit_default_stats(graph: "fx.Graph", src: "fx.Node") -> "fx.Node": + """ + Default stats: (min, max, mean, abs_max) — 4 floats. + + NOTE: nan_count/inf_count/std are intentionally excluded because the + underlying portable kernels (`isnan`, `isinf`, `sum.dtype`, `std.*`) + don't all have out variants registered in ExecuTorch's default runtime + op table, which fails memory planning or runtime method-load. If you + need them, supply a custom StatReducer. + """ + f = _cast_fp32(graph, src) + mn = _scalar_node(graph, exir_ops.edge.aten.amin.default, f) + mx = _scalar_node(graph, exir_ops.edge.aten.amax.default, f) + me = _scalar_node(graph, exir_ops.edge.aten.mean.default, f) + + abs_x = graph.call_function(exir_ops.edge.aten.abs.default, args=(f,)) + abs_max = _scalar_node(graph, exir_ops.edge.aten.amax.default, abs_x) + + return _stack(graph, [mn, mx, me, abs_max]) + + +DEFAULT_STATS: StatReducer = StatReducer( + name="DEFAULT_STATS", + fields=("min", "max", "mean", "abs_max"), + emit=_emit_default_stats, +) + + +# --- Registry ------------------------------------------------------------- + +_BUILTIN_REDUCERS: dict[str, StatReducer] = { + r.name: r + for r in (FULL_TENSOR, ABS_MAX_ONLY, MIN_MAX_MEAN, DEFAULT_STATS) +} + + +def get_reducer(name_or_reducer: str | StatReducer) -> StatReducer: + """Look up a built-in by name, or return a user-supplied StatReducer as-is.""" + if isinstance(name_or_reducer, StatReducer): + return name_or_reducer + if name_or_reducer not in _BUILTIN_REDUCERS: + raise ValueError( + f"Unknown reducer {name_or_reducer!r}; " + f"available: {sorted(_BUILTIN_REDUCERS)}" + ) + return _BUILTIN_REDUCERS[name_or_reducer] diff --git a/devtools/intermediate_output_tap/_selectors.py b/devtools/intermediate_output_tap/_selectors.py new file mode 100644 index 00000000000..004c057df13 --- /dev/null +++ b/devtools/intermediate_output_tap/_selectors.py @@ -0,0 +1,132 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +""" +Predicates for selecting which FX nodes to tap. + +A `NodeSelector` is just `Callable[[fx.Node], bool]`. The provided builders +let you compose them by op type, by `nn_module_stack` path, by arbitrary meta +tag, and via boolean combinators. + +Examples: + selector = select_any( + select_by_op_type("aten.linear.default", "aten.matmul.default"), + select_by_module_path("layers.*.attention"), + ) + selector = select_all(selector, select_not(select_by_op_type("aten.view.default"))) +""" + +from __future__ import annotations + +import fnmatch +from collections.abc import Callable +from typing import Any + +import torch.fx as fx + + +NodeSelector = Callable[[fx.Node], bool] + + +def select_all_call_function( + exclude: tuple[str, ...] = ("getitem",), +) -> NodeSelector: + """Match every `call_function` node whose target name is not in `exclude`.""" + excluded = set(exclude) + + def predicate(n: fx.Node) -> bool: + if n.op != "call_function": + return False + target_name = getattr(n.target, "__name__", str(n.target)) + # `getitem` shows up as the builtin name; also normalise common aten suffixes. + return target_name not in excluded and str(n.target) not in excluded + + return predicate + + +def select_by_op_type(*op_targets: str) -> NodeSelector: + """ + Match nodes whose `str(node.target)` ends with any of `op_targets`. + + The "ends with" rule lets the user write either the short name + ("aten.linear.default") or a fully-qualified name and have it match. + """ + if not op_targets: + raise ValueError("select_by_op_type requires at least one op target") + suffixes = tuple(op_targets) + + def predicate(n: fx.Node) -> bool: + if n.op != "call_function": + return False + target_str = str(n.target) + return any(target_str.endswith(s) or target_str == s for s in suffixes) + + return predicate + + +def select_by_module_path(pattern: str) -> NodeSelector: + """ + Match nodes whose `nn_module_stack` (the chain of nn.Module attribute names + walked to reach this op during tracing) contains a path matching `pattern`. + `pattern` is a shell-glob (fnmatch) — e.g. "layers.*", "layers.0.attention", + "*.attention.*". + """ + + def predicate(n: fx.Node) -> bool: + stack = n.meta.get("nn_module_stack") + if not stack: + return False + # nn_module_stack is an OrderedDict: id -> (qualified_path, module_type) + for entry in stack.values(): + path = entry[0] if isinstance(entry, tuple) else entry + if fnmatch.fnmatchcase(path, pattern): + return True + return False + + return predicate + + +# Sentinel: matches when the meta key exists at all, regardless of value. +_ANY_VALUE: object = object() + + +def select_by_meta_tag(key: str, value: Any = _ANY_VALUE) -> NodeSelector: + """ + Match nodes that carry `node.meta[key]`. If `value` is provided, also + requires `node.meta[key] == value`. + """ + + def predicate(n: fx.Node) -> bool: + if key not in n.meta: + return False + if value is _ANY_VALUE: + return True + return n.meta[key] == value + + return predicate + + +def select_any(*selectors: NodeSelector) -> NodeSelector: + """Match if ANY of `selectors` matches.""" + if not selectors: + return lambda _n: False + sels = tuple(selectors) + return lambda n: any(s(n) for s in sels) + + +def select_all(*selectors: NodeSelector) -> NodeSelector: + """Match if ALL of `selectors` match.""" + if not selectors: + return lambda _n: True + sels = tuple(selectors) + return lambda n: all(s(n) for s in sels) + + +def select_not(selector: NodeSelector) -> NodeSelector: + """Match if `selector` does NOT match.""" + return lambda n: not selector(n) diff --git a/devtools/intermediate_output_tap/_spec.py b/devtools/intermediate_output_tap/_spec.py new file mode 100644 index 00000000000..71b3c7f526e --- /dev/null +++ b/devtools/intermediate_output_tap/_spec.py @@ -0,0 +1,60 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +""" +TapSpec records one tap inserted by `tap_intermediate_outputs(...)`. + +A list of TapSpecs is returned to the user from the AOT pass; the user passes +that same list to `Inspector.calculate_numeric_gap_from_taps(...)` at runtime +to demux the flat output tuple back into per-op intermediate values. +""" + +from __future__ import annotations + +from dataclasses import dataclass + + +@dataclass(frozen=True) +class TapSpec: + """ + Metadata about a single tap. + + Attributes: + node_name: The FX node name of the *source* node (the tapped op) at the + time the AOT pass ran. Useful for debugging / pretty-printing. + op_target: `str(node.target)` of the source node, e.g. + "aten.linear.default". + debug_handle: `node.meta["debug_handle"]` of the source node, or None + if the source had no debug handle. Set at AOT-pass time. NOT used + by the Inspector integration directly — the serializer regenerates + debug_handles, so Inspector aligns by `reducer_node_name` instead. + output_index: 0-based index into the runtime program's flat output + tuple where this tap's value lands. Computed at AOT time and stable + through `to_edge` / `to_backend` / `to_executorch` because we only + ever *append* to the output node and `OutputSpec`. + reducer_name: Name of the StatReducer used (e.g. "DEFAULT_STATS"). + fields: Names of the per-element fields in the reducer's output tensor + (e.g. ("min", "max", "abs_max")). Empty tuple for FULL_TENSOR. + stack_trace: `node.meta["stack_trace"]` of the source node if present, + for human-readable error messages. + reducer_node_name: The FX node name of the post-strip reducer terminal + node — i.e., the node whose value is surfaced as the runtime tap + output. Populated by `strip_taps_` when `tap_specs` is passed. + FX node names survive ETRecord serialization roundtrip, so this + is the stable bridge `Inspector.calculate_numeric_gap_from_taps` + uses to find the post-roundtrip handle for alignment. + """ + + node_name: str + op_target: str + debug_handle: int | None + output_index: int + reducer_name: str + fields: tuple[str, ...] + stack_trace: str | None = None + reducer_node_name: str | None = None diff --git a/devtools/intermediate_output_tap/_strip_pass.py b/devtools/intermediate_output_tap/_strip_pass.py new file mode 100644 index 00000000000..5047dba0cff --- /dev/null +++ b/devtools/intermediate_output_tap/_strip_pass.py @@ -0,0 +1,164 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +""" +Post-`to_backend` pass: replace each `executorch_devtools::tap.Tensor` node +with either an identity edge (FULL_TENSOR) or a portable reducer subgraph +(DEFAULT_STATS, MIN_MAX_MEAN, ABS_MAX_ONLY). + +Pattern stolen from `remove_graph_break_` in +`executorch/examples/apple/coreml/llama/export_static_llm_coreml.py`. + +This pass MUST run *after* `to_edge_transform_and_lower(...)` and *before* +`to_executorch()`. Running it before partitioning would defeat the whole +mechanism (the reducer ops would be eligible for delegation). + +When called with the `tap_specs` from `tap_intermediate_outputs`, this pass +also populates `TapSpec.reducer_node_name` for each spec — the FX node name +of the post-strip reducer terminal. This is the bridge +`Inspector.calculate_numeric_gap_from_taps` uses to recover the +post-ETRecord-roundtrip `debug_handle` for alignment. +""" + +from __future__ import annotations + +from collections.abc import Sequence +from dataclasses import replace as _dataclass_replace + +import torch.fx as fx +from executorch.devtools.intermediate_output_tap._reducers import get_reducer +from executorch.devtools.intermediate_output_tap._spec import TapSpec +from executorch.devtools.intermediate_output_tap._tap_pass import find_tap_nodes + + +def strip_taps_( + edge_manager, + tap_specs: Sequence[TapSpec] | None = None, +) -> list[TapSpec] | None: + """ + Replace every `tap.Tensor(src, reducer_name, debug_handle)` node in every + method of `edge_manager` with the materialised reducer subgraph, in place. + + For FULL_TENSOR the placeholder is collapsed (the source node's value + flows directly to whatever consumed the placeholder). + + Args: + edge_manager: An EdgeProgramManager (post-`to_edge_transform_and_lower`). + tap_specs: Optional. If provided, the pass returns a NEW list of + TapSpecs with `reducer_node_name` populated for each spec — the + FX name of the post-strip reducer terminal node. This list must + be passed to `Inspector.calculate_numeric_gap_from_taps` for + alignment to work. + + Returns: + Updated tap_specs list if `tap_specs` was provided, else None. + """ + # Walk in graph order; tap nodes appear in the same order they were + # created by `tap_intermediate_outputs`, which is the same order as + # `tap_specs`. Track each tap's replacement node so we can update the + # corresponding spec. + replacement_names: list[str | None] = [] + for method_name in edge_manager.methods: + ep = edge_manager.exported_program(method_name) + gm = ep.graph_module + for replacement_node in _strip_taps_in_graph_module(gm): + replacement_names.append( + replacement_node.name if replacement_node is not None else None + ) + + if tap_specs is None: + return None + + if len(tap_specs) != len(replacement_names): + raise RuntimeError( + f"strip_taps_: tap_specs length ({len(tap_specs)}) does not match " + f"the number of tap nodes found in the edge_manager " + f"({len(replacement_names)}). The strip pass cannot align specs " + f"to reducer nodes. Did you call strip_taps_ on a different " + f"edge_manager than the one produced from the tapped EP?" + ) + + return [ + _dataclass_replace(spec, reducer_node_name=name) + for spec, name in zip(tap_specs, replacement_names) + ] + + +def _strip_taps_in_graph_module(gm: fx.GraphModule) -> list[fx.Node | None]: + """ + Strip taps in a single GraphModule. Returns the list of replacement nodes + in tap-creation order (same as graph order). For FULL_TENSOR taps the + "replacement" is the source node itself (since the tap collapses to + identity). + """ + graph = gm.graph + tap_nodes = find_tap_nodes(gm) + if not tap_nodes: + return [] + + output_node = graph.output_node() + replacements: list[fx.Node | None] = [] + + # Compute next available debug_handle so each reducer terminal gets a + # unique one (necessary so Inspector can look it up by node name and find + # a non-None handle in the post-roundtrip graph). + existing_handles = [ + n.meta.get("debug_handle") + for n in graph.nodes + if isinstance(n.meta.get("debug_handle"), int) + ] + next_handle = (max(existing_handles) + 1) if existing_handles else 1 + + for tap in tap_nodes: + # tap.args = (src_node, reducer_name, debug_handle) + src, reducer_name, dh = tap.args[0], tap.args[1], tap.args[2] + reducer = get_reducer(str(reducer_name)) + + if reducer.name == "FULL_TENSOR": + # Identity: re-route all consumers to the source. The "reducer + # terminal" is the source itself. + tap.replace_all_uses_with(src) + replacements.append(src if isinstance(src, fx.Node) else None) + continue + + # Build the reducer subgraph (reads from src). + with graph.inserting_before(tap): + replacement = reducer.emit(graph, src) + # Always assign a debug_handle to the reducer terminal so Inspector + # can find it post-roundtrip. Prefer the source's pre-tap handle if + # available (carries semantic meaning); otherwise use next_handle. + if dh: + replacement.meta["debug_handle"] = dh + else: + replacement.meta["debug_handle"] = next_handle + next_handle += 1 + replacement.meta["is_tap"] = True + replacement.meta["source_node"] = ( + src.name if isinstance(src, fx.Node) else None + ) + + # `tap` may have ended up in the data path during to_edge's re-trace + # (because CompositeExplicitAutograd preserves the op as an identity + # node, and re-traced consumers point at it instead of `src`). So: + # - the OUTPUT-node use becomes the reducer (the value we want + # surfaced as a tap). + # - every OTHER use is rewritten back to `src` (identity passthrough), + # restoring the original data path. + for use_node in list(tap.users.keys()): + if use_node is output_node: + new_outs = tuple( + replacement if a is tap else a for a in output_node.args[0] + ) + output_node.args = (new_outs,) + else: + use_node.replace_input_with(tap, src) + replacements.append(replacement) + + graph.eliminate_dead_code() + gm.recompile() + return replacements diff --git a/devtools/intermediate_output_tap/_tap_pass.py b/devtools/intermediate_output_tap/_tap_pass.py new file mode 100644 index 00000000000..bb7b96661d3 --- /dev/null +++ b/devtools/intermediate_output_tap/_tap_pass.py @@ -0,0 +1,249 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +""" +AOT pass: insert `tap.Tensor` placeholders after selected nodes and surface +them as additional USER_OUTPUTs of the ExportedProgram. + +Pattern stolen from `executorch/exir/passes/weights_to_outputs_pass.py`: +- find existing output node +- build new output args (existing + new tap nodes) +- create new output node, replace_all_uses_with, erase old +- append OutputSpec(USER_OUTPUT) entries to gs.output_specs +- eliminate_dead_code() + recompile() +""" + +from __future__ import annotations + +import copy +from collections.abc import Callable + +import torch +import torch.fx as fx +from executorch.devtools.intermediate_output_tap import custom_ops_lib # noqa: F401 registers tap.Tensor +from executorch.devtools.intermediate_output_tap._reducers import ( + DEFAULT_STATS, + get_reducer, + StatReducer, +) +from executorch.devtools.intermediate_output_tap._selectors import ( + NodeSelector, + select_all_call_function, +) +from executorch.devtools.intermediate_output_tap._spec import TapSpec +from torch.export import ExportedProgram +from torch.export.exported_program import OutputKind, OutputSpec, TensorArgument + + +# Don't ever tap our own tap nodes if a user runs the pass twice. +# `tap.Tensor` is already an OpOverload (not a packet) since "Tensor" is the +# overload name — same convention as torch.ops.executorch_utils.graph_break.Tensor. +_TAP_TARGET = torch.ops.executorch_devtools.tap.Tensor + + +def _is_tap_node(n: fx.Node) -> bool: + return n.op == "call_function" and n.target is _TAP_TARGET + + +def tap_intermediate_outputs( + ep: ExportedProgram, + selector: NodeSelector | None = None, + reducer: str | StatReducer = DEFAULT_STATS, + *, + tap_name_prefix: str = "tap_", + skip_if_no_debug_handle: bool = False, + max_taps: int | None = None, + inplace: bool = False, +) -> tuple[ExportedProgram, list[TapSpec]]: + """ + Rewrite `ep` so each node matching `selector` has its output appended to + the program outputs (wrapped in a `tap.Tensor` placeholder that survives + partitioning). Returns the new ExportedProgram and a list of TapSpecs. + + The returned EP is safe to feed to + `to_edge_transform_and_lower(...).to_executorch()` *after* calling + `strip_taps_(edge_manager)` to replace the placeholders with their + reducer subgraphs (or identities, for FULL_TENSOR). + + Args: + ep: The ExportedProgram to tap. + selector: A predicate over fx.Node. Defaults to + `select_all_call_function()`. Tap nodes themselves are always + excluded so re-running the pass is idempotent. + reducer: Either a built-in reducer name ("DEFAULT_STATS", + "MIN_MAX_MEAN", "ABS_MAX_ONLY", "FULL_TENSOR") or a custom + StatReducer instance. + tap_name_prefix: Prefix for the tap nodes' names. Helps when + grepping the dumped graph. + skip_if_no_debug_handle: If True, only tap nodes that already + carry `node.meta["debug_handle"]`. Recommended for Inspector + integration since handle-less taps cannot be aligned with + AOT outputs. + max_taps: Optional cap on number of taps. Helps avoid OOM for + very large models. + inplace: If False (default), deep-copy `ep` before mutating. + """ + if selector is None: + selector = select_all_call_function() + reducer_obj = get_reducer(reducer) + + if not inplace: + ep = copy.deepcopy(ep) + + gs = ep.graph_signature + gm = ep.graph_module + graph = gm.graph + output_node = graph.output_node() + existing_outputs = list(output_node.args[0]) + + # Snapshot before we start mutating the graph. + candidate_nodes = [n for n in graph.nodes if not _is_tap_node(n)] + + specs: list[TapSpec] = [] + new_tap_nodes: list[fx.Node] = [] + + for node in candidate_nodes: + if node.op != "call_function" or not selector(node): + continue + debug_handle = node.meta.get("debug_handle") + if skip_if_no_debug_handle and debug_handle is None: + continue + if max_taps is not None and len(specs) >= max_taps: + break + + # tap.Tensor's int arg cannot be None; sentinel 0 means "no handle". + dh_arg = int(debug_handle) if isinstance(debug_handle, int) else 0 + + with graph.inserting_after(node): + tap_node = graph.call_function( + _TAP_TARGET, + args=(node, reducer_obj.name, dh_arg), + ) + # Don't override the auto-assigned name — FX guarantees uniqueness. + # Stash the prefixed-source-name in meta for human-readable logs. + tap_node.meta["tap_label"] = f"{tap_name_prefix}{node.name}" + # Preserve provenance for Inspector's `propagate_back_debug_handle` + # and for users that pretty-print the graph. + if debug_handle is not None: + tap_node.meta["debug_handle"] = debug_handle + if "from_node" in node.meta: + tap_node.meta["from_node"] = node.meta["from_node"] + if "stack_trace" in node.meta: + tap_node.meta["stack_trace"] = node.meta["stack_trace"] + if "nn_module_stack" in node.meta: + tap_node.meta["nn_module_stack"] = node.meta["nn_module_stack"] + tap_node.meta["is_tap"] = True + tap_node.meta["source_node"] = node.name + + new_tap_nodes.append(tap_node) + specs.append( + TapSpec( + node_name=node.name, + op_target=str(node.target), + debug_handle=debug_handle if isinstance(debug_handle, int) else None, + output_index=len(existing_outputs) + len(specs), + reducer_name=reducer_obj.name, + fields=reducer_obj.fields, + stack_trace=node.meta.get("stack_trace"), + ) + ) + + if not new_tap_nodes: + return ep, [] + + # Splice new outputs into the graph (mirror weights_to_outputs_pass). + new_output_args = tuple(existing_outputs + new_tap_nodes) + with graph.inserting_before(output_node): + new_output = graph.output(new_output_args) + output_node.replace_all_uses_with(new_output) + graph.erase_node(output_node) + + # Append OutputSpec entries so the EP's signature matches the graph. + for tap_node in new_tap_nodes: + gs.output_specs.append( + OutputSpec( + kind=OutputKind.USER_OUTPUT, + arg=TensorArgument(name=tap_node.name), + target=None, + ) + ) + + # Update each ModuleCallSignature's out_spec so `to_edge`'s re-trace can + # unflatten the new flat output list. The "" (root) entry holds the + # user-facing forward output structure; we wrap it in a tuple alongside + # the new tap leaves and re-derive the spec. + _extend_module_call_graph_outputs(ep, new_tap_nodes) + + graph.eliminate_dead_code() + gm.recompile() + return ep, specs + + +def _extend_module_call_graph_outputs( + ep: ExportedProgram, + new_tap_nodes: list[fx.Node], +) -> None: + """ + Append `len(new_tap_nodes)` extra leaves to the root module-call entry's + `out_spec` so the pytree unflatten step in `run_decompositions` works. + Also extends the entry's `outputs: list[ArgumentSpec]`. + + NOTE: We append TensorArgument(name="") for each new tap output. Empty + names are *skipped* by `_verify_exported_program_module_call_graph` (its + check is `if arg.name and arg.name not in nodes`). We can't use the + pre-trace tap node names because `to_edge`'s re-trace renames nodes via + `from_node` chains, and our tap nodes' provenance wouldn't update them + correctly — leading to "Output X does not exist in the graph" errors. + The verifier's name check is metadata-only; the actual pytree unflatten + only needs `out_spec` to have the correct number of leaves. + """ + import torch.utils._pytree as pytree + from torch.export.exported_program import TensorArgument as _TensorArgument + + n_new = len(new_tap_nodes) + if n_new == 0: + return + + for entry in ep._module_call_graph: + if entry.fqn != "": + continue + sig = entry.signature + if sig is None: + continue + old_spec = sig.out_spec + # Build a dummy structure matching the old spec, then wrap with N new + # leaves and re-derive the spec. This handles arbitrary pytree shapes. + old_dummy = pytree.tree_unflatten([0] * old_spec.num_leaves, old_spec) + if isinstance(old_dummy, tuple): + new_dummy = (*old_dummy, *([0] * n_new)) + else: + new_dummy = (old_dummy, *([0] * n_new)) + sig.out_spec = pytree.tree_structure(new_dummy) + for _ in range(n_new): + sig.outputs.append(_TensorArgument(name="")) + break + + +def find_tap_nodes(gm: fx.GraphModule) -> list[fx.Node]: + """Helper: enumerate tap.Tensor nodes in a GraphModule (any dialect).""" + out: list[fx.Node] = [] + for n in gm.graph.nodes: + if n.op != "call_function": + continue + # Match across dialects: + # pre-edge: torch.ops.executorch_devtools.tap.Tensor — str ends with name + # post-edge: : schema = ... + # so substring-match the qualified name. + if "executorch_devtools.tap.Tensor" in str(n.target) or n.target is _TAP_TARGET: + out.append(n) + return out + + +# Re-export the predicate so callers can identify tap nodes without importing +# torch.ops directly. +is_tap_node: Callable[[fx.Node], bool] = _is_tap_node diff --git a/devtools/intermediate_output_tap/custom_ops_lib.py b/devtools/intermediate_output_tap/custom_ops_lib.py new file mode 100644 index 00000000000..93c4636ad2c --- /dev/null +++ b/devtools/intermediate_output_tap/custom_ops_lib.py @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +""" +Custom op registration for the intermediate-output tap mechanism. + +The op `executorch_devtools::tap.Tensor(Tensor x, str reducer_name, int debug_handle) -> Tensor` +is an identity op whose sole job is to be an unknown-to-every-partitioner FX node +that "uses" a tapped tensor `x`. Because `x` now has a user outside any partition, +every ExecuTorch partitioner must surface `x` as a partition output (this is the +canonical contract enforced in `executorch/exir/lowered_backend_module.py`). + +After `to_edge_transform_and_lower(...)` the tap.Tensor node still exists in the +parent graph; `strip_taps_` (see `_strip_pass.py`) replaces it with either an +identity edge (FULL_TENSOR) or a small reducer subgraph of portable aten ops. + +The dispatch key MUST be `CompositeExplicitAutograd` (not `CompositeImplicitAutograd`) +so the op survives tracing/decomposition; otherwise it would inline at export time +and disappear before partitioning. This mirrors the pattern in +`executorch/examples/apple/coreml/llama/export_static_llm_coreml.py`. + +`reducer_name` and `debug_handle` are stored as op arguments (not just node.meta) +so they survive any meta-stripping pass between `to_edge` and `strip_taps_`. +""" + +from __future__ import annotations + +from torch.library import impl, Library + +# Library namespace verified collision-free across fbsource as of Nov 2025. +lib: Library = Library("executorch_devtools", "DEF") + +lib.define("tap.Tensor(Tensor x, str reducer_name, int debug_handle) -> Tensor") + + +@impl(lib, "tap.Tensor", "CompositeExplicitAutograd") +def tap_tensor_impl(x, reducer_name, debug_handle): # noqa: ARG001 + return x diff --git a/devtools/intermediate_output_tap/tests/TARGETS b/devtools/intermediate_output_tap/tests/TARGETS new file mode 100644 index 00000000000..79344627b17 --- /dev/null +++ b/devtools/intermediate_output_tap/tests/TARGETS @@ -0,0 +1,74 @@ +load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest") +load("@fbsource//tools/target_determinator/macros:ci.bzl", "ci") + +oncall("executorch") + +python_unittest( + name = "test_selectors", + srcs = ["test_selectors.py"], + deps = [ + "//caffe2:torch", + "//executorch/devtools/intermediate_output_tap:selectors", + ], +) + +python_unittest( + name = "test_reducers", + srcs = ["test_reducers.py"], + deps = [ + "//caffe2:torch", + "//executorch/devtools/intermediate_output_tap:reducers", + ], +) + +python_unittest( + name = "test_tap_pass", + srcs = ["test_tap_pass.py"], + deps = [ + "//caffe2:torch", + "//executorch/devtools/intermediate_output_tap:reducers", + "//executorch/devtools/intermediate_output_tap:selectors", + "//executorch/devtools/intermediate_output_tap:tap_pass", + ], +) + +python_unittest( + name = "test_strip_pass", + srcs = ["test_strip_pass.py"], + deps = [ + "//caffe2:torch", + "//executorch/devtools/intermediate_output_tap:reducers", + "//executorch/devtools/intermediate_output_tap:selectors", + "//executorch/devtools/intermediate_output_tap:strip_pass", + "//executorch/devtools/intermediate_output_tap:tap_pass", + "//executorch/exir:lib", + ], +) + +python_unittest( + name = "test_xnnpack_e2e", + srcs = ["test_xnnpack_e2e.py"], + deps = [ + "//caffe2:torch", + "//executorch/backends/xnnpack/partition:xnnpack_partitioner", + "//executorch/devtools/intermediate_output_tap:lib", + "//executorch/exir:lib", + "//executorch/runtime:runtime", + ], +) + +python_unittest( + name = "test_inspector_integration", + srcs = ["test_inspector_integration.py"], + labels = ci.labels( + ci.buckconfig("executorch.event_tracer_enabled", "true"), + ), + deps = [ + "//caffe2:torch", + "//executorch/backends/xnnpack/partition:xnnpack_partitioner", + "//executorch/devtools:lib", + "//executorch/devtools/intermediate_output_tap:lib", + "//executorch/exir:lib", + "//executorch/runtime:runtime", + ], +) diff --git a/devtools/intermediate_output_tap/tests/test_inspector_integration.py b/devtools/intermediate_output_tap/tests/test_inspector_integration.py new file mode 100644 index 00000000000..b48dcfe4e0a --- /dev/null +++ b/devtools/intermediate_output_tap/tests/test_inspector_integration.py @@ -0,0 +1,146 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +""" +Integration test: run the full pipeline (export -> tap -> lower with XNNPACK +-> strip -> generate_etrecord -> to_executorch -> runtime) and feed the flat +runtime outputs + (post-strip) TapSpecs to +Inspector.calculate_numeric_gap_from_taps. Verify the returned DataFrame has +rows aligned by debug_handle. + +KEY DESIGN POINTS: +1. ETRecord generation MUST happen AFTER `strip_taps_` so the snapshot of the + edge program contains no `tap.Tensor` nodes (which the EXIR serializer + can't handle). +2. `strip_taps_(edge, tap_specs=specs)` returns updated specs whose + `reducer_node_name` is set to the post-strip reducer terminal node name. + Inspector uses that name to look up the post-roundtrip `debug_handle` — + FX node names survive ETRecord serialization, debug_handle values do not. +""" + +import os +import sys +import tempfile +import unittest + +import torch +from executorch.backends.xnnpack.partition.xnnpack_partitioner import ( + XnnpackPartitioner, +) +from executorch.devtools import generate_etrecord, Inspector +from executorch.devtools.intermediate_output_tap import ( + DEFAULT_STATS, + format_tap_dataframe, + select_by_op_type, + strip_taps_, + tap_intermediate_outputs, +) +from executorch.exir import to_edge_transform_and_lower +from executorch.runtime import Runtime, Verification +from torch.export import export + + +class _MLP(torch.nn.Module): + def __init__(self): + super().__init__() + self.l1 = torch.nn.Linear(8, 16) + self.l2 = torch.nn.Linear(16, 4) + + def forward(self, x): + return self.l2(self.l1(x).relu()) + + +@unittest.skipIf(sys.platform.startswith("win"), "ExecuTorch runtime not available on Windows") +class InspectorIntegrationTest(unittest.TestCase): + def test_calculate_numeric_gap_from_taps(self): + model = _MLP() + example_inputs = (torch.randn(2, 8),) + + ep = export(model, example_inputs, strict=True) + ep_t, specs = tap_intermediate_outputs( + ep, + selector=select_by_op_type("aten.linear.default"), + reducer=DEFAULT_STATS, + ) + # Do NOT pass generate_etrecord=True — we'd snapshot the EP while it + # still has tap.Tensor nodes (unserializable). + edge = to_edge_transform_and_lower( + ep_t, + partitioner=[XnnpackPartitioner()], + ) + # strip_taps_ with tap_specs returns updated specs whose + # reducer_node_name points at the post-strip reducer terminal node. + specs = strip_taps_(edge, tap_specs=specs) + et_program = edge.to_executorch() + + with tempfile.TemporaryDirectory() as temp_dir: + pte_path = os.path.join(temp_dir, "model.pte") + et_program.save(pte_path) + + # ETRecord generated AFTER strip — the edge program is now + # serializable. Don't pass exported_program: Inspector falls back + # to the edge dialect program for AOT capture. + etrecord_path = os.path.join(temp_dir, "etrecord.bin") + generate_etrecord( + etrecord_path, + edge_dialect_program=edge, + executorch_program=et_program, + ) + + rt = Runtime.get() + program = rt.load_program( + pte_path, + verification=Verification.Minimal, + enable_etdump=True, + debug_buffer_size=1024 * 1024, + ) + method = program.load_method("forward") + flat_outputs = method.execute(list(example_inputs)) + + etdump_path = os.path.join(temp_dir, "etdump.etdp") + debug_buffer_path = os.path.join(temp_dir, "debug_buffer.bin") + program.write_etdump_result_to_file(etdump_path, debug_buffer_path) + if not os.path.exists(etdump_path): + self.skipTest( + "Event tracer not enabled. Run with " + "--config executorch.event_tracer_enabled=true" + ) + + inspector = Inspector( + etdump_path=etdump_path, + etrecord=etrecord_path, + debug_buffer_path=debug_buffer_path, + ) + inspector._etrecord._representative_inputs = list(example_inputs) + df = inspector.calculate_numeric_gap_from_taps( + flat_runtime_outputs=flat_outputs, + tap_specs=specs, + distance="MSE", + ) + # Print friendly per-tap view to stdout (visible via --print-passing-details). + friendly = format_tap_dataframe(df, specs) + import pandas as _pd + with _pd.option_context( + "display.max_columns", None, + "display.width", 240, + "display.max_colwidth", 30, + "display.float_format", "{:.4g}".format, + ): + print("\n=== Inspector.calculate_numeric_gap_from_taps (friendly) ===") + print(friendly.to_string()) + + self.assertGreater(len(df), 0, "expected at least one tap row in DataFrame") + for col in ("aot_ops", "runtime_ops", "gap"): + self.assertIn(col, df.columns) + for _, row in df.iterrows(): + self.assertIsNotNone(row["aot_ops"]) + self.assertIsNotNone(row["runtime_ops"]) + gap = row["gap"] + if isinstance(gap, list): + gap = gap[0] if gap else 0.0 + self.assertGreaterEqual(float(gap), 0.0) diff --git a/devtools/intermediate_output_tap/tests/test_reducers.py b/devtools/intermediate_output_tap/tests/test_reducers.py new file mode 100644 index 00000000000..e36762be5ac --- /dev/null +++ b/devtools/intermediate_output_tap/tests/test_reducers.py @@ -0,0 +1,51 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import unittest + +from executorch.devtools.intermediate_output_tap._reducers import ( + ABS_MAX_ONLY, + DEFAULT_STATS, + FULL_TENSOR, + get_reducer, + MIN_MAX_MEAN, + StatReducer, +) + + +class ReducersTest(unittest.TestCase): + def test_get_reducer_by_name(self): + self.assertIs(get_reducer("DEFAULT_STATS"), DEFAULT_STATS) + self.assertIs(get_reducer("FULL_TENSOR"), FULL_TENSOR) + self.assertIs(get_reducer("MIN_MAX_MEAN"), MIN_MAX_MEAN) + self.assertIs(get_reducer("ABS_MAX_ONLY"), ABS_MAX_ONLY) + + def test_get_reducer_passthrough(self): + custom = StatReducer(name="X", fields=("a",), emit=lambda g, n: n) + self.assertIs(get_reducer(custom), custom) + + def test_get_reducer_unknown_raises(self): + with self.assertRaises(ValueError): + get_reducer("DOES_NOT_EXIST") + + def test_reducer_field_counts(self): + self.assertEqual(FULL_TENSOR.fields, ()) + self.assertEqual(ABS_MAX_ONLY.fields, ("abs_max",)) + self.assertEqual(MIN_MAX_MEAN.fields, ("min", "max", "mean")) + self.assertEqual( + DEFAULT_STATS.fields, + ("min", "max", "mean", "abs_max"), + ) + + def test_reducer_names_unique(self): + names = {r.name for r in (FULL_TENSOR, ABS_MAX_ONLY, MIN_MAX_MEAN, DEFAULT_STATS)} + self.assertEqual(len(names), 4) + + def test_default_stats_eager_correctness(self): + """Confirm DEFAULT_STATS spec has 4 fields (std/nan_count/inf_count excluded).""" + self.assertEqual(len(DEFAULT_STATS.fields), 4) diff --git a/devtools/intermediate_output_tap/tests/test_selectors.py b/devtools/intermediate_output_tap/tests/test_selectors.py new file mode 100644 index 00000000000..495e184de52 --- /dev/null +++ b/devtools/intermediate_output_tap/tests/test_selectors.py @@ -0,0 +1,121 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import unittest + +import torch +from executorch.devtools.intermediate_output_tap._selectors import ( + select_all, + select_all_call_function, + select_any, + select_by_meta_tag, + select_by_module_path, + select_by_op_type, + select_not, +) +from torch.export import export + + +class _Inner(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(4, 4) + + def forward(self, x): + return self.linear(x).relu() + + +class _Outer(torch.nn.Module): + def __init__(self): + super().__init__() + self.inner = _Inner() + self.head = torch.nn.Linear(4, 2) + + def forward(self, x): + return self.head(self.inner(x)) + + +def _exported_graph(): + ep = export(_Outer(), (torch.randn(2, 4),), strict=True) + return ep.graph_module.graph + + +class SelectorsTest(unittest.TestCase): + def setUp(self): + self.graph = _exported_graph() + self.call_nodes = [n for n in self.graph.nodes if n.op == "call_function"] + + def test_select_all_call_function_excludes_getitem(self): + sel = select_all_call_function() + for n in self.call_nodes: + if "getitem" in str(n.target): + self.assertFalse(sel(n)) + else: + self.assertTrue(sel(n)) + + def test_select_by_op_type_matches_suffix(self): + sel = select_by_op_type("aten.linear.default", "aten.relu.default") + matched = [n for n in self.call_nodes if sel(n)] + # Two linears + one relu in the model. + self.assertGreaterEqual(len(matched), 2) + for n in matched: + self.assertTrue( + str(n.target).endswith("aten.linear.default") + or str(n.target).endswith("aten.relu.default") + ) + + def test_select_by_op_type_requires_target(self): + with self.assertRaises(ValueError): + select_by_op_type() + + def test_select_by_module_path(self): + sel = select_by_module_path("inner.*") + matched = [n for n in self.call_nodes if sel(n)] + # inner contains a linear and a relu. + self.assertGreater(len(matched), 0) + for n in matched: + stack = n.meta.get("nn_module_stack") or {} + paths = [ + v[0] if isinstance(v, tuple) else v for v in stack.values() + ] + self.assertTrue(any(p.startswith("inner") for p in paths)) + + def test_select_by_meta_tag_presence(self): + for n in self.call_nodes[:1]: + n.meta["debug_me"] = "yes" + sel = select_by_meta_tag("debug_me") + self.assertTrue(sel(self.call_nodes[0])) + self.assertFalse(sel(self.call_nodes[1])) + + def test_select_by_meta_tag_value(self): + self.call_nodes[0].meta["color"] = "blue" + self.call_nodes[1].meta["color"] = "red" + sel = select_by_meta_tag("color", "blue") + self.assertTrue(sel(self.call_nodes[0])) + self.assertFalse(sel(self.call_nodes[1])) + + def test_select_combinators(self): + a = select_by_op_type("aten.linear.default") + b = select_by_op_type("aten.relu.default") + any_sel = select_any(a, b) + all_sel = select_all(a, b) + not_sel = select_not(a) + + for n in self.call_nodes: + if a(n) or b(n): + self.assertTrue(any_sel(n)) + self.assertEqual(all_sel(n), a(n) and b(n)) + self.assertEqual(not_sel(n), not a(n)) + + def test_select_any_empty(self): + for n in self.call_nodes: + self.assertFalse(select_any()(n)) + + def test_select_all_empty(self): + for n in self.call_nodes: + self.assertTrue(select_all()(n)) diff --git a/devtools/intermediate_output_tap/tests/test_strip_pass.py b/devtools/intermediate_output_tap/tests/test_strip_pass.py new file mode 100644 index 00000000000..2ff0d9154b1 --- /dev/null +++ b/devtools/intermediate_output_tap/tests/test_strip_pass.py @@ -0,0 +1,118 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import unittest + +import torch +from executorch.devtools.intermediate_output_tap._reducers import ( + DEFAULT_STATS, + FULL_TENSOR, + MIN_MAX_MEAN, +) +from executorch.devtools.intermediate_output_tap._selectors import ( + select_by_op_type, +) +from executorch.devtools.intermediate_output_tap._strip_pass import strip_taps_ +from executorch.devtools.intermediate_output_tap._tap_pass import ( + find_tap_nodes, + tap_intermediate_outputs, +) +from executorch.exir import to_edge +from torch.export import export + + +class _MLP(torch.nn.Module): + def __init__(self): + super().__init__() + self.l1 = torch.nn.Linear(8, 8) + self.l2 = torch.nn.Linear(8, 4) + + def forward(self, x): + return self.l2(self.l1(x).relu()) + + +def _tapped_edge(reducer): + ep = export(_MLP(), (torch.randn(2, 8),), strict=True) + ep_t, specs = tap_intermediate_outputs( + ep, + selector=select_by_op_type("aten.linear.default"), + reducer=reducer, + ) + return to_edge(ep_t), specs + + +class StripPassTest(unittest.TestCase): + def test_strip_removes_all_tap_nodes_full_tensor(self): + edge, _ = _tapped_edge(FULL_TENSOR) + # Pre-strip: tap nodes present. + for method_name in edge.methods: + ep = edge.exported_program(method_name) + self.assertGreater(len(find_tap_nodes(ep.graph_module)), 0) + + strip_taps_(edge) + + # Post-strip: no tap nodes. + for method_name in edge.methods: + ep = edge.exported_program(method_name) + self.assertEqual(len(find_tap_nodes(ep.graph_module)), 0) + + def test_strip_full_tensor_routes_source_to_output(self): + edge, specs = _tapped_edge(FULL_TENSOR) + strip_taps_(edge) + # Output node should still have all the user outputs + tap outputs. + for method_name in edge.methods: + ep = edge.exported_program(method_name) + outs = list(ep.graph_module.graph.output_node().args[0]) + # Original outputs + 2 linears tapped. + self.assertGreaterEqual(len(outs), len(specs)) + + def test_strip_min_max_mean_emits_subgraph(self): + edge, specs = _tapped_edge(MIN_MAX_MEAN) + strip_taps_(edge) + for method_name in edge.methods: + ep = edge.exported_program(method_name) + self.assertEqual(len(find_tap_nodes(ep.graph_module)), 0) + # Some reduction op (amin/amax/mean) should now be in the graph. + # Substring match because EdgeOpOverload's str() looks like + # ": schema = ..." (no clean + # endswith). + targets = {str(n.target) for n in ep.graph_module.graph.nodes} + self.assertTrue( + any( + "aten.amin" in t or "aten.amax" in t or "aten.mean" in t + for t in targets + ), + f"expected reducer ops in graph, got {targets}", + ) + + def test_strip_default_stats_preserves_debug_handle(self): + edge, specs = _tapped_edge(DEFAULT_STATS) + # Take a known debug_handle from one of the tap specs. + known_handles = {s.debug_handle for s in specs if s.debug_handle is not None} + if not known_handles: + self.skipTest("Test model produced no debug_handle on tap sources") + + strip_taps_(edge) + + post_handles: set = set() + for method_name in edge.methods: + ep = edge.exported_program(method_name) + for n in ep.graph_module.graph.nodes: + if n.meta.get("is_tap"): + post_handles.add(n.meta.get("debug_handle")) + # At least one tapped debug handle should still be present. + self.assertTrue(known_handles & post_handles) + + def test_strip_idempotent(self): + edge, _ = _tapped_edge(FULL_TENSOR) + strip_taps_(edge) + # Second call should be a no-op. + strip_taps_(edge) + for method_name in edge.methods: + ep = edge.exported_program(method_name) + self.assertEqual(len(find_tap_nodes(ep.graph_module)), 0) diff --git a/devtools/intermediate_output_tap/tests/test_tap_pass.py b/devtools/intermediate_output_tap/tests/test_tap_pass.py new file mode 100644 index 00000000000..9e01ffe61d7 --- /dev/null +++ b/devtools/intermediate_output_tap/tests/test_tap_pass.py @@ -0,0 +1,159 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import copy +import unittest + +import torch +from executorch.devtools.intermediate_output_tap._reducers import ( + DEFAULT_STATS, + FULL_TENSOR, +) +from executorch.devtools.intermediate_output_tap._selectors import ( + select_by_op_type, +) +from executorch.devtools.intermediate_output_tap._tap_pass import ( + is_tap_node, + tap_intermediate_outputs, +) +from torch.export import export +from torch.export.exported_program import OutputKind + + +class _MLP(torch.nn.Module): + def __init__(self): + super().__init__() + self.l1 = torch.nn.Linear(8, 16) + self.l2 = torch.nn.Linear(16, 8) + self.l3 = torch.nn.Linear(8, 4) + + def forward(self, x): + return self.l3(self.l2(self.l1(x).relu()).relu()) + + +def _export(): + return export(_MLP(), (torch.randn(2, 8),), strict=True) + + +class TapPassTest(unittest.TestCase): + def test_inserts_tap_per_selected_node(self): + ep = _export() + ep_t, specs = tap_intermediate_outputs( + ep, + selector=select_by_op_type("aten.linear.default"), + reducer=FULL_TENSOR, + ) + # MLP has 3 linears. + self.assertEqual(len(specs), 3) + tap_nodes = [n for n in ep_t.graph_module.graph.nodes if is_tap_node(n)] + self.assertEqual(len(tap_nodes), 3) + + def test_appends_user_outputs(self): + ep = _export() + original_user_outs = sum( + 1 for s in ep.graph_signature.output_specs if s.kind == OutputKind.USER_OUTPUT + ) + ep_t, specs = tap_intermediate_outputs( + ep, + selector=select_by_op_type("aten.linear.default"), + reducer=FULL_TENSOR, + ) + new_user_outs = sum( + 1 + for s in ep_t.graph_signature.output_specs + if s.kind == OutputKind.USER_OUTPUT + ) + self.assertEqual(new_user_outs, original_user_outs + len(specs)) + + def test_output_indices_contiguous_after_user_outputs(self): + ep = _export() + original_user_outs = sum( + 1 for s in ep.graph_signature.output_specs if s.kind == OutputKind.USER_OUTPUT + ) + _, specs = tap_intermediate_outputs( + ep, + selector=select_by_op_type("aten.linear.default"), + reducer=FULL_TENSOR, + ) + for i, spec in enumerate(specs): + self.assertEqual(spec.output_index, original_user_outs + i) + + def test_default_reducer_is_default_stats(self): + ep = _export() + _, specs = tap_intermediate_outputs( + ep, selector=select_by_op_type("aten.linear.default") + ) + for s in specs: + self.assertEqual(s.reducer_name, DEFAULT_STATS.name) + self.assertEqual(s.fields, DEFAULT_STATS.fields) + + def test_inplace_false_does_not_mutate_original(self): + ep = _export() + before_outs = len(list(ep.graph_module.graph.output_node().args[0])) + before_specs = len(ep.graph_signature.output_specs) + _ = tap_intermediate_outputs( + ep, selector=select_by_op_type("aten.linear.default"), reducer=FULL_TENSOR + ) + after_outs = len(list(ep.graph_module.graph.output_node().args[0])) + after_specs = len(ep.graph_signature.output_specs) + self.assertEqual(before_outs, after_outs) + self.assertEqual(before_specs, after_specs) + + def test_max_taps(self): + ep = _export() + _, specs = tap_intermediate_outputs( + ep, + selector=select_by_op_type("aten.linear.default"), + reducer=FULL_TENSOR, + max_taps=2, + ) + self.assertEqual(len(specs), 2) + + def test_idempotent_does_not_tap_taps(self): + ep = _export() + ep_once, specs1 = tap_intermediate_outputs( + ep, + selector=select_by_op_type("aten.linear.default"), + reducer=FULL_TENSOR, + ) + # Running again should not add NEW taps for our existing tap nodes. + ep_twice, specs2 = tap_intermediate_outputs( + ep_once, + selector=select_by_op_type("aten.linear.default"), + reducer=FULL_TENSOR, + ) + # Same number of linears matched; tap.Tensor itself is excluded. + self.assertEqual(len(specs2), len(specs1)) + + def test_no_match_returns_empty_specs(self): + ep = _export() + ep_t, specs = tap_intermediate_outputs( + ep, + selector=select_by_op_type("aten.does.not.exist"), + reducer=FULL_TENSOR, + ) + self.assertEqual(specs, []) + # Original graph signature is unchanged. + self.assertEqual( + len(ep_t.graph_signature.output_specs), + len(ep.graph_signature.output_specs), + ) + + def test_skip_if_no_debug_handle(self): + ep = _export() + # Strip all debug handles to simulate a graph without them. + ep_clean = copy.deepcopy(ep) + for n in ep_clean.graph_module.graph.nodes: + n.meta.pop("debug_handle", None) + _, specs = tap_intermediate_outputs( + ep_clean, + selector=select_by_op_type("aten.linear.default"), + reducer=FULL_TENSOR, + skip_if_no_debug_handle=True, + ) + self.assertEqual(specs, []) diff --git a/devtools/intermediate_output_tap/tests/test_xnnpack_e2e.py b/devtools/intermediate_output_tap/tests/test_xnnpack_e2e.py new file mode 100644 index 00000000000..4967427a60a --- /dev/null +++ b/devtools/intermediate_output_tap/tests/test_xnnpack_e2e.py @@ -0,0 +1,127 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +""" +End-to-end test: prove that intermediate values surfaced as USER_OUTPUT taps +flow through XNNPACK delegation and out the runtime *with no XNNPACK-side +support*. This is the central correctness claim of the design. +""" + +import os +import sys +import tempfile +import unittest + +import torch +from executorch.backends.xnnpack.partition.xnnpack_partitioner import ( + XnnpackPartitioner, +) +from executorch.devtools.intermediate_output_tap import ( + ABS_MAX_ONLY, + DEFAULT_STATS, + FULL_TENSOR, + MIN_MAX_MEAN, + select_by_op_type, + strip_taps_, + tap_intermediate_outputs, +) +from executorch.exir import to_edge_transform_and_lower +from executorch.runtime import Runtime, Verification +from torch.export import export + + +class _MLP(torch.nn.Module): + def __init__(self): + super().__init__() + self.l1 = torch.nn.Linear(8, 16) + self.l2 = torch.nn.Linear(16, 4) + + def forward(self, x): + return self.l2(self.l1(x).relu()) + + +@unittest.skipIf(sys.platform.startswith("win"), "ExecuTorch runtime not available on Windows") +class XnnpackEndToEndTest(unittest.TestCase): + def _run_pipeline(self, reducer): + model = _MLP() + example_inputs = (torch.randn(2, 8),) + + ep = export(model, example_inputs, strict=True) + ep_t, specs = tap_intermediate_outputs( + ep, + selector=select_by_op_type("aten.linear.default"), + reducer=reducer, + ) + edge = to_edge_transform_and_lower( + ep_t, partitioner=[XnnpackPartitioner()] + ) + strip_taps_(edge) + et_program = edge.to_executorch() + + with tempfile.TemporaryDirectory() as temp_dir: + pte_path = os.path.join(temp_dir, "model.pte") + et_program.save(pte_path) + + rt = Runtime.get() + program = rt.load_program(pte_path, verification=Verification.Minimal) + method = program.load_method("forward") + flat_outputs = method.execute(list(example_inputs)) + + return specs, flat_outputs, model, example_inputs + + def test_full_tensor_taps_match_eager(self): + specs, flat, model, example_inputs = self._run_pipeline(FULL_TENSOR) + self.assertEqual(len(specs), 2) # two linears + + # The user output is at index 0; tap outputs follow. + for spec in specs: + tap_value = flat[spec.output_index] + self.assertIsInstance(tap_value, torch.Tensor) + # FULL_TENSOR preserves the source tensor's shape — so e.g. for + # the first linear, shape is (batch, l1.out_features). + self.assertGreater(tap_value.numel(), 0) + + def test_abs_max_only_returns_scalar(self): + specs, flat, _, _ = self._run_pipeline(ABS_MAX_ONLY) + self.assertEqual(len(specs), 2) + for spec in specs: + tap_value = flat[spec.output_index] + self.assertIsInstance(tap_value, torch.Tensor) + # 0-dim scalar + self.assertEqual(tap_value.numel(), 1) + self.assertGreaterEqual(float(tap_value), 0.0) + + def test_min_max_mean_e2e(self): + specs, flat, _, _ = self._run_pipeline(MIN_MAX_MEAN) + self.assertEqual(len(specs), 2) + for spec in specs: + tap_value = flat[spec.output_index] + self.assertEqual(tap_value.numel(), 3) + + def test_default_stats_returns_seven_floats(self): + specs, flat, _, _ = self._run_pipeline(DEFAULT_STATS) + self.assertEqual(len(specs), 2) + for spec in specs: + tap_value = flat[spec.output_index] + self.assertIsInstance(tap_value, torch.Tensor) + self.assertEqual(tap_value.numel(), 4) + mn, mx, _, abs_max = tap_value.tolist() + self.assertLessEqual(mn, mx) + self.assertGreaterEqual(abs_max, max(abs(mn), abs(mx)) - 1e-5) + + def test_user_outputs_still_correct(self): + """Tap outputs must not corrupt the original user outputs.""" + specs, flat, model, example_inputs = self._run_pipeline(FULL_TENSOR) + + eager_out = model(*example_inputs) + # User output is at index 0 (one user output for our MLP). + user_out = flat[0] + torch.testing.assert_close(user_out, eager_out, atol=1e-3, rtol=1e-3) + # Verify tap indices are non-overlapping with user-output index 0. + for spec in specs: + self.assertGreaterEqual(spec.output_index, 1) From 4ef4ff554dfa5426ec9ae40ce35052ff07de7244 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Tue, 5 May 2026 18:34:34 -0700 Subject: [PATCH 2/3] up --- devtools/intermediate_output_tap/__init__.py | 2 + .../intermediate_output_tap/_convenience.py | 80 ++++++++++++++++++- devtools/intermediate_output_tap/_reducers.py | 36 ++++++++- devtools/intermediate_output_tap/_spec.py | 1 + devtools/intermediate_output_tap/_tap_pass.py | 12 +++ 5 files changed, 128 insertions(+), 3 deletions(-) diff --git a/devtools/intermediate_output_tap/__init__.py b/devtools/intermediate_output_tap/__init__.py index c1f85509409..03221ad4c76 100644 --- a/devtools/intermediate_output_tap/__init__.py +++ b/devtools/intermediate_output_tap/__init__.py @@ -37,6 +37,7 @@ custom_ops_lib, # noqa: F401 ensures torch.ops.executorch_devtools.tap is registered ) from executorch.devtools.intermediate_output_tap._convenience import ( + compare_aot_runtime_dataframe, format_tap_dataframe, specs_to_dataframe, tap_all_and_run, @@ -77,6 +78,7 @@ "tap_all_and_run", "specs_to_dataframe", "format_tap_dataframe", + "compare_aot_runtime_dataframe", # Reducers "StatReducer", "FULL_TENSOR", diff --git a/devtools/intermediate_output_tap/_convenience.py b/devtools/intermediate_output_tap/_convenience.py index 2d4481b9bb2..a2a884b9ec9 100644 --- a/devtools/intermediate_output_tap/_convenience.py +++ b/devtools/intermediate_output_tap/_convenience.py @@ -27,7 +27,10 @@ import pandas as pd import torch -from executorch.devtools.intermediate_output_tap._reducers import StatReducer +from executorch.devtools.intermediate_output_tap._reducers import ( + get_reducer, + StatReducer, +) from executorch.devtools.intermediate_output_tap._selectors import ( NodeSelector, select_all_call_function, @@ -219,3 +222,78 @@ def _to_float_list(v: Any) -> list[float]: out.append(float("nan")) return out return [] + + +def _flat_floats(v: Any) -> list[float]: + """Flatten a tap value (tensor / list / scalar) to a flat list of floats.""" + if isinstance(v, torch.Tensor): + return [float(x) for x in v.detach().to(torch.float32).cpu().reshape(-1).tolist()] + if isinstance(v, (list, tuple)): + out: list[float] = [] + for x in v: + out.extend(_flat_floats(x)) + return out + try: + return [float(v)] + except (TypeError, ValueError): + return [] + + +def compare_aot_runtime_dataframe( + specs: Sequence[TapSpec], + aot_flat: Sequence[Any], + rt_flat: Sequence[Any], +) -> pd.DataFrame: + """ + Build a side-by-side AOT-vs-runtime DataFrame from the flat outputs of + the *tapped* ExportedProgram (eager) and the post-strip runtime program. + + AOT side: + `aot_flat[spec.output_index]` is the **raw** tapped tensor — at eager + time `tap.Tensor` is identity, so the output is the source op's + output. We apply the reducer's `eager` callable to reproduce what + `strip_taps_` materialises in the runtime graph. + + Runtime side: + `rt_flat[spec.output_index]` already contains the reduced 1-D tensor + (or original tensor for FULL_TENSOR). + + Returns one row per spec with columns: + node_name, op_target, reducer_name, output_index, + aot_, rt_, aot_, rt_, ... + """ + rows: list[dict[str, Any]] = [] + for spec in specs: + aot_raw = aot_flat[spec.output_index] + rt_raw = rt_flat[spec.output_index] + + # AOT raw might be wrapped in a 1-tuple; unwrap first tensor. + if not isinstance(aot_raw, torch.Tensor) and isinstance( + aot_raw, (list, tuple) + ) and aot_raw: + aot_raw = aot_raw[0] + + reducer = get_reducer(spec.reducer_name) + if isinstance(aot_raw, torch.Tensor): + aot_reduced = reducer.eager(aot_raw) + else: + aot_reduced = aot_raw + + aot_vals = _flat_floats(aot_reduced) + rt_vals = _flat_floats(rt_raw) + + fields = list(spec.fields) if spec.fields else [ + f"v{i}" for i in range(max(len(aot_vals), len(rt_vals))) + ] + row: dict[str, Any] = { + "node_name": spec.node_name, + "module_path": spec.module_path, + "op_target": spec.op_target, + "reducer_name": spec.reducer_name, + "output_index": spec.output_index, + } + for i, f in enumerate(fields): + row[f"aot_{f}"] = aot_vals[i] if i < len(aot_vals) else float("nan") + row[f"rt_{f}"] = rt_vals[i] if i < len(rt_vals) else float("nan") + rows.append(row) + return pd.DataFrame(rows) diff --git a/devtools/intermediate_output_tap/_reducers.py b/devtools/intermediate_output_tap/_reducers.py index 45a4f62578e..b2d19c00c66 100644 --- a/devtools/intermediate_output_tap/_reducers.py +++ b/devtools/intermediate_output_tap/_reducers.py @@ -42,13 +42,18 @@ # --- Reducer dataclass --------------------------------------------------- EmitFn = Callable[["fx.Graph", "fx.Node"], "fx.Node"] +EagerFn = Callable[[torch.Tensor], torch.Tensor] @dataclass(frozen=True) class StatReducer: """ - A reducer specification. `emit` is invoked by `strip_taps_` to materialise - the reducer subgraph in the post-lowering graph. + A reducer specification. + + `emit` is invoked by `strip_taps_` to materialise the reducer subgraph + in the post-lowering FX graph. `eager` is the equivalent pure-torch + implementation, used by callers that want to reproduce what the runtime + will compute (e.g. AOT-vs-runtime comparisons without a debugger). `name` is what the user types and what's stored on each TapSpec. `fields` enumerates the columns of the 1-D output tensor (empty for @@ -58,6 +63,7 @@ class StatReducer: name: str fields: tuple[str, ...] emit: EmitFn + eager: EagerFn # --- Helpers ------------------------------------------------------------- @@ -96,10 +102,15 @@ def _emit_full_tensor(_graph: "fx.Graph", src: "fx.Node") -> "fx.Node": return src +def _eager_full_tensor(t: torch.Tensor) -> torch.Tensor: + return t.detach() + + FULL_TENSOR: StatReducer = StatReducer( name="FULL_TENSOR", fields=(), emit=_emit_full_tensor, + eager=_eager_full_tensor, ) @@ -109,10 +120,16 @@ def _emit_abs_max(graph: "fx.Graph", src: "fx.Node") -> "fx.Node": return _scalar_node(graph, exir_ops.edge.aten.amax.default, abs_x) +def _eager_abs_max(t: torch.Tensor) -> torch.Tensor: + f = t.detach().to(torch.float32) + return f.abs().amax() + + ABS_MAX_ONLY: StatReducer = StatReducer( name="ABS_MAX_ONLY", fields=("abs_max",), emit=_emit_abs_max, + eager=_eager_abs_max, ) @@ -124,10 +141,16 @@ def _emit_min_max_mean(graph: "fx.Graph", src: "fx.Node") -> "fx.Node": return _stack(graph, [mn, mx, me]) +def _eager_min_max_mean(t: torch.Tensor) -> torch.Tensor: + f = t.detach().to(torch.float32) + return torch.stack([f.amin(), f.amax(), f.mean()], dim=0) + + MIN_MAX_MEAN: StatReducer = StatReducer( name="MIN_MAX_MEAN", fields=("min", "max", "mean"), emit=_emit_min_max_mean, + eager=_eager_min_max_mean, ) @@ -152,10 +175,19 @@ def _emit_default_stats(graph: "fx.Graph", src: "fx.Node") -> "fx.Node": return _stack(graph, [mn, mx, me, abs_max]) +def _eager_default_stats(t: torch.Tensor) -> torch.Tensor: + f = t.detach().to(torch.float32) + return torch.stack( + [f.amin(), f.amax(), f.mean(), f.abs().amax()], + dim=0, + ) + + DEFAULT_STATS: StatReducer = StatReducer( name="DEFAULT_STATS", fields=("min", "max", "mean", "abs_max"), emit=_emit_default_stats, + eager=_eager_default_stats, ) diff --git a/devtools/intermediate_output_tap/_spec.py b/devtools/intermediate_output_tap/_spec.py index 71b3c7f526e..c043a142913 100644 --- a/devtools/intermediate_output_tap/_spec.py +++ b/devtools/intermediate_output_tap/_spec.py @@ -58,3 +58,4 @@ class TapSpec: fields: tuple[str, ...] stack_trace: str | None = None reducer_node_name: str | None = None + module_path: str | None = None diff --git a/devtools/intermediate_output_tap/_tap_pass.py b/devtools/intermediate_output_tap/_tap_pass.py index bb7b96661d3..aa4a30bf6c0 100644 --- a/devtools/intermediate_output_tap/_tap_pass.py +++ b/devtools/intermediate_output_tap/_tap_pass.py @@ -141,6 +141,17 @@ def tap_intermediate_outputs( tap_node.meta["source_node"] = node.name new_tap_nodes.append(tap_node) + # Leaf module FQN from nn_module_stack (e.g., "layers.0.attention.wqs.0"). + module_path: str | None = None + stack = node.meta.get("nn_module_stack") + if stack: + try: + last_entry = list(stack.values())[-1] + module_path = ( + last_entry[0] if isinstance(last_entry, tuple) else str(last_entry) + ) + except Exception: + module_path = None specs.append( TapSpec( node_name=node.name, @@ -150,6 +161,7 @@ def tap_intermediate_outputs( reducer_name=reducer_obj.name, fields=reducer_obj.fields, stack_trace=node.meta.get("stack_trace"), + module_path=module_path, ) ) From 7ad210a3055725449885b29cc6f2df671f4fd378 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Wed, 6 May 2026 12:12:01 -0700 Subject: [PATCH 3/3] up --- devtools/intermediate_output_tap/__init__.py | 17 +- .../intermediate_output_tap/_convenience.py | 117 ++++---- devtools/intermediate_output_tap/_reducers.py | 242 ++++++++++++----- devtools/intermediate_output_tap/_spec.py | 2 +- .../intermediate_output_tap/_strip_pass.py | 6 +- devtools/intermediate_output_tap/_tap_pass.py | 29 +- .../intermediate_output_tap/custom_ops_lib.py | 53 ++-- .../tests/test_coreml_e2e.py | 252 ++++++++++++++++++ .../tests/test_inspector_integration.py | 25 +- .../tests/test_reducers.py | 112 ++++++-- .../tests/test_selectors.py | 4 +- .../tests/test_strip_pass.py | 4 +- .../tests/test_tap_pass.py | 12 +- .../tests/test_xnnpack_e2e.py | 61 ++--- 14 files changed, 701 insertions(+), 235 deletions(-) create mode 100644 devtools/intermediate_output_tap/tests/test_coreml_e2e.py diff --git a/devtools/intermediate_output_tap/__init__.py b/devtools/intermediate_output_tap/__init__.py index 03221ad4c76..870e43f1375 100644 --- a/devtools/intermediate_output_tap/__init__.py +++ b/devtools/intermediate_output_tap/__init__.py @@ -20,11 +20,11 @@ Typical usage: from executorch.devtools.intermediate_output_tap import ( - tap_intermediate_outputs, strip_taps_, DEFAULT_STATS, + tap_intermediate_outputs, strip_taps_, STATS, ) ep = export(model, example_inputs) - ep_tapped, specs = tap_intermediate_outputs(ep, reducer=DEFAULT_STATS) + ep_tapped, specs = tap_intermediate_outputs(ep, reducer=STATS) edge = to_edge_transform_and_lower(ep_tapped, partitioner=[XnnpackPartitioner()]) strip_taps_(edge) et_program = edge.to_executorch() @@ -33,9 +33,8 @@ df = inspector.calculate_numeric_gap_from_taps(flat_outputs, specs) """ -from executorch.devtools.intermediate_output_tap import ( - custom_ops_lib, # noqa: F401 ensures torch.ops.executorch_devtools.tap is registered -) +# Importing this module registers torch.ops.executorch_devtools.tap.Tensor. +from executorch.devtools.intermediate_output_tap import custom_ops_lib # noqa: F401 from executorch.devtools.intermediate_output_tap._convenience import ( compare_aot_runtime_dataframe, format_tap_dataframe, @@ -43,12 +42,10 @@ tap_all_and_run, ) from executorch.devtools.intermediate_output_tap._reducers import ( - ABS_MAX_ONLY, - DEFAULT_STATS, FULL_TENSOR, get_reducer, - MIN_MAX_MEAN, StatReducer, + STATS, ) from executorch.devtools.intermediate_output_tap._selectors import ( NodeSelector, @@ -82,9 +79,7 @@ # Reducers "StatReducer", "FULL_TENSOR", - "ABS_MAX_ONLY", - "MIN_MAX_MEAN", - "DEFAULT_STATS", + "STATS", "get_reducer", # Selectors "NodeSelector", diff --git a/devtools/intermediate_output_tap/_convenience.py b/devtools/intermediate_output_tap/_convenience.py index a2a884b9ec9..c95dacc8add 100644 --- a/devtools/intermediate_output_tap/_convenience.py +++ b/devtools/intermediate_output_tap/_convenience.py @@ -7,15 +7,19 @@ # pyre-unsafe """ -One-line convenience wrapper for the most common smoke-test workflow: - - df = tap_all_and_run(model, example_inputs, partitioner=[XnnpackPartitioner()]) - -Exports `model`, taps every call_function, lowers with the user's partitioner, -runs through the ExecuTorch runtime, and returns a pandas DataFrame of one row -per tap (one column per stat field). No Inspector setup, no ETRecord. For -AOT-vs-runtime numerical comparison, use Inspector.calculate_numeric_gap_from_taps, -then `format_tap_dataframe(df, tap_specs)` to get a friendly view. +Convenience helpers built on top of `tap_intermediate_outputs` / `strip_taps_`: + +* `tap_all_and_run`: one-shot smoke-test wrapper that exports a model, taps + every call_function, lowers with the user's partitioner, runs through the + ExecuTorch runtime, and returns a per-tap DataFrame. +* `specs_to_dataframe`: build a per-tap DataFrame from a tap_specs list and + the runtime's flat output tuple. +* `compare_aot_runtime_dataframe`: side-by-side AOT-vs-runtime DataFrame from + the flat outputs of the *tapped* ExportedProgram (eager) and the post-strip + runtime program. The simpler alternative to `Inspector.calculate_numeric_gap_from_taps` + when you don't need ETDump/ETRecord plumbing. +* `format_tap_dataframe`: reshape the raw DataFrame returned by + `Inspector.calculate_numeric_gap_from_taps` into a friendlier per-tap view. """ from __future__ import annotations @@ -27,10 +31,7 @@ import pandas as pd import torch -from executorch.devtools.intermediate_output_tap._reducers import ( - get_reducer, - StatReducer, -) +from executorch.devtools.intermediate_output_tap._reducers import StatReducer from executorch.devtools.intermediate_output_tap._selectors import ( NodeSelector, select_all_call_function, @@ -46,7 +47,7 @@ def tap_all_and_run( model: torch.nn.Module, example_inputs: tuple[Any, ...], partitioner: list | None = None, - reducer: str | StatReducer = "DEFAULT_STATS", + reducer: str | StatReducer = "STATS", selector: NodeSelector | None = None, skip_if_no_debug_handle: bool = True, ) -> pd.DataFrame: @@ -67,9 +68,7 @@ def tap_all_and_run( reducer=reducer, skip_if_no_debug_handle=skip_if_no_debug_handle, ) - edge = to_edge_transform_and_lower( - ep_tapped, partitioner=partitioner or [] - ) + edge = to_edge_transform_and_lower(ep_tapped, partitioner=partitioner or []) strip_taps_(edge) et_program = edge.to_executorch() @@ -147,9 +146,7 @@ def format_tap_dataframe( """ # Map reducer_node_name -> spec for quick lookup. name_to_spec: dict[str, TapSpec] = { - s.reducer_node_name: s - for s in tap_specs - if s.reducer_node_name is not None + s.reducer_node_name: s for s in tap_specs if s.reducer_node_name is not None } rows = [] @@ -227,7 +224,9 @@ def _to_float_list(v: Any) -> list[float]: def _flat_floats(v: Any) -> list[float]: """Flatten a tap value (tensor / list / scalar) to a flat list of floats.""" if isinstance(v, torch.Tensor): - return [float(x) for x in v.detach().to(torch.float32).cpu().reshape(-1).tolist()] + return [ + float(x) for x in v.detach().to(torch.float32).cpu().reshape(-1).tolist() + ] if isinstance(v, (list, tuple)): out: list[float] = [] for x in v: @@ -239,6 +238,22 @@ def _flat_floats(v: Any) -> list[float]: return [] +def _sqnr_db(aot_vals: list[float], rt_vals: list[float]) -> float: + """Signal-to-quantization-noise ratio in dB. Higher is better. + + Thin wrapper around `torch.ao.ns.fx.utils.compute_sqnr` (the canonical + implementation already used by `backends/test/harness/error_statistics.py`). + """ + from torch.ao.ns.fx.utils import compute_sqnr + + n = min(len(aot_vals), len(rt_vals)) + if n == 0: + return float("nan") + aot_t = torch.tensor(aot_vals[:n], dtype=torch.float32) + rt_t = torch.tensor(rt_vals[:n], dtype=torch.float32) + return float(compute_sqnr(rt_t, aot_t)) + + def compare_aot_runtime_dataframe( specs: Sequence[TapSpec], aot_flat: Sequence[Any], @@ -248,43 +263,21 @@ def compare_aot_runtime_dataframe( Build a side-by-side AOT-vs-runtime DataFrame from the flat outputs of the *tapped* ExportedProgram (eager) and the post-strip runtime program. - AOT side: - `aot_flat[spec.output_index]` is the **raw** tapped tensor — at eager - time `tap.Tensor` is identity, so the output is the source op's - output. We apply the reducer's `eager` callable to reproduce what - `strip_taps_` materialises in the runtime graph. + Both `aot_flat[spec.output_index]` and `rt_flat[spec.output_index]` already + contain the *reduced* tap value, since `tap.Tensor`'s eager impl applies + the named reducer (see `custom_ops_lib.py`). - Runtime side: - `rt_flat[spec.output_index]` already contains the reduced 1-D tensor - (or original tensor for FULL_TENSOR). - - Returns one row per spec with columns: - node_name, op_target, reducer_name, output_index, - aot_, rt_, aot_, rt_, ... + Output columns per spec: + - For non-FULL_TENSOR reducers: one `aot_` and `rt_` per + reducer field (e.g. `aot_min`, `rt_min`, ...). + - For FULL_TENSOR: `sqnr_db` (signal-to-noise of aot vs rt over the + whole tensor, in dB) """ rows: list[dict[str, Any]] = [] for spec in specs: - aot_raw = aot_flat[spec.output_index] - rt_raw = rt_flat[spec.output_index] - - # AOT raw might be wrapped in a 1-tuple; unwrap first tensor. - if not isinstance(aot_raw, torch.Tensor) and isinstance( - aot_raw, (list, tuple) - ) and aot_raw: - aot_raw = aot_raw[0] - - reducer = get_reducer(spec.reducer_name) - if isinstance(aot_raw, torch.Tensor): - aot_reduced = reducer.eager(aot_raw) - else: - aot_reduced = aot_raw + aot_vals = _flat_floats(aot_flat[spec.output_index]) + rt_vals = _flat_floats(rt_flat[spec.output_index]) - aot_vals = _flat_floats(aot_reduced) - rt_vals = _flat_floats(rt_raw) - - fields = list(spec.fields) if spec.fields else [ - f"v{i}" for i in range(max(len(aot_vals), len(rt_vals))) - ] row: dict[str, Any] = { "node_name": spec.node_name, "module_path": spec.module_path, @@ -292,8 +285,20 @@ def compare_aot_runtime_dataframe( "reducer_name": spec.reducer_name, "output_index": spec.output_index, } - for i, f in enumerate(fields): - row[f"aot_{f}"] = aot_vals[i] if i < len(aot_vals) else float("nan") - row[f"rt_{f}"] = rt_vals[i] if i < len(rt_vals) else float("nan") + + if spec.reducer_name == "FULL_TENSOR": + row["sqnr_db"] = _sqnr_db(aot_vals, rt_vals) + row["aot_numel"] = len(aot_vals) + row["rt_numel"] = len(rt_vals) + else: + fields = ( + list(spec.fields) + if spec.fields + else [f"v{i}" for i in range(max(len(aot_vals), len(rt_vals)))] + ) + for i, f in enumerate(fields): + row[f"aot_{f}"] = aot_vals[i] if i < len(aot_vals) else float("nan") + row[f"rt_{f}"] = rt_vals[i] if i < len(rt_vals) else float("nan") + rows.append(row) return pd.DataFrame(rows) diff --git a/devtools/intermediate_output_tap/_reducers.py b/devtools/intermediate_output_tap/_reducers.py index b2d19c00c66..481dfb2f55e 100644 --- a/devtools/intermediate_output_tap/_reducers.py +++ b/devtools/intermediate_output_tap/_reducers.py @@ -17,16 +17,26 @@ just before the placeholder, using the source tensor `src_node` as input, and returns the final node whose output replaces the placeholder's output. +`eager(tensor) -> tensor` is the pure-torch equivalent that callers can use +to reproduce, in eager mode, what the runtime will compute. `tap.Tensor`'s +own dispatch impl uses this to produce the *reduced* value at AOT time, so +that `ep.module()(*inputs)` returns the same flat outputs as the runtime. + The emit functions cast to fp32 first for cross-backend numerical stability -and use full-tensor reductions (no `dim=`) so the result is a stable shape -regardless of the source tensor's rank. +and produce a fixed-shape output (0-D or 1-D) regardless of the source +tensor's rank, so callers don't need to track per-tap shapes. + +We ship two built-ins: -For v1 we ship: FULL_TENSOR, ABS_MAX_ONLY, MIN_MAX_MEAN, DEFAULT_STATS. -HISTOGRAM_64 is deferred (`aten.histc` has restricted edge support). +* `FULL_TENSOR` — identity. The whole source tensor is surfaced. +* `STATS` — a comprehensive bundle of debugging-friendly scalars: + min, max, mean, abs_max, abs_mean, std, rms, l1_norm, l2_norm, + nan_count, inf_count, zero_count, p99_abs. """ from __future__ import annotations +import operator from collections.abc import Callable from dataclasses import dataclass from typing import TYPE_CHECKING @@ -57,7 +67,7 @@ class StatReducer: `name` is what the user types and what's stored on each TapSpec. `fields` enumerates the columns of the 1-D output tensor (empty for - FULL_TENSOR, which preserves the original tensor shape). + FULL_TENSOR which preserves a tensor of values). """ name: str @@ -71,8 +81,6 @@ class StatReducer: def _cast_fp32(graph: "fx.Graph", x: "fx.Node") -> "fx.Node": """Insert a fp32 cast (no-op semantically if already fp32).""" - # exir_ops.edge.dim_order_ops._to_dim_order_copy.default exists for edge dialect, - # but the simpler aten._to_copy variant is broadly supported. return graph.call_function( exir_ops.edge.aten._to_copy.default, args=(x,), @@ -94,11 +102,37 @@ def _stack(graph: "fx.Graph", scalars: list["fx.Node"]) -> "fx.Node": ) -# --- Built-in reducers --------------------------------------------------- +def _abs(graph: "fx.Graph", x: "fx.Node") -> "fx.Node": + return graph.call_function(exir_ops.edge.aten.abs.default, args=(x,)) + + +def _square(graph: "fx.Graph", x: "fx.Node") -> "fx.Node": + return graph.call_function(exir_ops.edge.aten.pow.Tensor_Scalar, args=(x, 2.0)) + + +def _sqrt(graph: "fx.Graph", x: "fx.Node") -> "fx.Node": + return graph.call_function(exir_ops.edge.aten.sqrt.default, args=(x,)) + + +def _full_sum(graph: "fx.Graph", x: "fx.Node") -> "fx.Node": + """Full-tensor sum via aten.sum.dim_IntList(dim=[]) — portable + has out variant.""" + return graph.call_function(exir_ops.edge.aten.sum.dim_IntList, args=(x, [])) + + +def _bool_to_fp32_count(graph: "fx.Graph", mask: "fx.Node") -> "fx.Node": + """Sum of a bool mask cast to fp32 → a 0-d fp32 count.""" + casted = graph.call_function( + exir_ops.edge.aten._to_copy.default, + args=(mask,), + kwargs={"dtype": torch.float32}, + ) + return _full_sum(graph, casted) + + +# --- FULL_TENSOR --------------------------------------------------------- def _emit_full_tensor(_graph: "fx.Graph", src: "fx.Node") -> "fx.Node": - """Identity — return the source node directly. strip_taps_ will splice.""" return src @@ -114,89 +148,163 @@ def _eager_full_tensor(t: torch.Tensor) -> torch.Tensor: ) -def _emit_abs_max(graph: "fx.Graph", src: "fx.Node") -> "fx.Node": - f = _cast_fp32(graph, src) - abs_x = graph.call_function(exir_ops.edge.aten.abs.default, args=(f,)) - return _scalar_node(graph, exir_ops.edge.aten.amax.default, abs_x) - - -def _eager_abs_max(t: torch.Tensor) -> torch.Tensor: - f = t.detach().to(torch.float32) - return f.abs().amax() +# --- STATS --------------------------------------------------------------- -ABS_MAX_ONLY: StatReducer = StatReducer( - name="ABS_MAX_ONLY", - fields=("abs_max",), - emit=_emit_abs_max, - eager=_eager_abs_max, +_STATS_FIELDS: tuple[str, ...] = ( + "min", + "max", + "mean", + "abs_max", + "abs_mean", + "std", + "rms", + "l1_norm", + "l2_norm", + "nan_count", + "inf_count", + "zero_count", + "p99_abs", ) -def _emit_min_max_mean(graph: "fx.Graph", src: "fx.Node") -> "fx.Node": +def _emit_stats(graph: "fx.Graph", src: "fx.Node") -> "fx.Node": f = _cast_fp32(graph, src) - mn = _scalar_node(graph, exir_ops.edge.aten.amin.default, f) - mx = _scalar_node(graph, exir_ops.edge.aten.amax.default, f) - me = _scalar_node(graph, exir_ops.edge.aten.mean.default, f) - return _stack(graph, [mn, mx, me]) - - -def _eager_min_max_mean(t: torch.Tensor) -> torch.Tensor: - f = t.detach().to(torch.float32) - return torch.stack([f.amin(), f.amax(), f.mean()], dim=0) - - -MIN_MAX_MEAN: StatReducer = StatReducer( - name="MIN_MAX_MEAN", - fields=("min", "max", "mean"), - emit=_emit_min_max_mean, - eager=_eager_min_max_mean, -) - - -def _emit_default_stats(graph: "fx.Graph", src: "fx.Node") -> "fx.Node": - """ - Default stats: (min, max, mean, abs_max) — 4 floats. + abs_f = _abs(graph, f) + sq_f = _square(graph, f) - NOTE: nan_count/inf_count/std are intentionally excluded because the - underlying portable kernels (`isnan`, `isinf`, `sum.dtype`, `std.*`) - don't all have out variants registered in ExecuTorch's default runtime - op table, which fails memory planning or runtime method-load. If you - need them, supply a custom StatReducer. - """ - f = _cast_fp32(graph, src) mn = _scalar_node(graph, exir_ops.edge.aten.amin.default, f) mx = _scalar_node(graph, exir_ops.edge.aten.amax.default, f) me = _scalar_node(graph, exir_ops.edge.aten.mean.default, f) + abs_max = _scalar_node(graph, exir_ops.edge.aten.amax.default, abs_f) + abs_mean = _scalar_node(graph, exir_ops.edge.aten.mean.default, abs_f) + + sum_sq = _full_sum(graph, sq_f) + mean_sq = _scalar_node(graph, exir_ops.edge.aten.mean.default, sq_f) + rms = _sqrt(graph, mean_sq) - abs_x = graph.call_function(exir_ops.edge.aten.abs.default, args=(f,)) - abs_max = _scalar_node(graph, exir_ops.edge.aten.amax.default, abs_x) + # std = sqrt( E[x^2] - E[x]^2 ); avoids aten.var which lacks an out variant. + me_sq_scalar = graph.call_function( + exir_ops.edge.aten.pow.Tensor_Scalar, args=(me, 2.0) + ) + var = graph.call_function( + exir_ops.edge.aten.sub.Tensor, args=(mean_sq, me_sq_scalar) + ) + # Variance can be slightly negative due to fp roundoff; clamp at 0 via abs. + var = graph.call_function(exir_ops.edge.aten.abs.default, args=(var,)) + std = _sqrt(graph, var) + + l1 = _full_sum(graph, abs_f) + l2 = _sqrt(graph, sum_sq) + + nan_mask = graph.call_function(exir_ops.edge.aten.isnan.default, args=(f,)) + nan_count = _bool_to_fp32_count(graph, nan_mask) + + inf_mask = graph.call_function(exir_ops.edge.aten.isinf.default, args=(f,)) + inf_count = _bool_to_fp32_count(graph, inf_mask) + + zero_mask = graph.call_function(exir_ops.edge.aten.eq.Scalar, args=(f, 0.0)) + zero_count = _bool_to_fp32_count(graph, zero_mask) + + # p99_abs: use topk on flattened |x| to get the k-th largest, where + # k = max(1, ceil(numel * 0.01)). Numel is read from the source's + # FakeTensor at graph-build time. + fake = src.meta.get("val") + numel = int(fake.numel()) if fake is not None else 1 + k = max(1, (numel + 99) // 100) # ceil(numel/100) + abs_flat = graph.call_function( + exir_ops.edge.aten.view_copy.default, args=(abs_f, [-1]) + ) + topk_out = graph.call_function( + exir_ops.edge.aten.topk.default, + args=(abs_flat, k), + kwargs={"dim": -1, "largest": True, "sorted": True}, + ) + topk_values = graph.call_function(operator.getitem, args=(topk_out, 0)) + p99_abs = graph.call_function( + exir_ops.edge.aten.select_copy.int, args=(topk_values, 0, k - 1) + ) - return _stack(graph, [mn, mx, me, abs_max]) + return _stack( + graph, + [ + mn, + mx, + me, + abs_max, + abs_mean, + std, + rms, + l1, + l2, + nan_count, + inf_count, + zero_count, + p99_abs, + ], + ) -def _eager_default_stats(t: torch.Tensor) -> torch.Tensor: +def _eager_stats(t: torch.Tensor) -> torch.Tensor: f = t.detach().to(torch.float32) + abs_f = f.abs() + sq = f.pow(2.0) + + # std via E[x^2] - E[x]^2 (population variance) to match the emit subgraph. + if f.numel() > 0: + var = (sq.mean() - f.mean().pow(2)).abs() + std = var.sqrt() + else: + std = torch.tensor(0.0) + + sum_sq = sq.sum() + rms = sq.mean().sqrt() + l1 = abs_f.sum() + l2 = sum_sq.sqrt() + + nan_count = torch.isnan(f).to(torch.float32).sum() + inf_count = torch.isinf(f).to(torch.float32).sum() + zero_count = (f == 0).to(torch.float32).sum() + + numel = f.numel() + k = max(1, (numel + 99) // 100) + if numel > 0: + topk_vals = torch.topk(abs_f.reshape(-1), k=k, largest=True, sorted=True).values + p99_abs = topk_vals[k - 1] + else: + p99_abs = torch.tensor(float("nan")) + return torch.stack( - [f.amin(), f.amax(), f.mean(), f.abs().amax()], + [ + f.amin(), + f.amax(), + f.mean(), + abs_f.amax(), + abs_f.mean(), + std, + rms, + l1, + l2, + nan_count, + inf_count, + zero_count, + p99_abs, + ], dim=0, ) -DEFAULT_STATS: StatReducer = StatReducer( - name="DEFAULT_STATS", - fields=("min", "max", "mean", "abs_max"), - emit=_emit_default_stats, - eager=_eager_default_stats, +STATS: StatReducer = StatReducer( + name="STATS", + fields=_STATS_FIELDS, + emit=_emit_stats, + eager=_eager_stats, ) # --- Registry ------------------------------------------------------------- -_BUILTIN_REDUCERS: dict[str, StatReducer] = { - r.name: r - for r in (FULL_TENSOR, ABS_MAX_ONLY, MIN_MAX_MEAN, DEFAULT_STATS) -} +_BUILTIN_REDUCERS: dict[str, StatReducer] = {r.name: r for r in (FULL_TENSOR, STATS)} def get_reducer(name_or_reducer: str | StatReducer) -> StatReducer: diff --git a/devtools/intermediate_output_tap/_spec.py b/devtools/intermediate_output_tap/_spec.py index c043a142913..db035b7c17e 100644 --- a/devtools/intermediate_output_tap/_spec.py +++ b/devtools/intermediate_output_tap/_spec.py @@ -37,7 +37,7 @@ class TapSpec: tuple where this tap's value lands. Computed at AOT time and stable through `to_edge` / `to_backend` / `to_executorch` because we only ever *append* to the output node and `OutputSpec`. - reducer_name: Name of the StatReducer used (e.g. "DEFAULT_STATS"). + reducer_name: Name of the StatReducer used (e.g. "STATS"). fields: Names of the per-element fields in the reducer's output tensor (e.g. ("min", "max", "abs_max")). Empty tuple for FULL_TENSOR. stack_trace: `node.meta["stack_trace"]` of the source node if present, diff --git a/devtools/intermediate_output_tap/_strip_pass.py b/devtools/intermediate_output_tap/_strip_pass.py index 5047dba0cff..1b1d58040b4 100644 --- a/devtools/intermediate_output_tap/_strip_pass.py +++ b/devtools/intermediate_output_tap/_strip_pass.py @@ -9,7 +9,7 @@ """ Post-`to_backend` pass: replace each `executorch_devtools::tap.Tensor` node with either an identity edge (FULL_TENSOR) or a portable reducer subgraph -(DEFAULT_STATS, MIN_MAX_MEAN, ABS_MAX_ONLY). +(STATS, or any user-supplied StatReducer). Pattern stolen from `remove_graph_break_` in `executorch/examples/apple/coreml/llama/export_static_llm_coreml.py`. @@ -138,9 +138,7 @@ def _strip_taps_in_graph_module(gm: fx.GraphModule) -> list[fx.Node | None]: replacement.meta["debug_handle"] = next_handle next_handle += 1 replacement.meta["is_tap"] = True - replacement.meta["source_node"] = ( - src.name if isinstance(src, fx.Node) else None - ) + replacement.meta["source_node"] = src.name if isinstance(src, fx.Node) else None # `tap` may have ended up in the data path during to_edge's re-trace # (because CompositeExplicitAutograd preserves the op as an identity diff --git a/devtools/intermediate_output_tap/_tap_pass.py b/devtools/intermediate_output_tap/_tap_pass.py index aa4a30bf6c0..f94c30d4c9a 100644 --- a/devtools/intermediate_output_tap/_tap_pass.py +++ b/devtools/intermediate_output_tap/_tap_pass.py @@ -21,15 +21,18 @@ from __future__ import annotations import copy +import warnings from collections.abc import Callable import torch import torch.fx as fx -from executorch.devtools.intermediate_output_tap import custom_ops_lib # noqa: F401 registers tap.Tensor +from executorch.devtools.intermediate_output_tap import ( # noqa: F401 registers tap.Tensor + custom_ops_lib, +) from executorch.devtools.intermediate_output_tap._reducers import ( - DEFAULT_STATS, get_reducer, StatReducer, + STATS, ) from executorch.devtools.intermediate_output_tap._selectors import ( NodeSelector, @@ -50,15 +53,16 @@ def _is_tap_node(n: fx.Node) -> bool: return n.op == "call_function" and n.target is _TAP_TARGET -def tap_intermediate_outputs( +def tap_intermediate_outputs( # noqa: C901 ep: ExportedProgram, selector: NodeSelector | None = None, - reducer: str | StatReducer = DEFAULT_STATS, + reducer: str | StatReducer = STATS, *, tap_name_prefix: str = "tap_", skip_if_no_debug_handle: bool = False, max_taps: int | None = None, inplace: bool = False, + error_on_empty: bool = True, ) -> tuple[ExportedProgram, list[TapSpec]]: """ Rewrite `ep` so each node matching `selector` has its output appended to @@ -75,9 +79,8 @@ def tap_intermediate_outputs( selector: A predicate over fx.Node. Defaults to `select_all_call_function()`. Tap nodes themselves are always excluded so re-running the pass is idempotent. - reducer: Either a built-in reducer name ("DEFAULT_STATS", - "MIN_MAX_MEAN", "ABS_MAX_ONLY", "FULL_TENSOR") or a custom - StatReducer instance. + reducer: Either a built-in reducer name ("STATS", "FULL_TENSOR") + or a custom StatReducer instance. tap_name_prefix: Prefix for the tap nodes' names. Helps when grepping the dumped graph. skip_if_no_debug_handle: If True, only tap nodes that already @@ -87,6 +90,10 @@ def tap_intermediate_outputs( max_taps: Optional cap on number of taps. Helps avoid OOM for very large models. inplace: If False (default), deep-copy `ep` before mutating. + error_on_empty: If True (default), raise `ValueError` when no nodes + match the selector. Set to False to only emit a `UserWarning` + and return `(ep, [])` — handy when iterating on selector + patterns. """ if selector is None: selector = select_all_call_function() @@ -166,6 +173,14 @@ def tap_intermediate_outputs( ) if not new_tap_nodes: + msg = ( + "tap_intermediate_outputs: selector matched 0 nodes. " + "Double-check your selector predicates, " + "or pass `error_on_empty=False` to suppress this error." + ) + if error_on_empty: + raise ValueError(msg) + warnings.warn(msg, UserWarning, stacklevel=2) return ep, [] # Splice new outputs into the graph (mirror weights_to_outputs_pass). diff --git a/devtools/intermediate_output_tap/custom_ops_lib.py b/devtools/intermediate_output_tap/custom_ops_lib.py index 93c4636ad2c..79436915c27 100644 --- a/devtools/intermediate_output_tap/custom_ops_lib.py +++ b/devtools/intermediate_output_tap/custom_ops_lib.py @@ -10,22 +10,41 @@ Custom op registration for the intermediate-output tap mechanism. The op `executorch_devtools::tap.Tensor(Tensor x, str reducer_name, int debug_handle) -> Tensor` -is an identity op whose sole job is to be an unknown-to-every-partitioner FX node -that "uses" a tapped tensor `x`. Because `x` now has a user outside any partition, -every ExecuTorch partitioner must surface `x` as a partition output (this is the -canonical contract enforced in `executorch/exir/lowered_backend_module.py`). - -After `to_edge_transform_and_lower(...)` the tap.Tensor node still exists in the -parent graph; `strip_taps_` (see `_strip_pass.py`) replaces it with either an -identity edge (FULL_TENSOR) or a small reducer subgraph of portable aten ops. - -The dispatch key MUST be `CompositeExplicitAutograd` (not `CompositeImplicitAutograd`) -so the op survives tracing/decomposition; otherwise it would inline at export time -and disappear before partitioning. This mirrors the pattern in +is a placeholder whose sole job is to be an unknown-to-every-partitioner FX +node that "uses" a tapped tensor `x`. Because `x` now has a user outside any +partition, every ExecuTorch partitioner must surface `x` as a partition output +(this is the canonical contract enforced in +`executorch/exir/lowered_backend_module.py`). + +After `to_edge_transform_and_lower(...)` the tap.Tensor node still exists in +the parent graph; `strip_taps_` (see `_strip_pass.py`) replaces it with either +an identity edge (FULL_TENSOR) or a small reducer subgraph of portable aten +ops. + +The op's eager impl computes the named reducer's eager equivalent (e.g. +`min/max/mean/abs_max/...` for STATS). Two reasons for this: + +1. **Re-trace safety.** `to_edge_transform_and_lower` re-traces the graph. + If `tap.Tensor` simply returned `x` literally, the re-traced FX graph + would treat the tap output as identical to `x` and re-route downstream + consumers (which would otherwise be reading `x`) through the tap node, + pulling it into a delegate's input list. Returning a *different* tensor + (different shape for non-FULL_TENSOR; `x.detach()` for FULL_TENSOR) + keeps consumers wired to `x` directly so the tap stays a host-only stub + with the FX `output` node as its sole consumer. +2. **AOT/runtime parity.** Calling `ep_t.module()(*inputs)` then returns the + same reduced values the runtime emits post-strip, removing the need for + callers to reapply the reducer themselves. + +The dispatch key MUST be `CompositeExplicitAutograd` (not +`CompositeImplicitAutograd`) so the op survives tracing/decomposition; +otherwise it would inline at export time and disappear before partitioning. +This mirrors the pattern in `executorch/examples/apple/coreml/llama/export_static_llm_coreml.py`. -`reducer_name` and `debug_handle` are stored as op arguments (not just node.meta) -so they survive any meta-stripping pass between `to_edge` and `strip_taps_`. +`reducer_name` and `debug_handle` are stored as op arguments (not just +node.meta) so they survive any meta-stripping pass between `to_edge` and +`strip_taps_`. """ from __future__ import annotations @@ -40,4 +59,8 @@ @impl(lib, "tap.Tensor", "CompositeExplicitAutograd") def tap_tensor_impl(x, reducer_name, debug_handle): # noqa: ARG001 - return x + # Defer the import to break a module-import cycle (`_reducers` → torch → + # custom_ops_lib registration). + from executorch.devtools.intermediate_output_tap._reducers import get_reducer + + return get_reducer(reducer_name).eager(x) diff --git a/devtools/intermediate_output_tap/tests/test_coreml_e2e.py b/devtools/intermediate_output_tap/tests/test_coreml_e2e.py new file mode 100644 index 00000000000..36daeacef07 --- /dev/null +++ b/devtools/intermediate_output_tap/tests/test_coreml_e2e.py @@ -0,0 +1,252 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +""" +Numerical-debugging tutorial for the ExecuTorch CoreML backend. + +This script walks through end-to-end use of the intermediate-output tap +infrastructure on the static-attention Llama from +`examples/apple/coreml/llama/`. It is a smoke test (random weights, tiny +ModelArgs — no checkpoint download required) and produces two tables: + +1. A delegation summary showing how many subgraphs ExecuTorch handed off + to CoreML and which operators ran on which side. + +2. An AOT-vs-runtime comparison of the tapped intermediate values, so you + can see numerical drift between eager-PyTorch and the CoreML runtime + at hand-picked points in the model. + +Run with: + python swift_play/test_inspector_coreml.py +""" + +import os +import tempfile + +import coremltools as ct +import pandas as pd +import torch +import torch.utils._pytree as pytree +from executorch.backends.apple.coreml.compiler import CoreMLBackend +from executorch.backends.apple.coreml.partition import CoreMLPartitioner +from executorch.devtools.backend_debug import get_delegation_info +from executorch.devtools.intermediate_output_tap import ( + compare_aot_runtime_dataframe, + FULL_TENSOR, + select_all, + select_any, + select_by_module_path, + select_by_op_type, + STATS, + strip_taps_, + tap_intermediate_outputs, +) +import types + +from executorch.examples.apple.coreml.llama.export_static_llm_coreml import ( + _create_example_inputs, + _transform_eager_model, + remove_graph_break_, +) +from executorch.examples.models.llama.llama_transformer import construct_transformer +from executorch.examples.models.llama.model_args import ModelArgs +from executorch.exir import to_edge_transform_and_lower +from executorch.runtime import Runtime, Verification +from torch.export import export + + +def _build_model() -> tuple[torch.nn.Module, ModelArgs]: + """Build a tiny static-attention Llama with random weights.""" + args = ModelArgs( + dim=64, + n_layers=2, + n_heads=4, + n_kv_heads=2, + vocab_size=128, + hidden_dim=128, + max_seq_len=64, + max_context_len=64, + generate_full_logits=True, + ) + args.attention_type = "static_mha" + args.attention_kwargs = {"decompose_sdpa_in_mha": True} + + model = construct_transformer(args) + transform_args = types.SimpleNamespace( + target_split_size=None, + max_splits=8, + embedding_quantize="", + linear_quantize="c4w", + no_graph_breaks=False, + ) + model = _transform_eager_model(model, transform_args, torch.float16) + return model, args + + +def main() -> None: + # ------------------------------------------------------------------ + # Step 1: Build and quantize the model. + # ------------------------------------------------------------------ + model, model_args = _build_model() + + # ------------------------------------------------------------------ + # Step 2: Create example inputs and export. + # ------------------------------------------------------------------ + input_len = 8 + example_inputs, cache_len = _create_example_inputs( + model_args, + input_len=input_len, + max_context_len=model_args.max_context_len, + float_dtype=torch.float16, + ) + print(f"input_len={input_len} cache_len={cache_len}") + + with torch.no_grad(): + _ = model(*example_inputs) # eager sanity check + + print("Exporting...") + ep = export(model, example_inputs) + + # ------------------------------------------------------------------ + # Step 3: Pick which intermediate values to tap. + # + # We use two reducers in two passes: + # + # * FULL_TENSOR for `layers.1.attention.wvs.0` — surfaces the raw + # activation tensor; the comparison DataFrame computes SQNR over + # all elements. + # + # * STATS for everything else (`output`, all wqs/wks linears, layer 0's + # wvs, and layer 1's RMSNorm output mul) — gives a rich set of + # debugging scalars (min/max/mean/std/rms/l1/l2/abs_max/abs_mean/ + # nan_count/inf_count/zero_count/p99_abs). + # ------------------------------------------------------------------ + # Patterns use `*` between `layers..` and the inner module so they match + # both the bare path (`layers..attention...`) and the wrapped path + # (`layers..block.attention...`) that BlockWithGraphBreak introduces + # at the partition boundaries. + selector_full_tensor = select_any( + # Token-embedding output (one big tensor, before any transformer block). + select_by_op_type("aten.embedding.default"), + # First wvs linear in layer 1 — captures full activation post-Q/K/V. + select_all( + select_by_op_type("aten.linear.default"), + select_by_module_path("layers.1.*attention.wvs.*"), + ), + ) + selector_stats = select_any( + select_all( + select_by_op_type("aten.linear.default"), + select_any( + select_by_module_path("output"), + select_by_module_path("*.attention.wqs.*"), + select_by_module_path("*.attention.wks.*"), + select_by_module_path("layers.0.*attention.wvs.*"), + ), + ), + select_all( + select_by_op_type("aten.mul.Tensor"), + select_any( + select_by_module_path("layers.1.*attention_norm"), + select_by_module_path("layers.1.*attention_norm.*"), + select_by_module_path("layers.1.*ffn_norm"), + select_by_module_path("layers.1.*ffn_norm.*"), + ), + ), + ) + + ep_t, specs_full = tap_intermediate_outputs( + ep, selector=selector_full_tensor, reducer=FULL_TENSOR + ) + ep_t, specs_stats = tap_intermediate_outputs( + ep_t, selector=selector_stats, reducer=STATS + ) + specs = list(specs_full) + list(specs_stats) + print( + f"Inserted {len(specs)} tap(s) " + f"({len(specs_full)} FULL_TENSOR + {len(specs_stats)} STATS)." + ) + + # ------------------------------------------------------------------ + # Step 4: Capture the AOT-side reference values. + # + # `tap.Tensor`'s eager impl applies the reducer, so the flat outputs of + # the tapped EP already contain reduced values at the same positions + # the runtime will use. We pytree-flatten because the static-llama + # forward returns nested (logits, (k_caches, v_caches)). + # ------------------------------------------------------------------ + aot_out = ep_t.module()(*example_inputs) + aot_flat, _ = pytree.tree_flatten(aot_out) + + # ------------------------------------------------------------------ + # Step 5: Lower to CoreML, strip the taps, and show what got delegated. + # ------------------------------------------------------------------ + coreml_partitioner = CoreMLPartitioner( + compile_specs=CoreMLBackend.generate_compile_specs( + minimum_deployment_target=ct.target.iOS18, + compute_precision=ct.precision.FLOAT16, + compute_unit=ct.ComputeUnit.CPU_AND_NE, + ), + ) + print("Lowering to CoreML...") + edge = to_edge_transform_and_lower(ep_t, partitioner=[coreml_partitioner]) + # Drop the `executorch_utils::graph_break.Tensor` placeholders that + # `_transform_eager_model` inserted to force partition boundaries — they + # have no out-variant kernel, so they must not survive into the runtime + # program. + remove_graph_break_(edge) + specs = strip_taps_(edge, tap_specs=specs) + + delegation_info = get_delegation_info(edge.exported_program().graph_module) + print( + f"\n=== Delegation summary " + f"(num_delegated_subgraphs={delegation_info.num_delegated_subgraphs}) ===" + ) + print(delegation_info.get_summary()) + with pd.option_context( + "display.max_columns", None, + "display.width", 240, + "display.max_colwidth", 60, + ): + print( + delegation_info.get_operator_delegation_dataframe().to_string(index=False) + ) + + # ------------------------------------------------------------------ + # Step 6: Save the .pte and run it through the ExecuTorch runtime. + # ------------------------------------------------------------------ + et_program = edge.to_executorch() + with tempfile.TemporaryDirectory() as temp_dir: + pte_path = os.path.join(temp_dir, "model.pte") + et_program.save(pte_path) + print(f"\nSaved PTE: {pte_path} ({os.path.getsize(pte_path)} bytes)") + + rt = Runtime.get() + program = rt.load_program(pte_path, verification=Verification.Minimal) + method = program.load_method("forward") + # Runtime takes a flat tensor list — flatten the (tokens, options_dict) + # pytree the same way torch.export did. + flat_inputs, _ = pytree.tree_flatten(example_inputs) + rt_flat = list(method.execute(flat_inputs)) + + # ------------------------------------------------------------------ + # Step 7: Compare AOT vs runtime. + # ------------------------------------------------------------------ + df = compare_aot_runtime_dataframe(specs, aot_flat, rt_flat) + with pd.option_context( + "display.max_columns", None, + "display.width", 280, + "display.max_colwidth", 30, + "display.float_format", "{:.4g}".format, + ): + print(f"\n{len(specs)} tap(s) — AOT vs CoreML runtime:") + print(df.to_string(index=False)) + + +if __name__ == "__main__": + main() diff --git a/devtools/intermediate_output_tap/tests/test_inspector_integration.py b/devtools/intermediate_output_tap/tests/test_inspector_integration.py index b48dcfe4e0a..4247cd34360 100644 --- a/devtools/intermediate_output_tap/tests/test_inspector_integration.py +++ b/devtools/intermediate_output_tap/tests/test_inspector_integration.py @@ -29,14 +29,12 @@ import unittest import torch -from executorch.backends.xnnpack.partition.xnnpack_partitioner import ( - XnnpackPartitioner, -) +from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner from executorch.devtools import generate_etrecord, Inspector from executorch.devtools.intermediate_output_tap import ( - DEFAULT_STATS, format_tap_dataframe, select_by_op_type, + STATS, strip_taps_, tap_intermediate_outputs, ) @@ -55,7 +53,9 @@ def forward(self, x): return self.l2(self.l1(x).relu()) -@unittest.skipIf(sys.platform.startswith("win"), "ExecuTorch runtime not available on Windows") +@unittest.skipIf( + sys.platform.startswith("win"), "ExecuTorch runtime not available on Windows" +) class InspectorIntegrationTest(unittest.TestCase): def test_calculate_numeric_gap_from_taps(self): model = _MLP() @@ -65,7 +65,7 @@ def test_calculate_numeric_gap_from_taps(self): ep_t, specs = tap_intermediate_outputs( ep, selector=select_by_op_type("aten.linear.default"), - reducer=DEFAULT_STATS, + reducer=STATS, ) # Do NOT pass generate_etrecord=True — we'd snapshot the EP while it # still has tap.Tensor nodes (unserializable). @@ -125,11 +125,16 @@ def test_calculate_numeric_gap_from_taps(self): # Print friendly per-tap view to stdout (visible via --print-passing-details). friendly = format_tap_dataframe(df, specs) import pandas as _pd + with _pd.option_context( - "display.max_columns", None, - "display.width", 240, - "display.max_colwidth", 30, - "display.float_format", "{:.4g}".format, + "display.max_columns", + None, + "display.width", + 240, + "display.max_colwidth", + 30, + "display.float_format", + "{:.4g}".format, ): print("\n=== Inspector.calculate_numeric_gap_from_taps (friendly) ===") print(friendly.to_string()) diff --git a/devtools/intermediate_output_tap/tests/test_reducers.py b/devtools/intermediate_output_tap/tests/test_reducers.py index e36762be5ac..9970da265d9 100644 --- a/devtools/intermediate_output_tap/tests/test_reducers.py +++ b/devtools/intermediate_output_tap/tests/test_reducers.py @@ -8,25 +8,27 @@ import unittest +import torch from executorch.devtools.intermediate_output_tap._reducers import ( - ABS_MAX_ONLY, - DEFAULT_STATS, FULL_TENSOR, get_reducer, - MIN_MAX_MEAN, StatReducer, + STATS, ) class ReducersTest(unittest.TestCase): def test_get_reducer_by_name(self): - self.assertIs(get_reducer("DEFAULT_STATS"), DEFAULT_STATS) self.assertIs(get_reducer("FULL_TENSOR"), FULL_TENSOR) - self.assertIs(get_reducer("MIN_MAX_MEAN"), MIN_MAX_MEAN) - self.assertIs(get_reducer("ABS_MAX_ONLY"), ABS_MAX_ONLY) + self.assertIs(get_reducer("STATS"), STATS) def test_get_reducer_passthrough(self): - custom = StatReducer(name="X", fields=("a",), emit=lambda g, n: n) + custom = StatReducer( + name="X", + fields=("a",), + emit=lambda g, n: n, + eager=lambda t: t, + ) self.assertIs(get_reducer(custom), custom) def test_get_reducer_unknown_raises(self): @@ -35,17 +37,95 @@ def test_get_reducer_unknown_raises(self): def test_reducer_field_counts(self): self.assertEqual(FULL_TENSOR.fields, ()) - self.assertEqual(ABS_MAX_ONLY.fields, ("abs_max",)) - self.assertEqual(MIN_MAX_MEAN.fields, ("min", "max", "mean")) self.assertEqual( - DEFAULT_STATS.fields, - ("min", "max", "mean", "abs_max"), + STATS.fields, + ( + "min", + "max", + "mean", + "abs_max", + "abs_mean", + "std", + "rms", + "l1_norm", + "l2_norm", + "nan_count", + "inf_count", + "zero_count", + "p99_abs", + ), ) def test_reducer_names_unique(self): - names = {r.name for r in (FULL_TENSOR, ABS_MAX_ONLY, MIN_MAX_MEAN, DEFAULT_STATS)} - self.assertEqual(len(names), 4) + names = {r.name for r in (FULL_TENSOR, STATS)} + self.assertEqual(len(names), 2) + + def test_full_tensor_eager_is_identity(self): + t = torch.randn(2, 3, 4) + out = FULL_TENSOR.eager(t) + self.assertEqual(out.shape, t.shape) + torch.testing.assert_close(out, t.detach()) + + def test_stats_eager_correctness(self): + torch.manual_seed(0) + t = torch.randn(64) + out = STATS.eager(t) + self.assertEqual(out.shape, (len(STATS.fields),)) + + f = t.to(torch.float32) + expected = { + "min": float(f.amin()), + "max": float(f.amax()), + "mean": float(f.mean()), + "abs_max": float(f.abs().amax()), + "abs_mean": float(f.abs().mean()), + "rms": float(f.pow(2).mean().sqrt()), + "l1_norm": float(f.abs().sum()), + "l2_norm": float(f.pow(2).sum().sqrt()), + "nan_count": 0.0, + "inf_count": 0.0, + "zero_count": float((f == 0).to(torch.float32).sum()), + } + for i, field in enumerate(STATS.fields): + if field in expected: + torch.testing.assert_close( + float(out[i]), expected[field], rtol=1e-4, atol=1e-5 + ) + # std uses E[x^2] - E[x]^2 (population variance); compare to that. + pop_var = float((f.pow(2).mean() - f.mean().pow(2)).abs()) + torch.testing.assert_close( + float(out[STATS.fields.index("std")]) ** 2, + pop_var, + rtol=1e-4, + atol=1e-5, + ) - def test_default_stats_eager_correctness(self): - """Confirm DEFAULT_STATS spec has 4 fields (std/nan_count/inf_count excluded).""" - self.assertEqual(len(DEFAULT_STATS.fields), 4) + def test_stats_p99_abs_matches_topk(self): + torch.manual_seed(0) + t = torch.randn(1000) + out = STATS.eager(t) + numel = t.numel() + k = max(1, (numel + 99) // 100) + expected = float( + torch.topk(t.abs().reshape(-1), k=k, largest=True, sorted=True).values[ + k - 1 + ] + ) + torch.testing.assert_close( + float(out[STATS.fields.index("p99_abs")]), + expected, + rtol=1e-4, + atol=1e-5, + ) + + def test_stats_counts_nan_and_inf(self): + t = torch.tensor( + [1.0, float("nan"), 2.0, float("inf"), 0.0, -float("inf"), 0.0] + ) + out = STATS.eager(t) + i_nan = STATS.fields.index("nan_count") + i_inf = STATS.fields.index("inf_count") + i_zero = STATS.fields.index("zero_count") + self.assertEqual(float(out[i_nan]), 1.0) + self.assertEqual(float(out[i_inf]), 2.0) + self.assertEqual(float(out[i_zero]), 2.0) diff --git a/devtools/intermediate_output_tap/tests/test_selectors.py b/devtools/intermediate_output_tap/tests/test_selectors.py index 495e184de52..eafbb290916 100644 --- a/devtools/intermediate_output_tap/tests/test_selectors.py +++ b/devtools/intermediate_output_tap/tests/test_selectors.py @@ -80,9 +80,7 @@ def test_select_by_module_path(self): self.assertGreater(len(matched), 0) for n in matched: stack = n.meta.get("nn_module_stack") or {} - paths = [ - v[0] if isinstance(v, tuple) else v for v in stack.values() - ] + paths = [v[0] if isinstance(v, tuple) else v for v in stack.values()] self.assertTrue(any(p.startswith("inner") for p in paths)) def test_select_by_meta_tag_presence(self): diff --git a/devtools/intermediate_output_tap/tests/test_strip_pass.py b/devtools/intermediate_output_tap/tests/test_strip_pass.py index 2ff0d9154b1..d1ac8773a06 100644 --- a/devtools/intermediate_output_tap/tests/test_strip_pass.py +++ b/devtools/intermediate_output_tap/tests/test_strip_pass.py @@ -14,9 +14,7 @@ FULL_TENSOR, MIN_MAX_MEAN, ) -from executorch.devtools.intermediate_output_tap._selectors import ( - select_by_op_type, -) +from executorch.devtools.intermediate_output_tap._selectors import select_by_op_type from executorch.devtools.intermediate_output_tap._strip_pass import strip_taps_ from executorch.devtools.intermediate_output_tap._tap_pass import ( find_tap_nodes, diff --git a/devtools/intermediate_output_tap/tests/test_tap_pass.py b/devtools/intermediate_output_tap/tests/test_tap_pass.py index 9e01ffe61d7..d42d1d90bde 100644 --- a/devtools/intermediate_output_tap/tests/test_tap_pass.py +++ b/devtools/intermediate_output_tap/tests/test_tap_pass.py @@ -14,9 +14,7 @@ DEFAULT_STATS, FULL_TENSOR, ) -from executorch.devtools.intermediate_output_tap._selectors import ( - select_by_op_type, -) +from executorch.devtools.intermediate_output_tap._selectors import select_by_op_type from executorch.devtools.intermediate_output_tap._tap_pass import ( is_tap_node, tap_intermediate_outputs, @@ -56,7 +54,9 @@ def test_inserts_tap_per_selected_node(self): def test_appends_user_outputs(self): ep = _export() original_user_outs = sum( - 1 for s in ep.graph_signature.output_specs if s.kind == OutputKind.USER_OUTPUT + 1 + for s in ep.graph_signature.output_specs + if s.kind == OutputKind.USER_OUTPUT ) ep_t, specs = tap_intermediate_outputs( ep, @@ -73,7 +73,9 @@ def test_appends_user_outputs(self): def test_output_indices_contiguous_after_user_outputs(self): ep = _export() original_user_outs = sum( - 1 for s in ep.graph_signature.output_specs if s.kind == OutputKind.USER_OUTPUT + 1 + for s in ep.graph_signature.output_specs + if s.kind == OutputKind.USER_OUTPUT ) _, specs = tap_intermediate_outputs( ep, diff --git a/devtools/intermediate_output_tap/tests/test_xnnpack_e2e.py b/devtools/intermediate_output_tap/tests/test_xnnpack_e2e.py index 4967427a60a..54e3d48e142 100644 --- a/devtools/intermediate_output_tap/tests/test_xnnpack_e2e.py +++ b/devtools/intermediate_output_tap/tests/test_xnnpack_e2e.py @@ -18,15 +18,11 @@ import unittest import torch -from executorch.backends.xnnpack.partition.xnnpack_partitioner import ( - XnnpackPartitioner, -) +from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner from executorch.devtools.intermediate_output_tap import ( - ABS_MAX_ONLY, - DEFAULT_STATS, FULL_TENSOR, - MIN_MAX_MEAN, select_by_op_type, + STATS, strip_taps_, tap_intermediate_outputs, ) @@ -45,7 +41,9 @@ def forward(self, x): return self.l2(self.l1(x).relu()) -@unittest.skipIf(sys.platform.startswith("win"), "ExecuTorch runtime not available on Windows") +@unittest.skipIf( + sys.platform.startswith("win"), "ExecuTorch runtime not available on Windows" +) class XnnpackEndToEndTest(unittest.TestCase): def _run_pipeline(self, reducer): model = _MLP() @@ -57,9 +55,7 @@ def _run_pipeline(self, reducer): selector=select_by_op_type("aten.linear.default"), reducer=reducer, ) - edge = to_edge_transform_and_lower( - ep_t, partitioner=[XnnpackPartitioner()] - ) + edge = to_edge_transform_and_lower(ep_t, partitioner=[XnnpackPartitioner()]) strip_taps_(edge) et_program = edge.to_executorch() @@ -75,44 +71,35 @@ def _run_pipeline(self, reducer): return specs, flat_outputs, model, example_inputs def test_full_tensor_taps_match_eager(self): - specs, flat, model, example_inputs = self._run_pipeline(FULL_TENSOR) + specs, flat, _, _ = self._run_pipeline(FULL_TENSOR) self.assertEqual(len(specs), 2) # two linears - - # The user output is at index 0; tap outputs follow. + # FULL_TENSOR preserves the source tensor's shape. for spec in specs: tap_value = flat[spec.output_index] self.assertIsInstance(tap_value, torch.Tensor) - # FULL_TENSOR preserves the source tensor's shape — so e.g. for - # the first linear, shape is (batch, l1.out_features). self.assertGreater(tap_value.numel(), 0) - def test_abs_max_only_returns_scalar(self): - specs, flat, _, _ = self._run_pipeline(ABS_MAX_ONLY) - self.assertEqual(len(specs), 2) - for spec in specs: - tap_value = flat[spec.output_index] - self.assertIsInstance(tap_value, torch.Tensor) - # 0-dim scalar - self.assertEqual(tap_value.numel(), 1) - self.assertGreaterEqual(float(tap_value), 0.0) - - def test_min_max_mean_e2e(self): - specs, flat, _, _ = self._run_pipeline(MIN_MAX_MEAN) - self.assertEqual(len(specs), 2) - for spec in specs: - tap_value = flat[spec.output_index] - self.assertEqual(tap_value.numel(), 3) - - def test_default_stats_returns_seven_floats(self): - specs, flat, _, _ = self._run_pipeline(DEFAULT_STATS) + def test_stats_returns_thirteen_floats(self): + specs, flat, _, _ = self._run_pipeline(STATS) self.assertEqual(len(specs), 2) for spec in specs: tap_value = flat[spec.output_index] self.assertIsInstance(tap_value, torch.Tensor) - self.assertEqual(tap_value.numel(), 4) - mn, mx, _, abs_max = tap_value.tolist() + self.assertEqual(tap_value.numel(), len(STATS.fields)) + vals = tap_value.tolist() + field_idx = {f: i for i, f in enumerate(STATS.fields)} + mn = vals[field_idx["min"]] + mx = vals[field_idx["max"]] + abs_max = vals[field_idx["abs_max"]] + l2 = vals[field_idx["l2_norm"]] + l1 = vals[field_idx["l1_norm"]] self.assertLessEqual(mn, mx) - self.assertGreaterEqual(abs_max, max(abs(mn), abs(mx)) - 1e-5) + self.assertGreaterEqual(abs_max, max(abs(mn), abs(mx)) - 1e-3) + self.assertGreaterEqual(l1, 0.0) + self.assertGreaterEqual(l2, 0.0) + # No NaN/Inf in random fp32 — should be exactly zero. + self.assertEqual(vals[field_idx["nan_count"]], 0.0) + self.assertEqual(vals[field_idx["inf_count"]], 0.0) def test_user_outputs_still_correct(self): """Tap outputs must not corrupt the original user outputs."""