diff --git a/devtools/inspector/_inspector.py b/devtools/inspector/_inspector.py index e9fbc4778f5..969906a7c10 100644 --- a/devtools/inspector/_inspector.py +++ b/devtools/inspector/_inspector.py @@ -1649,3 +1649,129 @@ def get_stacktraces_for_row(aot_ops: List[str]) -> Dict[str, Optional[str]]: df["stacktraces"] = df["aot_ops"].apply(get_stacktraces_for_row) return df + + def calculate_numeric_gap_from_taps( + self, + flat_runtime_outputs: Sequence, + tap_specs: Sequence, + distance: Union[str, NumericalComparatorBase] = "MSE", + reference_graph: Optional[str] = None, + disable_debug_handle_valdiation: bool = False, + ) -> pd.DataFrame: + """ + Compares AOT intermediate outputs (from the ETRecord, captured by + IntermediateOutputCapturer) with runtime tap values exposed as + USER_OUTPUTs by the `intermediate_output_tap` package. + + Unlike `calculate_numeric_gap`, this method works through delegates + with no backend-side support: the runtime values come from extra + outputs the AOT pass added to the ExportedProgram before lowering. + + IMPORTANT: ETRecord serialization regenerates `debug_handle`s during + roundtrip, so the handles in `tap_specs` (set at AOT-pass time) are + stale. Each spec's `reducer_node_name` (set by + `strip_taps_(edge, tap_specs=specs)`) is used to look up the + post-roundtrip `debug_handle` in the AOT reference graph. + + Args: + flat_runtime_outputs: The flat output tuple/list returned by + running the lowered program (e.g. `Method.execute(inputs)`). + tap_specs: The list of TapSpec returned by + `strip_taps_(edge, tap_specs=specs)` — these carry + `reducer_node_name` for alignment. + distance: "MSE", "L1", "SNR", or a `NumericalComparatorBase`. + reference_graph: AOT graph to use as the golden. See + `calculate_numeric_gap` for valid values. + disable_debug_handle_valdiation: Bypass debug handle validation. + + Returns: + DataFrame with one row per (aot_handle, runtime_handle) pair. + Same shape produced by `calculate_numeric_gap`. + """ + reference_graph_module, _resolved_graph_name = self._resolve_reference_graph( + reference_graph, + disable_debug_handle_valdiation, + ) + aot_intermediate_outputs, aot_debug_handle_to_op_names = ( + self._get_aot_intermediate_outputs_and_op_names(reference_graph_module) + ) + if len(aot_intermediate_outputs) == 0: + raise ValueError( + "No AOT intermediate outputs were captured. ETRecord must be " + "provided with representative_inputs for tap-based comparison." + ) + + spec_handles = _lookup_handles_by_name(reference_graph_module, tap_specs) + + runtime_intermediate_outputs: Dict[DebugHandle, Tuple[Any, int]] = {} + runtime_debug_handle_to_op_names: Dict[DebugHandle, List[str]] = {} + for spec, dh in zip(tap_specs, spec_handles): + if dh is None: + continue + key: DebugHandle = (int(dh),) + runtime_intermediate_outputs[key] = ( + flat_runtime_outputs[spec.output_index], + 1, + ) + runtime_debug_handle_to_op_names[key] = [spec.op_target] + + if not runtime_intermediate_outputs: + raise ValueError( + "Could not recover any post-roundtrip handles for tap_specs. " + "Verify that strip_taps_(edge, tap_specs=specs) was called and " + "the returned specs were passed to this method, and that " + "generate_etrecord ran AFTER strip_taps_." + ) + + mapping = map_runtime_aot_intermediate_outputs( + aot_intermediate_outputs, runtime_intermediate_outputs + ) + + if isinstance(distance, NumericalComparatorBase): + comparator = distance + if comparator.inspector is None: + comparator.inspector = self + else: + metric = distance.strip().upper() + if metric == "MSE": + comparator = MSEComparator(inspector=self) + elif metric == "L1": + comparator = L1Comparator(inspector=self) + elif metric == "SNR": + comparator = SNRComparator(inspector=self) + else: + raise ValueError(f"Unsupported distance metric {distance!r}") + + return comparator.compare( + mapping, + aot_debug_handle_to_op_names, + runtime_debug_handle_to_op_names, + ) + + +def _lookup_handles_by_name( + reference_graph_module, + tap_specs: Sequence, +) -> List[Optional[int]]: + """ + For each TapSpec, return the post-roundtrip `debug_handle` of the FX node + whose name equals `spec.reducer_node_name`. Returns None for any spec + without a `reducer_node_name` set or whose name is not found. + + `reducer_node_name` is set by `strip_taps_(edge, tap_specs=specs)`. FX + node names survive ETRecord serialization roundtrip, so this lookup is + stable. + """ + name_to_handle: Dict[str, Optional[int]] = {} + for n in reference_graph_module.graph.nodes: + h = n.meta.get("debug_handle") + name_to_handle[n.name] = int(h) if isinstance(h, int) else None + + out: List[Optional[int]] = [] + for spec in tap_specs: + rn = getattr(spec, "reducer_node_name", None) + if rn is None: + out.append(None) + continue + out.append(name_to_handle.get(rn)) + return out diff --git a/devtools/intermediate_output_tap/TARGETS b/devtools/intermediate_output_tap/TARGETS new file mode 100644 index 00000000000..5ae2ddc8380 --- /dev/null +++ b/devtools/intermediate_output_tap/TARGETS @@ -0,0 +1,85 @@ +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +oncall("executorch") + +runtime.python_library( + name = "spec", + srcs = ["_spec.py"], +) + +runtime.python_library( + name = "custom_ops_lib", + srcs = ["custom_ops_lib.py"], + deps = [ + "//caffe2:torch", + ], +) + +runtime.python_library( + name = "selectors", + srcs = ["_selectors.py"], + deps = [ + "//caffe2:torch", + ], +) + +runtime.python_library( + name = "reducers", + srcs = ["_reducers.py"], + deps = [ + "//caffe2:torch", + "//executorch/exir/dialects:lib", + ], +) + +runtime.python_library( + name = "tap_pass", + srcs = ["_tap_pass.py"], + deps = [ + "//caffe2:torch", + ":custom_ops_lib", + ":reducers", + ":selectors", + ":spec", + ], +) + +runtime.python_library( + name = "strip_pass", + srcs = ["_strip_pass.py"], + deps = [ + "//caffe2:torch", + ":reducers", + ":tap_pass", + ], +) + +runtime.python_library( + name = "convenience", + srcs = ["_convenience.py"], + deps = [ + "fbsource//third-party/pypi/pandas:pandas", + "//caffe2:torch", + "//executorch/exir:lib", + "//executorch/runtime:runtime", + ":reducers", + ":selectors", + ":spec", + ":strip_pass", + ":tap_pass", + ], +) + +runtime.python_library( + name = "lib", + srcs = ["__init__.py"], + deps = [ + ":convenience", + ":custom_ops_lib", + ":reducers", + ":selectors", + ":spec", + ":strip_pass", + ":tap_pass", + ], +) diff --git a/devtools/intermediate_output_tap/__init__.py b/devtools/intermediate_output_tap/__init__.py new file mode 100644 index 00000000000..870e43f1375 --- /dev/null +++ b/devtools/intermediate_output_tap/__init__.py @@ -0,0 +1,96 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +""" +Public API for the ExecuTorch numerical debugger. + +Backend-agnostic intermediate-value tap that complements the existing +Inspector framework: + +- AOT side : `IntermediateOutputCapturer` (existing) +- Runtime side : ETDump intermediate output events (existing, opaque inside delegates) +- Runtime side : USER_OUTPUT taps (this module — works through delegates without + any backend-side changes) + +Typical usage: + + from executorch.devtools.intermediate_output_tap import ( + tap_intermediate_outputs, strip_taps_, STATS, + ) + + ep = export(model, example_inputs) + ep_tapped, specs = tap_intermediate_outputs(ep, reducer=STATS) + edge = to_edge_transform_and_lower(ep_tapped, partitioner=[XnnpackPartitioner()]) + strip_taps_(edge) + et_program = edge.to_executorch() + + flat_outputs = runtime.forward(*example_inputs) + df = inspector.calculate_numeric_gap_from_taps(flat_outputs, specs) +""" + +# Importing this module registers torch.ops.executorch_devtools.tap.Tensor. +from executorch.devtools.intermediate_output_tap import custom_ops_lib # noqa: F401 +from executorch.devtools.intermediate_output_tap._convenience import ( + compare_aot_runtime_dataframe, + format_tap_dataframe, + specs_to_dataframe, + tap_all_and_run, +) +from executorch.devtools.intermediate_output_tap._reducers import ( + FULL_TENSOR, + get_reducer, + StatReducer, + STATS, +) +from executorch.devtools.intermediate_output_tap._selectors import ( + NodeSelector, + select_all, + select_all_call_function, + select_any, + select_by_meta_tag, + select_by_module_path, + select_by_op_type, + select_not, +) +from executorch.devtools.intermediate_output_tap._spec import TapSpec +from executorch.devtools.intermediate_output_tap._strip_pass import strip_taps_ +from executorch.devtools.intermediate_output_tap._tap_pass import ( + find_tap_nodes, + is_tap_node, + tap_intermediate_outputs, +) + + +__all__ = [ + # Core API + "tap_intermediate_outputs", + "strip_taps_", + "TapSpec", + # Convenience + "tap_all_and_run", + "specs_to_dataframe", + "format_tap_dataframe", + "compare_aot_runtime_dataframe", + # Reducers + "StatReducer", + "FULL_TENSOR", + "STATS", + "get_reducer", + # Selectors + "NodeSelector", + "select_all_call_function", + "select_by_op_type", + "select_by_module_path", + "select_by_meta_tag", + "select_any", + "select_all", + "select_not", + # Helpers + "find_tap_nodes", + "is_tap_node", +] diff --git a/devtools/intermediate_output_tap/_convenience.py b/devtools/intermediate_output_tap/_convenience.py new file mode 100644 index 00000000000..c95dacc8add --- /dev/null +++ b/devtools/intermediate_output_tap/_convenience.py @@ -0,0 +1,304 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +""" +Convenience helpers built on top of `tap_intermediate_outputs` / `strip_taps_`: + +* `tap_all_and_run`: one-shot smoke-test wrapper that exports a model, taps + every call_function, lowers with the user's partitioner, runs through the + ExecuTorch runtime, and returns a per-tap DataFrame. +* `specs_to_dataframe`: build a per-tap DataFrame from a tap_specs list and + the runtime's flat output tuple. +* `compare_aot_runtime_dataframe`: side-by-side AOT-vs-runtime DataFrame from + the flat outputs of the *tapped* ExportedProgram (eager) and the post-strip + runtime program. The simpler alternative to `Inspector.calculate_numeric_gap_from_taps` + when you don't need ETDump/ETRecord plumbing. +* `format_tap_dataframe`: reshape the raw DataFrame returned by + `Inspector.calculate_numeric_gap_from_taps` into a friendlier per-tap view. +""" + +from __future__ import annotations + +import os +import tempfile +from collections.abc import Sequence +from typing import Any + +import pandas as pd +import torch +from executorch.devtools.intermediate_output_tap._reducers import StatReducer +from executorch.devtools.intermediate_output_tap._selectors import ( + NodeSelector, + select_all_call_function, +) +from executorch.devtools.intermediate_output_tap._spec import TapSpec +from executorch.devtools.intermediate_output_tap._strip_pass import strip_taps_ +from executorch.devtools.intermediate_output_tap._tap_pass import ( + tap_intermediate_outputs, +) + + +def tap_all_and_run( + model: torch.nn.Module, + example_inputs: tuple[Any, ...], + partitioner: list | None = None, + reducer: str | StatReducer = "STATS", + selector: NodeSelector | None = None, + skip_if_no_debug_handle: bool = True, +) -> pd.DataFrame: + """ + Export -> tap -> lower -> strip -> to_executorch -> run -> DataFrame. + + Returns a DataFrame indexed by tap with columns: + node_name, op_target, debug_handle, output_index, reducer_name, plus + one column per reducer field (or `value` for FULL_TENSOR). + """ + from executorch.exir import to_edge_transform_and_lower + + selector = selector or select_all_call_function() + ep = torch.export.export(model, example_inputs, strict=True) + ep_tapped, specs = tap_intermediate_outputs( + ep, + selector=selector, + reducer=reducer, + skip_if_no_debug_handle=skip_if_no_debug_handle, + ) + edge = to_edge_transform_and_lower(ep_tapped, partitioner=partitioner or []) + strip_taps_(edge) + et_program = edge.to_executorch() + + flat_outputs = _run_pte(et_program, example_inputs) + return specs_to_dataframe(specs, flat_outputs) + + +def _run_pte(et_program, example_inputs: tuple[Any, ...]) -> Sequence[Any]: + from executorch.runtime import Runtime, Verification + + with tempfile.TemporaryDirectory() as temp_dir: + pte_path = os.path.join(temp_dir, "model.pte") + et_program.save(pte_path) + rt = Runtime.get() + program = rt.load_program(pte_path, verification=Verification.Minimal) + method = program.load_method("forward") + return method.execute(example_inputs) + + +def specs_to_dataframe( + specs: Sequence[TapSpec], + flat_outputs: Sequence[Any], +) -> pd.DataFrame: + """Build a per-tap DataFrame from the tap_specs + flat output tuple.""" + rows = [] + for spec in specs: + runtime_value = flat_outputs[spec.output_index] + row: dict[str, Any] = { + "node_name": spec.node_name, + "op_target": spec.op_target, + "debug_handle": spec.debug_handle, + "output_index": spec.output_index, + "reducer_name": spec.reducer_name, + } + if spec.fields: + tensor_vals = ( + runtime_value.detach().cpu().tolist() + if isinstance(runtime_value, torch.Tensor) + else list(runtime_value) + ) + for i, field in enumerate(spec.fields): + row[field] = tensor_vals[i] if i < len(tensor_vals) else None + else: + row["value"] = runtime_value + rows.append(row) + return pd.DataFrame(rows) + + +def format_tap_dataframe( + df: pd.DataFrame, + tap_specs: Sequence[TapSpec], +) -> pd.DataFrame: + """ + Reshape the raw DataFrame returned by + `Inspector.calculate_numeric_gap_from_taps` into a friendlier per-tap, + per-field view. + + The raw DataFrame uses the existing Inspector comparator format, which + packs the reducer's stat tensor into a list of 0-d tensors and labels + rows by the post-strip reducer node name (e.g. `aten_stack_default`). + This helper: + - matches each raw row to a TapSpec (by reducer_node_name) + - renames `aot_ops`/`runtime_ops` columns to a single `node_name` (the + original source node name, e.g. `linear`, `linear_1`) + - expands the reducer stat tensor into one column per field + (e.g. `aot_min`, `rt_min`, `aot_max`, `rt_max`, ...) + - flattens the gap to a single float + - drops the verbose `aot_intermediate_output` / `runtime_intermediate_output` + list columns + + Returns a DataFrame with columns: + node_name, op_target, reducer_name, + gap, + aot_, rt_, aot_, rt_, ... + """ + # Map reducer_node_name -> spec for quick lookup. + name_to_spec: dict[str, TapSpec] = { + s.reducer_node_name: s for s in tap_specs if s.reducer_node_name is not None + } + + rows = [] + for _, row in df.iterrows(): + aot_ops = row.get("aot_ops", []) + spec = None + for op in aot_ops or []: + if op in name_to_spec: + spec = name_to_spec[op] + break + if spec is None: + # Couldn't match — keep a thin row with whatever we have. + rows.append( + { + "node_name": ",".join(aot_ops or []), + "op_target": "?", + "reducer_name": "?", + "gap": _flatten_gap(row.get("gap")), + } + ) + continue + + new_row: dict[str, Any] = { + "node_name": spec.node_name, + "op_target": spec.op_target, + "reducer_name": spec.reducer_name, + "gap": _flatten_gap(row.get("gap")), + } + aot_vals = _to_float_list(row.get("aot_intermediate_output")) + rt_vals = _to_float_list(row.get("runtime_intermediate_output")) + for i, field in enumerate(spec.fields): + new_row[f"aot_{field}"] = aot_vals[i] if i < len(aot_vals) else None + new_row[f"rt_{field}"] = rt_vals[i] if i < len(rt_vals) else None + if not spec.fields: # FULL_TENSOR + new_row["aot_value"] = row.get("aot_intermediate_output") + new_row["rt_value"] = row.get("runtime_intermediate_output") + rows.append(new_row) + return pd.DataFrame(rows) + + +def _flatten_gap(g: Any) -> float | None: + if g is None: + return None + if isinstance(g, list): + if not g: + return None + g = g[0] + if isinstance(g, torch.Tensor): + return float(g) + try: + return float(g) + except (TypeError, ValueError): + return None + + +def _to_float_list(v: Any) -> list[float]: + if v is None: + return [] + if isinstance(v, torch.Tensor): + return v.detach().cpu().tolist() + if isinstance(v, list): + out: list[float] = [] + for x in v: + if isinstance(x, torch.Tensor): + out.append(float(x)) + else: + try: + out.append(float(x)) + except (TypeError, ValueError): + out.append(float("nan")) + return out + return [] + + +def _flat_floats(v: Any) -> list[float]: + """Flatten a tap value (tensor / list / scalar) to a flat list of floats.""" + if isinstance(v, torch.Tensor): + return [ + float(x) for x in v.detach().to(torch.float32).cpu().reshape(-1).tolist() + ] + if isinstance(v, (list, tuple)): + out: list[float] = [] + for x in v: + out.extend(_flat_floats(x)) + return out + try: + return [float(v)] + except (TypeError, ValueError): + return [] + + +def _sqnr_db(aot_vals: list[float], rt_vals: list[float]) -> float: + """Signal-to-quantization-noise ratio in dB. Higher is better. + + Thin wrapper around `torch.ao.ns.fx.utils.compute_sqnr` (the canonical + implementation already used by `backends/test/harness/error_statistics.py`). + """ + from torch.ao.ns.fx.utils import compute_sqnr + + n = min(len(aot_vals), len(rt_vals)) + if n == 0: + return float("nan") + aot_t = torch.tensor(aot_vals[:n], dtype=torch.float32) + rt_t = torch.tensor(rt_vals[:n], dtype=torch.float32) + return float(compute_sqnr(rt_t, aot_t)) + + +def compare_aot_runtime_dataframe( + specs: Sequence[TapSpec], + aot_flat: Sequence[Any], + rt_flat: Sequence[Any], +) -> pd.DataFrame: + """ + Build a side-by-side AOT-vs-runtime DataFrame from the flat outputs of + the *tapped* ExportedProgram (eager) and the post-strip runtime program. + + Both `aot_flat[spec.output_index]` and `rt_flat[spec.output_index]` already + contain the *reduced* tap value, since `tap.Tensor`'s eager impl applies + the named reducer (see `custom_ops_lib.py`). + + Output columns per spec: + - For non-FULL_TENSOR reducers: one `aot_` and `rt_` per + reducer field (e.g. `aot_min`, `rt_min`, ...). + - For FULL_TENSOR: `sqnr_db` (signal-to-noise of aot vs rt over the + whole tensor, in dB) + """ + rows: list[dict[str, Any]] = [] + for spec in specs: + aot_vals = _flat_floats(aot_flat[spec.output_index]) + rt_vals = _flat_floats(rt_flat[spec.output_index]) + + row: dict[str, Any] = { + "node_name": spec.node_name, + "module_path": spec.module_path, + "op_target": spec.op_target, + "reducer_name": spec.reducer_name, + "output_index": spec.output_index, + } + + if spec.reducer_name == "FULL_TENSOR": + row["sqnr_db"] = _sqnr_db(aot_vals, rt_vals) + row["aot_numel"] = len(aot_vals) + row["rt_numel"] = len(rt_vals) + else: + fields = ( + list(spec.fields) + if spec.fields + else [f"v{i}" for i in range(max(len(aot_vals), len(rt_vals)))] + ) + for i, f in enumerate(fields): + row[f"aot_{f}"] = aot_vals[i] if i < len(aot_vals) else float("nan") + row[f"rt_{f}"] = rt_vals[i] if i < len(rt_vals) else float("nan") + + rows.append(row) + return pd.DataFrame(rows) diff --git a/devtools/intermediate_output_tap/_reducers.py b/devtools/intermediate_output_tap/_reducers.py new file mode 100644 index 00000000000..481dfb2f55e --- /dev/null +++ b/devtools/intermediate_output_tap/_reducers.py @@ -0,0 +1,319 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +""" +Stat reducers used by `tap_intermediate_outputs`. + +A `StatReducer` is a small specification consumed by `strip_taps_` (after +`to_backend`) to materialise a portable reducer subgraph in place of the +`executorch_devtools::tap.Tensor` placeholder. + +`emit(graph, src_node) -> fx.Node` builds the reducer subgraph in `graph` +just before the placeholder, using the source tensor `src_node` as input, +and returns the final node whose output replaces the placeholder's output. + +`eager(tensor) -> tensor` is the pure-torch equivalent that callers can use +to reproduce, in eager mode, what the runtime will compute. `tap.Tensor`'s +own dispatch impl uses this to produce the *reduced* value at AOT time, so +that `ep.module()(*inputs)` returns the same flat outputs as the runtime. + +The emit functions cast to fp32 first for cross-backend numerical stability +and produce a fixed-shape output (0-D or 1-D) regardless of the source +tensor's rank, so callers don't need to track per-tap shapes. + +We ship two built-ins: + +* `FULL_TENSOR` — identity. The whole source tensor is surfaced. +* `STATS` — a comprehensive bundle of debugging-friendly scalars: + min, max, mean, abs_max, abs_mean, std, rms, l1_norm, l2_norm, + nan_count, inf_count, zero_count, p99_abs. +""" + +from __future__ import annotations + +import operator +from collections.abc import Callable +from dataclasses import dataclass +from typing import TYPE_CHECKING + +import torch +from executorch.exir.dialects._ops import ops as exir_ops + + +if TYPE_CHECKING: + import torch.fx as fx + + +# --- Reducer dataclass --------------------------------------------------- + +EmitFn = Callable[["fx.Graph", "fx.Node"], "fx.Node"] +EagerFn = Callable[[torch.Tensor], torch.Tensor] + + +@dataclass(frozen=True) +class StatReducer: + """ + A reducer specification. + + `emit` is invoked by `strip_taps_` to materialise the reducer subgraph + in the post-lowering FX graph. `eager` is the equivalent pure-torch + implementation, used by callers that want to reproduce what the runtime + will compute (e.g. AOT-vs-runtime comparisons without a debugger). + + `name` is what the user types and what's stored on each TapSpec. + `fields` enumerates the columns of the 1-D output tensor (empty for + FULL_TENSOR which preserves a tensor of values). + """ + + name: str + fields: tuple[str, ...] + emit: EmitFn + eager: EagerFn + + +# --- Helpers ------------------------------------------------------------- + + +def _cast_fp32(graph: "fx.Graph", x: "fx.Node") -> "fx.Node": + """Insert a fp32 cast (no-op semantically if already fp32).""" + return graph.call_function( + exir_ops.edge.aten._to_copy.default, + args=(x,), + kwargs={"dtype": torch.float32}, + ) + + +def _scalar_node(graph: "fx.Graph", op, x: "fx.Node") -> "fx.Node": + """Call a full-reduction op (amin/amax/mean/sum) producing a 0-d tensor.""" + return graph.call_function(op, args=(x,)) + + +def _stack(graph: "fx.Graph", scalars: list["fx.Node"]) -> "fx.Node": + """Stack a list of 0-d tensors into a 1-D tensor.""" + return graph.call_function( + exir_ops.edge.aten.stack.default, + args=(scalars,), + kwargs={"dim": 0}, + ) + + +def _abs(graph: "fx.Graph", x: "fx.Node") -> "fx.Node": + return graph.call_function(exir_ops.edge.aten.abs.default, args=(x,)) + + +def _square(graph: "fx.Graph", x: "fx.Node") -> "fx.Node": + return graph.call_function(exir_ops.edge.aten.pow.Tensor_Scalar, args=(x, 2.0)) + + +def _sqrt(graph: "fx.Graph", x: "fx.Node") -> "fx.Node": + return graph.call_function(exir_ops.edge.aten.sqrt.default, args=(x,)) + + +def _full_sum(graph: "fx.Graph", x: "fx.Node") -> "fx.Node": + """Full-tensor sum via aten.sum.dim_IntList(dim=[]) — portable + has out variant.""" + return graph.call_function(exir_ops.edge.aten.sum.dim_IntList, args=(x, [])) + + +def _bool_to_fp32_count(graph: "fx.Graph", mask: "fx.Node") -> "fx.Node": + """Sum of a bool mask cast to fp32 → a 0-d fp32 count.""" + casted = graph.call_function( + exir_ops.edge.aten._to_copy.default, + args=(mask,), + kwargs={"dtype": torch.float32}, + ) + return _full_sum(graph, casted) + + +# --- FULL_TENSOR --------------------------------------------------------- + + +def _emit_full_tensor(_graph: "fx.Graph", src: "fx.Node") -> "fx.Node": + return src + + +def _eager_full_tensor(t: torch.Tensor) -> torch.Tensor: + return t.detach() + + +FULL_TENSOR: StatReducer = StatReducer( + name="FULL_TENSOR", + fields=(), + emit=_emit_full_tensor, + eager=_eager_full_tensor, +) + + +# --- STATS --------------------------------------------------------------- + + +_STATS_FIELDS: tuple[str, ...] = ( + "min", + "max", + "mean", + "abs_max", + "abs_mean", + "std", + "rms", + "l1_norm", + "l2_norm", + "nan_count", + "inf_count", + "zero_count", + "p99_abs", +) + + +def _emit_stats(graph: "fx.Graph", src: "fx.Node") -> "fx.Node": + f = _cast_fp32(graph, src) + abs_f = _abs(graph, f) + sq_f = _square(graph, f) + + mn = _scalar_node(graph, exir_ops.edge.aten.amin.default, f) + mx = _scalar_node(graph, exir_ops.edge.aten.amax.default, f) + me = _scalar_node(graph, exir_ops.edge.aten.mean.default, f) + abs_max = _scalar_node(graph, exir_ops.edge.aten.amax.default, abs_f) + abs_mean = _scalar_node(graph, exir_ops.edge.aten.mean.default, abs_f) + + sum_sq = _full_sum(graph, sq_f) + mean_sq = _scalar_node(graph, exir_ops.edge.aten.mean.default, sq_f) + rms = _sqrt(graph, mean_sq) + + # std = sqrt( E[x^2] - E[x]^2 ); avoids aten.var which lacks an out variant. + me_sq_scalar = graph.call_function( + exir_ops.edge.aten.pow.Tensor_Scalar, args=(me, 2.0) + ) + var = graph.call_function( + exir_ops.edge.aten.sub.Tensor, args=(mean_sq, me_sq_scalar) + ) + # Variance can be slightly negative due to fp roundoff; clamp at 0 via abs. + var = graph.call_function(exir_ops.edge.aten.abs.default, args=(var,)) + std = _sqrt(graph, var) + + l1 = _full_sum(graph, abs_f) + l2 = _sqrt(graph, sum_sq) + + nan_mask = graph.call_function(exir_ops.edge.aten.isnan.default, args=(f,)) + nan_count = _bool_to_fp32_count(graph, nan_mask) + + inf_mask = graph.call_function(exir_ops.edge.aten.isinf.default, args=(f,)) + inf_count = _bool_to_fp32_count(graph, inf_mask) + + zero_mask = graph.call_function(exir_ops.edge.aten.eq.Scalar, args=(f, 0.0)) + zero_count = _bool_to_fp32_count(graph, zero_mask) + + # p99_abs: use topk on flattened |x| to get the k-th largest, where + # k = max(1, ceil(numel * 0.01)). Numel is read from the source's + # FakeTensor at graph-build time. + fake = src.meta.get("val") + numel = int(fake.numel()) if fake is not None else 1 + k = max(1, (numel + 99) // 100) # ceil(numel/100) + abs_flat = graph.call_function( + exir_ops.edge.aten.view_copy.default, args=(abs_f, [-1]) + ) + topk_out = graph.call_function( + exir_ops.edge.aten.topk.default, + args=(abs_flat, k), + kwargs={"dim": -1, "largest": True, "sorted": True}, + ) + topk_values = graph.call_function(operator.getitem, args=(topk_out, 0)) + p99_abs = graph.call_function( + exir_ops.edge.aten.select_copy.int, args=(topk_values, 0, k - 1) + ) + + return _stack( + graph, + [ + mn, + mx, + me, + abs_max, + abs_mean, + std, + rms, + l1, + l2, + nan_count, + inf_count, + zero_count, + p99_abs, + ], + ) + + +def _eager_stats(t: torch.Tensor) -> torch.Tensor: + f = t.detach().to(torch.float32) + abs_f = f.abs() + sq = f.pow(2.0) + + # std via E[x^2] - E[x]^2 (population variance) to match the emit subgraph. + if f.numel() > 0: + var = (sq.mean() - f.mean().pow(2)).abs() + std = var.sqrt() + else: + std = torch.tensor(0.0) + + sum_sq = sq.sum() + rms = sq.mean().sqrt() + l1 = abs_f.sum() + l2 = sum_sq.sqrt() + + nan_count = torch.isnan(f).to(torch.float32).sum() + inf_count = torch.isinf(f).to(torch.float32).sum() + zero_count = (f == 0).to(torch.float32).sum() + + numel = f.numel() + k = max(1, (numel + 99) // 100) + if numel > 0: + topk_vals = torch.topk(abs_f.reshape(-1), k=k, largest=True, sorted=True).values + p99_abs = topk_vals[k - 1] + else: + p99_abs = torch.tensor(float("nan")) + + return torch.stack( + [ + f.amin(), + f.amax(), + f.mean(), + abs_f.amax(), + abs_f.mean(), + std, + rms, + l1, + l2, + nan_count, + inf_count, + zero_count, + p99_abs, + ], + dim=0, + ) + + +STATS: StatReducer = StatReducer( + name="STATS", + fields=_STATS_FIELDS, + emit=_emit_stats, + eager=_eager_stats, +) + + +# --- Registry ------------------------------------------------------------- + +_BUILTIN_REDUCERS: dict[str, StatReducer] = {r.name: r for r in (FULL_TENSOR, STATS)} + + +def get_reducer(name_or_reducer: str | StatReducer) -> StatReducer: + """Look up a built-in by name, or return a user-supplied StatReducer as-is.""" + if isinstance(name_or_reducer, StatReducer): + return name_or_reducer + if name_or_reducer not in _BUILTIN_REDUCERS: + raise ValueError( + f"Unknown reducer {name_or_reducer!r}; " + f"available: {sorted(_BUILTIN_REDUCERS)}" + ) + return _BUILTIN_REDUCERS[name_or_reducer] diff --git a/devtools/intermediate_output_tap/_selectors.py b/devtools/intermediate_output_tap/_selectors.py new file mode 100644 index 00000000000..004c057df13 --- /dev/null +++ b/devtools/intermediate_output_tap/_selectors.py @@ -0,0 +1,132 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +""" +Predicates for selecting which FX nodes to tap. + +A `NodeSelector` is just `Callable[[fx.Node], bool]`. The provided builders +let you compose them by op type, by `nn_module_stack` path, by arbitrary meta +tag, and via boolean combinators. + +Examples: + selector = select_any( + select_by_op_type("aten.linear.default", "aten.matmul.default"), + select_by_module_path("layers.*.attention"), + ) + selector = select_all(selector, select_not(select_by_op_type("aten.view.default"))) +""" + +from __future__ import annotations + +import fnmatch +from collections.abc import Callable +from typing import Any + +import torch.fx as fx + + +NodeSelector = Callable[[fx.Node], bool] + + +def select_all_call_function( + exclude: tuple[str, ...] = ("getitem",), +) -> NodeSelector: + """Match every `call_function` node whose target name is not in `exclude`.""" + excluded = set(exclude) + + def predicate(n: fx.Node) -> bool: + if n.op != "call_function": + return False + target_name = getattr(n.target, "__name__", str(n.target)) + # `getitem` shows up as the builtin name; also normalise common aten suffixes. + return target_name not in excluded and str(n.target) not in excluded + + return predicate + + +def select_by_op_type(*op_targets: str) -> NodeSelector: + """ + Match nodes whose `str(node.target)` ends with any of `op_targets`. + + The "ends with" rule lets the user write either the short name + ("aten.linear.default") or a fully-qualified name and have it match. + """ + if not op_targets: + raise ValueError("select_by_op_type requires at least one op target") + suffixes = tuple(op_targets) + + def predicate(n: fx.Node) -> bool: + if n.op != "call_function": + return False + target_str = str(n.target) + return any(target_str.endswith(s) or target_str == s for s in suffixes) + + return predicate + + +def select_by_module_path(pattern: str) -> NodeSelector: + """ + Match nodes whose `nn_module_stack` (the chain of nn.Module attribute names + walked to reach this op during tracing) contains a path matching `pattern`. + `pattern` is a shell-glob (fnmatch) — e.g. "layers.*", "layers.0.attention", + "*.attention.*". + """ + + def predicate(n: fx.Node) -> bool: + stack = n.meta.get("nn_module_stack") + if not stack: + return False + # nn_module_stack is an OrderedDict: id -> (qualified_path, module_type) + for entry in stack.values(): + path = entry[0] if isinstance(entry, tuple) else entry + if fnmatch.fnmatchcase(path, pattern): + return True + return False + + return predicate + + +# Sentinel: matches when the meta key exists at all, regardless of value. +_ANY_VALUE: object = object() + + +def select_by_meta_tag(key: str, value: Any = _ANY_VALUE) -> NodeSelector: + """ + Match nodes that carry `node.meta[key]`. If `value` is provided, also + requires `node.meta[key] == value`. + """ + + def predicate(n: fx.Node) -> bool: + if key not in n.meta: + return False + if value is _ANY_VALUE: + return True + return n.meta[key] == value + + return predicate + + +def select_any(*selectors: NodeSelector) -> NodeSelector: + """Match if ANY of `selectors` matches.""" + if not selectors: + return lambda _n: False + sels = tuple(selectors) + return lambda n: any(s(n) for s in sels) + + +def select_all(*selectors: NodeSelector) -> NodeSelector: + """Match if ALL of `selectors` match.""" + if not selectors: + return lambda _n: True + sels = tuple(selectors) + return lambda n: all(s(n) for s in sels) + + +def select_not(selector: NodeSelector) -> NodeSelector: + """Match if `selector` does NOT match.""" + return lambda n: not selector(n) diff --git a/devtools/intermediate_output_tap/_spec.py b/devtools/intermediate_output_tap/_spec.py new file mode 100644 index 00000000000..db035b7c17e --- /dev/null +++ b/devtools/intermediate_output_tap/_spec.py @@ -0,0 +1,61 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +""" +TapSpec records one tap inserted by `tap_intermediate_outputs(...)`. + +A list of TapSpecs is returned to the user from the AOT pass; the user passes +that same list to `Inspector.calculate_numeric_gap_from_taps(...)` at runtime +to demux the flat output tuple back into per-op intermediate values. +""" + +from __future__ import annotations + +from dataclasses import dataclass + + +@dataclass(frozen=True) +class TapSpec: + """ + Metadata about a single tap. + + Attributes: + node_name: The FX node name of the *source* node (the tapped op) at the + time the AOT pass ran. Useful for debugging / pretty-printing. + op_target: `str(node.target)` of the source node, e.g. + "aten.linear.default". + debug_handle: `node.meta["debug_handle"]` of the source node, or None + if the source had no debug handle. Set at AOT-pass time. NOT used + by the Inspector integration directly — the serializer regenerates + debug_handles, so Inspector aligns by `reducer_node_name` instead. + output_index: 0-based index into the runtime program's flat output + tuple where this tap's value lands. Computed at AOT time and stable + through `to_edge` / `to_backend` / `to_executorch` because we only + ever *append* to the output node and `OutputSpec`. + reducer_name: Name of the StatReducer used (e.g. "STATS"). + fields: Names of the per-element fields in the reducer's output tensor + (e.g. ("min", "max", "abs_max")). Empty tuple for FULL_TENSOR. + stack_trace: `node.meta["stack_trace"]` of the source node if present, + for human-readable error messages. + reducer_node_name: The FX node name of the post-strip reducer terminal + node — i.e., the node whose value is surfaced as the runtime tap + output. Populated by `strip_taps_` when `tap_specs` is passed. + FX node names survive ETRecord serialization roundtrip, so this + is the stable bridge `Inspector.calculate_numeric_gap_from_taps` + uses to find the post-roundtrip handle for alignment. + """ + + node_name: str + op_target: str + debug_handle: int | None + output_index: int + reducer_name: str + fields: tuple[str, ...] + stack_trace: str | None = None + reducer_node_name: str | None = None + module_path: str | None = None diff --git a/devtools/intermediate_output_tap/_strip_pass.py b/devtools/intermediate_output_tap/_strip_pass.py new file mode 100644 index 00000000000..1b1d58040b4 --- /dev/null +++ b/devtools/intermediate_output_tap/_strip_pass.py @@ -0,0 +1,162 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +""" +Post-`to_backend` pass: replace each `executorch_devtools::tap.Tensor` node +with either an identity edge (FULL_TENSOR) or a portable reducer subgraph +(STATS, or any user-supplied StatReducer). + +Pattern stolen from `remove_graph_break_` in +`executorch/examples/apple/coreml/llama/export_static_llm_coreml.py`. + +This pass MUST run *after* `to_edge_transform_and_lower(...)` and *before* +`to_executorch()`. Running it before partitioning would defeat the whole +mechanism (the reducer ops would be eligible for delegation). + +When called with the `tap_specs` from `tap_intermediate_outputs`, this pass +also populates `TapSpec.reducer_node_name` for each spec — the FX node name +of the post-strip reducer terminal. This is the bridge +`Inspector.calculate_numeric_gap_from_taps` uses to recover the +post-ETRecord-roundtrip `debug_handle` for alignment. +""" + +from __future__ import annotations + +from collections.abc import Sequence +from dataclasses import replace as _dataclass_replace + +import torch.fx as fx +from executorch.devtools.intermediate_output_tap._reducers import get_reducer +from executorch.devtools.intermediate_output_tap._spec import TapSpec +from executorch.devtools.intermediate_output_tap._tap_pass import find_tap_nodes + + +def strip_taps_( + edge_manager, + tap_specs: Sequence[TapSpec] | None = None, +) -> list[TapSpec] | None: + """ + Replace every `tap.Tensor(src, reducer_name, debug_handle)` node in every + method of `edge_manager` with the materialised reducer subgraph, in place. + + For FULL_TENSOR the placeholder is collapsed (the source node's value + flows directly to whatever consumed the placeholder). + + Args: + edge_manager: An EdgeProgramManager (post-`to_edge_transform_and_lower`). + tap_specs: Optional. If provided, the pass returns a NEW list of + TapSpecs with `reducer_node_name` populated for each spec — the + FX name of the post-strip reducer terminal node. This list must + be passed to `Inspector.calculate_numeric_gap_from_taps` for + alignment to work. + + Returns: + Updated tap_specs list if `tap_specs` was provided, else None. + """ + # Walk in graph order; tap nodes appear in the same order they were + # created by `tap_intermediate_outputs`, which is the same order as + # `tap_specs`. Track each tap's replacement node so we can update the + # corresponding spec. + replacement_names: list[str | None] = [] + for method_name in edge_manager.methods: + ep = edge_manager.exported_program(method_name) + gm = ep.graph_module + for replacement_node in _strip_taps_in_graph_module(gm): + replacement_names.append( + replacement_node.name if replacement_node is not None else None + ) + + if tap_specs is None: + return None + + if len(tap_specs) != len(replacement_names): + raise RuntimeError( + f"strip_taps_: tap_specs length ({len(tap_specs)}) does not match " + f"the number of tap nodes found in the edge_manager " + f"({len(replacement_names)}). The strip pass cannot align specs " + f"to reducer nodes. Did you call strip_taps_ on a different " + f"edge_manager than the one produced from the tapped EP?" + ) + + return [ + _dataclass_replace(spec, reducer_node_name=name) + for spec, name in zip(tap_specs, replacement_names) + ] + + +def _strip_taps_in_graph_module(gm: fx.GraphModule) -> list[fx.Node | None]: + """ + Strip taps in a single GraphModule. Returns the list of replacement nodes + in tap-creation order (same as graph order). For FULL_TENSOR taps the + "replacement" is the source node itself (since the tap collapses to + identity). + """ + graph = gm.graph + tap_nodes = find_tap_nodes(gm) + if not tap_nodes: + return [] + + output_node = graph.output_node() + replacements: list[fx.Node | None] = [] + + # Compute next available debug_handle so each reducer terminal gets a + # unique one (necessary so Inspector can look it up by node name and find + # a non-None handle in the post-roundtrip graph). + existing_handles = [ + n.meta.get("debug_handle") + for n in graph.nodes + if isinstance(n.meta.get("debug_handle"), int) + ] + next_handle = (max(existing_handles) + 1) if existing_handles else 1 + + for tap in tap_nodes: + # tap.args = (src_node, reducer_name, debug_handle) + src, reducer_name, dh = tap.args[0], tap.args[1], tap.args[2] + reducer = get_reducer(str(reducer_name)) + + if reducer.name == "FULL_TENSOR": + # Identity: re-route all consumers to the source. The "reducer + # terminal" is the source itself. + tap.replace_all_uses_with(src) + replacements.append(src if isinstance(src, fx.Node) else None) + continue + + # Build the reducer subgraph (reads from src). + with graph.inserting_before(tap): + replacement = reducer.emit(graph, src) + # Always assign a debug_handle to the reducer terminal so Inspector + # can find it post-roundtrip. Prefer the source's pre-tap handle if + # available (carries semantic meaning); otherwise use next_handle. + if dh: + replacement.meta["debug_handle"] = dh + else: + replacement.meta["debug_handle"] = next_handle + next_handle += 1 + replacement.meta["is_tap"] = True + replacement.meta["source_node"] = src.name if isinstance(src, fx.Node) else None + + # `tap` may have ended up in the data path during to_edge's re-trace + # (because CompositeExplicitAutograd preserves the op as an identity + # node, and re-traced consumers point at it instead of `src`). So: + # - the OUTPUT-node use becomes the reducer (the value we want + # surfaced as a tap). + # - every OTHER use is rewritten back to `src` (identity passthrough), + # restoring the original data path. + for use_node in list(tap.users.keys()): + if use_node is output_node: + new_outs = tuple( + replacement if a is tap else a for a in output_node.args[0] + ) + output_node.args = (new_outs,) + else: + use_node.replace_input_with(tap, src) + replacements.append(replacement) + + graph.eliminate_dead_code() + gm.recompile() + return replacements diff --git a/devtools/intermediate_output_tap/_tap_pass.py b/devtools/intermediate_output_tap/_tap_pass.py new file mode 100644 index 00000000000..f94c30d4c9a --- /dev/null +++ b/devtools/intermediate_output_tap/_tap_pass.py @@ -0,0 +1,276 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +""" +AOT pass: insert `tap.Tensor` placeholders after selected nodes and surface +them as additional USER_OUTPUTs of the ExportedProgram. + +Pattern stolen from `executorch/exir/passes/weights_to_outputs_pass.py`: +- find existing output node +- build new output args (existing + new tap nodes) +- create new output node, replace_all_uses_with, erase old +- append OutputSpec(USER_OUTPUT) entries to gs.output_specs +- eliminate_dead_code() + recompile() +""" + +from __future__ import annotations + +import copy +import warnings +from collections.abc import Callable + +import torch +import torch.fx as fx +from executorch.devtools.intermediate_output_tap import ( # noqa: F401 registers tap.Tensor + custom_ops_lib, +) +from executorch.devtools.intermediate_output_tap._reducers import ( + get_reducer, + StatReducer, + STATS, +) +from executorch.devtools.intermediate_output_tap._selectors import ( + NodeSelector, + select_all_call_function, +) +from executorch.devtools.intermediate_output_tap._spec import TapSpec +from torch.export import ExportedProgram +from torch.export.exported_program import OutputKind, OutputSpec, TensorArgument + + +# Don't ever tap our own tap nodes if a user runs the pass twice. +# `tap.Tensor` is already an OpOverload (not a packet) since "Tensor" is the +# overload name — same convention as torch.ops.executorch_utils.graph_break.Tensor. +_TAP_TARGET = torch.ops.executorch_devtools.tap.Tensor + + +def _is_tap_node(n: fx.Node) -> bool: + return n.op == "call_function" and n.target is _TAP_TARGET + + +def tap_intermediate_outputs( # noqa: C901 + ep: ExportedProgram, + selector: NodeSelector | None = None, + reducer: str | StatReducer = STATS, + *, + tap_name_prefix: str = "tap_", + skip_if_no_debug_handle: bool = False, + max_taps: int | None = None, + inplace: bool = False, + error_on_empty: bool = True, +) -> tuple[ExportedProgram, list[TapSpec]]: + """ + Rewrite `ep` so each node matching `selector` has its output appended to + the program outputs (wrapped in a `tap.Tensor` placeholder that survives + partitioning). Returns the new ExportedProgram and a list of TapSpecs. + + The returned EP is safe to feed to + `to_edge_transform_and_lower(...).to_executorch()` *after* calling + `strip_taps_(edge_manager)` to replace the placeholders with their + reducer subgraphs (or identities, for FULL_TENSOR). + + Args: + ep: The ExportedProgram to tap. + selector: A predicate over fx.Node. Defaults to + `select_all_call_function()`. Tap nodes themselves are always + excluded so re-running the pass is idempotent. + reducer: Either a built-in reducer name ("STATS", "FULL_TENSOR") + or a custom StatReducer instance. + tap_name_prefix: Prefix for the tap nodes' names. Helps when + grepping the dumped graph. + skip_if_no_debug_handle: If True, only tap nodes that already + carry `node.meta["debug_handle"]`. Recommended for Inspector + integration since handle-less taps cannot be aligned with + AOT outputs. + max_taps: Optional cap on number of taps. Helps avoid OOM for + very large models. + inplace: If False (default), deep-copy `ep` before mutating. + error_on_empty: If True (default), raise `ValueError` when no nodes + match the selector. Set to False to only emit a `UserWarning` + and return `(ep, [])` — handy when iterating on selector + patterns. + """ + if selector is None: + selector = select_all_call_function() + reducer_obj = get_reducer(reducer) + + if not inplace: + ep = copy.deepcopy(ep) + + gs = ep.graph_signature + gm = ep.graph_module + graph = gm.graph + output_node = graph.output_node() + existing_outputs = list(output_node.args[0]) + + # Snapshot before we start mutating the graph. + candidate_nodes = [n for n in graph.nodes if not _is_tap_node(n)] + + specs: list[TapSpec] = [] + new_tap_nodes: list[fx.Node] = [] + + for node in candidate_nodes: + if node.op != "call_function" or not selector(node): + continue + debug_handle = node.meta.get("debug_handle") + if skip_if_no_debug_handle and debug_handle is None: + continue + if max_taps is not None and len(specs) >= max_taps: + break + + # tap.Tensor's int arg cannot be None; sentinel 0 means "no handle". + dh_arg = int(debug_handle) if isinstance(debug_handle, int) else 0 + + with graph.inserting_after(node): + tap_node = graph.call_function( + _TAP_TARGET, + args=(node, reducer_obj.name, dh_arg), + ) + # Don't override the auto-assigned name — FX guarantees uniqueness. + # Stash the prefixed-source-name in meta for human-readable logs. + tap_node.meta["tap_label"] = f"{tap_name_prefix}{node.name}" + # Preserve provenance for Inspector's `propagate_back_debug_handle` + # and for users that pretty-print the graph. + if debug_handle is not None: + tap_node.meta["debug_handle"] = debug_handle + if "from_node" in node.meta: + tap_node.meta["from_node"] = node.meta["from_node"] + if "stack_trace" in node.meta: + tap_node.meta["stack_trace"] = node.meta["stack_trace"] + if "nn_module_stack" in node.meta: + tap_node.meta["nn_module_stack"] = node.meta["nn_module_stack"] + tap_node.meta["is_tap"] = True + tap_node.meta["source_node"] = node.name + + new_tap_nodes.append(tap_node) + # Leaf module FQN from nn_module_stack (e.g., "layers.0.attention.wqs.0"). + module_path: str | None = None + stack = node.meta.get("nn_module_stack") + if stack: + try: + last_entry = list(stack.values())[-1] + module_path = ( + last_entry[0] if isinstance(last_entry, tuple) else str(last_entry) + ) + except Exception: + module_path = None + specs.append( + TapSpec( + node_name=node.name, + op_target=str(node.target), + debug_handle=debug_handle if isinstance(debug_handle, int) else None, + output_index=len(existing_outputs) + len(specs), + reducer_name=reducer_obj.name, + fields=reducer_obj.fields, + stack_trace=node.meta.get("stack_trace"), + module_path=module_path, + ) + ) + + if not new_tap_nodes: + msg = ( + "tap_intermediate_outputs: selector matched 0 nodes. " + "Double-check your selector predicates, " + "or pass `error_on_empty=False` to suppress this error." + ) + if error_on_empty: + raise ValueError(msg) + warnings.warn(msg, UserWarning, stacklevel=2) + return ep, [] + + # Splice new outputs into the graph (mirror weights_to_outputs_pass). + new_output_args = tuple(existing_outputs + new_tap_nodes) + with graph.inserting_before(output_node): + new_output = graph.output(new_output_args) + output_node.replace_all_uses_with(new_output) + graph.erase_node(output_node) + + # Append OutputSpec entries so the EP's signature matches the graph. + for tap_node in new_tap_nodes: + gs.output_specs.append( + OutputSpec( + kind=OutputKind.USER_OUTPUT, + arg=TensorArgument(name=tap_node.name), + target=None, + ) + ) + + # Update each ModuleCallSignature's out_spec so `to_edge`'s re-trace can + # unflatten the new flat output list. The "" (root) entry holds the + # user-facing forward output structure; we wrap it in a tuple alongside + # the new tap leaves and re-derive the spec. + _extend_module_call_graph_outputs(ep, new_tap_nodes) + + graph.eliminate_dead_code() + gm.recompile() + return ep, specs + + +def _extend_module_call_graph_outputs( + ep: ExportedProgram, + new_tap_nodes: list[fx.Node], +) -> None: + """ + Append `len(new_tap_nodes)` extra leaves to the root module-call entry's + `out_spec` so the pytree unflatten step in `run_decompositions` works. + Also extends the entry's `outputs: list[ArgumentSpec]`. + + NOTE: We append TensorArgument(name="") for each new tap output. Empty + names are *skipped* by `_verify_exported_program_module_call_graph` (its + check is `if arg.name and arg.name not in nodes`). We can't use the + pre-trace tap node names because `to_edge`'s re-trace renames nodes via + `from_node` chains, and our tap nodes' provenance wouldn't update them + correctly — leading to "Output X does not exist in the graph" errors. + The verifier's name check is metadata-only; the actual pytree unflatten + only needs `out_spec` to have the correct number of leaves. + """ + import torch.utils._pytree as pytree + from torch.export.exported_program import TensorArgument as _TensorArgument + + n_new = len(new_tap_nodes) + if n_new == 0: + return + + for entry in ep._module_call_graph: + if entry.fqn != "": + continue + sig = entry.signature + if sig is None: + continue + old_spec = sig.out_spec + # Build a dummy structure matching the old spec, then wrap with N new + # leaves and re-derive the spec. This handles arbitrary pytree shapes. + old_dummy = pytree.tree_unflatten([0] * old_spec.num_leaves, old_spec) + if isinstance(old_dummy, tuple): + new_dummy = (*old_dummy, *([0] * n_new)) + else: + new_dummy = (old_dummy, *([0] * n_new)) + sig.out_spec = pytree.tree_structure(new_dummy) + for _ in range(n_new): + sig.outputs.append(_TensorArgument(name="")) + break + + +def find_tap_nodes(gm: fx.GraphModule) -> list[fx.Node]: + """Helper: enumerate tap.Tensor nodes in a GraphModule (any dialect).""" + out: list[fx.Node] = [] + for n in gm.graph.nodes: + if n.op != "call_function": + continue + # Match across dialects: + # pre-edge: torch.ops.executorch_devtools.tap.Tensor — str ends with name + # post-edge: : schema = ... + # so substring-match the qualified name. + if "executorch_devtools.tap.Tensor" in str(n.target) or n.target is _TAP_TARGET: + out.append(n) + return out + + +# Re-export the predicate so callers can identify tap nodes without importing +# torch.ops directly. +is_tap_node: Callable[[fx.Node], bool] = _is_tap_node diff --git a/devtools/intermediate_output_tap/custom_ops_lib.py b/devtools/intermediate_output_tap/custom_ops_lib.py new file mode 100644 index 00000000000..79436915c27 --- /dev/null +++ b/devtools/intermediate_output_tap/custom_ops_lib.py @@ -0,0 +1,66 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +""" +Custom op registration for the intermediate-output tap mechanism. + +The op `executorch_devtools::tap.Tensor(Tensor x, str reducer_name, int debug_handle) -> Tensor` +is a placeholder whose sole job is to be an unknown-to-every-partitioner FX +node that "uses" a tapped tensor `x`. Because `x` now has a user outside any +partition, every ExecuTorch partitioner must surface `x` as a partition output +(this is the canonical contract enforced in +`executorch/exir/lowered_backend_module.py`). + +After `to_edge_transform_and_lower(...)` the tap.Tensor node still exists in +the parent graph; `strip_taps_` (see `_strip_pass.py`) replaces it with either +an identity edge (FULL_TENSOR) or a small reducer subgraph of portable aten +ops. + +The op's eager impl computes the named reducer's eager equivalent (e.g. +`min/max/mean/abs_max/...` for STATS). Two reasons for this: + +1. **Re-trace safety.** `to_edge_transform_and_lower` re-traces the graph. + If `tap.Tensor` simply returned `x` literally, the re-traced FX graph + would treat the tap output as identical to `x` and re-route downstream + consumers (which would otherwise be reading `x`) through the tap node, + pulling it into a delegate's input list. Returning a *different* tensor + (different shape for non-FULL_TENSOR; `x.detach()` for FULL_TENSOR) + keeps consumers wired to `x` directly so the tap stays a host-only stub + with the FX `output` node as its sole consumer. +2. **AOT/runtime parity.** Calling `ep_t.module()(*inputs)` then returns the + same reduced values the runtime emits post-strip, removing the need for + callers to reapply the reducer themselves. + +The dispatch key MUST be `CompositeExplicitAutograd` (not +`CompositeImplicitAutograd`) so the op survives tracing/decomposition; +otherwise it would inline at export time and disappear before partitioning. +This mirrors the pattern in +`executorch/examples/apple/coreml/llama/export_static_llm_coreml.py`. + +`reducer_name` and `debug_handle` are stored as op arguments (not just +node.meta) so they survive any meta-stripping pass between `to_edge` and +`strip_taps_`. +""" + +from __future__ import annotations + +from torch.library import impl, Library + +# Library namespace verified collision-free across fbsource as of Nov 2025. +lib: Library = Library("executorch_devtools", "DEF") + +lib.define("tap.Tensor(Tensor x, str reducer_name, int debug_handle) -> Tensor") + + +@impl(lib, "tap.Tensor", "CompositeExplicitAutograd") +def tap_tensor_impl(x, reducer_name, debug_handle): # noqa: ARG001 + # Defer the import to break a module-import cycle (`_reducers` → torch → + # custom_ops_lib registration). + from executorch.devtools.intermediate_output_tap._reducers import get_reducer + + return get_reducer(reducer_name).eager(x) diff --git a/devtools/intermediate_output_tap/tests/TARGETS b/devtools/intermediate_output_tap/tests/TARGETS new file mode 100644 index 00000000000..79344627b17 --- /dev/null +++ b/devtools/intermediate_output_tap/tests/TARGETS @@ -0,0 +1,74 @@ +load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest") +load("@fbsource//tools/target_determinator/macros:ci.bzl", "ci") + +oncall("executorch") + +python_unittest( + name = "test_selectors", + srcs = ["test_selectors.py"], + deps = [ + "//caffe2:torch", + "//executorch/devtools/intermediate_output_tap:selectors", + ], +) + +python_unittest( + name = "test_reducers", + srcs = ["test_reducers.py"], + deps = [ + "//caffe2:torch", + "//executorch/devtools/intermediate_output_tap:reducers", + ], +) + +python_unittest( + name = "test_tap_pass", + srcs = ["test_tap_pass.py"], + deps = [ + "//caffe2:torch", + "//executorch/devtools/intermediate_output_tap:reducers", + "//executorch/devtools/intermediate_output_tap:selectors", + "//executorch/devtools/intermediate_output_tap:tap_pass", + ], +) + +python_unittest( + name = "test_strip_pass", + srcs = ["test_strip_pass.py"], + deps = [ + "//caffe2:torch", + "//executorch/devtools/intermediate_output_tap:reducers", + "//executorch/devtools/intermediate_output_tap:selectors", + "//executorch/devtools/intermediate_output_tap:strip_pass", + "//executorch/devtools/intermediate_output_tap:tap_pass", + "//executorch/exir:lib", + ], +) + +python_unittest( + name = "test_xnnpack_e2e", + srcs = ["test_xnnpack_e2e.py"], + deps = [ + "//caffe2:torch", + "//executorch/backends/xnnpack/partition:xnnpack_partitioner", + "//executorch/devtools/intermediate_output_tap:lib", + "//executorch/exir:lib", + "//executorch/runtime:runtime", + ], +) + +python_unittest( + name = "test_inspector_integration", + srcs = ["test_inspector_integration.py"], + labels = ci.labels( + ci.buckconfig("executorch.event_tracer_enabled", "true"), + ), + deps = [ + "//caffe2:torch", + "//executorch/backends/xnnpack/partition:xnnpack_partitioner", + "//executorch/devtools:lib", + "//executorch/devtools/intermediate_output_tap:lib", + "//executorch/exir:lib", + "//executorch/runtime:runtime", + ], +) diff --git a/devtools/intermediate_output_tap/tests/test_coreml_e2e.py b/devtools/intermediate_output_tap/tests/test_coreml_e2e.py new file mode 100644 index 00000000000..36daeacef07 --- /dev/null +++ b/devtools/intermediate_output_tap/tests/test_coreml_e2e.py @@ -0,0 +1,252 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +""" +Numerical-debugging tutorial for the ExecuTorch CoreML backend. + +This script walks through end-to-end use of the intermediate-output tap +infrastructure on the static-attention Llama from +`examples/apple/coreml/llama/`. It is a smoke test (random weights, tiny +ModelArgs — no checkpoint download required) and produces two tables: + +1. A delegation summary showing how many subgraphs ExecuTorch handed off + to CoreML and which operators ran on which side. + +2. An AOT-vs-runtime comparison of the tapped intermediate values, so you + can see numerical drift between eager-PyTorch and the CoreML runtime + at hand-picked points in the model. + +Run with: + python swift_play/test_inspector_coreml.py +""" + +import os +import tempfile + +import coremltools as ct +import pandas as pd +import torch +import torch.utils._pytree as pytree +from executorch.backends.apple.coreml.compiler import CoreMLBackend +from executorch.backends.apple.coreml.partition import CoreMLPartitioner +from executorch.devtools.backend_debug import get_delegation_info +from executorch.devtools.intermediate_output_tap import ( + compare_aot_runtime_dataframe, + FULL_TENSOR, + select_all, + select_any, + select_by_module_path, + select_by_op_type, + STATS, + strip_taps_, + tap_intermediate_outputs, +) +import types + +from executorch.examples.apple.coreml.llama.export_static_llm_coreml import ( + _create_example_inputs, + _transform_eager_model, + remove_graph_break_, +) +from executorch.examples.models.llama.llama_transformer import construct_transformer +from executorch.examples.models.llama.model_args import ModelArgs +from executorch.exir import to_edge_transform_and_lower +from executorch.runtime import Runtime, Verification +from torch.export import export + + +def _build_model() -> tuple[torch.nn.Module, ModelArgs]: + """Build a tiny static-attention Llama with random weights.""" + args = ModelArgs( + dim=64, + n_layers=2, + n_heads=4, + n_kv_heads=2, + vocab_size=128, + hidden_dim=128, + max_seq_len=64, + max_context_len=64, + generate_full_logits=True, + ) + args.attention_type = "static_mha" + args.attention_kwargs = {"decompose_sdpa_in_mha": True} + + model = construct_transformer(args) + transform_args = types.SimpleNamespace( + target_split_size=None, + max_splits=8, + embedding_quantize="", + linear_quantize="c4w", + no_graph_breaks=False, + ) + model = _transform_eager_model(model, transform_args, torch.float16) + return model, args + + +def main() -> None: + # ------------------------------------------------------------------ + # Step 1: Build and quantize the model. + # ------------------------------------------------------------------ + model, model_args = _build_model() + + # ------------------------------------------------------------------ + # Step 2: Create example inputs and export. + # ------------------------------------------------------------------ + input_len = 8 + example_inputs, cache_len = _create_example_inputs( + model_args, + input_len=input_len, + max_context_len=model_args.max_context_len, + float_dtype=torch.float16, + ) + print(f"input_len={input_len} cache_len={cache_len}") + + with torch.no_grad(): + _ = model(*example_inputs) # eager sanity check + + print("Exporting...") + ep = export(model, example_inputs) + + # ------------------------------------------------------------------ + # Step 3: Pick which intermediate values to tap. + # + # We use two reducers in two passes: + # + # * FULL_TENSOR for `layers.1.attention.wvs.0` — surfaces the raw + # activation tensor; the comparison DataFrame computes SQNR over + # all elements. + # + # * STATS for everything else (`output`, all wqs/wks linears, layer 0's + # wvs, and layer 1's RMSNorm output mul) — gives a rich set of + # debugging scalars (min/max/mean/std/rms/l1/l2/abs_max/abs_mean/ + # nan_count/inf_count/zero_count/p99_abs). + # ------------------------------------------------------------------ + # Patterns use `*` between `layers..` and the inner module so they match + # both the bare path (`layers..attention...`) and the wrapped path + # (`layers..block.attention...`) that BlockWithGraphBreak introduces + # at the partition boundaries. + selector_full_tensor = select_any( + # Token-embedding output (one big tensor, before any transformer block). + select_by_op_type("aten.embedding.default"), + # First wvs linear in layer 1 — captures full activation post-Q/K/V. + select_all( + select_by_op_type("aten.linear.default"), + select_by_module_path("layers.1.*attention.wvs.*"), + ), + ) + selector_stats = select_any( + select_all( + select_by_op_type("aten.linear.default"), + select_any( + select_by_module_path("output"), + select_by_module_path("*.attention.wqs.*"), + select_by_module_path("*.attention.wks.*"), + select_by_module_path("layers.0.*attention.wvs.*"), + ), + ), + select_all( + select_by_op_type("aten.mul.Tensor"), + select_any( + select_by_module_path("layers.1.*attention_norm"), + select_by_module_path("layers.1.*attention_norm.*"), + select_by_module_path("layers.1.*ffn_norm"), + select_by_module_path("layers.1.*ffn_norm.*"), + ), + ), + ) + + ep_t, specs_full = tap_intermediate_outputs( + ep, selector=selector_full_tensor, reducer=FULL_TENSOR + ) + ep_t, specs_stats = tap_intermediate_outputs( + ep_t, selector=selector_stats, reducer=STATS + ) + specs = list(specs_full) + list(specs_stats) + print( + f"Inserted {len(specs)} tap(s) " + f"({len(specs_full)} FULL_TENSOR + {len(specs_stats)} STATS)." + ) + + # ------------------------------------------------------------------ + # Step 4: Capture the AOT-side reference values. + # + # `tap.Tensor`'s eager impl applies the reducer, so the flat outputs of + # the tapped EP already contain reduced values at the same positions + # the runtime will use. We pytree-flatten because the static-llama + # forward returns nested (logits, (k_caches, v_caches)). + # ------------------------------------------------------------------ + aot_out = ep_t.module()(*example_inputs) + aot_flat, _ = pytree.tree_flatten(aot_out) + + # ------------------------------------------------------------------ + # Step 5: Lower to CoreML, strip the taps, and show what got delegated. + # ------------------------------------------------------------------ + coreml_partitioner = CoreMLPartitioner( + compile_specs=CoreMLBackend.generate_compile_specs( + minimum_deployment_target=ct.target.iOS18, + compute_precision=ct.precision.FLOAT16, + compute_unit=ct.ComputeUnit.CPU_AND_NE, + ), + ) + print("Lowering to CoreML...") + edge = to_edge_transform_and_lower(ep_t, partitioner=[coreml_partitioner]) + # Drop the `executorch_utils::graph_break.Tensor` placeholders that + # `_transform_eager_model` inserted to force partition boundaries — they + # have no out-variant kernel, so they must not survive into the runtime + # program. + remove_graph_break_(edge) + specs = strip_taps_(edge, tap_specs=specs) + + delegation_info = get_delegation_info(edge.exported_program().graph_module) + print( + f"\n=== Delegation summary " + f"(num_delegated_subgraphs={delegation_info.num_delegated_subgraphs}) ===" + ) + print(delegation_info.get_summary()) + with pd.option_context( + "display.max_columns", None, + "display.width", 240, + "display.max_colwidth", 60, + ): + print( + delegation_info.get_operator_delegation_dataframe().to_string(index=False) + ) + + # ------------------------------------------------------------------ + # Step 6: Save the .pte and run it through the ExecuTorch runtime. + # ------------------------------------------------------------------ + et_program = edge.to_executorch() + with tempfile.TemporaryDirectory() as temp_dir: + pte_path = os.path.join(temp_dir, "model.pte") + et_program.save(pte_path) + print(f"\nSaved PTE: {pte_path} ({os.path.getsize(pte_path)} bytes)") + + rt = Runtime.get() + program = rt.load_program(pte_path, verification=Verification.Minimal) + method = program.load_method("forward") + # Runtime takes a flat tensor list — flatten the (tokens, options_dict) + # pytree the same way torch.export did. + flat_inputs, _ = pytree.tree_flatten(example_inputs) + rt_flat = list(method.execute(flat_inputs)) + + # ------------------------------------------------------------------ + # Step 7: Compare AOT vs runtime. + # ------------------------------------------------------------------ + df = compare_aot_runtime_dataframe(specs, aot_flat, rt_flat) + with pd.option_context( + "display.max_columns", None, + "display.width", 280, + "display.max_colwidth", 30, + "display.float_format", "{:.4g}".format, + ): + print(f"\n{len(specs)} tap(s) — AOT vs CoreML runtime:") + print(df.to_string(index=False)) + + +if __name__ == "__main__": + main() diff --git a/devtools/intermediate_output_tap/tests/test_inspector_integration.py b/devtools/intermediate_output_tap/tests/test_inspector_integration.py new file mode 100644 index 00000000000..4247cd34360 --- /dev/null +++ b/devtools/intermediate_output_tap/tests/test_inspector_integration.py @@ -0,0 +1,151 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +""" +Integration test: run the full pipeline (export -> tap -> lower with XNNPACK +-> strip -> generate_etrecord -> to_executorch -> runtime) and feed the flat +runtime outputs + (post-strip) TapSpecs to +Inspector.calculate_numeric_gap_from_taps. Verify the returned DataFrame has +rows aligned by debug_handle. + +KEY DESIGN POINTS: +1. ETRecord generation MUST happen AFTER `strip_taps_` so the snapshot of the + edge program contains no `tap.Tensor` nodes (which the EXIR serializer + can't handle). +2. `strip_taps_(edge, tap_specs=specs)` returns updated specs whose + `reducer_node_name` is set to the post-strip reducer terminal node name. + Inspector uses that name to look up the post-roundtrip `debug_handle` — + FX node names survive ETRecord serialization, debug_handle values do not. +""" + +import os +import sys +import tempfile +import unittest + +import torch +from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner +from executorch.devtools import generate_etrecord, Inspector +from executorch.devtools.intermediate_output_tap import ( + format_tap_dataframe, + select_by_op_type, + STATS, + strip_taps_, + tap_intermediate_outputs, +) +from executorch.exir import to_edge_transform_and_lower +from executorch.runtime import Runtime, Verification +from torch.export import export + + +class _MLP(torch.nn.Module): + def __init__(self): + super().__init__() + self.l1 = torch.nn.Linear(8, 16) + self.l2 = torch.nn.Linear(16, 4) + + def forward(self, x): + return self.l2(self.l1(x).relu()) + + +@unittest.skipIf( + sys.platform.startswith("win"), "ExecuTorch runtime not available on Windows" +) +class InspectorIntegrationTest(unittest.TestCase): + def test_calculate_numeric_gap_from_taps(self): + model = _MLP() + example_inputs = (torch.randn(2, 8),) + + ep = export(model, example_inputs, strict=True) + ep_t, specs = tap_intermediate_outputs( + ep, + selector=select_by_op_type("aten.linear.default"), + reducer=STATS, + ) + # Do NOT pass generate_etrecord=True — we'd snapshot the EP while it + # still has tap.Tensor nodes (unserializable). + edge = to_edge_transform_and_lower( + ep_t, + partitioner=[XnnpackPartitioner()], + ) + # strip_taps_ with tap_specs returns updated specs whose + # reducer_node_name points at the post-strip reducer terminal node. + specs = strip_taps_(edge, tap_specs=specs) + et_program = edge.to_executorch() + + with tempfile.TemporaryDirectory() as temp_dir: + pte_path = os.path.join(temp_dir, "model.pte") + et_program.save(pte_path) + + # ETRecord generated AFTER strip — the edge program is now + # serializable. Don't pass exported_program: Inspector falls back + # to the edge dialect program for AOT capture. + etrecord_path = os.path.join(temp_dir, "etrecord.bin") + generate_etrecord( + etrecord_path, + edge_dialect_program=edge, + executorch_program=et_program, + ) + + rt = Runtime.get() + program = rt.load_program( + pte_path, + verification=Verification.Minimal, + enable_etdump=True, + debug_buffer_size=1024 * 1024, + ) + method = program.load_method("forward") + flat_outputs = method.execute(list(example_inputs)) + + etdump_path = os.path.join(temp_dir, "etdump.etdp") + debug_buffer_path = os.path.join(temp_dir, "debug_buffer.bin") + program.write_etdump_result_to_file(etdump_path, debug_buffer_path) + if not os.path.exists(etdump_path): + self.skipTest( + "Event tracer not enabled. Run with " + "--config executorch.event_tracer_enabled=true" + ) + + inspector = Inspector( + etdump_path=etdump_path, + etrecord=etrecord_path, + debug_buffer_path=debug_buffer_path, + ) + inspector._etrecord._representative_inputs = list(example_inputs) + df = inspector.calculate_numeric_gap_from_taps( + flat_runtime_outputs=flat_outputs, + tap_specs=specs, + distance="MSE", + ) + # Print friendly per-tap view to stdout (visible via --print-passing-details). + friendly = format_tap_dataframe(df, specs) + import pandas as _pd + + with _pd.option_context( + "display.max_columns", + None, + "display.width", + 240, + "display.max_colwidth", + 30, + "display.float_format", + "{:.4g}".format, + ): + print("\n=== Inspector.calculate_numeric_gap_from_taps (friendly) ===") + print(friendly.to_string()) + + self.assertGreater(len(df), 0, "expected at least one tap row in DataFrame") + for col in ("aot_ops", "runtime_ops", "gap"): + self.assertIn(col, df.columns) + for _, row in df.iterrows(): + self.assertIsNotNone(row["aot_ops"]) + self.assertIsNotNone(row["runtime_ops"]) + gap = row["gap"] + if isinstance(gap, list): + gap = gap[0] if gap else 0.0 + self.assertGreaterEqual(float(gap), 0.0) diff --git a/devtools/intermediate_output_tap/tests/test_reducers.py b/devtools/intermediate_output_tap/tests/test_reducers.py new file mode 100644 index 00000000000..9970da265d9 --- /dev/null +++ b/devtools/intermediate_output_tap/tests/test_reducers.py @@ -0,0 +1,131 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import unittest + +import torch +from executorch.devtools.intermediate_output_tap._reducers import ( + FULL_TENSOR, + get_reducer, + StatReducer, + STATS, +) + + +class ReducersTest(unittest.TestCase): + def test_get_reducer_by_name(self): + self.assertIs(get_reducer("FULL_TENSOR"), FULL_TENSOR) + self.assertIs(get_reducer("STATS"), STATS) + + def test_get_reducer_passthrough(self): + custom = StatReducer( + name="X", + fields=("a",), + emit=lambda g, n: n, + eager=lambda t: t, + ) + self.assertIs(get_reducer(custom), custom) + + def test_get_reducer_unknown_raises(self): + with self.assertRaises(ValueError): + get_reducer("DOES_NOT_EXIST") + + def test_reducer_field_counts(self): + self.assertEqual(FULL_TENSOR.fields, ()) + self.assertEqual( + STATS.fields, + ( + "min", + "max", + "mean", + "abs_max", + "abs_mean", + "std", + "rms", + "l1_norm", + "l2_norm", + "nan_count", + "inf_count", + "zero_count", + "p99_abs", + ), + ) + + def test_reducer_names_unique(self): + names = {r.name for r in (FULL_TENSOR, STATS)} + self.assertEqual(len(names), 2) + + def test_full_tensor_eager_is_identity(self): + t = torch.randn(2, 3, 4) + out = FULL_TENSOR.eager(t) + self.assertEqual(out.shape, t.shape) + torch.testing.assert_close(out, t.detach()) + + def test_stats_eager_correctness(self): + torch.manual_seed(0) + t = torch.randn(64) + out = STATS.eager(t) + self.assertEqual(out.shape, (len(STATS.fields),)) + + f = t.to(torch.float32) + expected = { + "min": float(f.amin()), + "max": float(f.amax()), + "mean": float(f.mean()), + "abs_max": float(f.abs().amax()), + "abs_mean": float(f.abs().mean()), + "rms": float(f.pow(2).mean().sqrt()), + "l1_norm": float(f.abs().sum()), + "l2_norm": float(f.pow(2).sum().sqrt()), + "nan_count": 0.0, + "inf_count": 0.0, + "zero_count": float((f == 0).to(torch.float32).sum()), + } + for i, field in enumerate(STATS.fields): + if field in expected: + torch.testing.assert_close( + float(out[i]), expected[field], rtol=1e-4, atol=1e-5 + ) + # std uses E[x^2] - E[x]^2 (population variance); compare to that. + pop_var = float((f.pow(2).mean() - f.mean().pow(2)).abs()) + torch.testing.assert_close( + float(out[STATS.fields.index("std")]) ** 2, + pop_var, + rtol=1e-4, + atol=1e-5, + ) + + def test_stats_p99_abs_matches_topk(self): + torch.manual_seed(0) + t = torch.randn(1000) + out = STATS.eager(t) + numel = t.numel() + k = max(1, (numel + 99) // 100) + expected = float( + torch.topk(t.abs().reshape(-1), k=k, largest=True, sorted=True).values[ + k - 1 + ] + ) + torch.testing.assert_close( + float(out[STATS.fields.index("p99_abs")]), + expected, + rtol=1e-4, + atol=1e-5, + ) + + def test_stats_counts_nan_and_inf(self): + t = torch.tensor( + [1.0, float("nan"), 2.0, float("inf"), 0.0, -float("inf"), 0.0] + ) + out = STATS.eager(t) + i_nan = STATS.fields.index("nan_count") + i_inf = STATS.fields.index("inf_count") + i_zero = STATS.fields.index("zero_count") + self.assertEqual(float(out[i_nan]), 1.0) + self.assertEqual(float(out[i_inf]), 2.0) + self.assertEqual(float(out[i_zero]), 2.0) diff --git a/devtools/intermediate_output_tap/tests/test_selectors.py b/devtools/intermediate_output_tap/tests/test_selectors.py new file mode 100644 index 00000000000..eafbb290916 --- /dev/null +++ b/devtools/intermediate_output_tap/tests/test_selectors.py @@ -0,0 +1,119 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import unittest + +import torch +from executorch.devtools.intermediate_output_tap._selectors import ( + select_all, + select_all_call_function, + select_any, + select_by_meta_tag, + select_by_module_path, + select_by_op_type, + select_not, +) +from torch.export import export + + +class _Inner(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(4, 4) + + def forward(self, x): + return self.linear(x).relu() + + +class _Outer(torch.nn.Module): + def __init__(self): + super().__init__() + self.inner = _Inner() + self.head = torch.nn.Linear(4, 2) + + def forward(self, x): + return self.head(self.inner(x)) + + +def _exported_graph(): + ep = export(_Outer(), (torch.randn(2, 4),), strict=True) + return ep.graph_module.graph + + +class SelectorsTest(unittest.TestCase): + def setUp(self): + self.graph = _exported_graph() + self.call_nodes = [n for n in self.graph.nodes if n.op == "call_function"] + + def test_select_all_call_function_excludes_getitem(self): + sel = select_all_call_function() + for n in self.call_nodes: + if "getitem" in str(n.target): + self.assertFalse(sel(n)) + else: + self.assertTrue(sel(n)) + + def test_select_by_op_type_matches_suffix(self): + sel = select_by_op_type("aten.linear.default", "aten.relu.default") + matched = [n for n in self.call_nodes if sel(n)] + # Two linears + one relu in the model. + self.assertGreaterEqual(len(matched), 2) + for n in matched: + self.assertTrue( + str(n.target).endswith("aten.linear.default") + or str(n.target).endswith("aten.relu.default") + ) + + def test_select_by_op_type_requires_target(self): + with self.assertRaises(ValueError): + select_by_op_type() + + def test_select_by_module_path(self): + sel = select_by_module_path("inner.*") + matched = [n for n in self.call_nodes if sel(n)] + # inner contains a linear and a relu. + self.assertGreater(len(matched), 0) + for n in matched: + stack = n.meta.get("nn_module_stack") or {} + paths = [v[0] if isinstance(v, tuple) else v for v in stack.values()] + self.assertTrue(any(p.startswith("inner") for p in paths)) + + def test_select_by_meta_tag_presence(self): + for n in self.call_nodes[:1]: + n.meta["debug_me"] = "yes" + sel = select_by_meta_tag("debug_me") + self.assertTrue(sel(self.call_nodes[0])) + self.assertFalse(sel(self.call_nodes[1])) + + def test_select_by_meta_tag_value(self): + self.call_nodes[0].meta["color"] = "blue" + self.call_nodes[1].meta["color"] = "red" + sel = select_by_meta_tag("color", "blue") + self.assertTrue(sel(self.call_nodes[0])) + self.assertFalse(sel(self.call_nodes[1])) + + def test_select_combinators(self): + a = select_by_op_type("aten.linear.default") + b = select_by_op_type("aten.relu.default") + any_sel = select_any(a, b) + all_sel = select_all(a, b) + not_sel = select_not(a) + + for n in self.call_nodes: + if a(n) or b(n): + self.assertTrue(any_sel(n)) + self.assertEqual(all_sel(n), a(n) and b(n)) + self.assertEqual(not_sel(n), not a(n)) + + def test_select_any_empty(self): + for n in self.call_nodes: + self.assertFalse(select_any()(n)) + + def test_select_all_empty(self): + for n in self.call_nodes: + self.assertTrue(select_all()(n)) diff --git a/devtools/intermediate_output_tap/tests/test_strip_pass.py b/devtools/intermediate_output_tap/tests/test_strip_pass.py new file mode 100644 index 00000000000..d1ac8773a06 --- /dev/null +++ b/devtools/intermediate_output_tap/tests/test_strip_pass.py @@ -0,0 +1,116 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import unittest + +import torch +from executorch.devtools.intermediate_output_tap._reducers import ( + DEFAULT_STATS, + FULL_TENSOR, + MIN_MAX_MEAN, +) +from executorch.devtools.intermediate_output_tap._selectors import select_by_op_type +from executorch.devtools.intermediate_output_tap._strip_pass import strip_taps_ +from executorch.devtools.intermediate_output_tap._tap_pass import ( + find_tap_nodes, + tap_intermediate_outputs, +) +from executorch.exir import to_edge +from torch.export import export + + +class _MLP(torch.nn.Module): + def __init__(self): + super().__init__() + self.l1 = torch.nn.Linear(8, 8) + self.l2 = torch.nn.Linear(8, 4) + + def forward(self, x): + return self.l2(self.l1(x).relu()) + + +def _tapped_edge(reducer): + ep = export(_MLP(), (torch.randn(2, 8),), strict=True) + ep_t, specs = tap_intermediate_outputs( + ep, + selector=select_by_op_type("aten.linear.default"), + reducer=reducer, + ) + return to_edge(ep_t), specs + + +class StripPassTest(unittest.TestCase): + def test_strip_removes_all_tap_nodes_full_tensor(self): + edge, _ = _tapped_edge(FULL_TENSOR) + # Pre-strip: tap nodes present. + for method_name in edge.methods: + ep = edge.exported_program(method_name) + self.assertGreater(len(find_tap_nodes(ep.graph_module)), 0) + + strip_taps_(edge) + + # Post-strip: no tap nodes. + for method_name in edge.methods: + ep = edge.exported_program(method_name) + self.assertEqual(len(find_tap_nodes(ep.graph_module)), 0) + + def test_strip_full_tensor_routes_source_to_output(self): + edge, specs = _tapped_edge(FULL_TENSOR) + strip_taps_(edge) + # Output node should still have all the user outputs + tap outputs. + for method_name in edge.methods: + ep = edge.exported_program(method_name) + outs = list(ep.graph_module.graph.output_node().args[0]) + # Original outputs + 2 linears tapped. + self.assertGreaterEqual(len(outs), len(specs)) + + def test_strip_min_max_mean_emits_subgraph(self): + edge, specs = _tapped_edge(MIN_MAX_MEAN) + strip_taps_(edge) + for method_name in edge.methods: + ep = edge.exported_program(method_name) + self.assertEqual(len(find_tap_nodes(ep.graph_module)), 0) + # Some reduction op (amin/amax/mean) should now be in the graph. + # Substring match because EdgeOpOverload's str() looks like + # ": schema = ..." (no clean + # endswith). + targets = {str(n.target) for n in ep.graph_module.graph.nodes} + self.assertTrue( + any( + "aten.amin" in t or "aten.amax" in t or "aten.mean" in t + for t in targets + ), + f"expected reducer ops in graph, got {targets}", + ) + + def test_strip_default_stats_preserves_debug_handle(self): + edge, specs = _tapped_edge(DEFAULT_STATS) + # Take a known debug_handle from one of the tap specs. + known_handles = {s.debug_handle for s in specs if s.debug_handle is not None} + if not known_handles: + self.skipTest("Test model produced no debug_handle on tap sources") + + strip_taps_(edge) + + post_handles: set = set() + for method_name in edge.methods: + ep = edge.exported_program(method_name) + for n in ep.graph_module.graph.nodes: + if n.meta.get("is_tap"): + post_handles.add(n.meta.get("debug_handle")) + # At least one tapped debug handle should still be present. + self.assertTrue(known_handles & post_handles) + + def test_strip_idempotent(self): + edge, _ = _tapped_edge(FULL_TENSOR) + strip_taps_(edge) + # Second call should be a no-op. + strip_taps_(edge) + for method_name in edge.methods: + ep = edge.exported_program(method_name) + self.assertEqual(len(find_tap_nodes(ep.graph_module)), 0) diff --git a/devtools/intermediate_output_tap/tests/test_tap_pass.py b/devtools/intermediate_output_tap/tests/test_tap_pass.py new file mode 100644 index 00000000000..d42d1d90bde --- /dev/null +++ b/devtools/intermediate_output_tap/tests/test_tap_pass.py @@ -0,0 +1,161 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import copy +import unittest + +import torch +from executorch.devtools.intermediate_output_tap._reducers import ( + DEFAULT_STATS, + FULL_TENSOR, +) +from executorch.devtools.intermediate_output_tap._selectors import select_by_op_type +from executorch.devtools.intermediate_output_tap._tap_pass import ( + is_tap_node, + tap_intermediate_outputs, +) +from torch.export import export +from torch.export.exported_program import OutputKind + + +class _MLP(torch.nn.Module): + def __init__(self): + super().__init__() + self.l1 = torch.nn.Linear(8, 16) + self.l2 = torch.nn.Linear(16, 8) + self.l3 = torch.nn.Linear(8, 4) + + def forward(self, x): + return self.l3(self.l2(self.l1(x).relu()).relu()) + + +def _export(): + return export(_MLP(), (torch.randn(2, 8),), strict=True) + + +class TapPassTest(unittest.TestCase): + def test_inserts_tap_per_selected_node(self): + ep = _export() + ep_t, specs = tap_intermediate_outputs( + ep, + selector=select_by_op_type("aten.linear.default"), + reducer=FULL_TENSOR, + ) + # MLP has 3 linears. + self.assertEqual(len(specs), 3) + tap_nodes = [n for n in ep_t.graph_module.graph.nodes if is_tap_node(n)] + self.assertEqual(len(tap_nodes), 3) + + def test_appends_user_outputs(self): + ep = _export() + original_user_outs = sum( + 1 + for s in ep.graph_signature.output_specs + if s.kind == OutputKind.USER_OUTPUT + ) + ep_t, specs = tap_intermediate_outputs( + ep, + selector=select_by_op_type("aten.linear.default"), + reducer=FULL_TENSOR, + ) + new_user_outs = sum( + 1 + for s in ep_t.graph_signature.output_specs + if s.kind == OutputKind.USER_OUTPUT + ) + self.assertEqual(new_user_outs, original_user_outs + len(specs)) + + def test_output_indices_contiguous_after_user_outputs(self): + ep = _export() + original_user_outs = sum( + 1 + for s in ep.graph_signature.output_specs + if s.kind == OutputKind.USER_OUTPUT + ) + _, specs = tap_intermediate_outputs( + ep, + selector=select_by_op_type("aten.linear.default"), + reducer=FULL_TENSOR, + ) + for i, spec in enumerate(specs): + self.assertEqual(spec.output_index, original_user_outs + i) + + def test_default_reducer_is_default_stats(self): + ep = _export() + _, specs = tap_intermediate_outputs( + ep, selector=select_by_op_type("aten.linear.default") + ) + for s in specs: + self.assertEqual(s.reducer_name, DEFAULT_STATS.name) + self.assertEqual(s.fields, DEFAULT_STATS.fields) + + def test_inplace_false_does_not_mutate_original(self): + ep = _export() + before_outs = len(list(ep.graph_module.graph.output_node().args[0])) + before_specs = len(ep.graph_signature.output_specs) + _ = tap_intermediate_outputs( + ep, selector=select_by_op_type("aten.linear.default"), reducer=FULL_TENSOR + ) + after_outs = len(list(ep.graph_module.graph.output_node().args[0])) + after_specs = len(ep.graph_signature.output_specs) + self.assertEqual(before_outs, after_outs) + self.assertEqual(before_specs, after_specs) + + def test_max_taps(self): + ep = _export() + _, specs = tap_intermediate_outputs( + ep, + selector=select_by_op_type("aten.linear.default"), + reducer=FULL_TENSOR, + max_taps=2, + ) + self.assertEqual(len(specs), 2) + + def test_idempotent_does_not_tap_taps(self): + ep = _export() + ep_once, specs1 = tap_intermediate_outputs( + ep, + selector=select_by_op_type("aten.linear.default"), + reducer=FULL_TENSOR, + ) + # Running again should not add NEW taps for our existing tap nodes. + ep_twice, specs2 = tap_intermediate_outputs( + ep_once, + selector=select_by_op_type("aten.linear.default"), + reducer=FULL_TENSOR, + ) + # Same number of linears matched; tap.Tensor itself is excluded. + self.assertEqual(len(specs2), len(specs1)) + + def test_no_match_returns_empty_specs(self): + ep = _export() + ep_t, specs = tap_intermediate_outputs( + ep, + selector=select_by_op_type("aten.does.not.exist"), + reducer=FULL_TENSOR, + ) + self.assertEqual(specs, []) + # Original graph signature is unchanged. + self.assertEqual( + len(ep_t.graph_signature.output_specs), + len(ep.graph_signature.output_specs), + ) + + def test_skip_if_no_debug_handle(self): + ep = _export() + # Strip all debug handles to simulate a graph without them. + ep_clean = copy.deepcopy(ep) + for n in ep_clean.graph_module.graph.nodes: + n.meta.pop("debug_handle", None) + _, specs = tap_intermediate_outputs( + ep_clean, + selector=select_by_op_type("aten.linear.default"), + reducer=FULL_TENSOR, + skip_if_no_debug_handle=True, + ) + self.assertEqual(specs, []) diff --git a/devtools/intermediate_output_tap/tests/test_xnnpack_e2e.py b/devtools/intermediate_output_tap/tests/test_xnnpack_e2e.py new file mode 100644 index 00000000000..54e3d48e142 --- /dev/null +++ b/devtools/intermediate_output_tap/tests/test_xnnpack_e2e.py @@ -0,0 +1,114 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +""" +End-to-end test: prove that intermediate values surfaced as USER_OUTPUT taps +flow through XNNPACK delegation and out the runtime *with no XNNPACK-side +support*. This is the central correctness claim of the design. +""" + +import os +import sys +import tempfile +import unittest + +import torch +from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner +from executorch.devtools.intermediate_output_tap import ( + FULL_TENSOR, + select_by_op_type, + STATS, + strip_taps_, + tap_intermediate_outputs, +) +from executorch.exir import to_edge_transform_and_lower +from executorch.runtime import Runtime, Verification +from torch.export import export + + +class _MLP(torch.nn.Module): + def __init__(self): + super().__init__() + self.l1 = torch.nn.Linear(8, 16) + self.l2 = torch.nn.Linear(16, 4) + + def forward(self, x): + return self.l2(self.l1(x).relu()) + + +@unittest.skipIf( + sys.platform.startswith("win"), "ExecuTorch runtime not available on Windows" +) +class XnnpackEndToEndTest(unittest.TestCase): + def _run_pipeline(self, reducer): + model = _MLP() + example_inputs = (torch.randn(2, 8),) + + ep = export(model, example_inputs, strict=True) + ep_t, specs = tap_intermediate_outputs( + ep, + selector=select_by_op_type("aten.linear.default"), + reducer=reducer, + ) + edge = to_edge_transform_and_lower(ep_t, partitioner=[XnnpackPartitioner()]) + strip_taps_(edge) + et_program = edge.to_executorch() + + with tempfile.TemporaryDirectory() as temp_dir: + pte_path = os.path.join(temp_dir, "model.pte") + et_program.save(pte_path) + + rt = Runtime.get() + program = rt.load_program(pte_path, verification=Verification.Minimal) + method = program.load_method("forward") + flat_outputs = method.execute(list(example_inputs)) + + return specs, flat_outputs, model, example_inputs + + def test_full_tensor_taps_match_eager(self): + specs, flat, _, _ = self._run_pipeline(FULL_TENSOR) + self.assertEqual(len(specs), 2) # two linears + # FULL_TENSOR preserves the source tensor's shape. + for spec in specs: + tap_value = flat[spec.output_index] + self.assertIsInstance(tap_value, torch.Tensor) + self.assertGreater(tap_value.numel(), 0) + + def test_stats_returns_thirteen_floats(self): + specs, flat, _, _ = self._run_pipeline(STATS) + self.assertEqual(len(specs), 2) + for spec in specs: + tap_value = flat[spec.output_index] + self.assertIsInstance(tap_value, torch.Tensor) + self.assertEqual(tap_value.numel(), len(STATS.fields)) + vals = tap_value.tolist() + field_idx = {f: i for i, f in enumerate(STATS.fields)} + mn = vals[field_idx["min"]] + mx = vals[field_idx["max"]] + abs_max = vals[field_idx["abs_max"]] + l2 = vals[field_idx["l2_norm"]] + l1 = vals[field_idx["l1_norm"]] + self.assertLessEqual(mn, mx) + self.assertGreaterEqual(abs_max, max(abs(mn), abs(mx)) - 1e-3) + self.assertGreaterEqual(l1, 0.0) + self.assertGreaterEqual(l2, 0.0) + # No NaN/Inf in random fp32 — should be exactly zero. + self.assertEqual(vals[field_idx["nan_count"]], 0.0) + self.assertEqual(vals[field_idx["inf_count"]], 0.0) + + def test_user_outputs_still_correct(self): + """Tap outputs must not corrupt the original user outputs.""" + specs, flat, model, example_inputs = self._run_pipeline(FULL_TENSOR) + + eager_out = model(*example_inputs) + # User output is at index 0 (one user output for our MLP). + user_out = flat[0] + torch.testing.assert_close(user_out, eager_out, atol=1e-3, rtol=1e-3) + # Verify tap indices are non-overlapping with user-output index 0. + for spec in specs: + self.assertGreaterEqual(spec.output_index, 1)