diff --git a/gigl/distributed/dist_dataset.py b/gigl/distributed/dist_dataset.py index cd38c5653..b40f2969a 100644 --- a/gigl/distributed/dist_dataset.py +++ b/gigl/distributed/dist_dataset.py @@ -499,16 +499,11 @@ def _initialize_node_ids( ) else: train_nodes, val_nodes, test_nodes = splits - self._num_train = ( - train_nodes.numel() # ty: ignore[unresolved-attribute] - ) - self._num_val = val_nodes.numel() # ty: ignore[unresolved-attribute] - self._num_test = test_nodes.numel() # ty: ignore[unresolved-attribute] + self._num_train = train_nodes.numel() + self._num_val = val_nodes.numel() + self._num_test = test_nodes.numel() self._node_ids = _append_non_split_node_ids( - train_nodes, # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. - val_nodes, # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. - test_nodes, # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. - node_ids_on_machine, + train_nodes, val_nodes, test_nodes, node_ids_on_machine ) else: logger.info( @@ -642,8 +637,8 @@ def _initialize_node_features( # if it is not an edge type, since it must be one of the two. assert not isinstance(node_type, EdgeType) self._node_feature_info[node_type] = FeatureInfo( - dim=node_features_per_node_type.size(1), # ty: ignore[unresolved-attribute] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. - dtype=node_features_per_node_type.dtype, # ty: ignore[unresolved-attribute] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. + dim=node_features_per_node_type.size(1), + dtype=node_features_per_node_type.dtype, ) logger.info( f"Initialized node features for heterogeneous graph to dataset with node types: {node_features.keys()}" @@ -725,8 +720,8 @@ def _initialize_edge_features( for edge_type, edge_features_per_edge_type in edge_features.items(): assert isinstance(edge_type, EdgeType) self._edge_feature_info[edge_type] = FeatureInfo( - dim=edge_features_per_edge_type.size(1), # ty: ignore[unresolved-attribute] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. - dtype=edge_features_per_edge_type.dtype, # ty: ignore[unresolved-attribute] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. + dim=edge_features_per_edge_type.size(1), + dtype=edge_features_per_edge_type.dtype, ) logger.info( f"Initialized edge features for heterogeneous graph to dataset with edge types: {edge_features.keys()}" diff --git a/gigl/distributed/dist_ppr_sampler.py b/gigl/distributed/dist_ppr_sampler.py index 402e381c1..83369d8c2 100644 --- a/gigl/distributed/dist_ppr_sampler.py +++ b/gigl/distributed/dist_ppr_sampler.py @@ -537,7 +537,7 @@ async def _sample_from_nodes( seed_types = list(nodes_to_sample.keys()) ppr_results = await asyncio.gather( *[ - self._compute_ppr_scores(nodes_to_sample[seed_type], seed_type) # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. + self._compute_ppr_scores(nodes_to_sample[seed_type], seed_type) for seed_type in seed_types ] ) @@ -556,20 +556,20 @@ async def _sample_from_nodes( for ntype, flat_ids in ntype_to_flat_ids.items(): ppr_edge_type: EdgeType = (seed_type, "ppr", ntype) - valid_counts = ntype_to_valid_counts[ntype] # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. + valid_counts = ntype_to_valid_counts[ntype] ppr_edge_type_to_flat_weights[ppr_edge_type] = ( - ntype_to_flat_weights[ntype] # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. + ntype_to_flat_weights[ntype] ) # Skip empty pairs; induce_next handles deduplication across # seed types so a neighbor reachable from multiple seed types # gets one consistent local index in node_dict[ntype]. - if flat_ids.numel() > 0: # ty: ignore[unresolved-attribute] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. + if flat_ids.numel() > 0: nbr_dict[ppr_edge_type] = [ src_dict[seed_type], flat_ids, valid_counts, - ] # ty: ignore[invalid-assignment] TODO(ty-torch-container-shapes): fix ty false positives for torch container and return shapes. + ] # induce_next processes all PPR edge types in nbr_dict in one # pass, assigning local indices to neighbors not yet registered and diff --git a/gigl/nn/graph_transformer.py b/gigl/nn/graph_transformer.py index f9d8b345a..d826bedce 100644 --- a/gigl/nn/graph_transformer.py +++ b/gigl/nn/graph_transformer.py @@ -239,6 +239,18 @@ class GraphTransformerEncoderLayer(nn.Module): activation: Activation function for the feed-forward network. Supported values: "gelu" (default), "relu", "silu", "tanh", "geglu", "swiglu", "reglu". + relation_attention_mode: Optional relation-aware augmentation strategy + for attention scores. ``"none"`` preserves the default shared + self-attention path. ``"edge_type_bilinear"`` adds a learned + per-edge-type bilinear term for sampled directed graph edges. This + changes attention weights, not value/message content. + relation_value_mode: Optional relation-aware value augmentation strategy. + ``"sparse_residual_gate"`` adds a zero-initialized sparse residual + message path from relation-indexed source values to target queries. + This changes relation-specific message content without replacing + the main SDPA attention implementation. + num_relations: Number of relation channels expected in + ``pairwise_relation_indices`` when a relation-aware mode is enabled. Raises: ValueError: If model_dim is not divisible by num_heads. @@ -252,16 +264,41 @@ def __init__( dropout_rate: float = 0.1, attention_dropout_rate: float = 0.0, activation: str = "gelu", + relation_attention_mode: Literal["none", "edge_type_bilinear"] = "none", + relation_value_mode: Literal["none", "sparse_residual_gate"] = "none", + num_relations: int = 0, ) -> None: super().__init__() if model_dim % num_heads != 0: raise ValueError( f"model_dim ({model_dim}) must be divisible by num_heads ({num_heads})" ) + if relation_attention_mode not in {"none", "edge_type_bilinear"}: + raise ValueError( + "relation_attention_mode must be one of " + "{'none', 'edge_type_bilinear'}, " + f"got '{relation_attention_mode}'" + ) + if relation_value_mode not in {"none", "sparse_residual_gate"}: + raise ValueError( + "relation_value_mode must be one of " + "{'none', 'sparse_residual_gate'}, " + f"got '{relation_value_mode}'" + ) + if ( + relation_attention_mode == "edge_type_bilinear" + or relation_value_mode == "sparse_residual_gate" + ) and num_relations <= 0: + raise ValueError( + "Relation-aware attention/value modes require num_relations > 0." + ) self._num_heads = num_heads self._head_dim = model_dim // num_heads self._attention_dropout_rate = attention_dropout_rate + self._relation_attention_mode = relation_attention_mode + self._relation_value_mode = relation_value_mode + self._num_relations = num_relations self._attention_norm = nn.LayerNorm(model_dim) self._query_projection = nn.Linear(model_dim, model_dim) @@ -269,6 +306,23 @@ def __init__( self._value_projection = nn.Linear(model_dim, model_dim) self._output_projection = nn.Linear(model_dim, model_dim) self._dropout = nn.Dropout(dropout_rate) + self._relation_attention_matrices: Optional[nn.Parameter] = None + if relation_attention_mode == "edge_type_bilinear": + # Relation-specific bilinear logit term: + # score(target, source, relation) += q_target^T W_relation k_source + # Zero init keeps startup behavior identical to shared attention. + self._relation_attention_matrices = nn.Parameter( + torch.zeros(num_relations, num_heads, self._head_dim, self._head_dim) + ) + self._relation_value_gates: Optional[nn.Parameter] = None + if relation_value_mode == "sparse_residual_gate": + # Lightweight relation-specific value/message path: + # message(target) += gate_relation * value_source + # This is a sparse residual on relation edges because PyTorch SDPA + # only accepts one shared value tensor, not per-edge transformed values. + self._relation_value_gates = nn.Parameter( + torch.zeros(num_relations, num_heads, self._head_dim) + ) self._ffn_norm = nn.LayerNorm(model_dim) self._ffn = FeedForwardNetwork( @@ -287,6 +341,10 @@ def reset_parameters(self) -> None: nn.init.xavier_uniform_(projection.weight) if projection.bias is not None: nn.init.zeros_(projection.bias) + if self._relation_attention_matrices is not None: + nn.init.zeros_(self._relation_attention_matrices) + if self._relation_value_gates is not None: + nn.init.zeros_(self._relation_value_gates) self._ffn_norm.reset_parameters() self._ffn.reset_parameters() @@ -294,6 +352,7 @@ def forward( self, x: Tensor, attn_bias: Optional[Tensor] = None, + pairwise_relation_indices: Optional[Tensor] = None, valid_mask: Optional[Tensor] = None, ) -> Tensor: """Forward pass. @@ -303,6 +362,9 @@ def forward( attn_bias: Optional attention bias of shape ``(batch, num_heads, seq, seq)`` or broadcastable. Added as an additive mask to attention scores. + pairwise_relation_indices: Optional long tensor of shape + ``(num_relation_edges, 4)`` with sparse + ``(batch_idx, query_pos, key_pos, relation_idx)`` coordinates. valid_mask: Optional boolean tensor of shape ``(batch, seq)`` used to zero out padded token states after each residual block. @@ -330,15 +392,24 @@ def forward( batch_size, seq_len, self._num_heads, self._head_dim ).transpose(1, 2) - attention_output = F.scaled_dot_product_attention( - query, - key, - value, - attn_mask=attn_bias, - dropout_p=self._attention_dropout_rate if self.training else 0.0, - is_causal=False, + attention_output = self._run_attention( + query=query, + key=key, + value=value, + attn_bias=attn_bias, + pairwise_relation_indices=pairwise_relation_indices, ) + if self._relation_value_mode == "sparse_residual_gate": + # Relation bilinear attention decides how strongly to attend along + # relation edges. This residual separately lets relation type + # change the content passed from source values to target queries. + attention_output = self._add_relation_value_residual( + attention_output=attention_output, + value=value, + pairwise_relation_indices=pairwise_relation_indices, + ) + # Reshape back to (batch, seq, model_dim) attention_output = attention_output.transpose(1, 2).reshape( batch_size, seq_len, model_dim @@ -360,14 +431,291 @@ def forward( return x + def _run_attention( + self, + query: Tensor, + key: Tensor, + value: Tensor, + attn_bias: Optional[Tensor], + pairwise_relation_indices: Optional[Tensor], + ) -> Tensor: + if self._relation_attention_mode == "edge_type_bilinear": + return self._run_relation_aware_attention( + query=query, + key=key, + value=value, + attn_bias=attn_bias, + pairwise_relation_indices=pairwise_relation_indices, + ) + + # Keep the main path on PyTorch SDPA. Depending on device/dtype/mask, + # PyTorch may dispatch this to FlashAttention or another SDPA backend. + return F.scaled_dot_product_attention( + query, + key, + value, + attn_mask=attn_bias, + dropout_p=self._attention_dropout_rate if self.training else 0.0, + is_causal=False, + ) + + def _run_relation_aware_attention( + self, + query: Tensor, + key: Tensor, + value: Tensor, + attn_bias: Optional[Tensor], + pairwise_relation_indices: Optional[Tensor], + ) -> Tensor: + # The relation-aware logit path still uses SDPA. We only add an + # additive per-relation bias before attention; value vectors remain + # shared unless relation_value_mode adds a sparse residual afterward. + attn_bias = self._add_relation_attention_bias( + attn_bias=attn_bias, + query=query, + key=key, + pairwise_relation_indices=pairwise_relation_indices, + ) + + return F.scaled_dot_product_attention( + query, + key, + value, + attn_mask=attn_bias, + dropout_p=self._attention_dropout_rate if self.training else 0.0, + is_causal=False, + ) + + def _add_relation_value_residual( + self, + attention_output: Tensor, + value: Tensor, + pairwise_relation_indices: Optional[Tensor], + ) -> Tensor: + if pairwise_relation_indices is None: + raise ValueError( + "pairwise_relation_indices is required when " + "relation_value_mode='sparse_residual_gate'." + ) + if self._relation_value_gates is None: + raise ValueError("Relation value gates are not initialized.") + if pairwise_relation_indices.numel() == 0: + return attention_output + if ( + pairwise_relation_indices.dim() != 2 + or pairwise_relation_indices.size(-1) != 4 + ): + raise ValueError( + "pairwise_relation_indices must have shape (num_relation_edges, 4)." + ) + + pairwise_relation_indices = pairwise_relation_indices.to( + device=value.device, + dtype=torch.long, + ) + batch_indices = pairwise_relation_indices[:, 0] + query_indices = pairwise_relation_indices[:, 1] + key_indices = pairwise_relation_indices[:, 2] + relation_indices = pairwise_relation_indices[:, 3] + if ( + relation_indices.min().item() < 0 + or relation_indices.max().item() >= self._num_relations + ): + raise ValueError( + "pairwise_relation_indices contains relation ids outside " + f"[0, {self._num_relations})." + ) + + batch_size, _, seq_len, _ = value.shape + value_by_position = value.transpose(1, 2) + selected_values = value_by_position[batch_indices, key_indices] + selected_gates = self._relation_value_gates.to(dtype=value.dtype)[ + relation_indices + ] + messages = selected_values * selected_gates + + residual_by_position = value.new_zeros( + (batch_size, seq_len, self._num_heads, self._head_dim) + ) + residual_by_position.index_put_( + (batch_indices, query_indices), + messages, + accumulate=True, + ) + + # Multiple relation edges can target the same token. Average the sparse + # residual by target degree so this auxiliary path does not grow just + # because a sampled sequence has more relation edges. + counts = value.new_zeros((batch_size, seq_len)) + counts.index_put_( + (batch_indices, query_indices), + torch.ones( + pairwise_relation_indices.size(0), + dtype=value.dtype, + device=value.device, + ), + accumulate=True, + ) + residual_by_position = residual_by_position / counts.clamp_min(1).view( + batch_size, + seq_len, + 1, + 1, + ) + + return attention_output + residual_by_position.transpose(1, 2).to( + dtype=attention_output.dtype + ) + + def _build_relation_attention_bias( + self, + query: Tensor, + key: Tensor, + pairwise_relation_indices: Optional[Tensor], + ) -> Optional[Tensor]: + if pairwise_relation_indices is not None and pairwise_relation_indices.numel() == 0: + return None + + batch_size, _, seq_len, _ = query.shape + empty_bias = query.new_zeros( + (batch_size, self._num_heads, seq_len, seq_len) + ) + return self._add_relation_attention_bias( + attn_bias=empty_bias, + query=query, + key=key, + pairwise_relation_indices=pairwise_relation_indices, + ) + + def _add_relation_attention_bias( + self, + attn_bias: Optional[Tensor], + query: Tensor, + key: Tensor, + pairwise_relation_indices: Optional[Tensor], + ) -> Optional[Tensor]: + if pairwise_relation_indices is None: + raise ValueError( + "pairwise_relation_indices is required when " + "relation_attention_mode='edge_type_bilinear'." + ) + if self._relation_attention_matrices is None: + raise ValueError("Relation attention matrices are not initialized.") + if pairwise_relation_indices.numel() == 0: + return attn_bias + if ( + pairwise_relation_indices.dim() != 2 + or pairwise_relation_indices.size(-1) != 4 + ): + raise ValueError( + "pairwise_relation_indices must have shape (num_relation_edges, 4)." + ) + + pairwise_relation_indices = pairwise_relation_indices.to( + device=query.device, + dtype=torch.long, + ) + batch_indices = pairwise_relation_indices[:, 0] + query_indices = pairwise_relation_indices[:, 1] + key_indices = pairwise_relation_indices[:, 2] + relation_indices = pairwise_relation_indices[:, 3] + if ( + relation_indices.min().item() < 0 + or relation_indices.max().item() >= self._num_relations + ): + raise ValueError( + "pairwise_relation_indices contains relation ids outside " + f"[0, {self._num_relations})." + ) + + batch_size, _, seq_len, _ = query.shape + if attn_bias is None: + attn_bias = query.new_zeros( + (batch_size, self._num_heads, seq_len, seq_len) + ) + elif attn_bias.shape[1:] != (self._num_heads, seq_len, seq_len): + attn_bias = attn_bias.expand( + batch_size, + self._num_heads, + seq_len, + seq_len, + ).clone() + else: + # The same base attention bias is reused across encoder layers, so + # relation-aware logits must be added to a per-layer copy. + attn_bias = attn_bias.clone() + + attn_bias_by_position = attn_bias.permute(0, 2, 3, 1) + query_by_position = query.transpose(1, 2) + key_by_position = key.transpose(1, 2) + relation_matrices = self._relation_attention_matrices.to(dtype=query.dtype) + + # Relation indices are emitted grouped by relation in the transform path. + # Sort only if callers provide an unsorted tensor, avoiding repeated + # full boolean masks over all relation edges. + if relation_indices.numel() > 1 and not torch.all( + relation_indices[1:] >= relation_indices[:-1] + ).item(): + relation_sort_perm = torch.argsort(relation_indices) + relation_indices = relation_indices[relation_sort_perm] + batch_indices = batch_indices[relation_sort_perm] + query_indices = query_indices[relation_sort_perm] + key_indices = key_indices[relation_sort_perm] + + unique_relation_indices, relation_counts = torch.unique_consecutive( + relation_indices, + return_counts=True, + ) + relation_start = 0 + for relation_idx_tensor, relation_count_tensor in zip( + unique_relation_indices, + relation_counts, + ): + relation_idx = int(relation_idx_tensor.item()) + relation_end = relation_start + int(relation_count_tensor.item()) + relation_batch_indices = batch_indices[relation_start:relation_end] + relation_query_indices = query_indices[relation_start:relation_end] + relation_key_indices = key_indices[relation_start:relation_end] + + selected_query = query_by_position[ + relation_batch_indices, + relation_query_indices, + ] + # This term changes the attention score for relation edge + # source -> target, but the SDPA value content is still value_source. + transformed_query = torch.einsum( + "nhd,hde->nhe", + selected_query, + relation_matrices[relation_idx], + ) + selected_key = key_by_position[ + relation_batch_indices, + relation_key_indices, + ] + relation_scores = (transformed_query * selected_key).sum(dim=-1) + attn_bias_by_position.index_put_( + ( + relation_batch_indices, + relation_query_indices, + relation_key_indices, + ), + (relation_scores / math.sqrt(self._head_dim)).to( + dtype=attn_bias.dtype + ), + accumulate=True, + ) + relation_start = relation_end + + return attn_bias + class GraphTransformerEncoder(nn.Module): """Graph Transformer encoder for heterogeneous graphs. Converts heterogeneous graph data into fixed-length sequences via ``heterodata_to_graph_transformer_input``, processes through pre-norm - transformer encoder layers, and produces per-node embeddings via - attention-weighted neighbor readout (from RelGT's LocalModule). + transformer encoder layers, and produces per-node embeddings via a + configurable readout over the anchor token and its sequence context. Conforms to the same forward interface as ``HGT`` and ``SimpleHGN``, making it a drop-in encoder for ``LinkPredictionGNN``. @@ -376,9 +724,8 @@ class GraphTransformerEncoder(nn.Module): node_type_to_feat_dim_map: Dictionary mapping node types to their input feature dimensions. edge_type_to_feat_dim_map: Dictionary mapping edge types to their - feature dimensions. Accepted for interface conformance with - ``HGT``/``SimpleHGN``; edge features are not used by the - graph transformer. + feature dimensions. Used by optional relation-aware and sparse + edge-attribute attention-bias paths. hid_dim: Hidden dimension for transformer layers. All node types are projected to this dimension before processing. out_dim: Output embedding dimension. @@ -435,6 +782,10 @@ class GraphTransformerEncoder(nn.Module): feature_embedding_layer_dict: Optional ModuleDict mapping node types to feature embedding layers. If provided, these are applied to node features before node projection. (default: None) + readout_mode: Readout applied after the transformer encoder stack. + ``"anchor_neighbor_attention"`` preserves the current RelGT-style + anchor-plus-neighbor attention pooling. ``"anchor_only"`` returns + the normalized anchor token directly. pe_integration_mode: How to fuse positional encodings into the model input. ``"concat"`` preserves the current behavior by concatenating node-level PE to token features. ``"add"`` uses node-level additive @@ -450,6 +801,18 @@ class GraphTransformerEncoder(nn.Module): uses 4.0 for standard activations and 8/3 (~2.67) for XGLU variants, following the convention that XGLU's gating doubles the effective parameters, so a smaller ratio maintains similar parameter count. + relation_attention_mode: Optional relation-aware augmentation for + attention scores. ``"none"`` preserves the current transformer path. + ``"edge_type_bilinear"`` adds a learned per-edge-type bilinear score + term for sampled directed edges. + relation_value_mode: Optional relation-aware value augmentation. + ``"none"`` preserves the current transformer path. + ``"sparse_residual_gate"`` adds a zero-initialized sparse residual + value path on sampled directed relation edges. + edge_attr_attention_bias_mode: Optional edge-attribute logit-bias path. + ``"none"`` preserves the current behavior. ``"sparse_linear"`` adds + a zero-initialized per-edge-type linear projection from sampled + edge attributes to per-head attention logits. Notes: This encoder uses ``nn.LazyLinear`` for node-level PE fusion. If you wrap @@ -496,9 +859,15 @@ def __init__( anchor_based_input_embedding_dict: Optional[nn.ModuleDict] = None, pairwise_attention_bias_attr_names: Optional[list[str]] = None, feature_embedding_layer_dict: Optional[nn.ModuleDict] = None, + readout_mode: Literal["anchor_neighbor_attention", "anchor_only"] = ( + "anchor_neighbor_attention" + ), pe_integration_mode: Literal["concat", "add"] = "concat", activation: str = "gelu", feedforward_ratio: Optional[float] = None, + relation_attention_mode: Literal["none", "edge_type_bilinear"] = "none", + relation_value_mode: Literal["none", "sparse_residual_gate"] = "none", + edge_attr_attention_bias_mode: Literal["none", "sparse_linear"] = "none", **kwargs: object, ) -> None: super().__init__() @@ -540,6 +909,30 @@ def __init__( "sequence_construction_method='ppr' because khop sequences do not " "enforce a stable token order." ) + if readout_mode not in {"anchor_neighbor_attention", "anchor_only"}: + raise ValueError( + "readout_mode must be one of " + "{'anchor_neighbor_attention', 'anchor_only'}, " + f"got '{readout_mode}'" + ) + if relation_attention_mode not in {"none", "edge_type_bilinear"}: + raise ValueError( + "relation_attention_mode must be one of " + "{'none', 'edge_type_bilinear'}, " + f"got '{relation_attention_mode}'" + ) + if relation_value_mode not in {"none", "sparse_residual_gate"}: + raise ValueError( + "relation_value_mode must be one of " + "{'none', 'sparse_residual_gate'}, " + f"got '{relation_value_mode}'" + ) + if edge_attr_attention_bias_mode not in {"none", "sparse_linear"}: + raise ValueError( + "edge_attr_attention_bias_mode must be one of " + "{'none', 'sparse_linear'}, " + f"got '{edge_attr_attention_bias_mode}'" + ) anchor_bias_attr_names = anchor_based_attention_bias_attr_names or [] anchor_input_attr_names = anchor_based_input_attr_names or [] pairwise_bias_attr_names = pairwise_attention_bias_attr_names or [] @@ -569,8 +962,27 @@ def __init__( self._anchor_based_input_embedding_dict = anchor_based_input_embedding_dict self._pairwise_attention_bias_attr_names = pairwise_attention_bias_attr_names self._feature_embedding_layer_dict = feature_embedding_layer_dict + self._readout_mode = readout_mode self._pe_integration_mode = pe_integration_mode self._num_heads = num_heads + self._relation_attention_mode = relation_attention_mode + self._relation_value_mode = relation_value_mode + self._edge_attr_attention_bias_mode = edge_attr_attention_bias_mode + self._edge_type_to_feat_dim_map = { + edge_type: edge_type_to_feat_dim_map[edge_type] + for edge_type in sorted(edge_type_to_feat_dim_map.keys()) + } + self._relation_attention_edge_types = ( + list(self._edge_type_to_feat_dim_map.keys()) + if relation_attention_mode == "edge_type_bilinear" + or relation_value_mode == "sparse_residual_gate" + else [] + ) + self._edge_attr_attention_bias_edge_types = ( + list(self._edge_type_to_feat_dim_map.keys()) + if edge_attr_attention_bias_mode == "sparse_linear" + else [] + ) anchor_input_embedding_attr_names = ( set(anchor_based_input_embedding_dict.keys()) if anchor_based_input_embedding_dict is not None @@ -638,6 +1050,8 @@ def __init__( num_heads, bias=False, ) + # Start structural logit bias neutral; training can turn it on if useful. + nn.init.zeros_(self._anchor_pe_attention_bias_projection.weight) self._pairwise_pe_attention_bias_projection: Optional[nn.Linear] = None if self._pairwise_attention_bias_attr_names: @@ -646,6 +1060,26 @@ def __init__( num_heads, bias=False, ) + nn.init.zeros_(self._pairwise_pe_attention_bias_projection.weight) + self._pairwise_nonmissing_attention_bias = nn.Parameter( + torch.zeros(num_heads) + ) + else: + self.register_parameter("_pairwise_nonmissing_attention_bias", None) + + self._edge_attr_attention_bias_projection_dict = nn.ModuleDict() + if self._edge_attr_attention_bias_mode == "sparse_linear": + for relation_idx, edge_type in enumerate( + self._edge_attr_attention_bias_edge_types + ): + edge_attr_dim = int(self._edge_type_to_feat_dim_map[edge_type]) + if edge_attr_dim <= 0: + continue + projection = nn.Linear(edge_attr_dim, num_heads, bias=False) + nn.init.zeros_(projection.weight) + self._edge_attr_attention_bias_projection_dict[str(relation_idx)] = ( + projection + ) # Transformer encoder layers # Default feedforward ratio: 4.0 for standard activations, 8/3 for XGLU @@ -664,6 +1098,9 @@ def __init__( dropout_rate=dropout_rate, attention_dropout_rate=attention_dropout_rate, activation=activation, + relation_attention_mode=relation_attention_mode, + relation_value_mode=relation_value_mode, + num_relations=len(self._relation_attention_edge_types), ) for _ in range(num_layers) ] @@ -671,7 +1108,8 @@ def __init__( self._final_norm = nn.LayerNorm(hid_dim) - # Readout attention: projects concatenated (anchor, neighbor) to score + # Always instantiate the neighbor readout head so checkpoints can move + # between readout modes without changing parameter shapes. self._readout_attention = nn.Linear(2 * hid_dim, 1) # Output projection: hid_dim -> out_dim @@ -801,6 +1239,20 @@ def forward( anchor_based_attention_bias_attr_names=self._anchor_based_attention_bias_attr_names, anchor_based_input_attr_names=self._anchor_based_input_attr_names, pairwise_attention_bias_attr_names=self._pairwise_attention_bias_attr_names, + relation_edge_types=( + self._relation_attention_edge_types + if self._relation_attention_mode == "edge_type_bilinear" + or self._relation_value_mode == "sparse_residual_gate" + else None + ), + edge_attr_edge_type_to_feat_dim_map=( + { + edge_type: self._edge_type_to_feat_dim_map[edge_type] + for edge_type in self._edge_attr_attention_bias_edge_types + } + if self._edge_attr_attention_bias_mode == "sparse_linear" + else None + ), ) # Free memory after sequences are built @@ -837,6 +1289,9 @@ def forward( sequences=sequences, valid_mask=valid_mask, attn_bias=attn_bias, + pairwise_relation_indices=sequence_auxiliary_data.get( + "pairwise_relation_indices" + ), ) embeddings = self._output_projection(embeddings) @@ -966,6 +1421,9 @@ def _build_attention_bias( attention_bias_data: Dictionary containing optional PE tensors: - "anchor_bias": (batch, seq, num_anchor_attrs) or None - "pairwise_bias": (batch, seq, seq, num_pairwise_attrs) or None + - "pairwise_nonmissing_indices": (num_pairs, 3) or None + - "pairwise_edge_attr_indices": dict[int, (num_edges, 3)] or None + - "pairwise_edge_attr_values": dict[int, (num_edges, edge_dim)] or None Returns: Combined attention bias tensor of shape (batch_size, num_heads, seq_len, seq_len) @@ -985,14 +1443,14 @@ def _build_attention_bias( device = sequences.device negative_inf = torch.finfo(dtype).min - # Step 1: Initialize with padding mask bias - # Shape: (batch, 1, 1, seq) - broadcasts to mask invalid keys for all queries/heads + # Step 1: Initialize with padding mask bias. + # Shape: (batch, 1, 1, seq) broadcasts to mask invalid keys. attn_bias = torch.zeros( (batch_size, 1, 1, seq_len), dtype=dtype, device=device, ) - attn_bias = attn_bias.masked_fill( + attn_bias.masked_fill_( ~valid_mask.unsqueeze(1).unsqueeze(2), # (batch, 1, 1, seq) negative_inf, ) @@ -1029,6 +1487,120 @@ def _build_attention_bias( ) # (batch, num_heads, seq, seq) attn_bias = attn_bias + pairwise_bias + pairwise_nonmissing_indices = attention_bias_data.get( + "pairwise_nonmissing_indices" + ) + if pairwise_nonmissing_indices is not None: + if self._pairwise_nonmissing_attention_bias is None: + raise ValueError( + "Pairwise nonmissing attention bias is not initialized." + ) + if pairwise_nonmissing_indices.numel() > 0: + if attn_bias.shape[1:] != (self._num_heads, seq_len, seq_len): + attn_bias = attn_bias.expand( + batch_size, + self._num_heads, + seq_len, + seq_len, + ).clone() + pairwise_nonmissing_indices = pairwise_nonmissing_indices.to( + device=device, + dtype=torch.long, + ) + attn_bias_by_position = attn_bias.permute(0, 2, 3, 1) + nonmissing_bias = self._pairwise_nonmissing_attention_bias.to( + dtype=attn_bias.dtype + ).view(1, -1) + attn_bias_by_position.index_put_( + ( + pairwise_nonmissing_indices[:, 0], + pairwise_nonmissing_indices[:, 1], + pairwise_nonmissing_indices[:, 2], + ), + nonmissing_bias.expand(pairwise_nonmissing_indices.size(0), -1), + accumulate=True, + ) + + pairwise_edge_attr_indices = attention_bias_data.get( + "pairwise_edge_attr_indices" + ) + pairwise_edge_attr_values = attention_bias_data.get("pairwise_edge_attr_values") + if pairwise_edge_attr_indices is not None or pairwise_edge_attr_values is not None: + if self._edge_attr_attention_bias_mode != "sparse_linear": + raise ValueError( + "Sparse edge-attribute attention-bias payloads require " + "edge_attr_attention_bias_mode='sparse_linear'." + ) + if pairwise_edge_attr_indices is None or pairwise_edge_attr_values is None: + raise ValueError( + "pairwise_edge_attr_indices and pairwise_edge_attr_values " + "must be provided together." + ) + if set(pairwise_edge_attr_indices.keys()) != set( + pairwise_edge_attr_values.keys() + ): + raise ValueError( + "pairwise_edge_attr_indices and pairwise_edge_attr_values " + "must have identical relation-index keys." + ) + + attn_bias_by_position: Optional[Tensor] = None + for relation_idx in sorted(pairwise_edge_attr_indices.keys()): + edge_attr_indices = pairwise_edge_attr_indices[relation_idx] + edge_attr_values = pairwise_edge_attr_values[relation_idx] + if edge_attr_indices.numel() == 0: + continue + if edge_attr_indices.dim() != 2 or edge_attr_indices.size(-1) != 3: + raise ValueError( + "pairwise_edge_attr_indices entries must have shape " + "(num_edges, 3)." + ) + if ( + edge_attr_values.dim() != 2 + or edge_attr_values.size(0) != edge_attr_indices.size(0) + ): + raise ValueError( + "pairwise_edge_attr_values entries must have shape " + "(num_edges, edge_attr_dim) with the same num_edges " + "as their index tensor." + ) + + projection_key = str(relation_idx) + if projection_key not in self._edge_attr_attention_bias_projection_dict: + raise ValueError( + "No edge-attribute attention-bias projection is " + f"initialized for relation index {relation_idx}." + ) + if attn_bias_by_position is None: + if attn_bias.shape[1:] != (self._num_heads, seq_len, seq_len): + attn_bias = attn_bias.expand( + batch_size, + self._num_heads, + seq_len, + seq_len, + ).clone() + attn_bias_by_position = attn_bias.permute(0, 2, 3, 1) + + edge_attr_projection = ( + self._edge_attr_attention_bias_projection_dict[projection_key] + ) + edge_attr_indices = edge_attr_indices.to( + device=device, + dtype=torch.long, + ) + edge_attr_bias = edge_attr_projection( + edge_attr_values.to(device=device, dtype=dtype) + ) + attn_bias_by_position.index_put_( + ( + edge_attr_indices[:, 0], + edge_attr_indices[:, 1], + edge_attr_indices[:, 2], + ), + edge_attr_bias.to(dtype=attn_bias.dtype), + accumulate=True, + ) + return attn_bias def _encode_and_readout( @@ -1036,14 +1608,17 @@ def _encode_and_readout( sequences: Tensor, valid_mask: Tensor, attn_bias: Optional[Tensor] = None, + pairwise_relation_indices: Optional[Tensor] = None, ) -> Tensor: - """Process sequences through transformer layers and attention readout. + """Process sequences through transformer layers and configured readout. Args: sequences: Input tensor of shape ``(batch_size, max_seq_len, hid_dim)``. valid_mask: Boolean mask of shape ``(batch_size, max_seq_len)``. attn_bias: Optional additive attention bias broadcastable to ``(batch_size, num_heads, seq, seq)``. + pairwise_relation_indices: Optional sparse relation coordinates shaped + ``(num_relation_edges, 4)``. Returns: Output embeddings of shape ``(batch_size, hid_dim)``. @@ -1051,12 +1626,21 @@ def _encode_and_readout( x = sequences * valid_mask.unsqueeze(-1).to(sequences.dtype) for encoder_layer in self._encoder_layers: - x = encoder_layer(x, attn_bias=attn_bias, valid_mask=valid_mask) + x = encoder_layer( + x, + attn_bias=attn_bias, + pairwise_relation_indices=pairwise_relation_indices, + valid_mask=valid_mask, + ) x = self._final_norm(x) x = x * valid_mask.unsqueeze(-1).to(x.dtype) - # Readout: anchor (position 0) + attention-weighted neighbor aggregation + if self._readout_mode == "anchor_only": + return x[:, 0, :] + + # RelGT-style readout: anchor (position 0) + attention-weighted + # neighbor aggregation. anchor = x[:, 0, :].unsqueeze(1) # (batch, 1, hid_dim) neighbors = x[:, 1:, :] # (batch, seq-1, hid_dim) neighbor_valid_mask = valid_mask[:, 1:] diff --git a/gigl/transforms/graph_transformer.py b/gigl/transforms/graph_transformer.py index 602f95bde..06fadf016 100644 --- a/gigl/transforms/graph_transformer.py +++ b/gigl/transforms/graph_transformer.py @@ -57,7 +57,7 @@ >>> # attention_bias_data['anchor_bias']: (batch_size, max_seq_len, 1) """ -from typing import Literal, Optional, TypedDict +from typing import Literal, NamedTuple, Optional, TypedDict import torch from torch import Tensor @@ -65,18 +65,34 @@ from torch_geometric.typing import NodeType from torch_geometric.utils import to_torch_sparse_tensor +from gigl.src.common.types.graph_data import EdgeType as GiGLEdgeType + TokenInputData = dict[str, Tensor] class SequenceAuxiliaryData(TypedDict): anchor_bias: Optional[Tensor] pairwise_bias: Optional[Tensor] + pairwise_nonmissing_indices: Optional[Tensor] + pairwise_relation_indices: Optional[Tensor] + pairwise_edge_attr_indices: Optional[dict[int, Tensor]] + pairwise_edge_attr_values: Optional[dict[int, Tensor]] token_input: Optional[TokenInputData] PPR_WEIGHT_FEATURE_NAME = "ppr_weight" +class _TokenOccurrenceIndex(NamedTuple): + batch_indices: Tensor + positions: Tensor + node_indices: Tensor + sorted_node_indices: Tensor + node_sort_perm: Tensor + sorted_batch_node_keys: Tensor + batch_node_sort_perm: Tensor + + def heterodata_to_graph_transformer_input( data: HeteroData, batch_size: int, @@ -90,6 +106,8 @@ def heterodata_to_graph_transformer_input( anchor_based_attention_bias_attr_names: Optional[list[str]] = None, anchor_based_input_attr_names: Optional[list[str]] = None, pairwise_attention_bias_attr_names: Optional[list[str]] = None, + relation_edge_types: Optional[list[GiGLEdgeType]] = None, + edge_attr_edge_type_to_feat_dim_map: Optional[dict[GiGLEdgeType, int]] = None, ) -> tuple[Tensor, Tensor, SequenceAuxiliaryData]: """ Transform a HeteroData object to Graph Transformer sequence input. @@ -131,6 +149,16 @@ def heterodata_to_graph_transformer_input( pairwise_attention_bias_attr_names: List of pairwise feature names used as attention bias. These must correspond to sparse graph-level attributes on ``data``. Example: ['pairwise_distance']. + relation_edge_types: Optional ordered edge types used to materialize sparse + relation coordinates. Each output relation index corresponds to one + edge type in this list. Directed edges are stored as + ``(batch_idx, query_pos=dst_token, key_pos=src_token, relation_idx)``. + edge_attr_edge_type_to_feat_dim_map: Optional ordered-by-sorted-key edge + feature dimensions used to materialize sparse edge-attribute + attention-bias payloads. Only edge types with positive feature dim + contribute. Directed edges are stored as + ``(batch_idx, query_pos=dst_token, key_pos=src_token)`` under the + same relation index as the sorted edge-type order. Returns: (sequences, valid_mask, attention_bias_data), where: @@ -143,6 +171,17 @@ def heterodata_to_graph_transformer_input( ``"anchor_bias"`` shaped ``(batch, seq, num_anchor_attrs)`` or None ``"pairwise_bias"`` shaped ``(batch, seq, seq, num_pairwise_attrs)`` or None + ``"pairwise_nonmissing_indices"`` shaped ``(num_pairs, 3)`` or None, + storing ``(batch_idx, row_pos, col_pos)`` coordinates for + nonmissing pairwise entries + ``"pairwise_relation_indices"`` shaped + ``(num_relation_edges, 4)`` or None, storing + ``(batch_idx, query_pos, key_pos, relation_idx)`` coordinates + ``"pairwise_edge_attr_indices"`` as a dict mapping relation index + to ``(num_edges, 3)`` sparse ``(batch_idx, query_pos, key_pos)`` + coordinates, or None + ``"pairwise_edge_attr_values"`` as a dict mapping relation index + to ``(num_edges, edge_attr_dim)`` edge-attribute values, or None ``"token_input"`` as a dict mapping attribute name to a ``(batch, seq, 1)`` tensor, or None @@ -204,8 +243,12 @@ def heterodata_to_graph_transformer_input( device = data[anchor_node_type].x.device - # Convert to homogeneous for easier neighborhood extraction - homo_data = data.to_homogeneous() + # Convert to homogeneous for easier neighborhood extraction. In khop mode + # edge attributes stay on the original hetero edge stores because different + # relations may have different feature dimensions. + homo_data = data.to_homogeneous( + edge_attrs=[] if sequence_construction_method == "khop" else None + ) homo_x = homo_data.x # (total_nodes, feature_dim) num_nodes = homo_data.num_nodes @@ -306,11 +349,50 @@ def heterodata_to_graph_transformer_input( device=device, ) - pairwise_feature_sequences = _lookup_pairwise_relative_features( + needs_token_occurrence_index = bool(relation_edge_types) or bool( + edge_attr_edge_type_to_feat_dim_map + ) + token_occurrences = ( + _build_token_occurrence_index( + node_index_sequences=node_index_sequences, + valid_mask=valid_mask, + num_nodes=num_nodes, + device=device, + ) + if needs_token_occurrence_index + else None + ) + + pairwise_feature_sequences, pairwise_nonmissing_indices = ( + _lookup_pairwise_relative_features( + node_index_sequences=node_index_sequences, + valid_mask=valid_mask, + csr_matrices=pairwise_pe_matrices if pairwise_pe_matrices else None, + attr_names=pairwise_bias_attr_names, + device=device, + ) + ) + pairwise_relation_indices = _lookup_pairwise_relation_indices( + data=data, node_index_sequences=node_index_sequences, valid_mask=valid_mask, - csr_matrices=pairwise_pe_matrices if pairwise_pe_matrices else None, + relation_edge_types=relation_edge_types, + node_type_offsets=node_type_offsets, + num_nodes=num_nodes, device=device, + token_occurrences=token_occurrences, + ) + pairwise_edge_attr_indices, pairwise_edge_attr_values = ( + _lookup_pairwise_edge_attr_payloads( + data=data, + node_index_sequences=node_index_sequences, + valid_mask=valid_mask, + edge_attr_edge_type_to_feat_dim_map=edge_attr_edge_type_to_feat_dim_map, + node_type_offsets=node_type_offsets, + num_nodes=num_nodes, + device=device, + token_occurrences=token_occurrences, + ) ) anchor_bias_features = _compose_anchor_feature_tensor( @@ -332,6 +414,10 @@ def heterodata_to_graph_transformer_input( { "anchor_bias": anchor_bias_features, "pairwise_bias": pairwise_feature_sequences, + "pairwise_nonmissing_indices": pairwise_nonmissing_indices, + "pairwise_relation_indices": pairwise_relation_indices, + "pairwise_edge_attr_indices": pairwise_edge_attr_indices, + "pairwise_edge_attr_values": pairwise_edge_attr_values, "token_input": token_input_features, }, ) @@ -798,8 +884,9 @@ def _lookup_pairwise_relative_features( node_index_sequences: Tensor, valid_mask: Tensor, csr_matrices: Optional[list[Tensor]], + attr_names: Optional[list[str]], device: torch.device, -) -> Optional[Tensor]: +) -> tuple[Optional[Tensor], Optional[Tensor]]: """ Look up pairwise sparse values for each valid token pair in the sequence. @@ -815,13 +902,20 @@ def _lookup_pairwise_relative_features( node_index_sequences: (batch_size, max_seq_len) node indices for each sequence position valid_mask: (batch_size, max_seq_len) bool tensor indicating valid positions csr_matrices: List of sparse CSR matrices, each (num_nodes, num_nodes) + attr_names: Optional names for the pairwise attributes. Used only to + produce clearer error messages when multiple attrs disagree on + sparse support. device: Device for output tensor Returns: features: (batch_size, max_seq_len, max_seq_len, num_attrs) tensor where features[b, i, j, k] = csr_matrices[k][node_index_sequences[b, i], node_index_sequences[b, j]] for valid (i, j) pairs, 0.0 for padding positions. - Returns None if csr_matrices is empty. + nonmissing_indices: (num_nonmissing_pairs, 3) long tensor containing + ``(batch_idx, row_pos, col_pos)`` coordinates for valid diagonal + self pairs and valid sparse entries. Missing non-self pairs and + padding are omitted. + Returns (None, None) if csr_matrices is empty. Example: # batch_size=2, max_seq_len=3, num_attrs=1 (e.g., random_walk_se) @@ -844,7 +938,7 @@ def _lookup_pairwise_relative_features( # (pad) [0.0, 0.0, 0.0] """ if not csr_matrices: - return None + return None, None batch_size, max_seq_len = node_index_sequences.shape num_attrs = len(csr_matrices) @@ -853,26 +947,440 @@ def _lookup_pairwise_relative_features( dtype=torch.float, device=device, ) + ( + valid_batch_indices, + valid_row_positions, + valid_col_positions, + valid_row_indices, + valid_col_indices, + ) = _build_flat_valid_pair_layout( + node_index_sequences=node_index_sequences, + valid_mask=valid_mask, + device=device, + ) + if valid_batch_indices.numel() == 0: + return features, torch.zeros((0, 3), dtype=torch.long, device=device) - pair_valid_mask = valid_mask.unsqueeze(2) & valid_mask.unsqueeze(1) - if not pair_valid_mask.any(): - return features - - row_indices = node_index_sequences.unsqueeze(2).expand(-1, -1, max_seq_len) - col_indices = node_index_sequences.unsqueeze(1).expand(-1, max_seq_len, -1) - - valid_row_indices = row_indices[pair_valid_mask] - valid_col_indices = col_indices[pair_valid_mask] + self_pair_mask = valid_row_positions == valid_col_positions + first_attr_name = attr_names[0] if attr_names else "attr_0" + nonmissing_support: Optional[Tensor] = None for attr_idx, pe_matrix in enumerate(csr_matrices): - pe_values = _lookup_csr_values( + pe_values, found_mask = _lookup_csr_values_and_found( csr_matrix=pe_matrix, row_indices=valid_row_indices, col_indices=valid_col_indices, ) - features[..., attr_idx][pair_valid_mask] = pe_values + features[ + valid_batch_indices, + valid_row_positions, + valid_col_positions, + attr_idx, + ] = pe_values + attr_nonmissing_support = found_mask | self_pair_mask + if attr_idx == 0: + nonmissing_support = attr_nonmissing_support + continue + if nonmissing_support is None or not torch.equal( + nonmissing_support, + attr_nonmissing_support, + ): + attr_name = attr_names[attr_idx] if attr_names else f"attr_{attr_idx}" + raise ValueError( + "Pairwise attention bias attributes must share identical " + "nonmissing support after treating valid diagonal self pairs " + f"as nonmissing, but '{first_attr_name}' and '{attr_name}' " + "differ." + ) - return features + assert nonmissing_support is not None + pairwise_nonmissing_indices = torch.stack( + [ + valid_batch_indices[nonmissing_support], + valid_row_positions[nonmissing_support], + valid_col_positions[nonmissing_support], + ], + dim=1, + ) + return features, pairwise_nonmissing_indices + + +def _lookup_pairwise_relation_indices( + data: HeteroData, + node_index_sequences: Tensor, + valid_mask: Tensor, + relation_edge_types: Optional[list[GiGLEdgeType]], + node_type_offsets: dict[NodeType, int], + num_nodes: int, + device: torch.device, + token_occurrences: Optional[_TokenOccurrenceIndex] = None, +) -> Optional[Tensor]: + """Build sparse relation coordinates for valid token pairs. + + For a directed edge ``source -> target``, attention uses + ``query=target`` and ``key=source`` so relation-aware attention follows + message-passing orientation. + """ + if not relation_edge_types: + return None + + if token_occurrences is None: + token_occurrences = _build_token_occurrence_index( + node_index_sequences=node_index_sequences, + valid_mask=valid_mask, + num_nodes=num_nodes, + device=device, + ) + if token_occurrences.batch_indices.numel() == 0: + return torch.zeros((0, 4), dtype=torch.long, device=device) + + relation_index_parts: list[Tensor] = [] + for relation_idx, edge_type in enumerate(relation_edge_types): + edge_type_tuple = edge_type.tuple_repr() + if edge_type_tuple not in data.edge_types: + continue + + edge_index = data[edge_type_tuple].edge_index.to( + device=device, dtype=torch.long + ) + if edge_index.numel() == 0: + continue + + src_offset = int(node_type_offsets[edge_type.src_node_type]) + dst_offset = int(node_type_offsets[edge_type.dst_node_type]) + source_indices = edge_index[0] + src_offset + target_indices = edge_index[1] + dst_offset + ( + relation_batch_indices, + relation_query_positions, + relation_key_positions, + _, + ) = _match_directed_edges_to_token_pairs( + source_indices=source_indices, + target_indices=target_indices, + token_occurrences=token_occurrences, + num_nodes=num_nodes, + device=device, + ) + if relation_batch_indices.numel() == 0: + continue + + relation_indices = torch.stack( + [ + relation_batch_indices, + relation_query_positions, + relation_key_positions, + torch.full( + (relation_batch_indices.size(0),), + relation_idx, + dtype=torch.long, + device=device, + ), + ], + dim=1, + ) + relation_index_parts.append(torch.unique(relation_indices, dim=0)) + + if not relation_index_parts: + return torch.zeros((0, 4), dtype=torch.long, device=device) + return torch.cat(relation_index_parts, dim=0) + + +def _lookup_pairwise_edge_attr_payloads( + data: HeteroData, + node_index_sequences: Tensor, + valid_mask: Tensor, + edge_attr_edge_type_to_feat_dim_map: Optional[dict[GiGLEdgeType, int]], + node_type_offsets: dict[NodeType, int], + num_nodes: int, + device: torch.device, + token_occurrences: Optional[_TokenOccurrenceIndex] = None, +) -> tuple[Optional[dict[int, Tensor]], Optional[dict[int, Tensor]]]: + """Build sparse edge-attribute payloads for valid token pairs. + + For a directed edge ``source -> target``, attention uses + ``query=target`` and ``key=source`` so edge-attribute bias follows the same + message-passing orientation as GAT. + """ + if not edge_attr_edge_type_to_feat_dim_map: + return None, None + + edge_attr_indices_by_relation: dict[int, Tensor] = {} + edge_attr_values_by_relation: dict[int, Tensor] = {} + if token_occurrences is None: + token_occurrences = _build_token_occurrence_index( + node_index_sequences=node_index_sequences, + valid_mask=valid_mask, + num_nodes=num_nodes, + device=device, + ) + if token_occurrences.batch_indices.numel() == 0: + return edge_attr_indices_by_relation, edge_attr_values_by_relation + + for relation_idx, edge_type in enumerate( + sorted(edge_attr_edge_type_to_feat_dim_map.keys()) + ): + edge_attr_dim = int(edge_attr_edge_type_to_feat_dim_map[edge_type]) + if edge_attr_dim <= 0: + continue + + edge_type_tuple = edge_type.tuple_repr() + if edge_type_tuple not in data.edge_types: + continue + + edge_store = data[edge_type_tuple] + edge_index = edge_store.edge_index.to(device=device, dtype=torch.long) + if edge_index.numel() == 0: + continue + + if not hasattr(edge_store, "edge_attr") or edge_store.edge_attr is None: + raise ValueError( + "edge_attr_attention_bias_mode='sparse_linear' requires " + f"edge_attr for edge type {edge_type_tuple} because its " + f"configured feature dim is {edge_attr_dim}." + ) + edge_attr = edge_store.edge_attr.to(device=device) + if edge_attr.dim() == 1: + edge_attr = edge_attr.view(-1, 1) + if edge_attr.dim() != 2: + raise ValueError( + f"edge_attr for edge type {edge_type_tuple} must be 1D or 2D, " + f"got shape {tuple(edge_attr.shape)}." + ) + if edge_attr.size(0) != edge_index.size(1): + raise ValueError( + f"edge_attr for edge type {edge_type_tuple} has " + f"{edge_attr.size(0)} rows but edge_index has " + f"{edge_index.size(1)} edges." + ) + if edge_attr.size(1) != edge_attr_dim: + raise ValueError( + f"edge_attr for edge type {edge_type_tuple} has dim " + f"{edge_attr.size(1)} but configured dim is {edge_attr_dim}." + ) + + src_offset = int(node_type_offsets[edge_type.src_node_type]) + dst_offset = int(node_type_offsets[edge_type.dst_node_type]) + source_indices = edge_index[0] + src_offset + target_indices = edge_index[1] + dst_offset + ( + edge_batch_indices, + edge_query_positions, + edge_key_positions, + matched_edge_indices, + ) = _match_directed_edges_to_token_pairs( + source_indices=source_indices, + target_indices=target_indices, + token_occurrences=token_occurrences, + num_nodes=num_nodes, + device=device, + ) + if edge_batch_indices.numel() == 0: + continue + + edge_attr_indices_by_relation[relation_idx] = torch.stack( + [ + edge_batch_indices, + edge_query_positions, + edge_key_positions, + ], + dim=1, + ) + edge_attr_values_by_relation[relation_idx] = edge_attr[matched_edge_indices] + + return edge_attr_indices_by_relation, edge_attr_values_by_relation + + +def _build_token_occurrence_index( + node_index_sequences: Tensor, + valid_mask: Tensor, + num_nodes: int, + device: torch.device, +) -> _TokenOccurrenceIndex: + """Index valid sequence tokens for sparse directed-edge to token matching.""" + token_batch_indices, token_positions = torch.nonzero( + valid_mask, + as_tuple=True, + ) + token_batch_indices = token_batch_indices.to(device=device, dtype=torch.long) + token_positions = token_positions.to(device=device, dtype=torch.long) + token_node_indices = node_index_sequences[token_batch_indices, token_positions].to( + device=device, + dtype=torch.long, + ) + + sorted_token_node_indices, node_sort_perm = torch.sort(token_node_indices) + token_batch_node_keys = token_batch_indices * num_nodes + token_node_indices + sorted_token_batch_node_keys, batch_node_sort_perm = torch.sort( + token_batch_node_keys + ) + + return _TokenOccurrenceIndex( + batch_indices=token_batch_indices, + positions=token_positions, + node_indices=token_node_indices, + sorted_node_indices=sorted_token_node_indices, + node_sort_perm=node_sort_perm, + sorted_batch_node_keys=sorted_token_batch_node_keys, + batch_node_sort_perm=batch_node_sort_perm, + ) + + +def _match_directed_edges_to_token_pairs( + source_indices: Tensor, + target_indices: Tensor, + token_occurrences: _TokenOccurrenceIndex, + num_nodes: int, + device: torch.device, +) -> tuple[Tensor, Tensor, Tensor, Tensor]: + """Map ``source -> target`` graph edges onto valid sequence coordinates. + + The returned coordinates follow attention orientation: + ``query_pos=target_token`` and ``key_pos=source_token``. The final tensor + contains source edge-row ids, repeated when an edge is present in multiple + anchor sequences. + """ + empty = torch.zeros((0,), dtype=torch.long, device=device) + if source_indices.numel() == 0 or token_occurrences.batch_indices.numel() == 0: + return empty, empty, empty, empty + + source_indices = source_indices.to(device=device, dtype=torch.long) + target_indices = target_indices.to(device=device, dtype=torch.long) + + target_lower_bounds = torch.searchsorted( + token_occurrences.sorted_node_indices, + target_indices, + right=False, + ) + target_upper_bounds = torch.searchsorted( + token_occurrences.sorted_node_indices, + target_indices, + right=True, + ) + target_match_counts = target_upper_bounds - target_lower_bounds + matched_edge_mask = target_match_counts > 0 + if not matched_edge_mask.any(): + return empty, empty, empty, empty + + matched_edge_indices = torch.nonzero(matched_edge_mask, as_tuple=True)[0] + matched_target_counts = target_match_counts[matched_edge_indices] + total_target_matches = int(matched_target_counts.sum().item()) + repeated_target_edge_indices = torch.repeat_interleave( + matched_edge_indices, + matched_target_counts, + ) + repeated_target_lower_bounds = torch.repeat_interleave( + target_lower_bounds[matched_edge_indices], + matched_target_counts, + ) + target_group_start_offsets = torch.repeat_interleave( + torch.cumsum(matched_target_counts, dim=0) - matched_target_counts, + matched_target_counts, + ) + target_sorted_positions = ( + repeated_target_lower_bounds + + torch.arange(total_target_matches, device=device, dtype=torch.long) + - target_group_start_offsets + ) + target_token_indices = token_occurrences.node_sort_perm[target_sorted_positions] + target_batch_indices = token_occurrences.batch_indices[target_token_indices] + target_query_positions = token_occurrences.positions[target_token_indices] + + source_query_keys = ( + target_batch_indices * num_nodes + source_indices[repeated_target_edge_indices] + ) + source_lower_bounds = torch.searchsorted( + token_occurrences.sorted_batch_node_keys, + source_query_keys, + right=False, + ) + source_upper_bounds = torch.searchsorted( + token_occurrences.sorted_batch_node_keys, + source_query_keys, + right=True, + ) + source_match_counts = source_upper_bounds - source_lower_bounds + matched_target_mask = source_match_counts > 0 + if not matched_target_mask.any(): + return empty, empty, empty, empty + + matched_target_indices = torch.nonzero(matched_target_mask, as_tuple=True)[0] + matched_source_counts = source_match_counts[matched_target_indices] + total_source_matches = int(matched_source_counts.sum().item()) + repeated_target_indices = torch.repeat_interleave( + matched_target_indices, + matched_source_counts, + ) + repeated_source_lower_bounds = torch.repeat_interleave( + source_lower_bounds[matched_target_indices], + matched_source_counts, + ) + source_group_start_offsets = torch.repeat_interleave( + torch.cumsum(matched_source_counts, dim=0) - matched_source_counts, + matched_source_counts, + ) + source_sorted_positions = ( + repeated_source_lower_bounds + + torch.arange(total_source_matches, device=device, dtype=torch.long) + - source_group_start_offsets + ) + source_token_indices = token_occurrences.batch_node_sort_perm[ + source_sorted_positions + ] + + return ( + target_batch_indices[repeated_target_indices], + target_query_positions[repeated_target_indices], + token_occurrences.positions[source_token_indices], + repeated_target_edge_indices[repeated_target_indices], + ) + + +def _build_flat_valid_pair_layout( + node_index_sequences: Tensor, + valid_mask: Tensor, + device: torch.device, +) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: + """Enumerate valid sequence pairs without building dense pairwise masks.""" + batch_indices_parts: list[Tensor] = [] + row_positions_parts: list[Tensor] = [] + col_positions_parts: list[Tensor] = [] + row_node_indices_parts: list[Tensor] = [] + col_node_indices_parts: list[Tensor] = [] + + for batch_idx in range(valid_mask.size(0)): + valid_positions = torch.nonzero(valid_mask[batch_idx], as_tuple=True)[0] + num_valid = valid_positions.numel() + if num_valid == 0: + continue + + valid_node_indices = node_index_sequences[batch_idx, valid_positions] + pair_count = num_valid * num_valid + + batch_indices_parts.append( + torch.full( + (pair_count,), + batch_idx, + dtype=torch.long, + device=device, + ) + ) + row_positions_parts.append(valid_positions.repeat_interleave(num_valid)) + col_positions_parts.append(valid_positions.repeat(num_valid)) + row_node_indices_parts.append(valid_node_indices.repeat_interleave(num_valid)) + col_node_indices_parts.append(valid_node_indices.repeat(num_valid)) + + if not batch_indices_parts: + empty = torch.zeros((0,), dtype=torch.long, device=device) + return empty, empty, empty, empty, empty + + return ( + torch.cat(batch_indices_parts, dim=0), + torch.cat(row_positions_parts, dim=0), + torch.cat(col_positions_parts, dim=0), + torch.cat(row_node_indices_parts, dim=0), + torch.cat(col_node_indices_parts, dim=0), + ) def _get_k_hop_neighbors_sparse( @@ -964,47 +1472,69 @@ def _lookup_csr_values( Returns: (n,) values from csr_matrix[row, col], or default_value if not present """ + values, _ = _lookup_csr_values_and_found( + csr_matrix=csr_matrix, + row_indices=row_indices, + col_indices=col_indices, + default_value=default_value, + ) + return values + + +def _lookup_csr_values_and_found( + csr_matrix: Tensor, + row_indices: Tensor, + col_indices: Tensor, + default_value: float = 0.0, +) -> tuple[Tensor, Tensor]: + """ + Look up values in a CSR sparse matrix and report which entries were present. + + Returns both the looked-up values and a boolean found-mask so callers can + distinguish missing sparse entries from explicit zero-valued entries. + """ n = row_indices.size(0) device = row_indices.device if n == 0: - return torch.zeros(0, device=device, dtype=torch.float) + return ( + torch.zeros(0, device=device, dtype=torch.float), + torch.zeros(0, device=device, dtype=torch.bool), + ) crow_indices = csr_matrix.crow_indices() col_indices_csr = csr_matrix.col_indices() values_csr = csr_matrix.values() - # Get row start/end pointers - row_starts = crow_indices[row_indices] - row_ends = crow_indices[row_indices + 1] - row_lengths = row_ends - row_starts - max_row_len = row_lengths.max().item() - - if max_row_len == 0: - return torch.full((n,), default_value, device=device, dtype=torch.float) - - # Build offset matrix: (n, max_row_len) - offsets = row_starts.unsqueeze(1) + torch.arange(max_row_len, device=device) - valid_mask = offsets < row_ends.unsqueeze(1) - - # Safe indexing with clamping - nnz = col_indices_csr.size(0) - offsets_clamped = offsets.clamp(max=max(nnz - 1, 0)) - - # Get columns at offsets and find matches - cols_at_offsets = col_indices_csr[offsets_clamped] - col_matches = (cols_at_offsets == col_indices.unsqueeze(1)) & valid_mask + if col_indices_csr.numel() == 0: + return ( + torch.full((n,), default_value, device=device, dtype=torch.float), + torch.zeros((n,), device=device, dtype=torch.bool), + ) - # Find which queries have matches - found = col_matches.any(dim=1) + num_rows, num_cols = csr_matrix.size() + row_counts = crow_indices[1:] - crow_indices[:-1] + csr_row_indices = torch.repeat_interleave( + torch.arange(num_rows, device=device), + row_counts, + ) + # CSR stores entries grouped by row, and sparse graph features are emitted + # with sorted column indices per row, so linearized row-major keys remain + # globally sorted for searchsorted. + csr_keys = csr_row_indices * num_cols + col_indices_csr + query_keys = row_indices * num_cols + col_indices + match_positions = torch.searchsorted(csr_keys, query_keys) + + candidate_mask = match_positions < csr_keys.numel() + found = torch.zeros((n,), device=device, dtype=torch.bool) + if candidate_mask.any(): + valid_match_positions = match_positions[candidate_mask] + found[candidate_mask] = ( + csr_keys[valid_match_positions] == query_keys[candidate_mask] + ) - # Initialize output result = torch.full((n,), default_value, device=device, dtype=torch.float) - if found.any(): - # Get match positions and retrieve values - match_offsets = col_matches.float().argmax(dim=1) - value_indices = row_starts[found] + match_offsets[found] - result[found] = values_csr[value_indices].float() + result[found] = values_csr[match_positions[found]].float() - return result + return result, found diff --git a/tests/unit/nn/graph_transformer_test.py b/tests/unit/nn/graph_transformer_test.py index d0fce10c3..05670939c 100644 --- a/tests/unit/nn/graph_transformer_test.py +++ b/tests/unit/nn/graph_transformer_test.py @@ -1,5 +1,7 @@ """Tests for GraphTransformerEncoder.""" +import sys +import types from typing import cast import torch @@ -8,13 +10,128 @@ from torch import Tensor from torch_geometric.data import HeteroData + +def _install_torchrec_stub() -> None: + if "torchrec" in sys.modules: + return + + torchrec_module = types.ModuleType("torchrec") + distributed_module = types.ModuleType("torchrec.distributed") + distributed_types_module = types.ModuleType("torchrec.distributed.types") + modules_module = types.ModuleType("torchrec.modules") + embedding_configs_module = types.ModuleType("torchrec.modules.embedding_configs") + embedding_modules_module = types.ModuleType("torchrec.modules.embedding_modules") + sparse_module = types.ModuleType("torchrec.sparse") + jagged_tensor_module = types.ModuleType("torchrec.sparse.jagged_tensor") + + class Awaitable: # pragma: no cover - import compatibility shim + pass + + class EmbeddingBagConfig: # pragma: no cover - import compatibility shim + def __init__(self, *args: object, **kwargs: object) -> None: + del args, kwargs + + class EmbeddingBagCollection: # pragma: no cover - import compatibility shim + def __init__(self, *args: object, **kwargs: object) -> None: + del args, kwargs + + class KeyedJaggedTensor: # pragma: no cover - import compatibility shim + pass + + distributed_types_module.Awaitable = Awaitable + embedding_configs_module.EmbeddingBagConfig = EmbeddingBagConfig + embedding_modules_module.EmbeddingBagCollection = EmbeddingBagCollection + jagged_tensor_module.KeyedJaggedTensor = KeyedJaggedTensor + + torchrec_module.distributed = distributed_module + torchrec_module.modules = modules_module + torchrec_module.sparse = sparse_module + distributed_module.types = distributed_types_module + modules_module.embedding_configs = embedding_configs_module + modules_module.embedding_modules = embedding_modules_module + sparse_module.jagged_tensor = jagged_tensor_module + + sys.modules["torchrec"] = torchrec_module + sys.modules["torchrec.distributed"] = distributed_module + sys.modules["torchrec.distributed.types"] = distributed_types_module + sys.modules["torchrec.modules"] = modules_module + sys.modules["torchrec.modules.embedding_configs"] = embedding_configs_module + sys.modules["torchrec.modules.embedding_modules"] = embedding_modules_module + sys.modules["torchrec.sparse"] = sparse_module + sys.modules["torchrec.sparse.jagged_tensor"] = jagged_tensor_module + + +def _install_graphlearn_torch_stub() -> None: + if "graphlearn_torch.partition" in sys.modules: + return + + graphlearn_torch_module = types.ModuleType("graphlearn_torch") + partition_module = types.ModuleType("graphlearn_torch.partition") + + class PartitionBook: # pragma: no cover - import compatibility shim + pass + + partition_module.PartitionBook = PartitionBook + graphlearn_torch_module.partition = partition_module + + sys.modules["graphlearn_torch"] = graphlearn_torch_module + sys.modules["graphlearn_torch.partition"] = partition_module + + +def _install_tensorflow_metadata_stub() -> None: + if "tensorflow_metadata.proto.v0.schema_pb2" in sys.modules: + return + + tensorflow_metadata_module = types.ModuleType("tensorflow_metadata") + proto_module = types.ModuleType("tensorflow_metadata.proto") + v0_module = types.ModuleType("tensorflow_metadata.proto.v0") + schema_pb2_module = types.ModuleType("tensorflow_metadata.proto.v0.schema_pb2") + + class Feature: # pragma: no cover - import compatibility shim + pass + + class Schema: # pragma: no cover - import compatibility shim + pass + + schema_pb2_module.Feature = Feature + schema_pb2_module.Schema = Schema + tensorflow_metadata_module.proto = proto_module + proto_module.v0 = v0_module + v0_module.schema_pb2 = schema_pb2_module + + sys.modules["tensorflow_metadata"] = tensorflow_metadata_module + sys.modules["tensorflow_metadata.proto"] = proto_module + sys.modules["tensorflow_metadata.proto.v0"] = v0_module + sys.modules["tensorflow_metadata.proto.v0.schema_pb2"] = schema_pb2_module + + +def _install_tensorflow_transform_stub() -> None: + if "tensorflow_transform" in sys.modules: + return + + tensorflow_transform_module = types.ModuleType("tensorflow_transform") + common_types = types.SimpleNamespace(FeatureSpecType=object, TensorType=object) + + tensorflow_transform_module.common_types = common_types + sys.modules["tensorflow_transform"] = tensorflow_transform_module + + +_install_tensorflow_metadata_stub() +_install_tensorflow_transform_stub() +_install_torchrec_stub() +_install_graphlearn_torch_stub() + from gigl.nn.graph_transformer import ( FeedForwardNetwork, GraphTransformerEncoder, GraphTransformerEncoderLayer, ) from gigl.src.common.types.graph_data import EdgeType, NodeType, Relation -from tests.test_assets.test_case import TestCase + +try: + from tests.test_assets.test_case import TestCase +except ModuleNotFoundError: # pragma: no cover - optional test harness deps + TestCase = absltest.TestCase def _create_simple_hetero_data() -> HeteroData: @@ -274,6 +391,18 @@ def _create_user_graph_with_ppr_edges() -> HeteroData: return data +def _pairwise_nonmissing_indices( + coords: list[tuple[int, int, int]], +) -> torch.Tensor: + return torch.tensor(coords, dtype=torch.long) + + +def _pairwise_relation_indices( + coords: list[tuple[int, int, int, int]], +) -> torch.Tensor: + return torch.tensor(coords, dtype=torch.long) + + class TestGraphTransformerEncoderPEModes(TestCase): def setUp(self) -> None: self._node_type = NodeType("user") @@ -388,6 +517,7 @@ def test_attention_bias_features_are_projected_per_head(self) -> None: assert encoder._anchor_pe_attention_bias_projection is not None assert encoder._pairwise_pe_attention_bias_projection is not None + assert encoder._pairwise_nonmissing_attention_bias is not None with torch.no_grad(): encoder._anchor_pe_attention_bias_projection.weight.copy_( @@ -411,6 +541,7 @@ def test_attention_bias_features_are_projected_per_head(self) -> None: ] ] ), + "pairwise_nonmissing_indices": None, "token_input": None, }, ) @@ -447,6 +578,7 @@ def test_attention_bias_supports_anchor_relative_attrs_and_ppr_weights( [[[1.0, 0.5], [2.0, 0.25], [3.0, 0.125]]] ), "pairwise_bias": None, + "pairwise_nonmissing_indices": None, "token_input": None, }, ) @@ -457,6 +589,439 @@ def test_attention_bias_supports_anchor_relative_attrs_and_ppr_weights( self.assertEqual(attn_bias[0, 0, 0, 2].item(), 4.25) self.assertEqual(attn_bias[0, 1, 0, 2].item(), 8.5) + def test_pairwise_nonmissing_indices_add_head_specific_bias(self) -> None: + encoder = self._create_encoder( + pairwise_attention_bias_attr_names=["pairwise_distance"], + ) + + assert encoder._pairwise_pe_attention_bias_projection is not None + assert encoder._pairwise_nonmissing_attention_bias is not None + + with torch.no_grad(): + encoder._pairwise_pe_attention_bias_projection.weight.zero_() + encoder._pairwise_nonmissing_attention_bias.copy_(torch.tensor([0.5, 1.5])) + + attn_bias = encoder._build_attention_bias( + valid_mask=torch.ones((1, 3), dtype=torch.bool), + sequences=torch.zeros((1, 3, 8), dtype=torch.float), + attention_bias_data={ + "anchor_bias": None, + "pairwise_bias": torch.zeros((1, 3, 3, 1), dtype=torch.float), + "pairwise_nonmissing_indices": _pairwise_nonmissing_indices( + [(0, 0, 0), (0, 0, 1), (0, 1, 0), (0, 1, 1), (0, 2, 2)] + ), + "token_input": None, + }, + ) + + self.assertEqual(attn_bias.shape, (1, 2, 3, 3)) + self.assertEqual(attn_bias[0, 0, 0, 0].item(), 0.5) + self.assertEqual(attn_bias[0, 1, 0, 0].item(), 1.5) + self.assertEqual(attn_bias[0, 0, 0, 2].item(), 0.0) + self.assertEqual(attn_bias[0, 1, 1, 2].item(), 0.0) + + def test_relation_attention_zero_init_matches_plain_layer(self) -> None: + torch.manual_seed(0) + base_layer = GraphTransformerEncoderLayer( + model_dim=8, + num_heads=2, + feedforward_dim=16, + dropout_rate=0.0, + attention_dropout_rate=0.0, + ) + relation_layer = GraphTransformerEncoderLayer( + model_dim=8, + num_heads=2, + feedforward_dim=16, + dropout_rate=0.0, + attention_dropout_rate=0.0, + relation_attention_mode="edge_type_bilinear", + num_relations=2, + ) + relation_layer.load_state_dict(base_layer.state_dict(), strict=False) + base_layer.eval() + relation_layer.eval() + + x = torch.randn(2, 4, 8) + valid_mask = torch.ones((2, 4), dtype=torch.bool) + + with torch.no_grad(): + assert relation_layer._relation_attention_matrices is not None + self.assertTrue( + torch.equal( + relation_layer._relation_attention_matrices, + torch.zeros_like(relation_layer._relation_attention_matrices), + ) + ) + base_output = base_layer(x, valid_mask=valid_mask) + relation_output = relation_layer( + x, + pairwise_relation_indices=_pairwise_relation_indices( + [(0, 1, 0, 0), (1, 2, 1, 1)] + ), + valid_mask=valid_mask, + ) + + self.assertTrue(torch.allclose(base_output, relation_output, atol=1e-6)) + + def test_relation_value_mode_none_ignores_relation_indices(self) -> None: + torch.manual_seed(0) + layer = GraphTransformerEncoderLayer( + model_dim=8, + num_heads=2, + feedforward_dim=16, + dropout_rate=0.0, + attention_dropout_rate=0.0, + ) + layer.eval() + x = torch.randn(1, 3, 8) + + with torch.no_grad(): + base_output = layer(x, valid_mask=torch.ones((1, 3), dtype=torch.bool)) + relation_index_output = layer( + x, + pairwise_relation_indices=_pairwise_relation_indices( + [(0, 1, 0, 0), (0, 2, 1, 0)] + ), + valid_mask=torch.ones((1, 3), dtype=torch.bool), + ) + + self.assertTrue(torch.allclose(base_output, relation_index_output, atol=1e-6)) + + def test_relation_value_zero_init_matches_plain_layer(self) -> None: + torch.manual_seed(0) + base_layer = GraphTransformerEncoderLayer( + model_dim=8, + num_heads=2, + feedforward_dim=16, + dropout_rate=0.0, + attention_dropout_rate=0.0, + ) + relation_value_layer = GraphTransformerEncoderLayer( + model_dim=8, + num_heads=2, + feedforward_dim=16, + dropout_rate=0.0, + attention_dropout_rate=0.0, + relation_value_mode="sparse_residual_gate", + num_relations=2, + ) + relation_value_layer.load_state_dict(base_layer.state_dict(), strict=False) + base_layer.eval() + relation_value_layer.eval() + + x = torch.randn(2, 4, 8) + valid_mask = torch.ones((2, 4), dtype=torch.bool) + + with torch.no_grad(): + assert relation_value_layer._relation_value_gates is not None + self.assertTrue( + torch.equal( + relation_value_layer._relation_value_gates, + torch.zeros_like(relation_value_layer._relation_value_gates), + ) + ) + base_output = base_layer(x, valid_mask=valid_mask) + relation_value_output = relation_value_layer( + x, + pairwise_relation_indices=_pairwise_relation_indices( + [(0, 1, 0, 0), (1, 2, 1, 1)] + ), + valid_mask=valid_mask, + ) + + self.assertTrue(torch.allclose(base_output, relation_value_output, atol=1e-6)) + + def test_relation_value_residual_affects_indexed_queries_and_normalizes( + self, + ) -> None: + layer = GraphTransformerEncoderLayer( + model_dim=4, + num_heads=2, + feedforward_dim=8, + dropout_rate=0.0, + attention_dropout_rate=0.0, + relation_value_mode="sparse_residual_gate", + num_relations=2, + ) + attention_output = torch.zeros((1, 2, 3, 2)) + value = torch.tensor( + [ + [ + [[1.0, 10.0], [2.0, 20.0], [3.0, 30.0]], + [[4.0, 40.0], [5.0, 50.0], [6.0, 60.0]], + ] + ] + ) + + with torch.no_grad(): + assert layer._relation_value_gates is not None + layer._relation_value_gates[0] = torch.tensor( + [[1.0, 0.5], [0.25, 0.0]] + ) + layer._relation_value_gates[1] = torch.tensor( + [[2.0, 0.0], [0.0, 3.0]] + ) + actual = layer._add_relation_value_residual( + attention_output=attention_output, + value=value, + pairwise_relation_indices=_pairwise_relation_indices( + [ + (0, 1, 0, 0), + (0, 1, 0, 0), + (0, 1, 2, 1), + (0, 2, 1, 0), + ] + ), + ) + + expected = torch.zeros_like(attention_output) + expected[0, 0, 1] = torch.tensor([8.0 / 3.0, 10.0 / 3.0]) + expected[0, 1, 1] = torch.tensor([2.0 / 3.0, 60.0]) + expected[0, 0, 2] = torch.tensor([2.0, 10.0]) + expected[0, 1, 2] = torch.tensor([1.25, 0.0]) + self.assertTrue(torch.allclose(actual, expected, atol=1e-6)) + + def test_relation_value_residual_rejects_invalid_relation_ids(self) -> None: + layer = GraphTransformerEncoderLayer( + model_dim=4, + num_heads=2, + feedforward_dim=8, + dropout_rate=0.0, + attention_dropout_rate=0.0, + relation_value_mode="sparse_residual_gate", + num_relations=1, + ) + + with self.assertRaisesRegex(ValueError, "relation ids outside"): + layer._add_relation_value_residual( + attention_output=torch.zeros((1, 2, 2, 2)), + value=torch.zeros((1, 2, 2, 2)), + pairwise_relation_indices=_pairwise_relation_indices([(0, 1, 0, 1)]), + ) + + def test_relation_attention_nonzero_bias_only_indexed_pairs(self) -> None: + layer = GraphTransformerEncoderLayer( + model_dim=2, + num_heads=1, + feedforward_dim=4, + dropout_rate=0.0, + attention_dropout_rate=0.0, + relation_attention_mode="edge_type_bilinear", + num_relations=1, + ) + query = torch.tensor([[[[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]]]]) + key = torch.tensor([[[[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]]]]) + + with torch.no_grad(): + assert layer._relation_attention_matrices is not None + layer._relation_attention_matrices[0, 0] = torch.eye(2) + relation_bias = layer._build_relation_attention_bias( + query=query, + key=key, + pairwise_relation_indices=_pairwise_relation_indices([(0, 2, 1, 0)]), + ) + + assert relation_bias is not None + expected = torch.zeros((1, 1, 3, 3)) + expected[0, 0, 2, 1] = 1.0 / torch.sqrt(torch.tensor(2.0)).item() + self.assertTrue(torch.allclose(relation_bias, expected, atol=1e-6)) + + def test_relation_attention_respects_existing_negative_bias(self) -> None: + layer = GraphTransformerEncoderLayer( + model_dim=2, + num_heads=1, + feedforward_dim=4, + dropout_rate=0.0, + attention_dropout_rate=0.0, + relation_attention_mode="edge_type_bilinear", + num_relations=1, + ) + query = torch.tensor([[[[1.0, 0.0], [0.0, 1.0]]]]) + key = torch.tensor([[[[1.0, 0.0], [0.0, 1.0]]]]) + negative_inf = torch.finfo(torch.float).min + attn_bias = torch.zeros((1, 1, 2, 2), dtype=torch.float) + attn_bias[0, 0, 0, 0] = negative_inf + + with torch.no_grad(): + assert layer._relation_attention_matrices is not None + layer._relation_attention_matrices[0, 0] = torch.eye(2) + relation_bias = layer._build_relation_attention_bias( + query=query, + key=key, + pairwise_relation_indices=_pairwise_relation_indices([(0, 0, 0, 0)]), + ) + + assert relation_bias is not None + self.assertGreater(relation_bias[0, 0, 0, 0].item(), 0.0) + self.assertEqual((attn_bias + relation_bias)[0, 0, 0, 0].item(), negative_inf) + + def test_relation_attention_supports_ppr_sequence_construction(self) -> None: + data = _create_user_graph_with_ppr_edges() + ppr_edge_type = EdgeType(self._node_type, Relation("ppr"), self._node_type) + + base_encoder = self._create_encoder( + edge_type_to_feat_dim_map={ppr_edge_type: 0}, + sequence_construction_method="ppr", + ) + relation_encoder = self._create_encoder( + edge_type_to_feat_dim_map={ppr_edge_type: 0}, + sequence_construction_method="ppr", + relation_attention_mode="edge_type_bilinear", + ) + relation_encoder.load_state_dict(base_encoder.state_dict(), strict=False) + base_encoder.eval() + relation_encoder.eval() + + with torch.no_grad(): + base_output = base_encoder( + data=data, + anchor_node_type=self._node_type, + device=self._device, + ) + relation_output = relation_encoder( + data=data, + anchor_node_type=self._node_type, + device=self._device, + ) + + self.assertTrue(torch.allclose(base_output, relation_output, atol=1e-6)) + + def test_edge_attr_attention_bias_zero_init_matches_baseline(self) -> None: + data = HeteroData() + data["user"].x = torch.tensor( + [ + [1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0], + ] + ) + data["user"].batch_size = 1 + data[self._edge_type.tuple_repr()].edge_index = torch.tensor([[0], [1]]) + data[self._edge_type.tuple_repr()].edge_attr = torch.tensor([[3.0]]) + + edge_type_to_feat_dim_map = {self._edge_type: 1} + base_encoder = self._create_encoder( + edge_type_to_feat_dim_map=edge_type_to_feat_dim_map, + ) + edge_attr_encoder = self._create_encoder( + edge_type_to_feat_dim_map=edge_type_to_feat_dim_map, + edge_attr_attention_bias_mode="sparse_linear", + ) + edge_attr_encoder.load_state_dict(base_encoder.state_dict(), strict=False) + base_encoder.eval() + edge_attr_encoder.eval() + + with torch.no_grad(): + edge_attr_projection = ( + edge_attr_encoder._edge_attr_attention_bias_projection_dict["0"] + ) + self.assertTrue( + torch.equal( + edge_attr_projection.weight, + torch.zeros_like(edge_attr_projection.weight), + ) + ) + base_output = base_encoder( + data=data, + anchor_node_type=self._node_type, + device=self._device, + ) + edge_attr_output = edge_attr_encoder( + data=data, + anchor_node_type=self._node_type, + device=self._device, + ) + + self.assertTrue(torch.allclose(base_output, edge_attr_output, atol=1e-6)) + + def test_edge_attr_attention_bias_adds_only_indexed_pairs_and_accumulates( + self, + ) -> None: + encoder = self._create_encoder( + edge_type_to_feat_dim_map={self._edge_type: 2}, + edge_attr_attention_bias_mode="sparse_linear", + ) + projection = encoder._edge_attr_attention_bias_projection_dict["0"] + with torch.no_grad(): + projection.weight.copy_(torch.tensor([[1.0, 2.0], [3.0, 4.0]])) + + attn_bias = encoder._build_attention_bias( + valid_mask=torch.tensor([[True, True, True]]), + sequences=torch.zeros((1, 3, 8), dtype=torch.float), + attention_bias_data={ + "anchor_bias": None, + "pairwise_bias": None, + "pairwise_nonmissing_indices": None, + "pairwise_edge_attr_indices": { + 0: torch.tensor( + [ + (0, 1, 0), + (0, 1, 0), + (0, 2, 1), + ], + dtype=torch.long, + ) + }, + "pairwise_edge_attr_values": { + 0: torch.tensor( + [ + [1.0, 1.0], + [2.0, 0.0], + [0.0, 1.0], + ] + ) + }, + "token_input": None, + }, + ) + + expected = torch.zeros((1, 2, 3, 3), dtype=torch.float) + expected[0, :, 1, 0] = torch.tensor([5.0, 13.0]) + expected[0, :, 2, 1] = torch.tensor([2.0, 4.0]) + self.assertTrue(torch.allclose(attn_bias, expected, atol=1e-6)) + + def test_edge_attr_attention_bias_supports_ppr_sequence_construction( + self, + ) -> None: + data = _create_user_graph_with_ppr_edges() + ppr_edge_type = EdgeType(self._node_type, Relation("ppr"), self._node_type) + + base_encoder = self._create_encoder( + edge_type_to_feat_dim_map={ppr_edge_type: 1}, + sequence_construction_method="ppr", + ) + edge_attr_encoder = self._create_encoder( + edge_type_to_feat_dim_map={ppr_edge_type: 1}, + sequence_construction_method="ppr", + edge_attr_attention_bias_mode="sparse_linear", + ) + edge_attr_encoder.load_state_dict(base_encoder.state_dict(), strict=False) + base_encoder.eval() + edge_attr_encoder.eval() + + with torch.no_grad(): + edge_attr_projection = ( + edge_attr_encoder._edge_attr_attention_bias_projection_dict["0"] + ) + self.assertTrue( + torch.equal( + edge_attr_projection.weight, + torch.zeros_like(edge_attr_projection.weight), + ) + ) + base_output = base_encoder( + data=data, + anchor_node_type=self._node_type, + device=self._device, + ) + edge_attr_output = edge_attr_encoder( + data=data, + anchor_node_type=self._node_type, + device=self._device, + ) + + self.assertTrue(torch.allclose(base_output, edge_attr_output, atol=1e-6)) + def test_sinusoidal_sequence_positional_encoding_masks_padding(self) -> None: encoder = self._create_encoder( sequence_construction_method="ppr", diff --git a/tests/unit/transforms/graph_transformer_test.py b/tests/unit/transforms/graph_transformer_test.py index 18551014b..69cab0cd5 100644 --- a/tests/unit/transforms/graph_transformer_test.py +++ b/tests/unit/transforms/graph_transformer_test.py @@ -7,11 +7,16 @@ from absl.testing import absltest from torch_geometric.data import HeteroData +from gigl.src.common.types.graph_data import EdgeType, NodeType, Relation from gigl.transforms.graph_transformer import ( _get_k_hop_neighbors_sparse, heterodata_to_graph_transformer_input, ) -from tests.test_assets.test_case import TestCase + +try: + from tests.test_assets.test_case import TestCase +except ModuleNotFoundError: # pragma: no cover - optional test harness deps + TestCase = absltest.TestCase def create_simple_hetero_data() -> HeteroData: @@ -122,6 +127,27 @@ def create_ppr_sequence_hetero_data() -> HeteroData: return data +def _dense_nonmissing_mask_from_indices( + pairwise_nonmissing_indices: torch.Tensor | None, + batch_size: int, + seq_len: int, + device: torch.device, +) -> torch.Tensor: + dense_mask = torch.zeros( + (batch_size, seq_len, seq_len), + dtype=torch.bool, + device=device, + ) + if pairwise_nonmissing_indices is None or pairwise_nonmissing_indices.numel() == 0: + return dense_mask + dense_mask[ + pairwise_nonmissing_indices[:, 0], + pairwise_nonmissing_indices[:, 1], + pairwise_nonmissing_indices[:, 2], + ] = True + return dense_mask + + class TestGetKHopNeighborsSparse(TestCase): """Tests for _get_k_hop_neighbors_sparse helper function.""" @@ -252,6 +278,10 @@ def test_basic_transform(self): self.assertIsInstance(attention_bias_data, dict) self.assertIn("anchor_bias", attention_bias_data) self.assertIn("pairwise_bias", attention_bias_data) + self.assertIn("pairwise_nonmissing_indices", attention_bias_data) + self.assertIn("pairwise_relation_indices", attention_bias_data) + self.assertIn("pairwise_edge_attr_indices", attention_bias_data) + self.assertIn("pairwise_edge_attr_values", attention_bias_data) def test_attention_mask_validity(self): """Test that attention mask correctly identifies valid positions.""" @@ -305,6 +335,148 @@ def test_anchor_first(self): # First position should be anchor node self.assertTrue(torch.allclose(sequences[0, 0], anchor_feature)) + def test_pairwise_relation_indices_follow_order_direction_and_padding(self): + """Sparse relation indices preserve edge-type labels before homogenization.""" + user = NodeType("user") + likes = EdgeType(user, Relation("likes"), user) + follows = EdgeType(user, Relation("follows"), user) + missing = EdgeType(user, Relation("missing"), user) + + data = HeteroData() + data["user"].x = torch.tensor( + [ + [1.0, 0.0], + [0.0, 1.0], + [1.0, 1.0], + ] + ) + data["user"].batch_size = 1 + data[likes.tuple_repr()].edge_index = torch.tensor([[0], [1]]) + data[follows.tuple_repr()].edge_index = torch.tensor([[0, 1], [1, 2]]) + + _, valid_mask, auxiliary_data = heterodata_to_graph_transformer_input( + data=data, + batch_size=1, + max_seq_len=4, + anchor_node_type="user", + hop_distance=2, + relation_edge_types=[likes, follows, missing], + ) + + self.assertTrue( + torch.equal(valid_mask[0], torch.tensor([True, True, True, False])) + ) + pairwise_relation_indices = auxiliary_data["pairwise_relation_indices"] + self.assertIsNotNone(pairwise_relation_indices) + assert pairwise_relation_indices is not None + self.assertEqual(pairwise_relation_indices.shape[1], 4) + self.assertEqual( + {tuple(coord) for coord in pairwise_relation_indices.tolist()}, + { + (0, 1, 0, 0), # likes: source 0 -> target 1 + (0, 1, 0, 1), # follows: source 0 -> target 1 + (0, 2, 1, 1), # follows: source 1 -> target 2 + }, + ) + self.assertFalse((pairwise_relation_indices[:, 1:3] == 3).any().item()) + self.assertFalse((pairwise_relation_indices[:, 3] == 2).any().item()) + + def test_pairwise_edge_attr_payloads_follow_order_direction_and_padding(self): + """Sparse edge-attr payloads preserve relation labels and GAT direction.""" + user = NodeType("user") + likes = EdgeType(user, Relation("likes"), user) + follows = EdgeType(user, Relation("follows"), user) + missing = EdgeType(user, Relation("missing"), user) + + data = HeteroData() + data["user"].x = torch.tensor( + [ + [1.0, 0.0], + [0.0, 1.0], + [1.0, 1.0], + ] + ) + data["user"].batch_size = 1 + data[likes.tuple_repr()].edge_index = torch.tensor([[0], [1]]) + data[likes.tuple_repr()].edge_attr = torch.tensor([[0.25, 1.25]]) + data[follows.tuple_repr()].edge_index = torch.tensor([[0, 1], [1, 2]]) + data[follows.tuple_repr()].edge_attr = torch.tensor([[2.0], [3.0]]) + + edge_attr_dim_map = { + likes: 2, + follows: 1, + missing: 1, + } + _, valid_mask, auxiliary_data = heterodata_to_graph_transformer_input( + data=data, + batch_size=1, + max_seq_len=4, + anchor_node_type="user", + hop_distance=2, + edge_attr_edge_type_to_feat_dim_map=edge_attr_dim_map, + ) + + self.assertTrue( + torch.equal(valid_mask[0], torch.tensor([True, True, True, False])) + ) + edge_attr_indices = auxiliary_data["pairwise_edge_attr_indices"] + edge_attr_values = auxiliary_data["pairwise_edge_attr_values"] + self.assertIsNotNone(edge_attr_indices) + self.assertIsNotNone(edge_attr_values) + assert edge_attr_indices is not None + assert edge_attr_values is not None + + sorted_edge_types = sorted(edge_attr_dim_map.keys()) + likes_idx = sorted_edge_types.index(likes) + follows_idx = sorted_edge_types.index(follows) + missing_idx = sorted_edge_types.index(missing) + + self.assertEqual( + {tuple(coord) for coord in edge_attr_indices[likes_idx].tolist()}, + {(0, 1, 0)}, # likes: source 0 -> target 1 + ) + self.assertTrue( + torch.allclose( + edge_attr_values[likes_idx], + torch.tensor([[0.25, 1.25]]), + ) + ) + self.assertEqual( + {tuple(coord) for coord in edge_attr_indices[follows_idx].tolist()}, + { + (0, 1, 0), # follows: source 0 -> target 1 + (0, 2, 1), # follows: source 1 -> target 2 + }, + ) + self.assertTrue( + torch.allclose(edge_attr_values[follows_idx], torch.tensor([[2.0], [3.0]])) + ) + self.assertNotIn(missing_idx, edge_attr_indices) + self.assertFalse((edge_attr_indices[likes_idx] == 3).any().item()) + self.assertFalse((edge_attr_indices[follows_idx] == 3).any().item()) + + def test_pairwise_edge_attr_payloads_missing_edge_attr_raises(self): + user = NodeType("user") + follows = EdgeType(user, Relation("follows"), user) + + data = HeteroData() + data["user"].x = torch.tensor([[1.0, 0.0], [0.0, 1.0]]) + data["user"].batch_size = 1 + data[follows.tuple_repr()].edge_index = torch.tensor([[0], [1]]) + + with self.assertRaisesRegex( + ValueError, + "requires edge_attr for edge type", + ): + heterodata_to_graph_transformer_input( + data=data, + batch_size=1, + max_seq_len=2, + anchor_node_type="user", + hop_distance=1, + edge_attr_edge_type_to_feat_dim_map={follows: 1}, + ) + def test_different_anchor_types(self): """Test with different anchor node types.""" data = create_simple_hetero_data() @@ -792,6 +964,7 @@ def test_transform_returns_base_sequences_and_anchor_relative_bias(self) -> None assert anchor_bias is not None self.assertEqual(anchor_bias.shape, (1, 4, 1)) self.assertIsNone(attention_bias_data["pairwise_bias"]) + self.assertIsNone(attention_bias_data["pairwise_nonmissing_indices"]) self.assertTrue(valid_mask[0, 0].item()) def test_attention_bias_outputs_include_valid_mask_and_relative_features( @@ -817,17 +990,59 @@ def test_attention_bias_outputs_include_valid_mask_and_relative_features( self.assertEqual(valid_mask.shape, (1, 4)) anchor_bias = attention_bias_data["anchor_bias"] pairwise_bias = attention_bias_data["pairwise_bias"] + pairwise_nonmissing_indices = attention_bias_data["pairwise_nonmissing_indices"] assert anchor_bias is not None assert pairwise_bias is not None + assert pairwise_nonmissing_indices is not None + pairwise_nonmissing_mask = _dense_nonmissing_mask_from_indices( + pairwise_nonmissing_indices=pairwise_nonmissing_indices, + batch_size=1, + seq_len=4, + device=pairwise_bias.device, + ) self.assertEqual(anchor_bias.shape, (1, 4, 1)) self.assertEqual(pairwise_bias.shape, (1, 4, 4, 1)) + self.assertEqual(pairwise_nonmissing_indices.shape[1], 3) self.assertAlmostEqual(anchor_bias[0, 0, 0].item(), 0.0, places=5) self.assertAlmostEqual(anchor_bias[0, 1, 0].item(), 1.0, places=5) self.assertAlmostEqual(anchor_bias[0, 2, 0].item(), 3.0, places=5) self.assertAlmostEqual(pairwise_bias[0, 0, 0, 0].item(), 0.1, places=5) + self.assertTrue(torch.all(pairwise_nonmissing_mask[0, :3, :3])) invalid_pair_mask = ~(valid_mask.unsqueeze(2) & valid_mask.unsqueeze(1)) self.assertTrue(torch.all(pairwise_bias[..., 0][invalid_pair_mask] == 0)) + self.assertTrue(torch.all(~pairwise_nonmissing_mask[invalid_pair_mask])) + + def test_pairwise_attention_bias_attr_support_mismatch_raises(self) -> None: + data = _create_hetero_data_with_relative_pe() + pairwise_distance_sparse_mismatch = torch.tensor( + [ + [0.1, 0.0, 0.3, 0.4, 0.5], + [0.6, 0.7, 0.0, 0.9, 1.0], + [1.1, 1.2, 1.3, 1.4, 1.5], + [1.6, 1.7, 1.8, 1.9, 2.0], + [2.1, 2.2, 2.3, 2.4, 2.5], + ] + ) + data.pairwise_distance_sparse_mismatch = ( + pairwise_distance_sparse_mismatch.to_sparse_csr() + ) + + with self.assertRaisesRegex( + ValueError, + "Pairwise attention bias attributes must share identical nonmissing support", + ): + heterodata_to_graph_transformer_input( + data=data, + batch_size=1, + max_seq_len=4, + anchor_node_type="user", + hop_distance=2, + pairwise_attention_bias_attr_names=[ + "pairwise_distance", + "pairwise_distance_sparse_mismatch", + ], + ) if __name__ == "__main__":