Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 168 additions & 1 deletion examples/models/llama/attention.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Dict, Optional, Tuple, Type, TypedDict
Expand Down Expand Up @@ -564,6 +564,7 @@
self.head_v_dim = args.linear_value_head_dim
self.key_dim = self.head_k_dim * self.num_k_heads
self.value_dim = self.head_v_dim * self.num_v_heads
self.chunk_size = args.deltanet_chunk_size
self.conv_kernel_size = args.linear_conv_kernel_dim

assert (
Expand Down Expand Up @@ -702,6 +703,165 @@

return core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)

def _chunked_recurrent_gated_delta_rule(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
) -> torch.Tensor:
initial_dtype = query.dtype
query = _l2norm(query, dim=-1, eps=1e-6)
key = _l2norm(key, dim=-1, eps=1e-6)
query, key, value, beta, g = [
x.transpose(1, 2).contiguous().to(torch.float32)
for x in (query, key, value, beta, g)
]

batch_size, num_heads, sequence_length, k_head_dim = key.shape
v_head_dim = value.shape[-1]
scale = 1.0 / (query.shape[-1] ** 0.5)
query = query * scale

chunk_size = self.chunk_size
assert chunk_size is not None

# Pad T to next multiple of chunk_size.
remainder = sequence_length % chunk_size
if remainder != 0:
pad_len = chunk_size - remainder
query = F.pad(query, (0, 0, 0, pad_len))
key = F.pad(key, (0, 0, 0, pad_len))
value = F.pad(value, (0, 0, 0, pad_len))
g = F.pad(g, (0, 0, 0, pad_len))
beta = F.pad(beta, (0, 0, 0, pad_len))
padded_len = sequence_length + pad_len
else:
padded_len = sequence_length

num_chunks = padded_len // chunk_size

# Pre-scale by beta.
v_beta = value * beta.unsqueeze(-1)
k_beta = key * beta.unsqueeze(-1)

# Reshape into chunks: (B, H, N, C, dim).
q_c = query.reshape(
batch_size, num_heads, num_chunks, chunk_size, k_head_dim
)
k_c = key.reshape(
batch_size, num_heads, num_chunks, chunk_size, k_head_dim
)
v_c = v_beta.reshape(
batch_size, num_heads, num_chunks, chunk_size, v_head_dim
)
k_beta_c = k_beta.reshape(
batch_size, num_heads, num_chunks, chunk_size, k_head_dim
)
g_c = g.reshape(batch_size, num_heads, num_chunks, chunk_size)

# Cumulative log-gate within each chunk.
decay = g_c.cumsum(dim=-1)
decay_exp = decay.exp().unsqueeze(-1)

# Relative gate-decay matrix (lower triangular, C×C per chunk).
L_mask = (
(decay.unsqueeze(-1) - decay.unsqueeze(-2)).tril().exp().tril()
)

# --- WY representation (parallel across all chunks) ---
# M = (β·k @ k^T · L_mask), strictly lower triangular C×C.
kkt = k_beta_c @ k_c.transpose(-1, -2)
M = (kkt * L_mask).tril(diagonal=-1)

# Solve (I + M) X = [v_c | k_beta_c·exp(decay)] for corrected values/keys.
I_plus_M = M + torch.eye(
chunk_size, device=key.device, dtype=key.dtype
)
rhs = torch.cat([v_c, k_beta_c * decay_exp], dim=-1)
solved = torch.linalg.solve_triangular(
I_plus_M, rhs, upper=False, unitriangular=True
)
v_wy = solved[..., :v_head_dim]
k_cumdecay = solved[..., v_head_dim:]

# Intra-chunk causal attention: Q @ K^T · L_mask, upper-triangular masked.
causal_mask = torch.triu(
torch.ones(
chunk_size, chunk_size, dtype=torch.bool, device=key.device
),
diagonal=1,
)
intra_attn = (
q_c @ k_c.transpose(-1, -2) * L_mask
).masked_fill(causal_mask, 0)

# --- Inter-chunk state propagation via torch.scan ---
# Permute scan dimension first: (N, B, H, ...).
intra_attn_s = intra_attn.permute(2, 0, 1, 3, 4)
v_wy_s = v_wy.permute(2, 0, 1, 3, 4)
k_cumdecay_s = k_cumdecay.permute(2, 0, 1, 3, 4)
q_c_s = q_c.permute(2, 0, 1, 3, 4)
k_c_s = k_c.permute(2, 0, 1, 3, 4)
decay_s = decay.permute(2, 0, 1, 3)

init_state = self.recurrent_state[:batch_size].to(value.dtype)

def inter_chunk_fn(carry, x):
attn_i = x[0]
v_wy_i = x[1]
k_cumdecay_i = x[2]
q_ci = x[3]
k_ci = x[4]
decay_i = x[5]

# Subtract state contribution from WY-corrected values.
v_prime = k_cumdecay_i @ carry
v_new = v_wy_i - v_prime

# Output = inter (query × state) + intra (causal attn × corrected v).
o_inter = (q_ci * decay_i.unsqueeze(-1).exp()) @ carry
o_i = o_inter + attn_i @ v_new

# State update: gate-decay + rank-C outer product.
decay_last = decay_i[:, :, chunk_size - 1]
k_decayed = k_ci * (
decay_last.unsqueeze(-1).unsqueeze(-1)
- decay_i.unsqueeze(-1)
).exp()
next_carry = (
carry * decay_last.unsqueeze(-1).unsqueeze(-1).exp()
+ k_decayed.transpose(-1, -2) @ v_new
)
return next_carry, o_i

xs = (
intra_attn_s,
v_wy_s,
k_cumdecay_s,
q_c_s,
k_c_s,
decay_s,
)
final_state, all_outputs = torch.scan(
inter_chunk_fn, init=init_state, xs=xs
)

# (N, B, H, C, V) -> (B, H, T_padded, V)
all_outputs = all_outputs.permute(1, 2, 0, 3, 4).reshape(
batch_size, num_heads, padded_len, v_head_dim
)

core_attn_out = all_outputs[:, :, :sequence_length, :]

with torch.no_grad():
self.recurrent_state[:batch_size].copy_(
final_state.to(self.recurrent_state.dtype)
)

return core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)

def forward(
self,
x: torch.Tensor,
Expand Down Expand Up @@ -740,7 +900,14 @@

beta = b.sigmoid()
g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
core_attn_out = self._recurrent_gated_delta_rule(query, key, value, g, beta)
if self.chunk_size is not None:
core_attn_out = self._chunked_recurrent_gated_delta_rule(
query, key, value, g, beta
)
else:
core_attn_out = self._recurrent_gated_delta_rule(
query, key, value, g, beta
)

core_attn_out = core_attn_out.reshape(-1, self.head_v_dim)
z = z.reshape(-1, self.head_v_dim)
Expand Down
1 change: 1 addition & 0 deletions examples/models/llama/model_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class ModelArgs:
linear_value_head_dim: Optional[int] = None
linear_num_key_heads: Optional[int] = None
linear_num_value_heads: Optional[int] = None
deltanet_chunk_size: Optional[int] = None
# Qwen3.5 RMSNorm uses (1 + weight) scaling.
rms_norm_add_unit_offset: bool = False
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
Expand Down
18 changes: 18 additions & 0 deletions examples/models/llama/tests/BUCK
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,24 @@ fbcode_target(_kind = python_unittest,
],
)

fbcode_target(_kind = python_unittest,
name = "test_gated_deltanet_coreml_export",
srcs = [
"test_gated_deltanet_coreml_export.py",
],
preload_deps = [
"//executorch/extension/llm/custom_ops:custom_ops_aot_lib",
],
deps = [
"//caffe2:torch",
"//executorch/examples/models/llama:export_library",
"//executorch/examples/models/llama:llama_transformer",
"//executorch/examples/models/llama:sdpa",
"//executorch/examples/models/llama:custom_kv_cache",
"//executorch/extension/pybindings:portable_lib",
],
)

fbcode_target(_kind = python_unittest,
name = "test_export_llama_lib",
srcs = [
Expand Down
Loading
Loading