diff --git a/modelopt/onnx/quantization/autotune/autotuner_base.py b/modelopt/onnx/quantization/autotune/autotuner_base.py index 6df297e954..85ed8eb7a1 100644 --- a/modelopt/onnx/quantization/autotune/autotuner_base.py +++ b/modelopt/onnx/quantization/autotune/autotuner_base.py @@ -929,6 +929,113 @@ def _is_region_profiled(self, region: Region) -> bool: for p in self.profiled_patterns ) + def _sample_concat_group_mutation( + self, + selected_points: list, + all_points: list, + region: "Region", + ) -> list: + """Probabilistically add or remove a full Concat input group as an atomic unit. + + TRT requires ALL inputs of a Concat to be INT8 for INT8 Concat fusion; + partial quantization has no benefit. This method treats Concat input sets + as atomic groups — randomly choosing to add all inputs of one Concat, or + remove all inputs of one Concat, as a single mutation step. + + Called with adaptive probability (min 5%, scaling with budget) during scheme + generation to inject Concat-aware samples into the search space without forcing + all schemes to have full groups. + + Args: + selected_points: Currently selected NodeInputInsertionPoint list + all_points: Full set of available NodeInputInsertionPoint list + region: The region being profiled + + Returns: + Updated list with one Concat group atomically added or removed + """ + # Identify which local node indices are Concat ops + node_indices = region.get_nodes(sort=True) + concat_local_indices = set() + for local_idx, node_idx in enumerate(node_indices): + node = self.graph.nodes[node_idx] + if node.op == "Concat": + concat_local_indices.add(local_idx) + + if not concat_local_indices: + return selected_points + + # Build groups: concat_node_index -> all available insertion points for that Concat + concat_groups: dict[int, list] = {} + for p in all_points: + if p.node_index in concat_local_indices: + concat_groups.setdefault(p.node_index, []).append(p) + + if not concat_groups: + return selected_points + + # Determine the real arity of each Concat node (number of actual inputs). + # If some inputs were filtered out earlier (e.g. non-float, small tensors), + # the group in all_points is incomplete and can never satisfy TRT's "all + # Concat inputs quantized" requirement — skip such groups entirely. + complete_concat_groups: dict[int, list] = {} + for concat_idx, group_points in concat_groups.items(): + global_node_idx = node_indices[concat_idx] + concat_node = self.graph.nodes[global_node_idx] + real_arity = len(concat_node.inputs) + if len(group_points) == real_arity: + complete_concat_groups[concat_idx] = group_points + + if not complete_concat_groups: + return selected_points + + # Identify fully-present and absent Concat groups in current selection + selected_keys = {(p.node_index, p.input_index) for p in selected_points} + full_groups = [] # Concat groups fully present (can remove) + absent_groups = [] # Concat groups fully absent (can add) + + for concat_idx, group_points in complete_concat_groups.items(): + group_keys = {(p.node_index, p.input_index) for p in group_points} + present = group_keys & selected_keys + if len(present) == len(group_keys): + full_groups.append(concat_idx) + elif len(present) == 0 or len(present) < len(group_keys): + absent_groups.append(concat_idx) + + # Choose action only from feasible options to avoid no-op mutations + actions = [] + if absent_groups: + actions.append("add") + if full_groups: + actions.append("remove") + if not actions: + return selected_points + action = random.choice(actions) + + if action == "add": + target = random.choice(absent_groups) + points_to_add = [ + p for p in complete_concat_groups[target] + if (p.node_index, p.input_index) not in selected_keys + ] + logger.debug( + f"Concat group mutation: added {len(points_to_add)} points for Concat node {target}" + ) + # Rebuild in all_points order so scheme identity is independent of mutation history + result_keys = selected_keys | {(p.node_index, p.input_index) for p in points_to_add} + return [p for p in all_points if (p.node_index, p.input_index) in result_keys] + + elif action == "remove": + target = random.choice(full_groups) + group_keys = {(p.node_index, p.input_index) for p in complete_concat_groups[target]} + result = [p for p in selected_points if (p.node_index, p.input_index) not in group_keys] + logger.debug( + f"Concat group mutation: removed {len(group_keys)} points for Concat node {target}" + ) + return result + + return selected_points + def _mutate_insertion_points( self, base_points, all_points, point_type: str, max_mutations: int ) -> list: @@ -1057,6 +1164,17 @@ def _generate_next_insertion_sample(self) -> InsertionScheme: ), ) + # Probabilistically apply Concat-group-aware mutation: atomically add or remove + # all inputs of a Concat as a group. Probability adapts to budget: + # - Large budget (>=100): 5% → at least 5 samples + # - Small budget (<100): min_samples/budget, clamped to [0.05, 0.5] + num_schemes = max(len(pattern_schemes.schemes), 1) + concat_prob = min(max(self.config.concat_group_min_samples / num_schemes, 0.05), 0.5) + if random.random() < concat_prob: + scheme.node_inputs = self._sample_concat_group_mutation( + scheme.node_inputs, full_insertion_scheme.node_inputs, region + ) + return scheme def _copy_graph(self) -> gs.Graph: diff --git a/modelopt/onnx/quantization/autotune/common.py b/modelopt/onnx/quantization/autotune/common.py index d3b3de272f..d52d4e3514 100644 --- a/modelopt/onnx/quantization/autotune/common.py +++ b/modelopt/onnx/quantization/autotune/common.py @@ -842,6 +842,13 @@ class Config: minimum_schemes_to_mutate: int = 10 maximum_mutations: int = 3 maximum_generation_attempts: int = 100 + # Minimum number of Concat-group-aware mutations to attempt per region. Controls the + # probability of applying a Concat-group mutation to each generated scheme: + # prob = clamp(concat_group_min_samples / num_schemes, 0.05, 0.5) + # This ensures that Concat inputs are quantized/dequantized atomically as a group + # (all-or-nothing) rather than individually, which avoids mixed-precision mismatches + # at Concat boundaries. + concat_group_min_samples: int = 5 # Pattern Cache Settings pattern_cache_minimum_distance: int = 4 diff --git a/modelopt/onnx/quantization/autotune/insertion_points.py b/modelopt/onnx/quantization/autotune/insertion_points.py index 393071a65d..004acd390f 100644 --- a/modelopt/onnx/quantization/autotune/insertion_points.py +++ b/modelopt/onnx/quantization/autotune/insertion_points.py @@ -369,6 +369,41 @@ def skip_invalid_insertion_points( producer = node.inputs[0].inputs[0] if producer.op in ["Conv", "ConvTranspose"]: return True + # Conv -> [BN ->] Add -> Relu: skip quantizing the main-path Conv + # output feeding Add to preserve TRT Conv+Add+Relu INT8 fusion. + # Guards: + # 1. The Add output has a single consumer and that consumer is Relu + # (otherwise TRT cannot fuse, and skipping removes a legitimate + # quantization point). + # 2. The Conv feeding Add is a "main-path" Conv (its activation input + # has a single consumer), not a downsample/projection Conv (whose + # activation input fans out to multiple consumers). + if node.op == "Add": + # Guard 1: Add must feed exactly one Relu + add_out = node.outputs[0] if node.outputs else None + if add_out is None or len(add_out.outputs) != 1: + pass # Add fans out or has no consumer — skip not applicable + elif add_out.outputs[0].op != "Relu": + pass # Add does not feed Relu — fusion impossible + elif inp.inputs: + producer = inp.inputs[0] + # Unwrap optional BN + conv_node = None + if producer.op in ["Conv", "ConvTranspose"]: + conv_node = producer + elif producer.op == "BatchNormalization": + bn_act = producer.inputs[0] if producer.inputs else None + if ( + bn_act + and bn_act.inputs + and bn_act.inputs[0].op in ["Conv", "ConvTranspose"] + ): + conv_node = bn_act.inputs[0] + # Guard 2: main-path Conv (single consumer on activation input) + if conv_node is not None and conv_node.inputs: + conv_act_input = conv_node.inputs[0] + if len(conv_act_input.outputs) == 1: + return True # Filter 1: out boolean operations if node.op in ( get_bool_ops() @@ -472,6 +507,11 @@ def merge_resolved_insertion_points( to insert Q/DQ once at the tensor level rather than at each individual node input. This reduces the number of Q/DQ nodes in the graph and simplifies the quantization scheme. + Additionally, when a tensor has Q/DQ at some consumers and the remaining uncovered + consumers are all Concat nodes, the insertion is promoted to tensor-level. Concat is + a byte-level copy in TRT — quantizing its input has no accuracy cost and enables + INT8 Concat fusion when all Concat inputs are INT8. + Args: graph: The ONNX graph containing the nodes resolved_insertion_points: Set of resolved insertion points to optimize @@ -486,10 +526,29 @@ def merge_resolved_insertion_points( for tensor_name in {ip.tensor_name for ip in node_ips}: all_users = set(tensor_users_map.get(tensor_name, [])) qdq_users = {ip for ip in node_ips if ip.tensor_name == tensor_name} - if all_users == {ip.node_index for ip in qdq_users}: + covered_nodes = {ip.node_index for ip in qdq_users} + + if all_users == covered_nodes: + # All consumers have Q/DQ — merge to tensor-level results.add( ResolvedInsertionPoint(tensor_name=tensor_name, node_index=None, input_index=None) ) + elif covered_nodes and all_users - covered_nodes: + # Some consumers lack Q/DQ — check if all uncovered ones are Concat + uncovered = all_users - covered_nodes + uncovered_all_concat = all( + node_idx < len(graph.nodes) and graph.nodes[node_idx].op == "Concat" + for node_idx in uncovered + ) + if uncovered_all_concat: + # Promote to tensor-level: Concat is byte-copy, safe to quantize + results.add( + ResolvedInsertionPoint( + tensor_name=tensor_name, node_index=None, input_index=None + ) + ) + else: + results.update(qdq_users) else: results.update(qdq_users) return results @@ -497,7 +556,10 @@ def merge_resolved_insertion_points( def get_autotuner_skip_ops(): """Returns set of shape/structural operations that are not quantizable.""" - return set(get_copy_ops()) | { + # Concat is excluded: it can pass INT8 data through in TRT (byte-level copy). + # Blocking Concat prevents tensor-level Q/DQ when a quantizable op's output + # fans out to both a compute op (e.g. Conv) and a Concat, breaking INT8 fusion. + return (set(get_copy_ops()) - {"Concat"}) | { # Additional indexing/scatter/reshape ops "Compress", "Scatter",