Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/runtime/vm/attn_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,8 @@ struct Sequence {
* in the KV cache even when sliding window is enabled.
*/
int last_block_attn_sink_size = 0;
/*! \brief The LoRA adapter id associated with the sequence. */
int32_t lora_adapter_id = 0;

/*! \brief Whether the current appended tokens form a chain (not a tree). */
bool is_chain = true;
Expand Down
6 changes: 6 additions & 0 deletions src/runtime/vm/kv_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@ TVM_FFI_STATIC_INIT_BLOCK() {
&AttentionKVCacheObj::EnableSlidingWindowForSeq)
.def_method("vm.builtin.attention_kv_cache_commit_accepted_token_tree_nodes",
&AttentionKVCacheObj::CommitAcceptedTokenTreeNodes)
.def_method("vm.builtin.attention_kv_cache_set_sequence_lora_adapter",
&AttentionKVCacheObj::SetSequenceLoraAdapter)
.def_method("vm.builtin.attention_kv_cache_get_sequence_lora_adapter",
&AttentionKVCacheObj::GetSequenceLoraAdapter)
.def_method("vm.builtin.attention_kv_cache_get_current_lora_adapter_ids",
&AttentionKVCacheObj::GetCurrentLoraAdapterIds)
.def_method("vm.builtin.attention_kv_cache_empty", &AttentionKVCacheObj::Empty)
.def_method("vm.builtin.attention_kv_cache_get_num_available_pages",
&AttentionKVCacheObj::GetNumAvailablePages)
Expand Down
20 changes: 20 additions & 0 deletions src/runtime/vm/kv_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,26 @@ class AttentionKVCacheObj : public KVStateObj {
virtual void CommitAcceptedTokenTreeNodes(const IntTuple& seq_ids,
const IntTuple& leaf_indices) = 0;

/*!
* \brief Set the LoRA adapter id for the specified sequence.
* \param seq_id The id of the sequence to update.
* \param lora_adapter_id The LoRA adapter id to associate with the sequence.
*/
virtual void SetSequenceLoraAdapter(int64_t seq_id, int64_t lora_adapter_id) = 0;

/*!
* \brief Get the LoRA adapter id for the specified sequence.
* \param seq_id The id of the sequence to query.
* \return The LoRA adapter id associated with the sequence.
*/
virtual int64_t GetSequenceLoraAdapter(int64_t seq_id) = 0;

/*!
* \brief Get the LoRA adapter ids of the current batch specified in BeginForward.
* \return The LoRA adapter ids, in the same order as the current batch sequence ids.
*/
virtual ffi::Shape GetCurrentLoraAdapterIds() = 0;

/*! \brief Prepare for the disaggregation KV data receive for the specified sequence and length.*/
virtual IntTuple DisaggPrepareRecv(int64_t seq_id, int length) = 0;

Expand Down
36 changes: 35 additions & 1 deletion src/runtime/vm/paged_kv_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
std::vector<HostMemoryVector> k_rope_pos_offset_on_depths_host_;
std::vector<HostMemoryVector> k_rope_pos_offset_sliding_window_on_depths_host_;
HostMemoryVector k_ragged_rope_pos_offset_host_;
HostMemoryVector current_lora_adapter_ids_host_;
HostMemoryVector q_rope_position_map_host_;
HostMemoryVector append_position_map_host_;
HostMemoryVector cur_append_lengths_indptr_host_;
Expand Down Expand Up @@ -414,6 +415,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
}
k_ragged_rope_pos_offset_host_ =
HostMemoryVector(reserved_num_seqs, dtype_aux_, preferred_host_device);
current_lora_adapter_ids_host_ =
HostMemoryVector(reserved_num_seqs, dtype_aux_, preferred_host_device);
q_rope_position_map_host_ =
HostMemoryVector(prefill_chunk_size, dtype_aux_, preferred_host_device);
append_position_map_host_ =
Expand Down Expand Up @@ -685,7 +688,10 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
break;
}
// Create the child sequence with the child block.
seq_map_.insert({child_seq_id, Sequence(&global_block_pool_, child_block_idx)});
auto [child_it, inserted] =
seq_map_.insert({child_seq_id, Sequence(&global_block_pool_, child_block_idx)});
TVM_FFI_ICHECK(inserted);
child_it->second.lora_adapter_id = parent_it->second.lora_adapter_id;
dirty_aux_data_device_ = true;
}

Expand Down Expand Up @@ -874,11 +880,13 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
sequences.reserve(cur_batch_size_);
last_block_length_before_append.reserve(cur_batch_size_);
k_ragged_rope_pos_offset_host_.clear();
current_lora_adapter_ids_host_.clear();
for (int i = 0; i < cur_batch_size_; ++i) {
auto it = seq_map_.find(seq_ids[i]);
TVM_FFI_ICHECK(it != seq_map_.end())
<< "The sequence \"" << seq_ids[i] << "\" cannot be found in KV cache.";
sequences.push_back(&it->second);
current_lora_adapter_ids_host_.push_back(it->second.lora_adapter_id);
last_block_length_before_append.push_back(
global_block_pool_[it->second.last_block_idx].seq_length);
int k_rope_offset = it->second.seq_length;
Expand Down Expand Up @@ -1195,6 +1203,32 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
}
}

void SetSequenceLoraAdapter(int64_t seq_id, int64_t lora_adapter_id) final {
TVM_FFI_ICHECK(lora_adapter_id >= 0) << "LoRA adapter id must be non-negative.";
TVM_FFI_ICHECK(lora_adapter_id <= std::numeric_limits<int32_t>::max())
<< "LoRA adapter id exceeds int32 range.";
auto it = seq_map_.find(seq_id);
TVM_FFI_ICHECK(it != seq_map_.end())
<< "The sequence \"" << seq_id << "\" cannot be found in KV cache.";
it->second.lora_adapter_id = static_cast<int32_t>(lora_adapter_id);
}

int64_t GetSequenceLoraAdapter(int64_t seq_id) final {
auto it = seq_map_.find(seq_id);
TVM_FFI_ICHECK(it != seq_map_.end())
<< "The sequence \"" << seq_id << "\" cannot be found in KV cache.";
return it->second.lora_adapter_id;
}

ffi::Shape GetCurrentLoraAdapterIds() final {
std::vector<ffi::Shape::index_type> adapter_ids;
adapter_ids.reserve(current_lora_adapter_ids_host_.size());
for (size_t i = 0; i < current_lora_adapter_ids_host_.size(); ++i) {
adapter_ids.push_back(current_lora_adapter_ids_host_[i]);
}
return ffi::Shape(std::move(adapter_ids));
}
Comment thread
MagellaX marked this conversation as resolved.

ffi::Shape DisaggPrepareRecv(int64_t seq_id, int append_length) final {
// No CPU to GPU copy is needed.
// Essentially we
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@
fbegin_forward = None
fend_forward = None
fcommit_accepted_token_tree_nodes = None
fset_sequence_lora_adapter = None
fget_sequence_lora_adapter = None
fget_current_lora_adapter_ids = None
fattention_with_fuse_qkv = None
fis_empty = None
fdebug_get_kv = None
Expand All @@ -88,6 +91,7 @@
def set_global_func(head_dim, dtype):
global fclear, fadd_sequence, fremove_sequence, ffork_sequence, fenable_sliding_window_for_seq
global fpopn, fbegin_forward, fend_forward, fcommit_accepted_token_tree_nodes
global fset_sequence_lora_adapter, fget_sequence_lora_adapter, fget_current_lora_adapter_ids
global fattention_with_fuse_qkv, fis_empty, fdebug_get_kv
global ftranspose_append, fcopy_cache, fattn_prefill, fattn_decode
global \
Expand All @@ -110,6 +114,15 @@ def set_global_func(head_dim, dtype):
fcommit_accepted_token_tree_nodes = tvm.get_global_func(
"vm.builtin.attention_kv_cache_commit_accepted_token_tree_nodes"
)
fset_sequence_lora_adapter = tvm.get_global_func(
"vm.builtin.attention_kv_cache_set_sequence_lora_adapter"
)
fget_sequence_lora_adapter = tvm.get_global_func(
"vm.builtin.attention_kv_cache_get_sequence_lora_adapter"
)
fget_current_lora_adapter_ids = tvm.get_global_func(
"vm.builtin.attention_kv_cache_get_current_lora_adapter_ids"
)
fattention_with_fuse_qkv = tvm.get_global_func(
"vm.builtin.attention_kv_cache_attention_with_fused_qkv"
)
Expand Down Expand Up @@ -254,6 +267,27 @@ def verify_cached_kv(kv_cache, seq_ids, expected_k, expected_v):
tvm.testing.assert_allclose(values.numpy(), values_expected, rtol=1e-3, atol=1e-3)


def test_lora_adapter_metadata(kv_cache_and_config):
kv_cache, _, _ = kv_cache_and_config

fadd_sequence(kv_cache, 1)
fadd_sequence(kv_cache, 2)
assert fget_sequence_lora_adapter(kv_cache, 1) == 0
assert fget_sequence_lora_adapter(kv_cache, 2) == 0

fset_sequence_lora_adapter(kv_cache, 1, 7)
fset_sequence_lora_adapter(kv_cache, 2, 11)
assert fget_sequence_lora_adapter(kv_cache, 1) == 7
assert fget_sequence_lora_adapter(kv_cache, 2) == 11

ffork_sequence(kv_cache, 1, 3, -1)
assert fget_sequence_lora_adapter(kv_cache, 3) == 7

fbegin_forward(kv_cache, ShapeTuple([2, 3, 1]), ShapeTuple([1, 2, 1]))
assert list(fget_current_lora_adapter_ids(kv_cache)) == [11, 7, 7]
fend_forward(kv_cache)


def f_apply_rotary(x, offset, scale, theta, offset_list: list[int] | None = None):
# x: (N, H, D)
assert len(x.shape) == 3
Expand Down