Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions examples/models/llama/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ class ForwardOptions(TypedDict, total=False):
# When provided, the attention layer skips its own K/V projection
# and reuses the donor's K/V instead.
shared_kv: Optional[Tuple[torch.Tensor, torch.Tensor]]
# Per-call KV cache override. Used by `MultimodalTransformer` when
# `transformer_block_repeat_config` repeats a TransformerBlock so that each
# *occurrence* of the layer in the schedule writes to its own KV cache
# rather than sharing the layer's `self.kv_cache`. When None or absent the
# attention falls back to `self.kv_cache`.
kv_cache_override: Optional["KVCache"]


class Attention(nn.Module, ABC):
Expand Down Expand Up @@ -276,7 +282,7 @@ def __init__(
[0, 1, 2, 3, 4, NA, NA, NA] After cache update we would have
[8, 1, 2, 3, 4, 5, 6, 7]. We kicked out token at pos = 0. However, the
current step still has access to [pos - sliding_window_size, pos] tokens.

To make sure we dont over attend, i.e. we dont have pos = 5
to attend to pos = 1, mask calculaton has to account for the sliding window
size.
Expand Down Expand Up @@ -573,8 +579,12 @@ def forward(
q, k, v = self._prepare_qkv(q, x, bsz, seqlen, freqs_cos, freqs_sin)

if self.use_kv_cache:
# Per-call KV cache override (used when a TransformerBlock is invoked
# multiple times via `transformer_block_repeat_config` so each
# occurrence has its own KV cache). Falls back to `self.kv_cache`.
active_kv_cache = kwargs.get("kv_cache_override") or self.kv_cache
assert input_pos is not None
is_ring_buffer = getattr(self.kv_cache, "is_ring_buffer", False)
is_ring_buffer = getattr(active_kv_cache, "is_ring_buffer", False)

if is_ring_buffer:
# Ring buffer models compute their own mask after KV cache
Expand All @@ -594,14 +604,14 @@ def forward(

# Only update KV cache for non-shared layers
if shared_kv is None:
assert self.kv_cache is not None, (
assert active_kv_cache is not None, (
"kv_cache is required when shared_kv is not provided. "
"This layer may be a YOCO shared layer that requires shared_kv from a donor."
)
k, v = self.kv_cache.update(input_pos, k, v)
k, v = active_kv_cache.update(input_pos, k, v)

if is_ring_buffer:
attn_mask = self.kv_cache.create_causal_mask_for_ring_buffer(
attn_mask = active_kv_cache.create_causal_mask_for_ring_buffer(
input_pos[0].item(), seqlen
)

Expand Down
55 changes: 54 additions & 1 deletion examples/models/llama/model_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from dataclasses import dataclass
from enum import Enum
from functools import partial
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional

import torch.nn.functional as F

Expand Down Expand Up @@ -182,6 +182,12 @@ class ModelArgs:
use_ffn_learnable_scales: bool = False
output_soft_cap_temp: Optional[float] = None

# Block repetition: repeat contiguous ranges of transformer layers.
# List of {"start": int, "end": int, "count": int} dicts where start/end
# are layer indices (both inclusive) and count is total number of passes
# (1 = normal, 2 = run the block twice, etc.). Blocks must not overlap.
transformer_block_repeat_config: Optional[list] = None

def __post_init__(self): # noqa: C901
if self.n_kv_heads is None:
self.n_kv_heads = self.n_heads
Expand Down Expand Up @@ -224,3 +230,50 @@ def find_multiple(n: int, k: int) -> int:
# Convert string act_fn to enum if needed
if isinstance(self.act_fn, str):
self.act_fn = ActFn.from_string(self.act_fn)

self.validate_block_repeat_config()

def validate_block_repeat_config(self) -> None:
"""Validate transformer_block_repeat_config field.

Called from __post_init__ and should also be called after setting
transformer_block_repeat_config post-construction.
"""
if self.transformer_block_repeat_config is None:
return
for i, block in enumerate(self.transformer_block_repeat_config):
assert (
"start" in block and "end" in block and "count" in block
), f"transformer_block_repeat_config[{i}] must have 'start', 'end', and 'count' keys"
assert 0 <= block["start"] <= block["end"] < self.n_layers, (
f"transformer_block_repeat_config[{i}]: invalid range [{block['start']}, {block['end']}] "
f"for {self.n_layers} layers"
)
assert (
block["count"] >= 1
), f"transformer_block_repeat_config[{i}]: count must be >= 1"
# Check for overlapping blocks (end is inclusive, so next start must be > prev end)
sorted_blocks = sorted(
self.transformer_block_repeat_config, key=lambda b: b["start"]
)
for i in range(1, len(sorted_blocks)):
assert sorted_blocks[i]["start"] > sorted_blocks[i - 1]["end"], (
f"transformer_block_repeat_config: blocks {sorted_blocks[i-1]} and "
f"{sorted_blocks[i]} overlap"
)

@staticmethod
def normalize_block_repeat_config(
config: Optional[List[Dict[str, int]]],
) -> Optional[List[Dict[str, int]]]:
"""Drop entries with `count == 1`; return None if nothing remains.

A block-repeat entry with count=1 visits its layers exactly once --
the same as if the entry were omitted. Stripping these no-ops at
assignment time lets every downstream consumer assume each entry is
a genuine repeat (count > 1). Pure function: does not mutate input.
"""
if not config:
return None
normalized = [b for b in config if b.get("count", 1) > 1]
return normalized if normalized else None
Loading