Skip to content

Commit 32fcaab

Browse files
csbobbynrfulton
authored andcommitted
Adds working transformers v5 kv smash code.
1 parent ffb8b6c commit 32fcaab

3 files changed

Lines changed: 4485 additions & 4425 deletions

File tree

Lines changed: 96 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,106 @@
11
"""Utilities for KV smashing."""
22

33
from collections.abc import Iterable
4-
from functools import reduce
5-
from typing import Any
64

75
import torch
6+
from transformers import PreTrainedModel, PreTrainedTokenizerBase
87
from transformers.cache_utils import DynamicCache
9-
from transformers.tokenization_utils_base import BatchEncoding
108

11-
TokenizedCacheIterleaving = Iterable[BatchEncoding | DynamicCache]
12-
LegacyCache = Any
139

10+
@torch.no_grad()
11+
def prefill_cache_v5(
12+
model: PreTrainedModel,
13+
tokenizer: PreTrainedTokenizerBase,
14+
text: str,
15+
device: torch.device,
16+
) -> tuple[dict, DynamicCache]:
17+
"""Prefills cache for transformers v5."""
18+
toks = tokenizer(text, return_tensors="pt")
19+
toks = {k: v.to(device) for k, v in toks.items()}
1420

15-
def legacy_cache_smash(a: LegacyCache, b: LegacyCache) -> LegacyCache:
16-
"""Concatenates two LegacyCache Ks and Vs along the time axis."""
17-
legacy_merged = tuple(
18-
(torch.cat([a[i][0], b[i][0]], dim=2), torch.cat([a[i][1], b[i][1]], dim=2))
19-
for i in range(len(a))
21+
dc = DynamicCache()
22+
out = model(
23+
input_ids=toks["input_ids"],
24+
attention_mask=toks["attention_mask"],
25+
past_key_values=dc,
26+
use_cache=True,
2027
)
21-
return legacy_merged
22-
23-
24-
def merge_dynamic_caches(caches: Iterable[DynamicCache]) -> DynamicCache:
25-
"""Merges two DynamicCache Ks and Vs along the time axis."""
26-
legacies = [c.to_legacy_cache() for c in caches] # type: ignore
27-
assert len(legacies) >= 1
28-
rv = DynamicCache.from_legacy_cache(reduce(legacy_cache_smash, legacies)) # type: ignore
29-
return rv # type: ignore
30-
31-
32-
def tokens_to_legacy_cache(
33-
model, device: str, tokens_or_cache: BatchEncoding | DynamicCache
34-
) -> Iterable[LegacyCache]:
35-
"""Prefills and returns Ks and Vs as a LegacyCache."""
36-
if type(tokens_or_cache) is DynamicCache:
37-
return tokens_or_cache.to_legacy_cache() # type: ignore
38-
else:
39-
tokens = tokens_or_cache
40-
dc = DynamicCache()
41-
with torch.no_grad():
42-
dc = model(
43-
tokens["input_ids"].to(device), # type: ignore
44-
attention_mask=tokens["attention_mask"].to(device), # type: ignore
45-
past_key_values=dc,
46-
).past_key_values
47-
return dc.to_legacy_cache()
28+
dc = out.past_key_values
29+
dc.crop(-1)
30+
return toks, dc # v5 returns DynamicCache (not legacy tuple)
31+
32+
33+
def merge_dynamic_caches_v5(caches: Iterable[DynamicCache]) -> DynamicCache:
34+
"""Merge multiple v5 DynamicCache objects by concatenating KV states along the time axis."""
35+
caches = list(caches)
36+
assert len(caches) >= 1
37+
38+
for c in caches:
39+
if any(
40+
getattr(layer, "is_sliding", False) for layer in getattr(c, "layers", [])
41+
):
42+
raise ValueError("Check the issue.")
43+
44+
merged = DynamicCache()
45+
46+
# reuse Cache.update() to append each segment's KV to the merged cache per layer.
47+
# DynamicLayer.update(): self.keys = cat([self.keys, key_states], dim=-2).
48+
for c in caches:
49+
for layer_idx, layer in enumerate(c.layers):
50+
if layer.keys is None or layer.values is None:
51+
continue
52+
merged.update(layer.keys, layer.values, layer_idx=layer_idx)
53+
54+
return merged
55+
56+
57+
def merge_v5(
58+
model: PreTrainedModel,
59+
tokenizer: PreTrainedTokenizerBase,
60+
strs: list[str],
61+
device: torch.device,
62+
):
63+
"""Merges DynamicCache for transformers>=5.0.0."""
64+
strs_toks, strs_dcs = [], []
65+
for s in strs:
66+
toks, dc = prefill_cache_v5(model, tokenizer, s, device)
67+
strs_toks.append(toks)
68+
strs_dcs.append(dc)
69+
70+
merged_toks = torch.cat([t["input_ids"] for t in strs_toks], dim=1)
71+
merged_masks = torch.cat([t["attention_mask"] for t in strs_toks], dim=1)
72+
73+
merged_dc = merge_dynamic_caches_v5(strs_dcs)
74+
75+
return merged_toks, merged_masks, merged_dc
76+
77+
78+
if __name__ == "__main__":
79+
from mellea.backends.huggingface import LocalHFBackend
80+
from mellea.backends.model_ids import IBM_GRANITE_3_3_8B
81+
82+
backend = LocalHFBackend(model_id=IBM_GRANITE_3_3_8B.hf_model_name)
83+
model, tokenizer, device = backend._model, backend._tokenizer, backend._device
84+
model: PreTrainedModel = model
85+
86+
docs = [
87+
"Nathan Fulton is expert in large language models, formal verification, and reinforcement learning. He holds a Ph.D. from Carnegie Mellon University's Computer Science Department and has worked at Amazon Web Services and IBM Research. He currently works at IBM Research - Cambridge.",
88+
"IBM Research has a headquarters at 1101 Kitchawan Rd in Yorktown Heights and a Cambridge office at 314 Main Street in Cambridge, MA.",
89+
"What is the address of Nathan's place of work?",
90+
]
91+
92+
merged_tokens, merged_masks, merged_cache = merge_v5(
93+
model, tokenizer, docs, device=backend._device
94+
)
95+
input_ids = merged_tokens.to(device)
96+
result = model.generate(
97+
input_ids=input_ids,
98+
use_cache=True,
99+
return_dict_in_generate=True,
100+
past_key_values=merged_cache,
101+
max_new_tokens=512,
102+
)
103+
result = tokenizer.decode(
104+
result.sequences[0, input_ids.shape[1] :], skip_special_tokens=True
105+
)
106+
print(result)

pyproject.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,13 +74,14 @@ hf = [
7474
"outlines-core==0.1.26",
7575
"outlines", # intentionally un-versioned, expecting a minor update. coutlines-core version should be enough to specify it
7676
"peft>=0.18.0", # aLoRA support was added in Peft 0.18.0
77-
"transformers>=4.53.2,<5",
77+
"transformers==5.0.0",
7878
"trl==0.19.1",
7979
"granite-common[transformers]",
8080
]
8181

8282
vllm = [
83-
"transformers<4.54.0",
83+
"transformers", # Removing the <4.54.0 pin; need to figure out if this breaks anything. - TODO-nrf
84+
# "transformers<4.54.0",
8485
# see https://github.com/vllm-project/vllm-ascend/issues/2046
8586
"numpy<2.0.0", # patching incorrect dependencies in vllm and outlines.
8687
# see https://github.com/vllm-project/vllm/issues/5587

0 commit comments

Comments
 (0)