Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 52 additions & 4 deletions csrc/backends/attention_backends.hpp
Original file line number Diff line number Diff line change
@@ -1,25 +1,73 @@
#pragma once

#include <iostream>
#include <stdexcept>
#include <string>

namespace infinilm::backends {

/*
https://github.com/vllm-project/vllm/blob/main/vllm/v1/attention/backends/registry.py


class AttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta):
"""Enumeration of all supported attention backends.

The enum value is the default class path, but this can be overridden
at runtime using register_backend().

To get the actual backend class (respecting overrides), use:
backend.get_class()
"""

FLASH_ATTN = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend"

pass
*/
enum class AttentionBackend {
Default,
FlashAttn,
STATIC_ATTN,
PAGED_ATTN,
FLASH_ATTN,
FLASHINFER,
Default = STATIC_ATTN
};

inline std::ostream &operator<<(std::ostream &os, AttentionBackend backend) {
switch (backend) {
case AttentionBackend::STATIC_ATTN: // Default 与 STATIC_ATTN 共享底层值 0
return os << "AttentionBackend::STATIC_ATTN";
case AttentionBackend::PAGED_ATTN:
return os << "AttentionBackend::PAGED_ATTN";
case AttentionBackend::FLASH_ATTN:
return os << "AttentionBackend::FLASH_ATTN";
case AttentionBackend::FLASHINFER:
return os << "AttentionBackend::FLASHINFER";
default:
throw std::invalid_argument("Invalid attention backend: " + std::to_string(static_cast<int>(backend)));
break;
}
}

inline AttentionBackend parse_attention_backend(const std::string &backend) {
if (backend == "default") {
return AttentionBackend::Default;
}
if (backend == "static-attn") {
return AttentionBackend::STATIC_ATTN;
}
if (backend == "paged-attn") {
return AttentionBackend::PAGED_ATTN;
}
if (backend == "flash-attn") {
return AttentionBackend::FlashAttn;
return AttentionBackend::FLASH_ATTN;
}
if (backend == "flashinfer") {
return AttentionBackend::FLASHINFER;
}

throw std::invalid_argument(
"Invalid attention_backend: " + backend + ". Valid options are: default, flash-attn");
"Invalid attention_backend: " + backend + ". Valid options are: default, static-attn, paged-attn, flash-attn, flashinfer");
}

} // namespace infinilm::backends
73 changes: 73 additions & 0 deletions csrc/cache/kv_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,43 @@ StaticKVCache::StaticKVCache(
rank_info.device);
}

std::tuple<infinicore::Tensor, infinicore::Tensor> StaticKVCache::create_layer_kv_cache(
const infinicore::Size k_dim,
const infinicore::Size v_dim,
const infinicore::Size num_k_heads,
const infinicore::Size num_v_heads,
const infinicore::Size max_positional_embedding,
const infinicore::DataType dtype,
const StaticKVCacheConfig &config,
const engine::distributed::RankInfo &rank_info) {

size_t rank_batch_size = (config.max_batch_size());
size_t num_rank_k_heads = (num_k_heads / rank_info.tp_size);
size_t num_rank_v_heads = (num_v_heads / rank_info.tp_size);

size_t cache_len = (config.max_cache_len() == std::numeric_limits<infinicore::Size>::max() || config.max_cache_len() == 0 ? max_positional_embedding : config.max_cache_len());

// Allocate K cache
infinicore::Tensor k_caches = infinicore::Tensor::empty(
{rank_batch_size,
num_rank_k_heads,
cache_len,
k_dim},
dtype,
rank_info.device);

// Allocate V cache
infinicore::Tensor v_caches = infinicore::Tensor::empty(
{rank_batch_size,
num_rank_v_heads,
cache_len,
v_dim},
dtype,
rank_info.device);

return {k_caches, v_caches};
}

std::tuple<infinicore::Tensor, infinicore::Tensor>
StaticKVCache::update(size_t layer_idx,
const infinicore::Tensor &k,
Expand Down Expand Up @@ -182,6 +219,42 @@ PagedKVCache::PagedKVCache(
rank_info.device);
}

std::tuple<infinicore::Tensor, infinicore::Tensor> PagedKVCache::create_layer_kv_cache(
infinicore::Size k_dim,
infinicore::Size v_dim,
infinicore::Size num_k_heads,
infinicore::Size num_v_heads,
infinicore::DataType dtype,
const PagedKVCacheConfig &config,
const engine::distributed::RankInfo &rank_info) {

size_t num_rank_k_heads(num_k_heads / rank_info.tp_size);
size_t num_rank_v_heads(num_v_heads / rank_info.tp_size);

size_t num_blocks_per_layer = config.num_blocks();
size_t block_size = config.block_size();

// [ num_blocks, num_rank_k_heads, block_size, k_dim]
infinicore::Tensor k_caches = infinicore::Tensor::empty(
{num_blocks_per_layer,
num_rank_k_heads,
block_size,
k_dim},
dtype,
rank_info.device);

// [ num_blocks, num_rank_v_heads, block_size, v_dim]
infinicore::Tensor v_caches = infinicore::Tensor::empty(
{num_blocks_per_layer,
num_rank_v_heads,
block_size,
v_dim},
dtype,
rank_info.device);

return {k_caches, v_caches};
}

std::tuple<infinicore::Tensor, infinicore::Tensor> PagedKVCache::update(
size_t layer_idx,
const infinicore::Tensor &k,
Expand Down
20 changes: 18 additions & 2 deletions csrc/cache/kv_cache.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ class StaticKVCacheConfig final : public CacheConfig {
class StaticKVCache final : public Cache {
public:
StaticKVCache(

infinicore::Size k_dim,
infinicore::Size v_dim,
infinicore::Size num_k_heads,
Expand All @@ -45,6 +44,15 @@ class StaticKVCache final : public Cache {
const StaticKVCacheConfig &config,
const engine::distributed::RankInfo &rank_info);

static std::tuple<infinicore::Tensor, infinicore::Tensor> create_layer_kv_cache(
const infinicore::Size k_dim,
const infinicore::Size v_dim,
const infinicore::Size num_k_heads,
const infinicore::Size num_v_heads,
const infinicore::Size max_positional_embedding,
const infinicore::DataType dtype,
const StaticKVCacheConfig &config,
const engine::distributed::RankInfo &rank_info);
/**
* @brief Update KV cache at a given layer and cache position.
*
Expand Down Expand Up @@ -100,7 +108,6 @@ class PagedKVCacheConfig final : public CacheConfig {
class PagedKVCache final : public Cache {
public:
PagedKVCache(

infinicore::Size k_dim,
infinicore::Size v_dim,
infinicore::Size num_k_heads,
Expand All @@ -110,6 +117,15 @@ class PagedKVCache final : public Cache {
const PagedKVCacheConfig &config,
const engine::distributed::RankInfo &rank_info);

static std::tuple<infinicore::Tensor, infinicore::Tensor> create_layer_kv_cache(
infinicore::Size k_dim,
infinicore::Size v_dim,
infinicore::Size num_k_heads,
infinicore::Size num_v_heads,
infinicore::DataType dtype,
const PagedKVCacheConfig &config,
const engine::distributed::RankInfo &rank_info);

/**
* @brief Update Paged KV cache at a given layer given slot info for each token.
*
Expand Down
26 changes: 26 additions & 0 deletions csrc/config/config_factory.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#include "config_factory.hpp"
#include "../models/models_registry.hpp"
#include <stdexcept>

namespace infinilm {

std::shared_ptr<infinilm::config::ModelConfig> InfinilmConfigFactory::createConfig(const std::string &model_path) {
auto model_config = std::make_shared<infinilm::config::ModelConfig>(model_path + "/config.json");
if (nullptr == model_config) {
throw std::runtime_error("InfinilmConfigFactory::createConfig: model_config is not initialized");
}

const std::string model_type = model_config->get<std::string>("model_type");
const auto &config_map = models::get_model_config_map();
auto it = config_map.find(model_type);
if (it != config_map.end()) {
it->second(model_config);
} else {
model_config;
// throw std::invalid_argument("InfinilmConfigFactory::createConfig: Unsupported model config type: " + model_type);
}

return model_config;
}

} // namespace infinilm
15 changes: 15 additions & 0 deletions csrc/config/config_factory.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#pragma once

#include "model_config.hpp"
#include <memory>
#include <string>

namespace infinilm {

class InfinilmConfigFactory {
public:
static std::shared_ptr<infinilm::config::ModelConfig> createConfig(const std::string &model_path);

};

} // namespace infinilm
19 changes: 19 additions & 0 deletions csrc/config/infinilm_config.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#include "infinilm_config.hpp"

namespace infinilm::config {

namespace {

thread_local InfinilmConfig _current_infinilm_config;

} // namespace

void set_current_infinilm_config(const InfinilmConfig &config) {
_current_infinilm_config = config;
}

const InfinilmConfig &get_current_infinilm_config() {
return _current_infinilm_config;
}

} // namespace infinilm::config
79 changes: 79 additions & 0 deletions csrc/config/infinilm_config.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
#pragma once

#include "../backends/attention_backends.hpp"
#include "../cache/cache.hpp"
#include "model_config.hpp"
#include <memory>

namespace infinilm::config {
/*
https://github.com/vllm-project/vllm/blob/main/vllm/config/vllm.py

@config(config=ConfigDict(arbitrary_types_allowed=True))
class VllmConfig:
"""Dataclass which contains all vllm-related configuration. This
simplifies passing around the distinct configurations in the codebase.
"""

# TODO: use default_factory once default constructing ModelConfig doesn't try to download a model
model_config: ModelConfig = Field(default=None)

"""Model configuration."""
cache_config: CacheConfig = Field(default_factory=CacheConfig)

"""Cache configuration."""
parallel_config: ParallelConfig = Field(default_factory=ParallelConfig)

pass
*/
struct InfinilmConfig {
public:
InfinilmConfig() = default;
InfinilmConfig(const infinilm::backends::AttentionBackend &backend,
const std::shared_ptr<infinilm::config::ModelConfig> &model_config,
const cache::CacheConfig *cache_config)
: attention_backend(backend),
model_config(model_config),
cache_config(cache_config) {}

public:
infinilm::backends::AttentionBackend attention_backend;
std::shared_ptr<infinilm::config::ModelConfig> model_config;
const cache::CacheConfig *cache_config;
};

/*
@contextmanager
def set_current_vllm_config(
vllm_config: VllmConfig, check_compile=False, prefix: str | None = None
):
"""
Temporarily set the current vLLM config.
Used during model initialization.
We save the current vLLM config in a global variable,
so that all modules can access it, e.g. custom ops
can access the vLLM config to determine how to dispatch.
"""
global _current_vllm_config, _current_prefix
old_vllm_config = _current_vllm_config

*/
void set_current_infinilm_config(const InfinilmConfig &config);

/*
def get_current_vllm_config() -> VllmConfig:
if _current_vllm_config is None:
raise AssertionError(
"Current vLLM config is not set. This typically means "
"get_current_vllm_config() was called outside of a "
"set_current_vllm_config() context, or a CustomOp was instantiated "
"at module import time or model forward time when config is not set. "
"For tests that directly test custom ops/modules, use the "
"'default_vllm_config' pytest fixture from tests/conftest.py."
)
return _current_vllm_config

*/
const InfinilmConfig &get_current_infinilm_config();

} // namespace infinilm::config
Loading