From e78de054b04731f9981f501b9ce6f77c1d257b3c Mon Sep 17 00:00:00 2001 From: Ankur Kaul Date: Mon, 15 Jun 2026 03:15:36 +0530 Subject: [PATCH] fix: preserve recurrent/hybrid model state when the full prompt is already cached --- CHANGELOG.md | 2 + llama_cpp/llama.py | 49 +++++--- tests/test_llama.py | 288 ++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 323 insertions(+), 16 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d6d33dbbe6..b1d5fb8801 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +- fix: preserve recurrent/hybrid model state when the full prompt is already cached by @allthatido and @abetlen in #2306 + ## [0.3.31] - feat: update llama.cpp to ggml-org/llama.cpp@f449e0553 diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 4a09b55ee5..b5bffd46b5 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -471,6 +471,8 @@ def free_lora_adapter(): self._candidates = internals.LlamaTokenDataArray(n_vocab=self._n_vocab) self.n_tokens = 0 + # Restored or truncated state must decode before sampling. + self._requires_eval = True self.input_ids: npt.NDArray[np.intc] = np.ndarray((n_ctx,), dtype=np.intc) self.scores: npt.NDArray[np.single] = np.ndarray( (n_ctx if logits_all == True else n_batch, self._n_vocab), dtype=np.single @@ -647,6 +649,7 @@ def set_seed(self, seed: int): def reset(self): """Reset the model state.""" self.n_tokens = 0 + self._requires_eval = True if self._is_recurrent or self._is_hybrid: mem = llama_cpp.llama_get_memory(self._ctx.ctx) @@ -689,6 +692,7 @@ def eval(self, tokens: Sequence[int]): pass # Update n_tokens self.n_tokens += n_tokens + self._requires_eval = False def _init_sampler( self, @@ -900,41 +904,53 @@ def generate( grammar=grammar, ) + tokens = list(tokens) + # Check for kv cache prefix match if reset and self.n_tokens > 0: longest_prefix = 0 - for a, b in zip(self._input_ids, tokens[:-1]): + for a, b in zip(self._input_ids, tokens): if a == b: longest_prefix += 1 else: break - # Recurrent and hybrid models cannot rewind state; reset if needed - if ( - self._is_recurrent or self._is_hybrid - ) and longest_prefix < self.n_tokens: - longest_prefix = 0 - reset = True + prompt_consumed = longest_prefix == len(tokens) + exact_prompt_cached = self.n_tokens == len(tokens) and prompt_consumed + + # Exact cache hits can sample immediately only when the current + # logits were produced by a live decode, not restored state. + if exact_prompt_cached and not self._requires_eval: + reset = False + tokens = [] + reuse_prefix = 0 if self.verbose: print( - "Llama.generate: recurrent/hybrid model requires full state reset", + "Llama.generate: full prompt already cached, skipping reset", file=sys.stderr, ) - - if longest_prefix > 0: - if self._ctx.kv_cache_seq_rm(-1, longest_prefix, -1): + else: + # If there is no suffix to decode, replay one token to refresh + # logits after truncating to a valid prefix. + reuse_prefix = longest_prefix - 1 if prompt_consumed else longest_prefix + + # Prefix hits can reuse memory because the suffix decode refreshes + # logits before sampling. + if reuse_prefix > 0: + if self._ctx.kv_cache_seq_rm(-1, reuse_prefix, -1): reset = False - tokens = tokens[longest_prefix:] - self.n_tokens = longest_prefix + tokens = tokens[reuse_prefix:] + self.n_tokens = reuse_prefix + self._requires_eval = True if self.verbose: print( - f"Llama.generate: {longest_prefix} prefix-match hit, " + f"Llama.generate: {reuse_prefix} prefix-match hit, " f"remaining {len(tokens)} prompt tokens to eval", file=sys.stderr, ) elif self.verbose: print( - f"Llama.generate: {longest_prefix} prefix-match found " + f"Llama.generate: {reuse_prefix} prefix-match found " f"but partial kv removal not supported, re-evaluating full prompt", file=sys.stderr, ) @@ -948,7 +964,6 @@ def generate( # grammar.reset() sample_idx = self.n_tokens + len(tokens) - 1 - tokens = list(tokens) # Eval and sample while True: @@ -988,6 +1003,7 @@ def generate( if sample_idx < self.n_tokens and token != self._input_ids[sample_idx]: self.n_tokens = sample_idx self._ctx.kv_cache_seq_rm(-1, self.n_tokens, -1) + self._requires_eval = True break if self.draft_model is not None: @@ -2217,6 +2233,7 @@ def load_state(self, state: LlamaState) -> None: rest[rest > 0] = 0.0 self.input_ids = state.input_ids.copy() self.n_tokens = state.n_tokens + self._requires_eval = True self._seed = state.seed state_size = state.llama_state_size LLamaStateArrayType = ctypes.c_uint8 * state_size diff --git a/tests/test_llama.py b/tests/test_llama.py index 336d6a6122..70fce12d8e 100644 --- a/tests/test_llama.py +++ b/tests/test_llama.py @@ -1,4 +1,5 @@ import ctypes +import itertools import multiprocessing import numpy as np @@ -64,6 +65,14 @@ def llama_cpp_model_path(): return model_path +@pytest.fixture +def llama_cpp_transformer_model_path(): + repo_id = "ggml-org/models" + filename = "tinyllamas/stories15M-q4_0.gguf" + model_path = hf_hub_download(repo_id, filename) + return model_path + + @pytest.fixture def llama_cpp_embedding_model_path(): repo_id = "CompendiumLabs/bge-small-en-v1.5-gguf" @@ -339,6 +348,285 @@ def test_hybrid_model_prompt_cache_reset(llama_cpp_hybrid_model_path): ) +def _create_test_model(model_path): + return llama_cpp.Llama( + model_path, + n_ctx=64, + n_batch=64, + n_ubatch=64, + n_threads=multiprocessing.cpu_count(), + n_threads_batch=multiprocessing.cpu_count(), + logits_all=False, + verbose=False, + ) + + +def _generate_test_tokens(model, tokens, max_tokens=3): + return list( + itertools.islice( + model.generate( + tokens, + temp=0.0, + ), + max_tokens, + ) + ) + + +MODEL_CACHE_CASES = ( + ("llama_cpp_transformer_model_path", False, False), + ("llama_cpp_recurrent_model_path", True, False), + ("llama_cpp_hybrid_model_path", False, True), +) + +RESTORED_CACHE_CASES = MODEL_CACHE_CASES + + +def _eval_alternate_same_length_prompt(model, tokens, expected_next_token): + replacement_tokens = ( + model.token_eos(), + model.token_nl(), + 0, + 1, + 2, + model.n_vocab() - 1, + ) + + for replacement_token in replacement_tokens: + alternate_tokens = list(tokens) + alternate_tokens[-1] = replacement_token + if alternate_tokens == tokens: + continue + + model.reset() + model.eval(alternate_tokens) + if model.sample(temp=0.0, idx=len(tokens) - 1) != expected_next_token: + return + + raise AssertionError("failed to find an alternate same-length prompt") + + +def _assert_exact_cached_prompt_reuse_matches_fresh( + model_path, + *, + is_recurrent: bool, + is_hybrid: bool, +): + prompt = "The quick brown fox" + fresh = _create_test_model(model_path) + tokens = fresh.tokenize(prompt.encode(), add_bos=True, special=True) + + assert fresh._is_recurrent is is_recurrent + assert fresh._is_hybrid is is_hybrid + + expected_tokens = _generate_test_tokens(fresh, tokens) + + cached = _create_test_model(model_path) + assert cached._is_recurrent is is_recurrent + assert cached._is_hybrid is is_hybrid + + cached.eval(tokens) + assert cached.n_tokens == len(tokens) + assert cached.input_ids[: cached.n_tokens].tolist() == tokens + assert cached.sample(temp=0.0, idx=len(tokens) - 1) == expected_tokens[0] + + reset_calls = 0 + original_reset = cached.reset + + def reset_tracker(): + nonlocal reset_calls + reset_calls += 1 + original_reset() + + cached.reset = reset_tracker + + cached_tokens = _generate_test_tokens(cached, tokens) + assert reset_calls == 0 + assert cached_tokens == expected_tokens + assert cached.n_tokens == len(tokens) + len(cached_tokens) - 1 + + +def _assert_loaded_exact_cached_prompt_reuse_matches_fresh( + model_path, + *, + is_recurrent: bool, + is_hybrid: bool, +): + prompt = "The quick brown fox" + fresh = _create_test_model(model_path) + tokens = fresh.tokenize(prompt.encode(), add_bos=True, special=True) + expected_tokens = _generate_test_tokens(fresh, tokens) + + source = _create_test_model(model_path) + assert source._is_recurrent is is_recurrent + assert source._is_hybrid is is_hybrid + + source.eval(tokens) + state = source.save_state() + + loaded = _create_test_model(model_path) + assert loaded._is_recurrent is is_recurrent + assert loaded._is_hybrid is is_hybrid + + _eval_alternate_same_length_prompt( + loaded, + tokens, + expected_tokens[0], + ) + loaded.load_state(state) + + assert loaded.n_tokens == len(tokens) + assert loaded.input_ids[: loaded.n_tokens].tolist() == tokens + + loaded_tokens = _generate_test_tokens(loaded, tokens) + assert loaded_tokens == expected_tokens + assert loaded.n_tokens == len(tokens) + len(loaded_tokens) - 1 + + +def _assert_ram_cache_exact_prompt_hit_matches_fresh( + model_path, + *, + is_recurrent: bool, + is_hybrid: bool, +): + prompt = "The quick brown fox" + fresh = _create_test_model(model_path) + tokens = fresh.tokenize(prompt.encode(), add_bos=True, special=True) + expected = fresh.create_completion( + tokens, + max_tokens=1, + temperature=0.0, + seed=1337, + ) + + cache = llama_cpp.LlamaRAMCache() + writer = _create_test_model(model_path) + writer.set_cache(cache) + writer.create_completion( + tokens, + max_tokens=1, + temperature=0.0, + seed=1337, + ) + + cached = _create_test_model(model_path) + assert cached._is_recurrent is is_recurrent + assert cached._is_hybrid is is_hybrid + cached.set_cache(cache) + + load_state_calls = 0 + original_load_state = cached.load_state + + def load_state_tracker(state): + nonlocal load_state_calls + load_state_calls += 1 + original_load_state(state) + + cached.load_state = load_state_tracker + + actual = cached.create_completion( + tokens, + max_tokens=1, + temperature=0.0, + seed=1337, + ) + + assert load_state_calls == 1 + assert actual["choices"][0]["text"] == expected["choices"][0]["text"] + assert ( + actual["usage"]["completion_tokens"] == expected["usage"]["completion_tokens"] + ) + + +def _assert_shorter_prompt_prefix_reuse_matches_fresh( + model_path, + *, + is_recurrent: bool, + is_hybrid: bool, +): + prompt = "The quick brown fox" + history = " jumps over the lazy dog" + fresh = _create_test_model(model_path) + tokens = fresh.tokenize(prompt.encode(), add_bos=True, special=True) + history_tokens = fresh.tokenize(history.encode(), add_bos=False, special=True) + expected_tokens = _generate_test_tokens(fresh, tokens) + + cached = _create_test_model(model_path) + assert cached._is_recurrent is is_recurrent + assert cached._is_hybrid is is_hybrid + + cached.eval(tokens + history_tokens) + assert cached.n_tokens > len(tokens) + assert cached.input_ids[: len(tokens)].tolist() == tokens + + cached_tokens = _generate_test_tokens(cached, tokens) + assert cached_tokens == expected_tokens + + +@pytest.mark.parametrize( + ("model_path_fixture", "is_recurrent", "is_hybrid"), MODEL_CACHE_CASES +) +def test_exact_cached_prompt_reuse_matches_fresh( + request, + model_path_fixture, + is_recurrent, + is_hybrid, +): + _assert_exact_cached_prompt_reuse_matches_fresh( + request.getfixturevalue(model_path_fixture), + is_recurrent=is_recurrent, + is_hybrid=is_hybrid, + ) + + +@pytest.mark.parametrize( + ("model_path_fixture", "is_recurrent", "is_hybrid"), RESTORED_CACHE_CASES +) +def test_loaded_exact_cached_prompt_reuse_matches_fresh( + request, + model_path_fixture, + is_recurrent, + is_hybrid, +): + _assert_loaded_exact_cached_prompt_reuse_matches_fresh( + request.getfixturevalue(model_path_fixture), + is_recurrent=is_recurrent, + is_hybrid=is_hybrid, + ) + + +@pytest.mark.parametrize( + ("model_path_fixture", "is_recurrent", "is_hybrid"), RESTORED_CACHE_CASES +) +def test_ram_cache_exact_prompt_hit_matches_fresh( + request, + model_path_fixture, + is_recurrent, + is_hybrid, +): + _assert_ram_cache_exact_prompt_hit_matches_fresh( + request.getfixturevalue(model_path_fixture), + is_recurrent=is_recurrent, + is_hybrid=is_hybrid, + ) + + +@pytest.mark.parametrize( + ("model_path_fixture", "is_recurrent", "is_hybrid"), MODEL_CACHE_CASES +) +def test_shorter_prompt_prefix_reuse_matches_fresh( + request, + model_path_fixture, + is_recurrent, + is_hybrid, +): + _assert_shorter_prompt_prefix_reuse_matches_fresh( + request.getfixturevalue(model_path_fixture), + is_recurrent=is_recurrent, + is_hybrid=is_hybrid, + ) + + def test_real_llama_embeddings(llama_cpp_embedding_model_path): model = llama_cpp.Llama( llama_cpp_embedding_model_path,