From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 From: jmorganca Date: Tue, 15 Apr 2025 14:27:40 -0400 Subject: [PATCH] ensure KV cache is fully defragmented Sometimes the KV cache requires defragmentation even without triggering the threshold heuristic. In this case, decoding will not being able to find a KV cache slot. This is particularly difficult for the caller to handle if it happens in between ubatches. To avoid this, we should immediately trigger a defrag. In addition, a heavily fragmented cache can require more than max_moves to defragment. Currently, we stop when we hit the limit but this can leave a cache that still does not have adequate space even after defragmentation is triggered. Instead, we should do multiple batches of processing until everything is complete. --- src/llama-context.cpp | 18 ++++--- src/llama-context.h | 1 + src/llama-kv-cache.cpp | 107 ++++++++++++++--------------------------- src/llama-kv-cache.h | 12 ++++- 4 files changed, 59 insertions(+), 79 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index c22687e4..c5948e8f 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -950,9 +950,12 @@ int llama_context::decode(llama_batch & inp_batch) { // find KV slot if (!kv_self->find_slot(ubatch)) { - LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens); - - return 1; + kv_self->defrag_sched(-1.0f); + kv_self->update(*this); + if (!kv_self->find_slot(ubatch)) { + LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens); + return 1; + } } ggml_backend_sched_reset(sched.get()); @@ -1967,9 +1970,12 @@ void llama_context::opt_epoch_iter( // TODO: not sure if this is needed if (!kv_self->find_slot(ubatch)) { - LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens); - - GGML_ABORT("TODO: handle this error"); + kv_self->defrag_sched(-1.0f); + kv_self->update(*this); + if (!kv_self->find_slot(ubatch)) { + LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens); + GGML_ABORT("TODO: handle this error"); + } } auto * gf = graph_init(); diff --git a/src/llama-context.h b/src/llama-context.h index c0ceacb1..0264e937 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -5,6 +5,7 @@ #include "llama-cparams.h" #include "llama-graph.h" #include "llama-adapter.h" +#include "llama-kv-cache.h" #include "ggml-cpp.h" #include "ggml-opt.h" diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 3dcad65b..60e67b03 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -364,8 +364,6 @@ void llama_kv_cache_unified::commit() { } bool llama_kv_cache_unified::update(llama_context & lctx) { - bool need_reserve = false; - auto * sched = lctx.get_sched(); if (has_shift) { @@ -388,8 +386,6 @@ bool llama_kv_cache_unified::update(llama_context & lctx) { res->set_inputs(nullptr); lctx.graph_compute(gf, false); - - need_reserve = true; } { @@ -403,27 +399,36 @@ bool llama_kv_cache_unified::update(llama_context & lctx) { if (do_defrag) { LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__); + const uint32_t n_max_nodes = lctx.graph_max_nodes(); + const uint32_t max_moves = (n_max_nodes - 2*model.hparams.n_layer)/(6*model.hparams.n_layer); + if (!defrag_prepare(n_max_nodes)) { + LLAMA_LOG_ERROR("%s: failed to prepare defragmentation\n", __func__); + return false; + } + + for (std::size_t i = 0; i < defrag_info.moves.size(); i += max_moves) { + std::vector chunk; + auto end = std::min(i + max_moves, defrag_info.moves.size()); + chunk.assign(defrag_info.moves.begin() + i, defrag_info.moves.begin() + end); - if (defrag_prepare(lctx.graph_max_nodes())) { ggml_backend_sched_reset(sched); auto * gf = lctx.graph_init(); - auto res = build_graph_defrag(lctx.get_cparams(), lctx.get_ctx_compute(), gf); + auto res = build_graph_defrag(lctx.get_cparams(), lctx.get_ctx_compute(), gf, chunk); ggml_backend_sched_alloc_graph(sched, gf); res->set_inputs(nullptr); lctx.graph_compute(gf, false); - - need_reserve = true; } do_defrag = false; } - return need_reserve; + // we never need to reserve a worst case graph + return false; } void llama_kv_cache_unified::defrag_sched(float thold) { @@ -707,11 +712,10 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift( llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag( const llama_cparams & cparams, ggml_context * ctx, - ggml_cgraph * gf) const { + ggml_cgraph * gf, + const std::vector & moves) const { auto res = std::make_unique(); - const auto & ids = defrag_info.ids; - #if 0 // CPU defrag // @@ -783,32 +787,20 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag( ggml_backend_tensor_set(v_l[il], buf_v.data(), 0, buf_v.size()); } #else - for (uint32_t i = 0; i < ids.size(); ++i) { - const uint32_t id = ids[i]; - - if (i == id || id == ids.size()) { - continue; - } - - uint32_t nm = 1; - - while (i + nm < ids.size() && ids[i + nm] == id + nm) { - nm++; - } - + for (const auto & move : moves) { for (uint32_t il = 0; il < hparams.n_layer; ++il) { // NOLINT const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); ggml_tensor * view_k_src = ggml_view_2d(ctx, k_l[il], - n_embd_k_gqa, nm, + n_embd_k_gqa, move.len, ggml_row_size(k_l[il]->type, n_embd_k_gqa), - ggml_row_size(k_l[il]->type, n_embd_k_gqa*i)); + ggml_row_size(k_l[il]->type, n_embd_k_gqa*move.src)); ggml_tensor * view_k_dst = ggml_view_2d(ctx, k_l[il], - n_embd_k_gqa, nm, + n_embd_k_gqa, move.len, ggml_row_size(k_l[il]->type, n_embd_k_gqa), - ggml_row_size(k_l[il]->type, n_embd_k_gqa*id)); + ggml_row_size(k_l[il]->type, n_embd_k_gqa*move.dst)); ggml_tensor * view_v_src; ggml_tensor * view_v_dst; @@ -816,31 +808,29 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag( if (cparams.flash_attn) { // NOTE: the V cache is not transposed when using flash attention view_v_src = ggml_view_2d(ctx, v_l[il], - n_embd_v_gqa, nm, + n_embd_v_gqa, move.len, ggml_row_size(v_l[il]->type, n_embd_v_gqa), - ggml_row_size(v_l[il]->type, n_embd_v_gqa*i)); + ggml_row_size(v_l[il]->type, n_embd_v_gqa*move.dst)); view_v_dst = ggml_view_2d(ctx, v_l[il], - n_embd_v_gqa, nm, + move.len, n_embd_v_gqa, ggml_row_size(v_l[il]->type, n_embd_v_gqa), - ggml_row_size(v_l[il]->type, n_embd_v_gqa*id)); + ggml_row_size(v_l[il]->type, move.src)); } else { view_v_src = ggml_view_2d(ctx, v_l[il], - nm, n_embd_v_gqa, + move.len, n_embd_v_gqa, ggml_row_size(v_l[il]->type, size), - ggml_row_size(v_l[il]->type, i)); + ggml_row_size(v_l[il]->type, move.src)); view_v_dst = ggml_view_2d(ctx, v_l[il], - nm, n_embd_v_gqa, + move.len, n_embd_v_gqa, ggml_row_size(v_l[il]->type, size), - ggml_row_size(v_l[il]->type, id)); + ggml_row_size(v_l[il]->type, move.dst)); } ggml_build_forward_expand(gf, ggml_cpy(ctx, view_k_src, view_k_dst)); ggml_build_forward_expand(gf, ggml_cpy(ctx, view_v_src, view_v_dst)); } - - i += nm - 1; } //LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes); @@ -857,17 +847,7 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) { assert(n_used <= n_kv); - //const int64_t t_start = ggml_time_us(); - - // number of cells moved - uint32_t n_moves = 0; - - // each move requires 6*n_layer tensors (see graph_build_kv_self_defrag) - // - source view, destination view, copy operation - // - x2 for keys and values - //const uint32_t max_moves = max_nodes()/(6*n_layer); - // TODO: tmp fix https://github.com/ggerganov/llama.cpp/issues/6685#issuecomment-2057579516 - const uint32_t max_moves = (n_max_nodes - 2*n_layer)/(6*n_layer); + defrag_info.moves.clear(); // determine which KV cells to move where // @@ -875,10 +855,7 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) { // // if ids[i] == i || ids[i] == n_kv, then cell i is not moved // - auto & ids = defrag_info.ids; - - ids.clear(); - ids.resize(n_kv, n_kv); + std::vector ids(n_kv, n_kv); for (uint32_t i0 = 0; i0 < n_used; ++i0) { const auto & cell0 = cells[i0]; @@ -927,19 +904,11 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) { // are we moving a continuous block of memory? bool cont = false; - // should we stop searching for the next move? - bool stop = false; - // go back and move the nf cells to the hole for (; i1 < n_kv; ++i1) { auto & cell1 = cells[i1]; if (cell1.is_empty() || ids[i1] != n_kv) { - if (n_moves == max_moves) { - stop = true; - break; - } - cont = false; continue; } @@ -955,8 +924,10 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) { head = n_used; if (!cont) { - n_moves++; + defrag_info.moves.push_back({i1, i0 + nf, 1}); cont = true; + } else { + defrag_info.moves.back().len++; } nf++; @@ -966,22 +937,16 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) { } } - if (stop || n_moves == max_moves) { - break; - } - //LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, i1 + 1, i0, i0 + nh); i0 += nh - 1; } - if (n_moves == 0) { + if (defrag_info.moves.size() == 0) { return false; } - LLAMA_LOG_DEBUG("%s: (tmp log) KV defrag cell moves: %u\n", __func__, n_moves); - - LLAMA_LOG_DEBUG("%s: expected gf nodes: %u\n", __func__, 6*n_moves*n_layer); + // LLAMA_LOG_DEBUG("(tmp log) KV defrag cell moves: %u\n", n_moves); return true; } diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index bf3b4b6a..928b9712 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -82,6 +82,13 @@ struct llama_kv_cache_guard { private: llama_kv_cache * kv; }; + +// block of KV slots to move when defragging +struct llama_kv_defrag_move { + uint32_t src; + uint32_t dst; + uint32_t len; +}; // // llama_kv_cache_unified @@ -207,7 +214,7 @@ private: // defrag struct { - std::vector ids; + std::vector moves; } defrag_info; // return true if cells have been moved @@ -249,7 +256,8 @@ private: llm_graph_result_ptr build_graph_defrag( const llama_cparams & cparams, ggml_context * ctx, - ggml_cgraph * gf) const; + ggml_cgraph * gf, + const std::vector & moves) const; void state_write_meta(llama_io_write_i & io, const std::vector> & cell_ranges, llama_seq_id seq_id = -1) const; void state_write_data(llama_io_write_i & io, const std::vector> & cell_ranges) const;