Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

- fix: preserve recurrent/hybrid model state when the full prompt is already cached by @allthatido and @abetlen in #2306

## [0.3.31]

- feat: update llama.cpp to ggml-org/llama.cpp@f449e0553
Expand Down
49 changes: 33 additions & 16 deletions llama_cpp/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,8 @@ def free_lora_adapter():
self._candidates = internals.LlamaTokenDataArray(n_vocab=self._n_vocab)

self.n_tokens = 0
# Restored or truncated state must decode before sampling.
self._requires_eval = True
self.input_ids: npt.NDArray[np.intc] = np.ndarray((n_ctx,), dtype=np.intc)
self.scores: npt.NDArray[np.single] = np.ndarray(
(n_ctx if logits_all == True else n_batch, self._n_vocab), dtype=np.single
Expand Down Expand Up @@ -647,6 +649,7 @@ def set_seed(self, seed: int):
def reset(self):
"""Reset the model state."""
self.n_tokens = 0
self._requires_eval = True

if self._is_recurrent or self._is_hybrid:
mem = llama_cpp.llama_get_memory(self._ctx.ctx)
Expand Down Expand Up @@ -689,6 +692,7 @@ def eval(self, tokens: Sequence[int]):
pass
# Update n_tokens
self.n_tokens += n_tokens
self._requires_eval = False

def _init_sampler(
self,
Expand Down Expand Up @@ -900,41 +904,53 @@ def generate(
grammar=grammar,
)

tokens = list(tokens)

# Check for kv cache prefix match
if reset and self.n_tokens > 0:
longest_prefix = 0
for a, b in zip(self._input_ids, tokens[:-1]):
for a, b in zip(self._input_ids, tokens):
if a == b:
longest_prefix += 1
else:
break

# Recurrent and hybrid models cannot rewind state; reset if needed
if (
self._is_recurrent or self._is_hybrid
) and longest_prefix < self.n_tokens:
longest_prefix = 0
reset = True
prompt_consumed = longest_prefix == len(tokens)
exact_prompt_cached = self.n_tokens == len(tokens) and prompt_consumed

# Exact cache hits can sample immediately only when the current
# logits were produced by a live decode, not restored state.
if exact_prompt_cached and not self._requires_eval:
reset = False
tokens = []
reuse_prefix = 0
if self.verbose:
print(
"Llama.generate: recurrent/hybrid model requires full state reset",
"Llama.generate: full prompt already cached, skipping reset",
file=sys.stderr,
)

if longest_prefix > 0:
if self._ctx.kv_cache_seq_rm(-1, longest_prefix, -1):
else:
# If there is no suffix to decode, replay one token to refresh
# logits after truncating to a valid prefix.
reuse_prefix = longest_prefix - 1 if prompt_consumed else longest_prefix

# Prefix hits can reuse memory because the suffix decode refreshes
# logits before sampling.
if reuse_prefix > 0:
if self._ctx.kv_cache_seq_rm(-1, reuse_prefix, -1):
reset = False
tokens = tokens[longest_prefix:]
self.n_tokens = longest_prefix
tokens = tokens[reuse_prefix:]
self.n_tokens = reuse_prefix
self._requires_eval = True
if self.verbose:
print(
f"Llama.generate: {longest_prefix} prefix-match hit, "
f"Llama.generate: {reuse_prefix} prefix-match hit, "
f"remaining {len(tokens)} prompt tokens to eval",
file=sys.stderr,
)
elif self.verbose:
print(
f"Llama.generate: {longest_prefix} prefix-match found "
f"Llama.generate: {reuse_prefix} prefix-match found "
f"but partial kv removal not supported, re-evaluating full prompt",
file=sys.stderr,
)
Expand All @@ -948,7 +964,6 @@ def generate(
# grammar.reset()

sample_idx = self.n_tokens + len(tokens) - 1
tokens = list(tokens)

# Eval and sample
while True:
Expand Down Expand Up @@ -988,6 +1003,7 @@ def generate(
if sample_idx < self.n_tokens and token != self._input_ids[sample_idx]:
self.n_tokens = sample_idx
self._ctx.kv_cache_seq_rm(-1, self.n_tokens, -1)
self._requires_eval = True
break

if self.draft_model is not None:
Expand Down Expand Up @@ -2217,6 +2233,7 @@ def load_state(self, state: LlamaState) -> None:
rest[rest > 0] = 0.0
self.input_ids = state.input_ids.copy()
self.n_tokens = state.n_tokens
self._requires_eval = True
self._seed = state.seed
state_size = state.llama_state_size
LLamaStateArrayType = ctypes.c_uint8 * state_size
Expand Down
Loading
Loading