Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions scripts/context_length_test/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ Below are empirical results from running this script on various Qwen3 models acr

### A100 80GB

#### Vallina Settings (Baseline)
#### Vanilla Settings (Baseline)

| #GPU | SP | Qwen3-0.6B | Qwen3-1.7B | Qwen3-4B | Qwen3-8B | Qwen3-14B |
| ---- | -- | ---------- | ---------- | -------- | -------- | --------- |
Expand Down Expand Up @@ -177,7 +177,7 @@ Below are empirical results from running this script on various Qwen3 models acr
### H20 96GB (Higher VRAM, Slower Bandwidth)


#### Vallina Settings
#### Vanilla Settings


| #GPU | SP | Qwen3-0.6B | Qwen3-1.7B | Qwen3-4B | Qwen3-8B | Qwen3-14B |
Expand Down
53 changes: 47 additions & 6 deletions tests/common/vllm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import unittest

import ray
import torch
from openai import BadRequestError
from parameterized import parameterized_class
Expand All @@ -13,12 +14,14 @@
get_model_path,
get_template_config,
)
from trinity.common.config import Config
from trinity.common.models import create_inference_models
from trinity.common.models.model import ModelWrapper
from trinity.common.models.utils import (
tokenize_and_mask_messages_default,
tokenize_and_mask_messages_hf,
)
from trinity.manager.synchronizer import Synchronizer

DEBUG = False

Expand Down Expand Up @@ -669,21 +672,32 @@ async def test_logprobs_api(self):


class TestAsyncAPIServer(RayUnittestBaseAsync):
def setUp(self):
engine_type: str = "vllm"
model_path: str = get_model_path()

async def asyncSetUp(self):
self.config = get_template_config()
self._update_config()
await self._setup_engines()

def _update_config(self):
self.config.mode = "explore"
self.config.model.model_path = get_model_path()
self.config.model.model_path = self.model_path
self.config.explorer.rollout_model.engine_type = "vllm"
self.config.explorer.rollout_model.engine_num = 1
self.config.explorer.rollout_model.tensor_parallel_size = 1
self.config.explorer.rollout_model.chat_template = CHAT_TEMPLATE
self.config.explorer.rollout_model.enable_openai_api = True

self.config.check_and_update()

async def _setup_engines(self):
self.engines, self.auxiliary_engines = create_inference_models(self.config)
self.model_wrapper = ModelWrapper(self.engines[0], engine_type="vllm", enable_history=True)
self.model_wrapper = ModelWrapper(
self.engines[0], engine_type=self.engine_type, enable_history=True
)
self.model_wrapper_no_history = ModelWrapper(
self.engines[0], engine_type="vllm", enable_history=False
self.engines[0], engine_type=self.engine_type, enable_history=False
)

async def test_api_async(self):
Expand All @@ -695,7 +709,7 @@ async def test_api_async(self):
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is your name?"},
]
model_id = (await openai_client.models.list()).data[0].id
model_id = openai_client.model_path
response = await openai_client.chat.completions.create(
model=model_id, messages=messages, n=1
)
Expand All @@ -713,7 +727,8 @@ async def test_api_async(self):
self.assertTrue(response.choices[0].logprobs is not None)
self.assertEqual(0, len(response.choices[0].logprobs.content[2].top_logprobs))
# here we check the 3rd token logprob, because the first two tokens (`<think>`,`\n` usually have zero logprob)
self.assertTrue(response.choices[0].logprobs.content[2].logprob < 0)
if "Instruct" not in self.model_path:
self.assertTrue(response.choices[0].logprobs.content[2].logprob < 0)
self.assertTrue(hasattr(response, "prompt_token_ids"))
self.assertTrue(len(response.prompt_token_ids) > 0)
self.assertTrue(hasattr(response.choices[0], "token_ids"))
Expand Down Expand Up @@ -765,6 +780,32 @@ async def test_api_async(self):
self.assertEqual(len(self.model_wrapper_no_history.history), 0)


@unittest.skipIf("TINKER_API_KEY" not in os.environ, "TINKER_API_KEY is not set")
class TestTinkerAsyncAPIServer(TestAsyncAPIServer):
engine_type: str = "tinker"
model_path: str = "Qwen/Qwen3-4B-Instruct-2507"
# llama model in Tinker does not support chat template

def _update_config(self):
self.config.model.tinker.enable = True
self.config.algorithm.algorithm_type = "grpo"
super()._update_config()

async def _setup_engines(self):
@ray.remote
class FakeTrainer:
def __init__(self, config: Config):
self.config = config
self.synchronizer = Synchronizer.get_actor(config)

fake_trainer = FakeTrainer.remote(self.config)
await fake_trainer.__ray_ready__.remote()
await super()._setup_engines()

async def test_api_async(self):
await super().test_api_async()


class TestTokenizer(unittest.TestCase):
def test_action_mask(self):
messages = [
Expand Down
2 changes: 1 addition & 1 deletion tests/trainer/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1402,7 +1402,7 @@ def tearDown(self):


class TestTinkerTrainer(BaseTrainerCase):
@unittest.skip("Require tinker API key")
@unittest.skipIf("TINKER_API_KEY" not in os.environ, "TINKER_API_KEY is not set")
def test_trainer(self):
"""Test GSM8K on tinker."""
# test both mode
Expand Down
5 changes: 5 additions & 0 deletions trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1218,6 +1218,11 @@ def _check_tinker(self) -> None:
self.explorer.rollout_model.engine_type = "tinker"
logger.warning("Rollout model engine type is set to `tinker`.")

for aux_model_config in self.explorer.auxiliary_models:
if aux_model_config.engine_type != "tinker":
aux_model_config.engine_type = "tinker"
logger.warning("Auxiliary model engine type is set to `tinker`.")

if self.trainer.trainer_type != "tinker":
self.trainer.trainer_type = "tinker"
logger.warning("Trainer type is set to `tinker`.")
Expand Down
20 changes: 11 additions & 9 deletions trinity/common/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,16 +71,18 @@ def create_inference_models(
for i in range(engine_num)
]
auxiliary_engines = [
ray.remote(engine_cls)
.options(
name=f"{config.explorer.name}_auxiliary_model_{i}_{j}",
namespace=namespace,
)
.remote(
config=config.explorer.auxiliary_models[i],
)
[
ray.remote(engine_cls)
.options(
name=f"{config.explorer.name}_auxiliary_model_{i}_{j}",
namespace=namespace,
)
.remote(
config=config.explorer.auxiliary_models[i],
)
for j in range(model_config.engine_num)
]
for i, model_config in enumerate(config.explorer.auxiliary_models)
for j in range(model_config.engine_num)
]
return rollout_engines, auxiliary_engines
else:
Expand Down
91 changes: 80 additions & 11 deletions trinity/common/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,18 @@ def get_api_server_url(self) -> Optional[str]:
"""Get the API server URL if available."""
return None

def get_api_key(self) -> str:
"""Get the API key."""
return "EMPTY"

def get_model_config(self) -> InferenceModelConfig:
"""Get the model configuration."""
return self.config

def get_model_path(self) -> Optional[str]:
"""Get the model path"""
return self.config.model_path


def _history_recorder(func):
"""Decorator to record history of the model calls."""
Expand Down Expand Up @@ -118,10 +126,11 @@ def __init__(
engine_type.startswith("vllm") or engine_type == "tinker"
), "Only vLLM and tinker model is supported for now."
self.model = model
self.engine_type = engine_type
self.config: InferenceModelConfig = None # init during prepare
self._model_name: str = None
self._model_path: str = None
self.api_address: str = None
self._api_key: str = None
self.openai_client: openai.OpenAI = None
self.openai_async_client: openai.AsyncOpenAI = None
self.logger = get_logger(__name__)
Expand All @@ -138,7 +147,7 @@ async def prepare(self) -> None:
"""Prepare the model wrapper."""
self.config = await self.model.get_model_config.remote()
self._model_name = self.config.name
self._model_path = self.config.model_path
self._api_key = await self.model.get_api_key.remote()
self._generate_kwargs = {
"temperature": self.config.temperature,
"top_p": self.config.top_p,
Expand All @@ -152,6 +161,8 @@ async def prepare(self) -> None:
if self.api_address is None:
self.logger.info("API server is not enabled for inference model.")
return
if self.engine_type == "tinker":
return
max_retries = 30
interval = 2 # seconds
for i in range(max_retries):
Expand Down Expand Up @@ -285,6 +296,11 @@ async def convert_messages_to_experience_async(
messages, tools=tools, temperature=temperature
)

@property
def api_key(self) -> str:
"""Get the API key."""
return self._api_key

@property
def model_version(self) -> int:
"""Get the version of the model."""
Expand All @@ -297,8 +313,23 @@ async def model_version_async(self) -> int:

@property
def model_path(self) -> str:
"""Get the model path."""
return self._model_path
"""
Returns the path to the model files based on the current engine type.

- For 'vllm' engine: returns the model path from the configuration (`config.model_path`)
- For 'tinker' engine: returns the path to the most recent sampler weights
"""
return ray.get(self.model.get_model_path.remote())

@property
async def model_path_async(self) -> str:
"""
Returns the path to the model files based on the current engine type.

- For 'vllm' engine: returns the model path from the configuration (`config.model_path`)
- For 'tinker' engine: returns the path to the most recent sampler weights
"""
return await self.model.get_model_path.remote()

@property
def model_name(self) -> Optional[str]:
Expand Down Expand Up @@ -332,16 +363,36 @@ def get_openai_client(self) -> openai.OpenAI:
openai.OpenAI: The openai client. And `model_path` is added to the client which refers to the model path.
"""
if self.openai_client is not None:
setattr(self.openai_client, "model_path", self.model_path)
return self.openai_client
if not self.api_address:
raise ValueError(
"API server is not enabled for this model. OpenAI client is unavailable."
)
self.openai_client = openai.OpenAI(
base_url=f"{self.api_address}/v1",
api_key="EMPTY",
api_key=self._api_key,
)
if self.enable_history:
if self.engine_type == "tinker":
# ! TODO: because tinker's OpenAI API interface is in beta,
# we need to use original API in thinker instead.
def chat_completions(*args, **kwargs):
messages = kwargs.pop("messages")
chat_response = ray.get(
self.model.chat.remote(
messages=messages,
with_chat_completion=True,
return_token_ids=self.enable_history,
**kwargs,
)
)
response = chat_response.pop()
if self.enable_history:
self.history.extend(chat_response)
return response

self.openai_client.chat.completions.create = chat_completions
elif self.enable_history:
# add a decorator to the openai client to record history

ori_create = self.openai_client.chat.completions.create
Expand All @@ -359,7 +410,7 @@ def record_chat_completions(*args, **kwargs):
return response

self.openai_client.chat.completions.create = record_chat_completions
setattr(self.openai_client, "model_path", self.openai_client.models.list().data[0].id)
setattr(self.openai_client, "model_path", self.model_path)
return self.openai_client

def get_openai_async_client(self) -> openai.AsyncOpenAI:
Expand All @@ -369,6 +420,7 @@ def get_openai_async_client(self) -> openai.AsyncOpenAI:
openai.AsyncOpenAI: The async openai client. And `model_path` is added to the client which refers to the model path.
"""
if self.openai_async_client is not None:
setattr(self.openai_async_client, "model_path", self.model_path)
return self.openai_async_client
if not self.api_address:
raise ValueError(
Expand All @@ -377,9 +429,27 @@ def get_openai_async_client(self) -> openai.AsyncOpenAI:
# first make sure that we have the sync openai client
self.openai_async_client = openai.AsyncOpenAI(
base_url=f"{self.api_address}/v1",
api_key="EMPTY",
api_key=self._api_key,
)
if self.enable_history:

if self.engine_type == "tinker":
# ! TODO: because tinker's OpenAI API interface is in beta,
# we need to use original API in thinker instead.
async def chat_completions(*args, **kwargs):
messages = kwargs.pop("messages")
chat_response = await self.model.chat.remote(
messages=messages,
with_chat_completion=True,
return_token_ids=self.enable_history,
**kwargs,
)
response = chat_response.pop()
if self.enable_history:
self.history.extend(chat_response)
return response

self.openai_async_client.chat.completions.create = chat_completions
elif self.enable_history:
# add a decorator to the openai client to record history

ori_create = self.openai_async_client.chat.completions.create
Expand All @@ -400,8 +470,7 @@ async def record_chat_completions(*args, **kwargs):

self.openai_async_client.chat.completions.create = record_chat_completions
# get model_path from the sync openai client to avoid async call here
openai_client = self.get_openai_client()
setattr(self.openai_async_client, "model_path", openai_client.models.list().data[0].id)
setattr(self.openai_async_client, "model_path", self.model_path)
return self.openai_async_client

async def get_current_load(self) -> int:
Expand Down
Loading