Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions cli/alora/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
PreTrainedTokenizerBase,
TrainerCallback,
TrainerControl,
TrainerState,
Expand Down Expand Up @@ -47,7 +48,7 @@


def load_dataset_from_json(
json_path: str, tokenizer: AutoTokenizer, invocation_prompt: str
json_path: str, tokenizer: PreTrainedTokenizerBase, invocation_prompt: str
) -> Dataset:
"""Load a JSONL dataset and format it for SFT training.

Expand Down Expand Up @@ -159,8 +160,12 @@ def save_model(self, output_dir: str | None = None, _internal_call: bool = False
"""
if self.model is not None:
self.model.save_pretrained(output_dir, safe_serialization=True)
if self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir)
# transformers v5 renamed .tokenizer -> .processing_class
processor = getattr(self, "processing_class", None) or getattr(
self, "tokenizer", None
)
if processor is not None:
processor.save_pretrained(output_dir)


def train_model(
Expand Down Expand Up @@ -218,7 +223,6 @@ def train_model(
base_model, padding_side="right", trust_remote_code=True
)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.add_special_tokens = False

dataset = load_dataset_from_json(dataset_path, tokenizer, invocation_prompt)
dataset = dataset.shuffle(seed=42)
Expand Down
2 changes: 1 addition & 1 deletion docs/docs/advanced/prefix-caching-and-kv-blocks.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ When a prompt contains a mix of cached and uncached blocks, Mellea:
2. Runs forward passes on uncached blocks.
3. Retrieves stored `DynamicCache` for cached blocks.
4. **Smashes** (concatenates) all KV caches along the time axis using
`merge_dynamic_caches()`.
`merge_dynamic_caches_v5()`.
5. Passes the merged cache plus the combined input IDs to the generation step.

The result is identical to a single full-context forward pass, with the prefill
Expand Down
6 changes: 3 additions & 3 deletions docs/kv_smash/kv_with_chat.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch

from mellea.backends.huggingface import LocalHFBackend
from mellea.backends.kv_block_helpers import DynamicCache, merge_dynamic_caches
from mellea.backends.kv_block_helpers import DynamicCache, merge_dynamic_caches_v5
from mellea.backends.model_ids import IBM_GRANITE_4_HYBRID_MICRO

backend = LocalHFBackend(model_id=IBM_GRANITE_4_HYBRID_MICRO)
Expand Down Expand Up @@ -30,7 +30,7 @@ def cache(s: str, store=True) -> DynamicCache:
def merge(toks, dcs):
merged_toks = torch.cat([t["input_ids"] for t in toks], dim=1)
merged_masks = torch.cat([t["attention_mask"] for t in toks], dim=1)
merged_dcs = merge_dynamic_caches(dcs)
merged_dcs = merge_dynamic_caches_v5(dcs)

return merged_toks, merged_masks, merged_dcs

Expand Down Expand Up @@ -89,7 +89,7 @@ def merge(toks, dcs):
# Merge everything together.
merged_toks = torch.cat([toks["input_ids"] for toks in tok_parts], dim=1)
merged_masks = torch.cat([toks["attention_mask"] for toks in tok_parts], dim=1)
merged_dcs = merge_dynamic_caches(dc_parts)
merged_dcs = merge_dynamic_caches_v5(dc_parts)

# crop the last KV for safety.
merged_dcs.crop(-1)
Expand Down
17 changes: 10 additions & 7 deletions docs/kv_smash/kvcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,20 @@
# "mellea[hf]",
# ]
# ///
from typing import cast

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizer
from transformers.generation import GenerateDecoderOnlyOutput
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase
from transformers.generation.utils import GenerateDecoderOnlyOutput
from transformers.modeling_utils import PreTrainedModel

from mellea.backends.kv_block_helpers import DynamicCache, merge_dynamic_caches
from mellea.backends.kv_block_helpers import DynamicCache, merge_dynamic_caches_v5

model_id = "ibm-granite/granite-4.0-tiny-preview"
device = torch.device("mps")
model = AutoModelForCausalLM.from_pretrained(model_id)
model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(model_id) # type: ignore[assignment]
# model = model.to(device=device) # this part does not pass mypy; possible misconfiguration
tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer: PreTrainedTokenizerBase = AutoTokenizer.from_pretrained(model_id)


def cache(toks) -> DynamicCache:
Expand All @@ -34,7 +37,7 @@ def merge(strs: list[str]):

merged_toks = torch.cat([toks["input_ids"] for toks in strs_toks], dim=1)
merged_masks = torch.cat([toks["attention_mask"] for toks in strs_toks], dim=1)
merged_dcs = merge_dynamic_caches(strs_dcs)
merged_dcs = merge_dynamic_caches_v5(strs_dcs)

return merged_toks, merged_masks, merged_dcs

Expand All @@ -45,7 +48,7 @@ def merge(strs: list[str]):
merged_dcs.crop(-1)

# GenerateDecoderOnlyOutput | GenerateEncoderDecoderOutput | GenerateBeamDecoderOnlyOutput | GenerateBeamEncoderDecoderOutput | LongTensor
result = model.generate(
result = model.generate( # type: ignore[operator]
merged_toks.to(model.device),
attention_mask=merged_masks.to(model.device),
past_key_values=merged_dcs,
Expand Down
2 changes: 1 addition & 1 deletion docs/metrics/coverage-current.json
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@
"TokenizedCacheIterleaving",
"LegacyCache",
"legacy_cache_smash",
"merge_dynamic_caches",
"merge_dynamic_caches_v5",
"tokens_to_legacy_cache"
],
"mellea.backends.huggingface.granite_formatters": [
Expand Down
45 changes: 31 additions & 14 deletions mellea/backends/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,21 @@
import json
import threading
from collections.abc import Callable, Coroutine, Sequence
from typing import Any, overload
from typing import Any, cast, overload

import llguidance
import llguidance.hf
import llguidance.torch
import torch
import transformers as _transformers_module
from packaging.version import Version
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.cache_utils import DynamicCache
from transformers.generation.logits_process import LogitsProcessorList
from transformers.generation.streamers import AsyncTextIteratorStreamer
from transformers.generation.utils import GenerateDecoderOnlyOutput
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils import PreTrainedTokenizer
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.trainer_utils import set_seed

from ..backends import kv_block_helpers
Expand Down Expand Up @@ -67,11 +69,13 @@
)
from .utils import to_chat, to_tool_calls

_TRANSFORMERS_V5 = Version(_transformers_module.__version__) >= Version("5.0")

"""A configuration type for the unhappy path: Tokenizer * Model * torch device string

Huggingface backends can initialize themselves from a model string if the transformers `Auto*` classes can be used. Therefore, a TransformersTorchConfig usually isn't required. However, sometimes a model needs special care to instantiate properly, or a custom device type needs to bse used. Instead of trying to do a lot of partial magic, we basically have two modaliites: either the constructor can figure out everything from the model_id, or the user has to provide an entire config.
"""
TransformersTorchConfig = tuple[PreTrainedTokenizer, PreTrainedModel, torch.device]
TransformersTorchConfig = tuple[PreTrainedTokenizerBase, PreTrainedModel, torch.device]

format: None = None # typing this variable in order to shadow the global format function and ensure mypy checks for errors

Expand Down Expand Up @@ -302,8 +306,8 @@ def __init__(
self._model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(
self._hf_model_id, device_map=str(self._device), torch_dtype="auto"
)
self._tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(
self._hf_model_id
self._tokenizer: PreTrainedTokenizerBase = (
AutoTokenizer.from_pretrained(self._hf_model_id)
)
case _:
self._tokenizer, self._model, self._device = custom_config
Expand Down Expand Up @@ -726,7 +730,7 @@ def _make_merged_kv_cache(
[toks["attention_mask"] for toks in tok_parts], dim=1
)
assert input_ids.shape == attention_mask.shape
merged_cache: DynamicCache = kv_block_helpers.merge_dynamic_caches(dc_parts)
merged_cache: DynamicCache = kv_block_helpers.merge_dynamic_caches_v5(dc_parts)
# TODO: also assert that the merged cached is the correct shape given the input_ids and attention_mask shapes.
# rewind merged cache by 1 for safety.
merged_cache.crop(-1) # type: ignore
Expand Down Expand Up @@ -973,7 +977,8 @@ async def _generate_from_context_standard(
"", # Empty for no adapters.
self._model.generate, # type: ignore
# Passed as args/kwargs to generate.
input_ids,
inputs=input_ids["input_ids"],
attention_mask=input_ids["attention_mask"],
return_dict_in_generate=True,
use_cache=self._use_caches, # Only create KV cache if caching is enabled
**self._make_backend_specific_and_remove(generate_options),
Expand Down Expand Up @@ -1045,6 +1050,10 @@ async def processing(
input_ids: The prompt token IDs used for decoding; required to slice off
the prompt portion from the generated sequences.
"""
input_ids_tensor = (
input_ids if isinstance(input_ids, torch.Tensor) else input_ids["input_ids"]
)

if mot._underlying_value is None:
mot._underlying_value = ""

Expand All @@ -1055,8 +1064,12 @@ async def processing(
elif isinstance(chunk, GenerateDecoderOnlyOutput):
# Otherwise, it's a non-streaming request. Decode it here.
mot._meta["hf_output"] = chunk
mot._underlying_value += self._tokenizer.decode(
chunk.sequences[0, input_ids.shape[1] :], skip_special_tokens=True
mot._underlying_value += cast(
str,
self._tokenizer.decode(
chunk.sequences[0, input_ids_tensor.shape[1] :],
skip_special_tokens=True,
),
)

async def post_processing(
Expand Down Expand Up @@ -1110,7 +1123,7 @@ class used during generation, if any.
kv_cache=kv_cache,
merged_token_ids=output_complete,
merged_attention=torch.ones_like(output_complete).to(self._device),
q_end=len(input_ids[0]), # type: ignore
q_end=input_ids["input_ids"].shape[1], # type: ignore
scores=hf_output.scores,
)

Expand Down Expand Up @@ -1140,7 +1153,7 @@ class used during generation, if any.
if isinstance(hf_output, GenerateDecoderOnlyOutput):
try:
if input_ids is not None and hf_output.sequences is not None:
n_prompt = input_ids.shape[1]
n_prompt = input_ids["input_ids"].shape[1]
n_completion = hf_output.sequences[0].shape[0] - n_prompt
except Exception:
pass
Expand Down Expand Up @@ -1517,9 +1530,13 @@ def load_adapter(self, adapter_qualified_name: str):
try:
adapter_kwargs = {}

# Peft tries to stringify the device. If it's mps, it gets stringified as "mps:0" which causes
# an error when loading with safetensors.torch.load_file. Force the device as a string "mps" to fix.
if self._device == torch.device("mps"):
# v4: Peft tries to stringify the device. If it's mps, it gets stringified as "mps:0" which
# causes an error when loading with safetensors.torch.load_file. Force the device as a string
# "mps" to fix.
# v5: adapter_kwargs is forwarded to download_kwargs only; device is derived automatically
# from self.device, so passing it here would hit find_adapter_config_file() which no longer
# accepts a 'device' argument.
if not _TRANSFORMERS_V5 and self._device == torch.device("mps"):
adapter_kwargs["device"] = "mps"
self._model.load_adapter(
adapter.path, adapter.qualified_name, adapter_kwargs=adapter_kwargs
Expand Down
Loading
Loading