Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions inference_perf/apis/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class ChatMessage(BaseModel):
class ChatCompletionAPIData(InferenceAPIData):
messages: List[ChatMessage]
max_tokens: int = 0
model_response: str = "" # Store the assistant response for multi-turn chat

def get_api_type(self) -> APIType:
return APIType.Chat
Expand Down Expand Up @@ -62,6 +63,7 @@ async def process_response(
prompt_text = "".join([msg.content for msg in self.messages if msg.content])
prompt_len = tokenizer.count_tokens(prompt_text)
output_len = tokenizer.count_tokens(output_text)
self.model_response = output_text # Store response for multi-turn chat
return InferenceInfo(
input_tokens=prompt_len,
output_tokens=output_len,
Expand All @@ -76,6 +78,7 @@ async def process_response(
if len(choices) == 0:
return InferenceInfo(input_tokens=prompt_len, lora_adapter=lora_adapter)
output_text = "".join([choice.get("message", {}).get("content", "") for choice in choices])
self.model_response = output_text # Store response for multi-turn chat
output_len = tokenizer.count_tokens(output_text)
return InferenceInfo(
input_tokens=prompt_len,
Expand Down
14 changes: 12 additions & 2 deletions inference_perf/datagen/shared_prefix_datagen.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import numpy as np

from inference_perf.apis.base import InferenceAPIData, LazyLoadInferenceAPIData
from inference_perf.apis.chat import ChatCompletionAPIData, ChatMessage
from inference_perf.apis.completion import CompletionAPIData
from inference_perf.apis.user_session import LocalUserSession, UserSessionCompletionAPIData
from inference_perf.config import APIConfig, APIType, DataConfig, Distribution
Expand Down Expand Up @@ -96,12 +97,13 @@ def __init__(self, api_config: APIConfig, config: DataConfig, tokenizer: Optiona
self.output_len_list_per_group.append(output_lens.tolist())

self.prompts: List[str] = []
self.prompt_pairs: List[tuple[str, str]] = [] # (shared_prefix, question) pairs for Chat API
self.user_sessions: List[LocalUserSession] = []
self.flat_output_lens: List[int] = []
self._generate_prompts()

def get_supported_apis(self) -> List[APIType]:
return [APIType.Completion]
return [APIType.Completion, APIType.Chat]

def is_io_distribution_supported(self) -> bool:
return True
Expand All @@ -125,6 +127,10 @@ def load_lazy_data(self, data: LazyLoadInferenceAPIData) -> InferenceAPIData:
user_session_id=self.user_sessions[user_id].user_session_id,
target_round=round,
)
elif self.api_config.type == APIType.Chat:
shared_prefix, question = self.prompt_pairs[i]
messages = [ChatMessage(role="system", content=shared_prefix), ChatMessage(role="user", content=question)]
return ChatCompletionAPIData(messages=messages, max_tokens=output_len)
else:
return CompletionAPIData(prompt=self.prompts[i], max_tokens=output_len)

Expand Down Expand Up @@ -181,7 +187,9 @@ def _generate_prompts(self) -> None:
)
)
else:
# Single turn chat, Combine shared prefix and question
# Single turn: store (shared_prefix, question) pair for Chat API
self.prompt_pairs.append((shared_prefix_text, question_text))
# Combine shared prefix and question for Completion API
question_text = shared_prefix_text + " " + question_text

self.prompts.append(question_text)
Expand All @@ -197,3 +205,5 @@ def _generate_prompts(self) -> None:
self.flat_output_lens = [self.flat_output_lens[i] for i in indices]
if self.enable_multi_turn_chat:
self.user_sessions = [self.user_sessions[i] for i in indices]
else:
self.prompt_pairs = [self.prompt_pairs[i] for i in indices]
300 changes: 300 additions & 0 deletions tests/datagen/test_shared_prefix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,300 @@
import pytest
from unittest.mock import Mock

from inference_perf.datagen.shared_prefix_datagen import SharedPrefixDataGenerator
from inference_perf.apis.completion import CompletionAPIData
from inference_perf.apis.chat import ChatCompletionAPIData
from inference_perf.apis.user_session import UserSessionCompletionAPIData
from inference_perf.apis.base import LazyLoadInferenceAPIData
from inference_perf.config import APIConfig, APIType, DataConfig


def create_mock_tokenizer() -> Mock:
"""Create a mock tokenizer for testing."""
mock_tokenizer = Mock()
mock_hf_tokenizer = Mock()
mock_hf_tokenizer.vocab_size = 32000
mock_hf_tokenizer.decode = Mock(side_effect=lambda ids, **kwargs: f"text_{len(ids)}")
mock_hf_tokenizer.batch_decode = Mock(side_effect=lambda ids_list, **kwargs: [f"text_{len(ids)}" for ids in ids_list])
mock_tokenizer.get_tokenizer.return_value = mock_hf_tokenizer
return mock_tokenizer


def create_api_config(api_type: APIType) -> APIConfig:
"""Create an APIConfig for testing."""
return APIConfig(type=api_type)


def create_data_config(
num_groups: int = 2,
num_prompts_per_group: int = 3,
system_prompt_len: int = 10,
question_len: int = 5,
output_len: int = 20,
enable_multi_turn_chat: bool = False,
) -> DataConfig:
"""Create a DataConfig with shared_prefix settings for testing."""
config = DataConfig()
config.shared_prefix = Mock()
config.shared_prefix.num_groups = num_groups
config.shared_prefix.num_prompts_per_group = num_prompts_per_group
config.shared_prefix.system_prompt_len = system_prompt_len
config.shared_prefix.question_len = question_len
config.shared_prefix.output_len = output_len
config.shared_prefix.enable_multi_turn_chat = enable_multi_turn_chat
config.shared_prefix.question_distribution = None
config.shared_prefix.output_distribution = None
config.shared_prefix.seed = 42
return config


class TestSharedPrefixDataGeneratorBasic:
"""Basic tests for SharedPrefixDataGenerator."""

def test_get_supported_apis(self) -> None:
"""Test that both Completion and Chat APIs are supported."""
api_config = create_api_config(APIType.Completion)
data_config = create_data_config()
tokenizer = create_mock_tokenizer()

generator = SharedPrefixDataGenerator(api_config, data_config, tokenizer)

supported = generator.get_supported_apis()
assert APIType.Completion in supported
assert APIType.Chat in supported

def test_prompts_count(self) -> None:
"""Test that correct number of prompts are generated."""
api_config = create_api_config(APIType.Completion)
data_config = create_data_config(num_groups=2, num_prompts_per_group=3)
tokenizer = create_mock_tokenizer()

generator = SharedPrefixDataGenerator(api_config, data_config, tokenizer)

assert len(generator.prompts) == 6 # 2 groups * 3 prompts
assert len(generator.prompt_pairs) == 6

def test_prompt_pairs_structure(self) -> None:
"""Test that prompt_pairs contain (shared_prefix, question) tuples."""
api_config = create_api_config(APIType.Chat)
data_config = create_data_config()
tokenizer = create_mock_tokenizer()

generator = SharedPrefixDataGenerator(api_config, data_config, tokenizer)

for shared_prefix, question in generator.prompt_pairs:
assert isinstance(shared_prefix, str)
assert isinstance(question, str)


class TestSharedPrefixCompletionAPI:
"""Tests for Completion API support."""

def test_load_lazy_data_returns_completion_api_data(self) -> None:
"""Test that load_lazy_data returns CompletionAPIData for Completion API."""
api_config = create_api_config(APIType.Completion)
data_config = create_data_config(enable_multi_turn_chat=False)
tokenizer = create_mock_tokenizer()

generator = SharedPrefixDataGenerator(api_config, data_config, tokenizer)
lazy_data = LazyLoadInferenceAPIData(data_index=0, prefered_worker_id=0)

result = generator.load_lazy_data(lazy_data)

assert isinstance(result, CompletionAPIData)
assert result.max_tokens == generator.flat_output_lens[0]

@pytest.mark.asyncio
async def test_completion_api_to_payload(self) -> None:
"""Test that CompletionAPIData generates correct payload."""
api_config = create_api_config(APIType.Completion)
data_config = create_data_config(output_len=50, enable_multi_turn_chat=False)
tokenizer = create_mock_tokenizer()

generator = SharedPrefixDataGenerator(api_config, data_config, tokenizer)
lazy_data = LazyLoadInferenceAPIData(data_index=0, prefered_worker_id=0)
result = generator.load_lazy_data(lazy_data)

payload = await result.to_payload("test-model", 100, False, True)

assert payload["model"] == "test-model"
assert "prompt" in payload
assert payload["max_tokens"] == 50
assert payload["stream"] is True

def test_completion_api_prompt_content(self) -> None:
"""Test that CompletionAPIData prompt matches prompts list."""
api_config = create_api_config(APIType.Completion)
data_config = create_data_config(enable_multi_turn_chat=False)
tokenizer = create_mock_tokenizer()

generator = SharedPrefixDataGenerator(api_config, data_config, tokenizer)
lazy_data = LazyLoadInferenceAPIData(data_index=0, prefered_worker_id=0)

result = generator.load_lazy_data(lazy_data)

assert isinstance(result, CompletionAPIData)
assert result.prompt == generator.prompts[0]

def test_get_data_yields_lazy_load_data_for_completion(self) -> None:
"""Test that get_data yields LazyLoadInferenceAPIData which resolves to CompletionAPIData."""
api_config = create_api_config(APIType.Completion)
data_config = create_data_config(enable_multi_turn_chat=False)
tokenizer = create_mock_tokenizer()

generator = SharedPrefixDataGenerator(api_config, data_config, tokenizer)
data_gen = generator.get_data()

first_lazy = next(data_gen)
assert isinstance(first_lazy, LazyLoadInferenceAPIData)

first_item = generator.load_lazy_data(first_lazy)
assert isinstance(first_item, CompletionAPIData)


class TestSharedPrefixChatAPI:
"""Tests for Chat API support."""

def test_load_lazy_data_returns_chat_api_data(self) -> None:
"""Test that load_lazy_data returns ChatCompletionAPIData for Chat API."""
api_config = create_api_config(APIType.Chat)
data_config = create_data_config(enable_multi_turn_chat=False)
tokenizer = create_mock_tokenizer()

generator = SharedPrefixDataGenerator(api_config, data_config, tokenizer)
lazy_data = LazyLoadInferenceAPIData(data_index=0, prefered_worker_id=0)

result = generator.load_lazy_data(lazy_data)

assert isinstance(result, ChatCompletionAPIData)
assert result.max_tokens == generator.flat_output_lens[0]

def test_chat_api_messages_structure(self) -> None:
"""Test that Chat API messages have system and user roles."""
api_config = create_api_config(APIType.Chat)
data_config = create_data_config(enable_multi_turn_chat=False)
tokenizer = create_mock_tokenizer()

generator = SharedPrefixDataGenerator(api_config, data_config, tokenizer)
lazy_data = LazyLoadInferenceAPIData(data_index=0, prefered_worker_id=0)

result = generator.load_lazy_data(lazy_data)

assert isinstance(result, ChatCompletionAPIData)
assert len(result.messages) == 2
assert result.messages[0].role == "system"
assert result.messages[1].role == "user"

def test_chat_api_messages_content_from_prompt_pairs(self) -> None:
"""Test that Chat API messages content matches prompt_pairs."""
api_config = create_api_config(APIType.Chat)
data_config = create_data_config(enable_multi_turn_chat=False)
tokenizer = create_mock_tokenizer()

generator = SharedPrefixDataGenerator(api_config, data_config, tokenizer)
lazy_data = LazyLoadInferenceAPIData(data_index=0, prefered_worker_id=0)

result = generator.load_lazy_data(lazy_data)
expected_shared_prefix, expected_question = generator.prompt_pairs[0]

assert isinstance(result, ChatCompletionAPIData)
assert result.messages[0].content == expected_shared_prefix
assert result.messages[1].content == expected_question

@pytest.mark.asyncio
async def test_chat_api_to_payload(self) -> None:
"""Test that ChatCompletionAPIData generates correct payload with system and user messages."""
api_config = create_api_config(APIType.Chat)
data_config = create_data_config(output_len=50, enable_multi_turn_chat=False)
tokenizer = create_mock_tokenizer()

generator = SharedPrefixDataGenerator(api_config, data_config, tokenizer)
lazy_data = LazyLoadInferenceAPIData(data_index=0, prefered_worker_id=0)
result = generator.load_lazy_data(lazy_data)

payload = await result.to_payload("test-model", 100, False, False)

assert payload["model"] == "test-model"
assert "messages" in payload
assert len(payload["messages"]) == 2
assert payload["messages"][0]["role"] == "system"
assert payload["messages"][1]["role"] == "user"
assert payload["max_tokens"] == 50
assert payload["stream"] is False

def test_get_data_yields_lazy_load_data_for_chat(self) -> None:
"""Test that get_data yields LazyLoadInferenceAPIData which resolves to ChatCompletionAPIData."""
api_config = create_api_config(APIType.Chat)
data_config = create_data_config(enable_multi_turn_chat=False)
tokenizer = create_mock_tokenizer()

generator = SharedPrefixDataGenerator(api_config, data_config, tokenizer)
data_gen = generator.get_data()

first_lazy = next(data_gen)
assert isinstance(first_lazy, LazyLoadInferenceAPIData)

first_item = generator.load_lazy_data(first_lazy)
assert isinstance(first_item, ChatCompletionAPIData)
assert len(first_item.messages) == 2


class TestSharedPrefixMultiTurn:
"""Tests for multi-turn chat support."""

def test_multi_turn_creates_user_sessions(self) -> None:
"""Test that multi-turn mode creates user sessions."""
api_config = create_api_config(APIType.Completion)
data_config = create_data_config(num_groups=2, num_prompts_per_group=3, enable_multi_turn_chat=True)
tokenizer = create_mock_tokenizer()

generator = SharedPrefixDataGenerator(api_config, data_config, tokenizer)

assert len(generator.user_sessions) == 6 # 2 groups * 3 prompts

def test_multi_turn_load_lazy_data_returns_user_session_data(self) -> None:
"""Test that multi-turn load_lazy_data returns UserSessionCompletionAPIData."""
api_config = create_api_config(APIType.Completion)
data_config = create_data_config(enable_multi_turn_chat=True)
tokenizer = create_mock_tokenizer()

generator = SharedPrefixDataGenerator(api_config, data_config, tokenizer)
lazy_data = LazyLoadInferenceAPIData(data_index=0, prefered_worker_id=0)

result = generator.load_lazy_data(lazy_data)

assert isinstance(result, UserSessionCompletionAPIData)

def test_multi_turn_get_data_yields_lazy_load_data(self) -> None:
"""Test that multi-turn get_data yields LazyLoadInferenceAPIData."""
api_config = create_api_config(APIType.Completion)
data_config = create_data_config(enable_multi_turn_chat=True)
tokenizer = create_mock_tokenizer()

generator = SharedPrefixDataGenerator(api_config, data_config, tokenizer)
data_gen = generator.get_data()

first_item = next(data_gen)
assert isinstance(first_item, LazyLoadInferenceAPIData)


class TestSharedPrefixValidation:
"""Tests for validation and error handling."""

def test_requires_tokenizer(self) -> None:
"""Test that tokenizer is required."""
api_config = create_api_config(APIType.Completion)
data_config = create_data_config()

with pytest.raises((ValueError, AttributeError)):
SharedPrefixDataGenerator(api_config, data_config, None)

def test_requires_shared_prefix_config(self) -> None:
"""Test that shared_prefix config is required."""
api_config = create_api_config(APIType.Completion)
data_config = DataConfig()
data_config.shared_prefix = None
tokenizer = create_mock_tokenizer()

with pytest.raises(ValueError, match="Shared Prefix config is required"):
SharedPrefixDataGenerator(api_config, data_config, tokenizer)
Loading