Skip to content
39 changes: 39 additions & 0 deletions docs/kv_smash/hf_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from mellea.backends.huggingface import LocalHFBackend
from mellea.backends.model_ids import IBM_GRANITE_3_3_8B
from mellea.backends.types import ModelOption
from mellea.stdlib.base import CBlock, LinearContext
from mellea.stdlib.chat import Message

ctx = LinearContext(window_size=100)
ctx.insert(
CBlock(
"Nathan Fulton is a Senior Research Scientist at the MIT-IBM Watson AI Lab, a joint venture between MIT and IBM.",
cache=True,
)
)
ctx.insert(
CBlock(
"The MIT-IBM Watson AI Lab is located at 314 Main St, Cambridge, Massachusetts.",
cache=True,
)
)
ctx.insert(CBlock("The ZIP code for 314 Main St, Cambridge, Massachusetts is 02142"))


msg = Message(
role="user", content="What is the likely ZIP code of Nathan Fulton's work address."
)
backend = LocalHFBackend(model_id=IBM_GRANITE_3_3_8B)
result = backend._generate_from_context_with_kv_cache(
action=msg, ctx=ctx, model_options={ModelOption.MAX_NEW_TOKENS: 1000}
)
print(f".{result}.")

msg2 = Message(
role="user",
content="We know that Nathan does not work for a university. What is the likely name of Nathan's employer?",
)
result = backend._generate_from_context_with_kv_cache(
action=msg2, ctx=ctx, model_options={ModelOption.MAX_NEW_TOKENS: 1000}
)
print(f".{result}.")
110 changes: 110 additions & 0 deletions docs/kv_smash/kv_with_chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import torch

from mellea.backends.huggingface import LocalHFBackend
from mellea.backends.kv_block_helpers import DynamicCache, merge_dynamic_caches
from mellea.backends.model_ids import IBM_GRANITE_3_3_8B
from mellea.stdlib.base import CBlock, LinearContext
from mellea.stdlib.chat import Message

backend = LocalHFBackend(model_id=IBM_GRANITE_3_3_8B)

model = backend._model
tokenizer = backend._tokenizer
device = backend._device


KV_CACHE: dict[str, DynamicCache] = dict()


def cache(s: str, store=True) -> DynamicCache:
toks = tokenizer(s, return_tensors="pt")
dc = DynamicCache()
with torch.no_grad():
rv = model(
toks["input_ids"].to(device),
attention_mask=toks["attention_mask"].to(device),
past_key_values=dc,
).past_key_values
KV_CACHE[s] = rv
return rv


def merge(toks, dcs):
merged_toks = torch.cat([t["input_ids"] for t in toks], dim=1)
merged_masks = torch.cat([t["attention_mask"] for t in toks], dim=1)
merged_dcs = merge_dynamic_caches(dcs)

return merged_toks, merged_masks, merged_dcs


c_blocks = ["this is a test", "this is another test"]

# pretend this stuff already existed in the cahce.
for cb in c_blocks:
cache(cb)


# apply the chat template to a conversation that contins these strings, but without tokenization.
messages = [
{"role": "user", "content": c_blocks[0]},
{"role": "user", "content": "Not cached"},
{"role": "user", "content": c_blocks[1]},
{"role": "user", "content": "Also no cash"},
]
templatized_input = tokenizer.apply_chat_template(conversation=messages, tokenize=False)

str_parts = []
tok_parts = []
dc_parts = []

current_suffix = templatized_input
partially_cached_templatized_input = list[str | DynamicCache]
for cb in c_blocks:
parts = current_suffix.split(cb)
assert len(parts) == 2
prefix, next_suffix = parts

if prefix != "":
# Add the prefix.
str_parts.append(prefix)
# Add the tokens and attention mask for the prefix.
tok_parts.append(tokenizer(prefix, return_tensors="pt"))
# Add the dynamic cache for the prefix.
dc_parts.append(cache(prefix, store=False))

# Add cb itself.
str_parts.append(cb)
tok_parts.append(tokenizer(cb, return_tensors="pt"))
dc_parts.append(KV_CACHE[cb])

# set the current suffix.
current_suffix = next_suffix

# REMEMBER: add the final suffix.
if current_suffix != "":
str_parts.append(current_suffix)
tok_parts.append(tokenizer(current_suffix, return_tensors="pt"))
dc_parts.append(cache(current_suffix, store=False))

# Merge evertything together.
merged_toks = torch.cat([toks["input_ids"] for toks in tok_parts], dim=1)
merged_masks = torch.cat([toks["attention_mask"] for toks in tok_parts], dim=1)
merged_dcs = merge_dynamic_caches(dc_parts)

# crop the last KV for safety.
merged_dcs.crop(-1)

# generate and print result.
result = model.generate(
merged_toks.to(device),
attention_mask=merged_masks.to(device),
past_key_values=merged_dcs,
use_cache=True,
return_dict_in_generate=True,
output_scores=True,
)

result_decoded = tokenizer.decode(
result.sequences[0, merged_toks.shape[1] :], skip_special_tokens=True
)
print(result_decoded)
55 changes: 55 additions & 0 deletions docs/kv_smash/kvcache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import torch

from mellea.backends.huggingface import LocalHFBackend
from mellea.backends.kv_block_helpers import DynamicCache, merge_dynamic_caches
from mellea.backends.model_ids import IBM_GRANITE_3_3_8B
from mellea.stdlib.base import CBlock, LinearContext
from mellea.stdlib.chat import Message

backend = LocalHFBackend(model_id=IBM_GRANITE_3_3_8B)

model = backend._model
tokenizer = backend._tokenizer
device = backend._device


def cache(toks) -> DynamicCache:
dc = DynamicCache()
with torch.no_grad():
rv = model(
toks["input_ids"].to(device),
attention_mask=toks["attention_mask"].to(device),
past_key_values=dc,
).past_key_values
return rv


def merge(strs: list[str]):
strs_toks = [tokenizer(x, return_tensors="pt") for x in strs]
strs_dcs = [cache(toks) for toks in strs_toks]

merged_toks = torch.cat([toks["input_ids"] for toks in strs_toks], dim=1)
merged_masks = torch.cat([toks["attention_mask"] for toks in strs_toks], dim=1)
merged_dcs = merge_dynamic_caches(strs_dcs)

return merged_toks, merged_masks, merged_dcs


strs = ["this is a test", "this is another test"]

merged_toks, merged_masks, merged_dcs = merge(strs)
merged_dcs.crop(-1)

result = model.generate(
merged_toks.to(device),
attention_mask=merged_masks.to(device),
past_key_values=merged_dcs,
use_cache=True,
return_dict_in_generate=True,
output_scores=True,
)

result_decoded = tokenizer.decode(
result.sequences[0, merged_toks.shape[1] :], skip_special_tokens=True
)
print(result_decoded)
Loading
Loading