|
1 | 1 | """Utilities for KV smashing.""" |
2 | 2 |
|
3 | 3 | from collections.abc import Iterable |
4 | | -from functools import reduce |
5 | | -from typing import Any |
6 | 4 |
|
7 | 5 | import torch |
| 6 | +from transformers import PreTrainedModel, PreTrainedTokenizerBase |
8 | 7 | from transformers.cache_utils import DynamicCache |
9 | | -from transformers.tokenization_utils_base import BatchEncoding |
10 | 8 |
|
11 | | -TokenizedCacheIterleaving = Iterable[BatchEncoding | DynamicCache] |
12 | | -LegacyCache = Any |
13 | 9 |
|
| 10 | +@torch.no_grad() |
| 11 | +def prefill_cache_v5( |
| 12 | + model: PreTrainedModel, |
| 13 | + tokenizer: PreTrainedTokenizerBase, |
| 14 | + text: str, |
| 15 | + device: torch.device, |
| 16 | +) -> tuple[dict, DynamicCache]: |
| 17 | + """Prefills cache for transformers v5.""" |
| 18 | + toks = tokenizer(text, return_tensors="pt") |
| 19 | + toks = {k: v.to(device) for k, v in toks.items()} |
14 | 20 |
|
15 | | -def legacy_cache_smash(a: LegacyCache, b: LegacyCache) -> LegacyCache: |
16 | | - """Concatenates two LegacyCache Ks and Vs along the time axis.""" |
17 | | - legacy_merged = tuple( |
18 | | - (torch.cat([a[i][0], b[i][0]], dim=2), torch.cat([a[i][1], b[i][1]], dim=2)) |
19 | | - for i in range(len(a)) |
| 21 | + dc = DynamicCache() |
| 22 | + out = model( |
| 23 | + input_ids=toks["input_ids"], |
| 24 | + attention_mask=toks["attention_mask"], |
| 25 | + past_key_values=dc, |
| 26 | + use_cache=True, |
20 | 27 | ) |
21 | | - return legacy_merged |
22 | | - |
23 | | - |
24 | | -def merge_dynamic_caches(caches: Iterable[DynamicCache]) -> DynamicCache: |
25 | | - """Merges two DynamicCache Ks and Vs along the time axis.""" |
26 | | - legacies = [c.to_legacy_cache() for c in caches] # type: ignore |
27 | | - assert len(legacies) >= 1 |
28 | | - rv = DynamicCache.from_legacy_cache(reduce(legacy_cache_smash, legacies)) # type: ignore |
29 | | - return rv # type: ignore |
30 | | - |
31 | | - |
32 | | -def tokens_to_legacy_cache( |
33 | | - model, device: str, tokens_or_cache: BatchEncoding | DynamicCache |
34 | | -) -> Iterable[LegacyCache]: |
35 | | - """Prefills and returns Ks and Vs as a LegacyCache.""" |
36 | | - if type(tokens_or_cache) is DynamicCache: |
37 | | - return tokens_or_cache.to_legacy_cache() # type: ignore |
38 | | - else: |
39 | | - tokens = tokens_or_cache |
40 | | - dc = DynamicCache() |
41 | | - with torch.no_grad(): |
42 | | - dc = model( |
43 | | - tokens["input_ids"].to(device), # type: ignore |
44 | | - attention_mask=tokens["attention_mask"].to(device), # type: ignore |
45 | | - past_key_values=dc, |
46 | | - ).past_key_values |
47 | | - return dc.to_legacy_cache() |
| 28 | + dc = out.past_key_values |
| 29 | + dc.crop(-1) |
| 30 | + return toks, dc # v5 returns DynamicCache (not legacy tuple) |
| 31 | + |
| 32 | + |
| 33 | +def merge_dynamic_caches_v5(caches: Iterable[DynamicCache]) -> DynamicCache: |
| 34 | + """Merge multiple v5 DynamicCache objects by concatenating KV states along the time axis.""" |
| 35 | + caches = list(caches) |
| 36 | + assert len(caches) >= 1 |
| 37 | + |
| 38 | + for c in caches: |
| 39 | + if any( |
| 40 | + getattr(layer, "is_sliding", False) for layer in getattr(c, "layers", []) |
| 41 | + ): |
| 42 | + raise ValueError("Check the issue.") |
| 43 | + |
| 44 | + merged = DynamicCache() |
| 45 | + |
| 46 | + # reuse Cache.update() to append each segment's KV to the merged cache per layer. |
| 47 | + # DynamicLayer.update(): self.keys = cat([self.keys, key_states], dim=-2). |
| 48 | + for c in caches: |
| 49 | + for layer_idx, layer in enumerate(c.layers): |
| 50 | + if layer.keys is None or layer.values is None: |
| 51 | + continue |
| 52 | + merged.update(layer.keys, layer.values, layer_idx=layer_idx) |
| 53 | + |
| 54 | + return merged |
| 55 | + |
| 56 | + |
| 57 | +def merge_v5( |
| 58 | + model: PreTrainedModel, |
| 59 | + tokenizer: PreTrainedTokenizerBase, |
| 60 | + strs: list[str], |
| 61 | + device: torch.device, |
| 62 | +): |
| 63 | + """Merges DynamicCache for transformers>=5.0.0.""" |
| 64 | + strs_toks, strs_dcs = [], [] |
| 65 | + for s in strs: |
| 66 | + toks, dc = prefill_cache_v5(model, tokenizer, s, device) |
| 67 | + strs_toks.append(toks) |
| 68 | + strs_dcs.append(dc) |
| 69 | + |
| 70 | + merged_toks = torch.cat([t["input_ids"] for t in strs_toks], dim=1) |
| 71 | + merged_masks = torch.cat([t["attention_mask"] for t in strs_toks], dim=1) |
| 72 | + |
| 73 | + merged_dc = merge_dynamic_caches_v5(strs_dcs) |
| 74 | + |
| 75 | + return merged_toks, merged_masks, merged_dc |
| 76 | + |
| 77 | + |
| 78 | +if __name__ == "__main__": |
| 79 | + from mellea.backends.huggingface import LocalHFBackend |
| 80 | + from mellea.backends.model_ids import IBM_GRANITE_3_3_8B |
| 81 | + |
| 82 | + backend = LocalHFBackend(model_id=IBM_GRANITE_3_3_8B.hf_model_name) |
| 83 | + model, tokenizer, device = backend._model, backend._tokenizer, backend._device |
| 84 | + model: PreTrainedModel = model |
| 85 | + |
| 86 | + docs = [ |
| 87 | + "Nathan Fulton is expert in large language models, formal verification, and reinforcement learning. He holds a Ph.D. from Carnegie Mellon University's Computer Science Department and has worked at Amazon Web Services and IBM Research. He currently works at IBM Research - Cambridge.", |
| 88 | + "IBM Research has a headquarters at 1101 Kitchawan Rd in Yorktown Heights and a Cambridge office at 314 Main Street in Cambridge, MA.", |
| 89 | + "What is the address of Nathan's place of work?", |
| 90 | + ] |
| 91 | + |
| 92 | + merged_tokens, merged_masks, merged_cache = merge_v5( |
| 93 | + model, tokenizer, docs, device=backend._device |
| 94 | + ) |
| 95 | + input_ids = merged_tokens.to(device) |
| 96 | + result = model.generate( |
| 97 | + input_ids=input_ids, |
| 98 | + use_cache=True, |
| 99 | + return_dict_in_generate=True, |
| 100 | + past_key_values=merged_cache, |
| 101 | + max_new_tokens=512, |
| 102 | + ) |
| 103 | + result = tokenizer.decode( |
| 104 | + result.sequences[0, input_ids.shape[1] :], skip_special_tokens=True |
| 105 | + ) |
| 106 | + print(result) |
0 commit comments