You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ai_model/llama/patches/0008-ensure-KV-cache-is-ful...

353 lines
13 KiB

From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: jmorganca <jmorganca@gmail.com>
Date: Tue, 15 Apr 2025 14:27:40 -0400
Subject: [PATCH] ensure KV cache is fully defragmented
Sometimes the KV cache requires defragmentation even without
triggering the threshold heuristic. In this case, decoding
will not being able to find a KV cache slot. This is particularly
difficult for the caller to handle if it happens in between
ubatches. To avoid this, we should immediately trigger a defrag.
In addition, a heavily fragmented cache can require more than
max_moves to defragment. Currently, we stop when we hit the limit
but this can leave a cache that still does not have adequate space
even after defragmentation is triggered. Instead, we should do
multiple batches of processing until everything is complete.
---
src/llama-context.cpp | 18 ++++---
src/llama-context.h | 1 +
src/llama-kv-cache.cpp | 107 ++++++++++++++---------------------------
src/llama-kv-cache.h | 12 ++++-
4 files changed, 59 insertions(+), 79 deletions(-)
diff --git a/src/llama-context.cpp b/src/llama-context.cpp
index c22687e4..c5948e8f 100644
--- a/src/llama-context.cpp
+++ b/src/llama-context.cpp
@@ -950,9 +950,12 @@ int llama_context::decode(llama_batch & inp_batch) {
// find KV slot
if (!kv_self->find_slot(ubatch)) {
- LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
-
- return 1;
+ kv_self->defrag_sched(-1.0f);
+ kv_self->update(*this);
+ if (!kv_self->find_slot(ubatch)) {
+ LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
+ return 1;
+ }
}
ggml_backend_sched_reset(sched.get());
@@ -1967,9 +1970,12 @@ void llama_context::opt_epoch_iter(
// TODO: not sure if this is needed
if (!kv_self->find_slot(ubatch)) {
- LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
-
- GGML_ABORT("TODO: handle this error");
+ kv_self->defrag_sched(-1.0f);
+ kv_self->update(*this);
+ if (!kv_self->find_slot(ubatch)) {
+ LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
+ GGML_ABORT("TODO: handle this error");
+ }
}
auto * gf = graph_init();
diff --git a/src/llama-context.h b/src/llama-context.h
index c0ceacb1..0264e937 100644
--- a/src/llama-context.h
+++ b/src/llama-context.h
@@ -5,6 +5,7 @@
#include "llama-cparams.h"
#include "llama-graph.h"
#include "llama-adapter.h"
+#include "llama-kv-cache.h"
#include "ggml-cpp.h"
#include "ggml-opt.h"
diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp
index 3dcad65b..60e67b03 100644
--- a/src/llama-kv-cache.cpp
+++ b/src/llama-kv-cache.cpp
@@ -364,8 +364,6 @@ void llama_kv_cache_unified::commit() {
}
bool llama_kv_cache_unified::update(llama_context & lctx) {
- bool need_reserve = false;
-
auto * sched = lctx.get_sched();
if (has_shift) {
@@ -388,8 +386,6 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
res->set_inputs(nullptr);
lctx.graph_compute(gf, false);
-
- need_reserve = true;
}
{
@@ -403,27 +399,36 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
if (do_defrag) {
LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
+ const uint32_t n_max_nodes = lctx.graph_max_nodes();
+ const uint32_t max_moves = (n_max_nodes - 2*model.hparams.n_layer)/(6*model.hparams.n_layer);
+ if (!defrag_prepare(n_max_nodes)) {
+ LLAMA_LOG_ERROR("%s: failed to prepare defragmentation\n", __func__);
+ return false;
+ }
+
+ for (std::size_t i = 0; i < defrag_info.moves.size(); i += max_moves) {
+ std::vector<struct llama_kv_defrag_move> chunk;
+ auto end = std::min(i + max_moves, defrag_info.moves.size());
+ chunk.assign(defrag_info.moves.begin() + i, defrag_info.moves.begin() + end);
- if (defrag_prepare(lctx.graph_max_nodes())) {
ggml_backend_sched_reset(sched);
auto * gf = lctx.graph_init();
- auto res = build_graph_defrag(lctx.get_cparams(), lctx.get_ctx_compute(), gf);
+ auto res = build_graph_defrag(lctx.get_cparams(), lctx.get_ctx_compute(), gf, chunk);
ggml_backend_sched_alloc_graph(sched, gf);
res->set_inputs(nullptr);
lctx.graph_compute(gf, false);
-
- need_reserve = true;
}
do_defrag = false;
}
- return need_reserve;
+ // we never need to reserve a worst case graph
+ return false;
}
void llama_kv_cache_unified::defrag_sched(float thold) {
@@ -707,11 +712,10 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
const llama_cparams & cparams,
ggml_context * ctx,
- ggml_cgraph * gf) const {
+ ggml_cgraph * gf,
+ const std::vector<struct llama_kv_defrag_move> & moves) const {
auto res = std::make_unique<llm_graph_result>();
- const auto & ids = defrag_info.ids;
-
#if 0
// CPU defrag
//
@@ -783,32 +787,20 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
ggml_backend_tensor_set(v_l[il], buf_v.data(), 0, buf_v.size());
}
#else
- for (uint32_t i = 0; i < ids.size(); ++i) {
- const uint32_t id = ids[i];
-
- if (i == id || id == ids.size()) {
- continue;
- }
-
- uint32_t nm = 1;
-
- while (i + nm < ids.size() && ids[i + nm] == id + nm) {
- nm++;
- }
-
+ for (const auto & move : moves) {
for (uint32_t il = 0; il < hparams.n_layer; ++il) { // NOLINT
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
ggml_tensor * view_k_src = ggml_view_2d(ctx, k_l[il],
- n_embd_k_gqa, nm,
+ n_embd_k_gqa, move.len,
ggml_row_size(k_l[il]->type, n_embd_k_gqa),
- ggml_row_size(k_l[il]->type, n_embd_k_gqa*i));
+ ggml_row_size(k_l[il]->type, n_embd_k_gqa*move.src));
ggml_tensor * view_k_dst = ggml_view_2d(ctx, k_l[il],
- n_embd_k_gqa, nm,
+ n_embd_k_gqa, move.len,
ggml_row_size(k_l[il]->type, n_embd_k_gqa),
- ggml_row_size(k_l[il]->type, n_embd_k_gqa*id));
+ ggml_row_size(k_l[il]->type, n_embd_k_gqa*move.dst));
ggml_tensor * view_v_src;
ggml_tensor * view_v_dst;
@@ -816,31 +808,29 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
if (cparams.flash_attn) {
// NOTE: the V cache is not transposed when using flash attention
view_v_src = ggml_view_2d(ctx, v_l[il],
- n_embd_v_gqa, nm,
+ n_embd_v_gqa, move.len,
ggml_row_size(v_l[il]->type, n_embd_v_gqa),
- ggml_row_size(v_l[il]->type, n_embd_v_gqa*i));
+ ggml_row_size(v_l[il]->type, n_embd_v_gqa*move.dst));
view_v_dst = ggml_view_2d(ctx, v_l[il],
- n_embd_v_gqa, nm,
+ move.len, n_embd_v_gqa,
ggml_row_size(v_l[il]->type, n_embd_v_gqa),
- ggml_row_size(v_l[il]->type, n_embd_v_gqa*id));
+ ggml_row_size(v_l[il]->type, move.src));
} else {
view_v_src = ggml_view_2d(ctx, v_l[il],
- nm, n_embd_v_gqa,
+ move.len, n_embd_v_gqa,
ggml_row_size(v_l[il]->type, size),
- ggml_row_size(v_l[il]->type, i));
+ ggml_row_size(v_l[il]->type, move.src));
view_v_dst = ggml_view_2d(ctx, v_l[il],
- nm, n_embd_v_gqa,
+ move.len, n_embd_v_gqa,
ggml_row_size(v_l[il]->type, size),
- ggml_row_size(v_l[il]->type, id));
+ ggml_row_size(v_l[il]->type, move.dst));
}
ggml_build_forward_expand(gf, ggml_cpy(ctx, view_k_src, view_k_dst));
ggml_build_forward_expand(gf, ggml_cpy(ctx, view_v_src, view_v_dst));
}
-
- i += nm - 1;
}
//LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes);
@@ -857,17 +847,7 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
assert(n_used <= n_kv);
- //const int64_t t_start = ggml_time_us();
-
- // number of cells moved
- uint32_t n_moves = 0;
-
- // each move requires 6*n_layer tensors (see graph_build_kv_self_defrag)
- // - source view, destination view, copy operation
- // - x2 for keys and values
- //const uint32_t max_moves = max_nodes()/(6*n_layer);
- // TODO: tmp fix https://github.com/ggerganov/llama.cpp/issues/6685#issuecomment-2057579516
- const uint32_t max_moves = (n_max_nodes - 2*n_layer)/(6*n_layer);
+ defrag_info.moves.clear();
// determine which KV cells to move where
//
@@ -875,10 +855,7 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
//
// if ids[i] == i || ids[i] == n_kv, then cell i is not moved
//
- auto & ids = defrag_info.ids;
-
- ids.clear();
- ids.resize(n_kv, n_kv);
+ std::vector<uint32_t> ids(n_kv, n_kv);
for (uint32_t i0 = 0; i0 < n_used; ++i0) {
const auto & cell0 = cells[i0];
@@ -927,19 +904,11 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
// are we moving a continuous block of memory?
bool cont = false;
- // should we stop searching for the next move?
- bool stop = false;
-
// go back and move the nf cells to the hole
for (; i1 < n_kv; ++i1) {
auto & cell1 = cells[i1];
if (cell1.is_empty() || ids[i1] != n_kv) {
- if (n_moves == max_moves) {
- stop = true;
- break;
- }
-
cont = false;
continue;
}
@@ -955,8 +924,10 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
head = n_used;
if (!cont) {
- n_moves++;
+ defrag_info.moves.push_back({i1, i0 + nf, 1});
cont = true;
+ } else {
+ defrag_info.moves.back().len++;
}
nf++;
@@ -966,22 +937,16 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
}
}
- if (stop || n_moves == max_moves) {
- break;
- }
-
//LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, i1 + 1, i0, i0 + nh);
i0 += nh - 1;
}
- if (n_moves == 0) {
+ if (defrag_info.moves.size() == 0) {
return false;
}
- LLAMA_LOG_DEBUG("%s: (tmp log) KV defrag cell moves: %u\n", __func__, n_moves);
-
- LLAMA_LOG_DEBUG("%s: expected gf nodes: %u\n", __func__, 6*n_moves*n_layer);
+ // LLAMA_LOG_DEBUG("(tmp log) KV defrag cell moves: %u\n", n_moves);
return true;
}
diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h
index bf3b4b6a..928b9712 100644
--- a/src/llama-kv-cache.h
+++ b/src/llama-kv-cache.h
@@ -82,6 +82,13 @@ struct llama_kv_cache_guard {
private:
llama_kv_cache * kv;
};
+
+// block of KV slots to move when defragging
+struct llama_kv_defrag_move {
+ uint32_t src;
+ uint32_t dst;
+ uint32_t len;
+};
//
// llama_kv_cache_unified
@@ -207,7 +214,7 @@ private:
// defrag
struct {
- std::vector<uint32_t> ids;
+ std::vector<llama_kv_defrag_move> moves;
} defrag_info;
// return true if cells have been moved
@@ -249,7 +256,8 @@ private:
llm_graph_result_ptr build_graph_defrag(
const llama_cparams & cparams,
ggml_context * ctx,
- ggml_cgraph * gf) const;
+ ggml_cgraph * gf,
+ const std::vector<llama_kv_defrag_move> & moves) const;
void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;